cui-llama.rn 1.2.6 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -2
- package/android/src/main/CMakeLists.txt +26 -6
- package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
- package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
- package/android/src/main/jni.cpp +228 -40
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/cpp/amx/amx.cpp +196 -0
- package/cpp/amx/amx.h +20 -0
- package/cpp/amx/common.h +101 -0
- package/cpp/amx/mmq.cpp +2524 -0
- package/cpp/amx/mmq.h +16 -0
- package/cpp/common.cpp +118 -251
- package/cpp/common.h +53 -30
- package/cpp/ggml-aarch64.c +46 -3395
- package/cpp/ggml-aarch64.h +0 -20
- package/cpp/ggml-alloc.c +6 -8
- package/cpp/ggml-backend-impl.h +33 -11
- package/cpp/ggml-backend-reg.cpp +423 -0
- package/cpp/ggml-backend.cpp +14 -676
- package/cpp/ggml-backend.h +46 -9
- package/cpp/ggml-common.h +6 -0
- package/cpp/ggml-cpu-aarch64.c +3823 -0
- package/cpp/ggml-cpu-aarch64.h +32 -0
- package/cpp/ggml-cpu-impl.h +14 -242
- package/cpp/ggml-cpu-quants.c +10835 -0
- package/cpp/ggml-cpu-quants.h +63 -0
- package/cpp/ggml-cpu.c +13971 -13720
- package/cpp/ggml-cpu.cpp +715 -0
- package/cpp/ggml-cpu.h +65 -63
- package/cpp/ggml-impl.h +285 -25
- package/cpp/ggml-metal.h +8 -8
- package/cpp/ggml-metal.m +1221 -728
- package/cpp/ggml-quants.c +189 -10681
- package/cpp/ggml-quants.h +78 -125
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +12 -0
- package/cpp/ggml.c +688 -1460
- package/cpp/ggml.h +58 -244
- package/cpp/json-schema-to-grammar.cpp +1045 -1045
- package/cpp/json.hpp +24766 -24766
- package/cpp/llama-sampling.cpp +5 -2
- package/cpp/llama.cpp +409 -123
- package/cpp/llama.h +8 -4
- package/cpp/rn-llama.hpp +89 -25
- package/cpp/sampling.cpp +42 -3
- package/cpp/sampling.h +22 -1
- package/cpp/sgemm.cpp +608 -0
- package/cpp/speculative.cpp +270 -0
- package/cpp/speculative.h +28 -0
- package/cpp/unicode.cpp +11 -0
- package/ios/RNLlama.mm +43 -20
- package/ios/RNLlamaContext.h +9 -3
- package/ios/RNLlamaContext.mm +146 -33
- package/jest/mock.js +0 -1
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/grammar.js +4 -2
- package/lib/commonjs/grammar.js.map +1 -1
- package/lib/commonjs/index.js +52 -15
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/grammar.js +2 -1
- package/lib/module/grammar.js.map +1 -1
- package/lib/module/index.js +51 -15
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +122 -8
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/grammar.d.ts +5 -6
- package/lib/typescript/grammar.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +15 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +2 -1
- package/src/NativeRNLlama.ts +135 -13
- package/src/grammar.ts +10 -8
- package/src/index.ts +104 -28
package/cpp/ggml-metal.m
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
#import "ggml-impl.h"
|
4
4
|
#import "ggml-backend-impl.h"
|
5
|
+
#import "ggml-metal-impl.h"
|
5
6
|
|
6
7
|
#import <Foundation/Foundation.h>
|
7
8
|
|
@@ -36,16 +37,20 @@ static struct lm_ggml_backend_metal_device_context {
|
|
36
37
|
id<MTLDevice> mtl_device;
|
37
38
|
int mtl_device_ref_count;
|
38
39
|
|
39
|
-
bool
|
40
|
-
bool
|
40
|
+
bool has_simdgroup_reduction;
|
41
|
+
bool has_simdgroup_mm;
|
42
|
+
bool has_bfloat;
|
43
|
+
bool use_bfloat;
|
41
44
|
|
42
45
|
char name[128];
|
43
46
|
} g_lm_ggml_ctx_dev_main = {
|
44
|
-
/*.mtl_device
|
45
|
-
/*.mtl_device_ref_count
|
46
|
-
/*.
|
47
|
-
/*.
|
48
|
-
/*.
|
47
|
+
/*.mtl_device =*/ nil,
|
48
|
+
/*.mtl_device_ref_count =*/ 0,
|
49
|
+
/*.has_simdgroup_reduction =*/ false,
|
50
|
+
/*.has_simdgroup_mm =*/ false,
|
51
|
+
/*.has_bfloat =*/ false,
|
52
|
+
/*.use_bfloat =*/ false,
|
53
|
+
/*.name =*/ "",
|
49
54
|
};
|
50
55
|
|
51
56
|
// acquire
|
@@ -55,10 +60,19 @@ static id<MTLDevice> lm_ggml_backend_metal_device_acq(struct lm_ggml_backend_met
|
|
55
60
|
if (ctx->mtl_device == nil) {
|
56
61
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
57
62
|
|
58
|
-
ctx->
|
59
|
-
ctx->
|
63
|
+
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
64
|
+
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
60
65
|
|
61
|
-
ctx->
|
66
|
+
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
67
|
+
|
68
|
+
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
69
|
+
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
70
|
+
|
71
|
+
#if defined(LM_GGML_METAL_USE_BF16)
|
72
|
+
ctx->use_bfloat = ctx->has_bfloat;
|
73
|
+
#else
|
74
|
+
ctx->use_bfloat = false;
|
75
|
+
#endif
|
62
76
|
|
63
77
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
64
78
|
}
|
@@ -112,6 +126,7 @@ enum lm_ggml_metal_kernel_type {
|
|
112
126
|
LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
113
127
|
LM_GGML_METAL_KERNEL_TYPE_SILU,
|
114
128
|
LM_GGML_METAL_KERNEL_TYPE_SILU_4,
|
129
|
+
LM_GGML_METAL_KERNEL_TYPE_ELU,
|
115
130
|
LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
116
131
|
LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
117
132
|
LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
@@ -120,6 +135,7 @@ enum lm_ggml_metal_kernel_type {
|
|
120
135
|
LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
121
136
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
122
137
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
138
|
+
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
|
123
139
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
124
140
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
125
141
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
@@ -146,10 +162,14 @@ enum lm_ggml_metal_kernel_type {
|
|
146
162
|
LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
147
163
|
LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
148
164
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
149
|
-
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
150
165
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
151
166
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
152
167
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
168
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
169
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
170
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
171
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
172
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
153
173
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
154
174
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
155
175
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
@@ -170,10 +190,11 @@ enum lm_ggml_metal_kernel_type {
|
|
170
190
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
171
191
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
172
192
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
173
|
-
//LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
174
193
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
175
194
|
//LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
176
195
|
//LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
196
|
+
//LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
197
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
|
177
198
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
178
199
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
179
200
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
@@ -195,6 +216,7 @@ enum lm_ggml_metal_kernel_type {
|
|
195
216
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
196
217
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
197
218
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
219
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
198
220
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
199
221
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
200
222
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
@@ -216,6 +238,7 @@ enum lm_ggml_metal_kernel_type {
|
|
216
238
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
217
239
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
218
240
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
241
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
219
242
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
220
243
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
221
244
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
@@ -255,13 +278,64 @@ enum lm_ggml_metal_kernel_type {
|
|
255
278
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
256
279
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
257
280
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
258
|
-
|
281
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
282
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
283
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
284
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
285
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
286
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
287
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
288
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
289
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
290
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
291
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
292
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
293
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
294
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
295
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
296
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
297
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
298
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
299
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
300
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
301
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
302
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
303
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
304
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
305
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
306
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
307
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
308
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
309
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
310
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
311
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
312
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
313
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
314
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
315
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
316
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
317
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
259
318
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
260
|
-
|
319
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
320
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
321
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
322
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
323
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
324
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
325
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
326
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
327
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
328
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
329
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
330
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
331
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
261
332
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
262
333
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
334
|
+
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
263
335
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
264
336
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
337
|
+
LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
338
|
+
LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
|
265
339
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
266
340
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
267
341
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
@@ -440,7 +514,15 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
440
514
|
// dictionary of preprocessor macros
|
441
515
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
442
516
|
|
443
|
-
|
517
|
+
if (ctx_dev->use_bfloat) {
|
518
|
+
[prep setObject:@"1" forKey:@"LM_GGML_METAL_USE_BF16"];
|
519
|
+
}
|
520
|
+
|
521
|
+
#if LM_GGML_METAL_EMBED_LIBRARY
|
522
|
+
[prep setObject:@"1" forKey:@"LM_GGML_METAL_EMBED_LIBRARY"];
|
523
|
+
#endif
|
524
|
+
|
525
|
+
MTLCompileOptions * options = [MTLCompileOptions new];
|
444
526
|
options.preprocessorMacros = prep;
|
445
527
|
|
446
528
|
//[options setFastMathEnabled:false];
|
@@ -490,9 +572,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
490
572
|
}
|
491
573
|
}
|
492
574
|
|
493
|
-
LM_GGML_LOG_INFO("%s: simdgroup reduction
|
494
|
-
LM_GGML_LOG_INFO("%s: simdgroup matrix mul.
|
495
|
-
LM_GGML_LOG_INFO("%s:
|
575
|
+
LM_GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
576
|
+
LM_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
577
|
+
LM_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
578
|
+
LM_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
579
|
+
LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
496
580
|
|
497
581
|
ctx->capture_next_compute = false;
|
498
582
|
ctx->capture_started = false;
|
@@ -518,16 +602,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
518
602
|
ctx->kernels[i].pipeline = nil;
|
519
603
|
}
|
520
604
|
|
521
|
-
/*
|
522
|
-
LM_GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
523
|
-
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
524
|
-
(int) kernel->pipeline.threadExecutionWidth); \
|
525
|
-
*/
|
526
605
|
#define LM_GGML_METAL_ADD_KERNEL(e, name, supported) \
|
527
606
|
if (supported) { \
|
528
607
|
struct lm_ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
529
608
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
530
609
|
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
610
|
+
LM_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
611
|
+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
612
|
+
(int) kernel->pipeline.threadExecutionWidth); \
|
531
613
|
[metal_function release]; \
|
532
614
|
if (error) { \
|
533
615
|
LM_GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
@@ -538,8 +620,9 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
538
620
|
LM_GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
539
621
|
}
|
540
622
|
|
541
|
-
const bool
|
542
|
-
const bool
|
623
|
+
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
624
|
+
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
625
|
+
const bool use_bfloat = ctx_dev->use_bfloat;
|
543
626
|
|
544
627
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
545
628
|
|
@@ -567,14 +650,16 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
567
650
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
568
651
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
569
652
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
570
|
-
LM_GGML_METAL_ADD_KERNEL(
|
571
|
-
LM_GGML_METAL_ADD_KERNEL(
|
572
|
-
LM_GGML_METAL_ADD_KERNEL(
|
573
|
-
LM_GGML_METAL_ADD_KERNEL(
|
653
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
654
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
655
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
656
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
657
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
|
574
658
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
575
659
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
576
660
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
577
661
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
662
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
|
578
663
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
579
664
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
580
665
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
@@ -595,101 +680,108 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
595
680
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
596
681
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
597
682
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
598
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm,
|
599
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm,
|
683
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
684
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
600
685
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
601
686
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
602
687
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
603
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32,
|
604
|
-
LM_GGML_METAL_ADD_KERNEL(
|
605
|
-
LM_GGML_METAL_ADD_KERNEL(
|
606
|
-
LM_GGML_METAL_ADD_KERNEL(
|
607
|
-
LM_GGML_METAL_ADD_KERNEL(
|
608
|
-
LM_GGML_METAL_ADD_KERNEL(
|
609
|
-
LM_GGML_METAL_ADD_KERNEL(
|
610
|
-
LM_GGML_METAL_ADD_KERNEL(
|
611
|
-
LM_GGML_METAL_ADD_KERNEL(
|
612
|
-
LM_GGML_METAL_ADD_KERNEL(
|
613
|
-
LM_GGML_METAL_ADD_KERNEL(
|
614
|
-
LM_GGML_METAL_ADD_KERNEL(
|
615
|
-
LM_GGML_METAL_ADD_KERNEL(
|
616
|
-
LM_GGML_METAL_ADD_KERNEL(
|
617
|
-
LM_GGML_METAL_ADD_KERNEL(
|
618
|
-
LM_GGML_METAL_ADD_KERNEL(
|
619
|
-
LM_GGML_METAL_ADD_KERNEL(
|
620
|
-
LM_GGML_METAL_ADD_KERNEL(
|
621
|
-
LM_GGML_METAL_ADD_KERNEL(
|
622
|
-
LM_GGML_METAL_ADD_KERNEL(
|
623
|
-
LM_GGML_METAL_ADD_KERNEL(
|
624
|
-
LM_GGML_METAL_ADD_KERNEL(
|
625
|
-
LM_GGML_METAL_ADD_KERNEL(
|
626
|
-
LM_GGML_METAL_ADD_KERNEL(
|
627
|
-
LM_GGML_METAL_ADD_KERNEL(
|
628
|
-
|
629
|
-
LM_GGML_METAL_ADD_KERNEL(
|
630
|
-
|
631
|
-
|
632
|
-
LM_GGML_METAL_ADD_KERNEL(
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
LM_GGML_METAL_ADD_KERNEL(
|
637
|
-
LM_GGML_METAL_ADD_KERNEL(
|
638
|
-
LM_GGML_METAL_ADD_KERNEL(
|
639
|
-
LM_GGML_METAL_ADD_KERNEL(
|
640
|
-
LM_GGML_METAL_ADD_KERNEL(
|
641
|
-
LM_GGML_METAL_ADD_KERNEL(
|
642
|
-
LM_GGML_METAL_ADD_KERNEL(
|
643
|
-
LM_GGML_METAL_ADD_KERNEL(
|
644
|
-
LM_GGML_METAL_ADD_KERNEL(
|
645
|
-
LM_GGML_METAL_ADD_KERNEL(
|
646
|
-
LM_GGML_METAL_ADD_KERNEL(
|
647
|
-
LM_GGML_METAL_ADD_KERNEL(
|
648
|
-
LM_GGML_METAL_ADD_KERNEL(
|
649
|
-
LM_GGML_METAL_ADD_KERNEL(
|
650
|
-
LM_GGML_METAL_ADD_KERNEL(
|
651
|
-
LM_GGML_METAL_ADD_KERNEL(
|
652
|
-
LM_GGML_METAL_ADD_KERNEL(
|
653
|
-
LM_GGML_METAL_ADD_KERNEL(
|
654
|
-
LM_GGML_METAL_ADD_KERNEL(
|
655
|
-
LM_GGML_METAL_ADD_KERNEL(
|
656
|
-
LM_GGML_METAL_ADD_KERNEL(
|
657
|
-
LM_GGML_METAL_ADD_KERNEL(
|
658
|
-
LM_GGML_METAL_ADD_KERNEL(
|
659
|
-
LM_GGML_METAL_ADD_KERNEL(
|
660
|
-
LM_GGML_METAL_ADD_KERNEL(
|
661
|
-
LM_GGML_METAL_ADD_KERNEL(
|
662
|
-
LM_GGML_METAL_ADD_KERNEL(
|
663
|
-
LM_GGML_METAL_ADD_KERNEL(
|
664
|
-
LM_GGML_METAL_ADD_KERNEL(
|
665
|
-
LM_GGML_METAL_ADD_KERNEL(
|
666
|
-
LM_GGML_METAL_ADD_KERNEL(
|
667
|
-
LM_GGML_METAL_ADD_KERNEL(
|
668
|
-
LM_GGML_METAL_ADD_KERNEL(
|
669
|
-
LM_GGML_METAL_ADD_KERNEL(
|
670
|
-
LM_GGML_METAL_ADD_KERNEL(
|
671
|
-
LM_GGML_METAL_ADD_KERNEL(
|
672
|
-
LM_GGML_METAL_ADD_KERNEL(
|
673
|
-
LM_GGML_METAL_ADD_KERNEL(
|
674
|
-
LM_GGML_METAL_ADD_KERNEL(
|
675
|
-
LM_GGML_METAL_ADD_KERNEL(
|
676
|
-
LM_GGML_METAL_ADD_KERNEL(
|
677
|
-
LM_GGML_METAL_ADD_KERNEL(
|
678
|
-
LM_GGML_METAL_ADD_KERNEL(
|
679
|
-
LM_GGML_METAL_ADD_KERNEL(
|
680
|
-
LM_GGML_METAL_ADD_KERNEL(
|
681
|
-
LM_GGML_METAL_ADD_KERNEL(
|
682
|
-
LM_GGML_METAL_ADD_KERNEL(
|
683
|
-
LM_GGML_METAL_ADD_KERNEL(
|
684
|
-
LM_GGML_METAL_ADD_KERNEL(
|
685
|
-
LM_GGML_METAL_ADD_KERNEL(
|
686
|
-
LM_GGML_METAL_ADD_KERNEL(
|
687
|
-
LM_GGML_METAL_ADD_KERNEL(
|
688
|
-
LM_GGML_METAL_ADD_KERNEL(
|
689
|
-
LM_GGML_METAL_ADD_KERNEL(
|
690
|
-
LM_GGML_METAL_ADD_KERNEL(
|
691
|
-
LM_GGML_METAL_ADD_KERNEL(
|
692
|
-
LM_GGML_METAL_ADD_KERNEL(
|
688
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
689
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
690
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
691
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
692
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
693
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
694
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
695
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
696
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
697
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
|
698
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
|
699
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
700
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
701
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
702
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
703
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
704
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
705
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
|
706
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
|
707
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
|
708
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
|
709
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
|
710
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
|
711
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
|
712
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
|
713
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
|
714
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
|
715
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
|
716
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
|
717
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
|
718
|
+
//LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
719
|
+
//LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
720
|
+
//LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
721
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
722
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
723
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
724
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
725
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
726
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
727
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
728
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
729
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
730
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
|
731
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
|
732
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
|
733
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
|
734
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
|
735
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
|
736
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
|
737
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
|
738
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
739
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
740
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
741
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
742
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
743
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
744
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
745
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
746
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
747
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
748
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
749
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
750
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
751
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
752
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
|
753
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
|
754
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
|
755
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
|
756
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
|
757
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
|
758
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
|
759
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
|
760
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
761
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
762
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
763
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
764
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
765
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
766
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
767
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
768
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
769
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
770
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
771
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
772
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
773
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
774
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
775
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
776
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
777
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
778
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
779
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
780
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
781
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
782
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
783
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
784
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
693
785
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
694
786
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
695
787
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
@@ -705,18 +797,69 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
705
797
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
706
798
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
707
799
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
708
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
709
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
710
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
711
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
712
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
713
|
-
|
714
|
-
LM_GGML_METAL_ADD_KERNEL(
|
715
|
-
|
716
|
-
LM_GGML_METAL_ADD_KERNEL(
|
800
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
801
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
802
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
803
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
804
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
805
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
806
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
807
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
808
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
809
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
|
810
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
|
811
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
812
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
813
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
814
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
815
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
|
816
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
|
817
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
818
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
819
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
820
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
821
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
|
822
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
|
823
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
824
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
825
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
826
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
827
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
|
828
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
|
829
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
830
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
831
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
832
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
833
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
|
834
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
|
835
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
836
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
837
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
838
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
839
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
|
840
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
841
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
842
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
843
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
844
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
845
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
846
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
847
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
848
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
849
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
850
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
|
851
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
852
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
853
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
854
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
855
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
717
856
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
718
|
-
LM_GGML_METAL_ADD_KERNEL(
|
857
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
858
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
719
859
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
860
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
861
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
862
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
720
863
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
721
864
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
722
865
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
@@ -806,15 +949,18 @@ static id<MTLBuffer> lm_ggml_metal_get_buffer(struct lm_ggml_tensor * t, size_t
|
|
806
949
|
}
|
807
950
|
|
808
951
|
static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_context * ctx_dev, const struct lm_ggml_tensor * op) {
|
809
|
-
|
810
|
-
|
811
|
-
|
952
|
+
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
953
|
+
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
954
|
+
const bool use_bfloat = ctx_dev->use_bfloat;
|
955
|
+
|
956
|
+
if (!use_bfloat) {
|
957
|
+
for (size_t i = 0, n = 3; i < n; ++i) {
|
958
|
+
if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) {
|
959
|
+
return false;
|
960
|
+
}
|
812
961
|
}
|
813
962
|
}
|
814
963
|
|
815
|
-
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
816
|
-
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
817
|
-
|
818
964
|
switch (op->op) {
|
819
965
|
case LM_GGML_OP_UNARY:
|
820
966
|
switch (lm_ggml_get_unary_op(op)) {
|
@@ -824,6 +970,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
824
970
|
case LM_GGML_UNARY_OP_GELU:
|
825
971
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
826
972
|
case LM_GGML_UNARY_OP_SILU:
|
973
|
+
case LM_GGML_UNARY_OP_ELU:
|
827
974
|
return lm_ggml_is_contiguous(op->src[0]);
|
828
975
|
default:
|
829
976
|
return false;
|
@@ -850,9 +997,10 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
850
997
|
return lm_ggml_is_contiguous(op->src[0]);
|
851
998
|
case LM_GGML_OP_SUM_ROWS:
|
852
999
|
case LM_GGML_OP_SOFT_MAX:
|
853
|
-
case LM_GGML_OP_RMS_NORM:
|
854
1000
|
case LM_GGML_OP_GROUP_NORM:
|
855
|
-
return
|
1001
|
+
return has_simdgroup_reduction;
|
1002
|
+
case LM_GGML_OP_RMS_NORM:
|
1003
|
+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
856
1004
|
case LM_GGML_OP_NORM:
|
857
1005
|
case LM_GGML_OP_ROPE:
|
858
1006
|
return true;
|
@@ -869,22 +1017,16 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
869
1017
|
case LM_GGML_OP_LEAKY_RELU:
|
870
1018
|
return true;
|
871
1019
|
case LM_GGML_OP_FLASH_ATTN_EXT:
|
872
|
-
if (op->src[1]->type !=
|
873
|
-
return false;
|
874
|
-
}
|
875
|
-
if (op->src[2]->type != LM_GGML_TYPE_F16) {
|
876
|
-
return false;
|
877
|
-
}
|
878
|
-
if (op->src[0]->ne[0] == 256) {
|
1020
|
+
if (op->src[1]->type != op->src[2]->type) {
|
879
1021
|
return false;
|
880
1022
|
}
|
881
|
-
return
|
1023
|
+
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
882
1024
|
case LM_GGML_OP_SSM_CONV:
|
883
1025
|
case LM_GGML_OP_SSM_SCAN:
|
884
1026
|
return true;
|
885
1027
|
case LM_GGML_OP_MUL_MAT:
|
886
1028
|
case LM_GGML_OP_MUL_MAT_ID:
|
887
|
-
return
|
1029
|
+
return has_simdgroup_reduction &&
|
888
1030
|
(op->src[0]->type != LM_GGML_TYPE_F32 || op->src[1]->type == LM_GGML_TYPE_F32);
|
889
1031
|
case LM_GGML_OP_CPY:
|
890
1032
|
case LM_GGML_OP_DUP:
|
@@ -895,6 +1037,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
895
1037
|
switch (op->type) {
|
896
1038
|
case LM_GGML_TYPE_F32:
|
897
1039
|
case LM_GGML_TYPE_F16:
|
1040
|
+
case LM_GGML_TYPE_BF16:
|
898
1041
|
case LM_GGML_TYPE_Q8_0:
|
899
1042
|
case LM_GGML_TYPE_Q4_0:
|
900
1043
|
case LM_GGML_TYPE_Q4_1:
|
@@ -907,10 +1050,18 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
907
1050
|
}
|
908
1051
|
case LM_GGML_TYPE_F16:
|
909
1052
|
switch (op->type) {
|
910
|
-
|
911
|
-
|
1053
|
+
case LM_GGML_TYPE_F32:
|
1054
|
+
case LM_GGML_TYPE_F16:
|
912
1055
|
return true;
|
913
|
-
|
1056
|
+
default:
|
1057
|
+
return false;
|
1058
|
+
}
|
1059
|
+
case LM_GGML_TYPE_BF16:
|
1060
|
+
switch (op->type) {
|
1061
|
+
case LM_GGML_TYPE_F32:
|
1062
|
+
case LM_GGML_TYPE_BF16:
|
1063
|
+
return true;
|
1064
|
+
default:
|
914
1065
|
return false;
|
915
1066
|
}
|
916
1067
|
default:
|
@@ -996,7 +1147,7 @@ static void lm_ggml_metal_encode_node(
|
|
996
1147
|
const uint64_t nb20 = src2 ? src2->nb[0] : 0; LM_GGML_UNUSED(nb20);
|
997
1148
|
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
998
1149
|
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
999
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
1150
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; LM_GGML_UNUSED(nb23);
|
1000
1151
|
|
1001
1152
|
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
1002
1153
|
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
@@ -1047,35 +1198,39 @@ static void lm_ggml_metal_encode_node(
|
|
1047
1198
|
|
1048
1199
|
const int32_t dim = ((const int32_t *) dst->op_params)[0];
|
1049
1200
|
|
1201
|
+
lm_ggml_metal_kargs_concat args = {
|
1202
|
+
/*.ne00 =*/ ne00,
|
1203
|
+
/*.ne01 =*/ ne01,
|
1204
|
+
/*.ne02 =*/ ne02,
|
1205
|
+
/*.ne03 =*/ ne03,
|
1206
|
+
/*.nb00 =*/ nb00,
|
1207
|
+
/*.nb01 =*/ nb01,
|
1208
|
+
/*.nb02 =*/ nb02,
|
1209
|
+
/*.nb03 =*/ nb03,
|
1210
|
+
/*.ne10 =*/ ne10,
|
1211
|
+
/*.ne11 =*/ ne11,
|
1212
|
+
/*.ne12 =*/ ne12,
|
1213
|
+
/*.ne13 =*/ ne13,
|
1214
|
+
/*.nb10 =*/ nb10,
|
1215
|
+
/*.nb11 =*/ nb11,
|
1216
|
+
/*.nb12 =*/ nb12,
|
1217
|
+
/*.nb13 =*/ nb13,
|
1218
|
+
/*.ne0 =*/ ne0,
|
1219
|
+
/*.ne1 =*/ ne1,
|
1220
|
+
/*.ne2 =*/ ne2,
|
1221
|
+
/*.ne3 =*/ ne3,
|
1222
|
+
/*.nb0 =*/ nb0,
|
1223
|
+
/*.nb1 =*/ nb1,
|
1224
|
+
/*.nb2 =*/ nb2,
|
1225
|
+
/*.nb3 =*/ nb3,
|
1226
|
+
/*.dim =*/ dim,
|
1227
|
+
};
|
1228
|
+
|
1050
1229
|
[encoder setComputePipelineState:pipeline];
|
1051
|
-
[encoder
|
1052
|
-
[encoder setBuffer:
|
1053
|
-
[encoder setBuffer:
|
1054
|
-
[encoder
|
1055
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1056
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1057
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1058
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1059
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
1060
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
1061
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
1062
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1063
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1064
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1065
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1066
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1067
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1068
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1069
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1070
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1071
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1072
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1073
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1074
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1075
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1076
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1077
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1078
|
-
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
1230
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
1231
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1232
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
1233
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
1079
1234
|
|
1080
1235
|
const int nth = MIN(1024, ne0);
|
1081
1236
|
|
@@ -1093,8 +1248,6 @@ static void lm_ggml_metal_encode_node(
|
|
1093
1248
|
|
1094
1249
|
bool bcast_row = false;
|
1095
1250
|
|
1096
|
-
int64_t nb = ne00; // used by the "row" kernels
|
1097
|
-
|
1098
1251
|
id<MTLComputePipelineState> pipeline = nil;
|
1099
1252
|
|
1100
1253
|
if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
@@ -1103,7 +1256,6 @@ static void lm_ggml_metal_encode_node(
|
|
1103
1256
|
// src1 is a row
|
1104
1257
|
LM_GGML_ASSERT(ne11 == 1);
|
1105
1258
|
|
1106
|
-
nb = ne00 / 4;
|
1107
1259
|
switch (dst->op) {
|
1108
1260
|
case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
1109
1261
|
case LM_GGML_OP_SUB: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
@@ -1123,36 +1275,39 @@ static void lm_ggml_metal_encode_node(
|
|
1123
1275
|
}
|
1124
1276
|
}
|
1125
1277
|
|
1278
|
+
lm_ggml_metal_kargs_bin args = {
|
1279
|
+
/*.ne00 =*/ ne00,
|
1280
|
+
/*.ne01 =*/ ne01,
|
1281
|
+
/*.ne02 =*/ ne02,
|
1282
|
+
/*.ne03 =*/ ne03,
|
1283
|
+
/*.nb00 =*/ nb00,
|
1284
|
+
/*.nb01 =*/ nb01,
|
1285
|
+
/*.nb02 =*/ nb02,
|
1286
|
+
/*.nb03 =*/ nb03,
|
1287
|
+
/*.ne10 =*/ ne10,
|
1288
|
+
/*.ne11 =*/ ne11,
|
1289
|
+
/*.ne12 =*/ ne12,
|
1290
|
+
/*.ne13 =*/ ne13,
|
1291
|
+
/*.nb10 =*/ nb10,
|
1292
|
+
/*.nb11 =*/ nb11,
|
1293
|
+
/*.nb12 =*/ nb12,
|
1294
|
+
/*.nb13 =*/ nb13,
|
1295
|
+
/*.ne0 =*/ ne0,
|
1296
|
+
/*.ne1 =*/ ne1,
|
1297
|
+
/*.ne2 =*/ ne2,
|
1298
|
+
/*.ne3 =*/ ne3,
|
1299
|
+
/*.nb0 =*/ nb0,
|
1300
|
+
/*.nb1 =*/ nb1,
|
1301
|
+
/*.nb2 =*/ nb2,
|
1302
|
+
/*.nb3 =*/ nb3,
|
1303
|
+
/*.offs =*/ offs,
|
1304
|
+
};
|
1305
|
+
|
1126
1306
|
[encoder setComputePipelineState:pipeline];
|
1127
|
-
[encoder
|
1128
|
-
[encoder setBuffer:
|
1129
|
-
[encoder setBuffer:
|
1130
|
-
[encoder
|
1131
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1132
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1133
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1134
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1135
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
1136
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
1137
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
1138
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1139
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1140
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1141
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1142
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1143
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1144
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1145
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1146
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1147
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1148
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1149
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1150
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1151
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1152
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1153
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1154
|
-
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1155
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
1307
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
1308
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1309
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
1310
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
1156
1311
|
|
1157
1312
|
if (bcast_row) {
|
1158
1313
|
const int64_t n = lm_ggml_nelements(dst)/4;
|
@@ -1176,25 +1331,29 @@ static void lm_ggml_metal_encode_node(
|
|
1176
1331
|
default: LM_GGML_ABORT("fatal error");
|
1177
1332
|
}
|
1178
1333
|
|
1334
|
+
lm_ggml_metal_kargs_repeat args = {
|
1335
|
+
/*.ne00 =*/ ne00,
|
1336
|
+
/*.ne01 =*/ ne01,
|
1337
|
+
/*.ne02 =*/ ne02,
|
1338
|
+
/*.ne03 =*/ ne03,
|
1339
|
+
/*.nb00 =*/ nb00,
|
1340
|
+
/*.nb01 =*/ nb01,
|
1341
|
+
/*.nb02 =*/ nb02,
|
1342
|
+
/*.nb03 =*/ nb03,
|
1343
|
+
/*.ne0 =*/ ne0,
|
1344
|
+
/*.ne1 =*/ ne1,
|
1345
|
+
/*.ne2 =*/ ne2,
|
1346
|
+
/*.ne3 =*/ ne3,
|
1347
|
+
/*.nb0 =*/ nb0,
|
1348
|
+
/*.nb1 =*/ nb1,
|
1349
|
+
/*.nb2 =*/ nb2,
|
1350
|
+
/*.nb3 =*/ nb3,
|
1351
|
+
};
|
1352
|
+
|
1179
1353
|
[encoder setComputePipelineState:pipeline];
|
1180
|
-
[encoder
|
1181
|
-
[encoder setBuffer:
|
1182
|
-
[encoder
|
1183
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1184
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1185
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1186
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1187
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1188
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1189
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1190
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
1191
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
1192
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
1193
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
1194
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
1195
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1196
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
1197
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
1354
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
1355
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1356
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1198
1357
|
|
1199
1358
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1200
1359
|
|
@@ -1223,25 +1382,29 @@ static void lm_ggml_metal_encode_node(
|
|
1223
1382
|
|
1224
1383
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
1225
1384
|
|
1385
|
+
lm_ggml_metal_kargs_cpy args = {
|
1386
|
+
/*.ne00 =*/ ne00,
|
1387
|
+
/*.ne01 =*/ ne01,
|
1388
|
+
/*.ne02 =*/ ne02,
|
1389
|
+
/*.ne03 =*/ ne03,
|
1390
|
+
/*.nb00 =*/ nb00,
|
1391
|
+
/*.nb01 =*/ nb01,
|
1392
|
+
/*.nb02 =*/ nb02,
|
1393
|
+
/*.nb03 =*/ nb03,
|
1394
|
+
/*.ne0 =*/ ne0,
|
1395
|
+
/*.ne1 =*/ ne1,
|
1396
|
+
/*.ne2 =*/ ne2,
|
1397
|
+
/*.ne3 =*/ ne3,
|
1398
|
+
/*.nb0 =*/ nb0,
|
1399
|
+
/*.nb1 =*/ nb1,
|
1400
|
+
/*.nb2 =*/ nb2,
|
1401
|
+
/*.nb3 =*/ nb3,
|
1402
|
+
};
|
1403
|
+
|
1226
1404
|
[encoder setComputePipelineState:pipeline];
|
1227
|
-
[encoder
|
1228
|
-
[encoder setBuffer:
|
1229
|
-
[encoder
|
1230
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1231
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1232
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1233
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1234
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1235
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1236
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1237
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1238
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1239
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1240
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1241
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1242
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1243
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1244
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1405
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
1406
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1407
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1245
1408
|
|
1246
1409
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
1247
1410
|
|
@@ -1250,35 +1413,39 @@ static void lm_ggml_metal_encode_node(
|
|
1250
1413
|
|
1251
1414
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
1252
1415
|
|
1416
|
+
lm_ggml_metal_kargs_bin args = {
|
1417
|
+
/*.ne00 =*/ ne00,
|
1418
|
+
/*.ne01 =*/ ne01,
|
1419
|
+
/*.ne02 =*/ ne02,
|
1420
|
+
/*.ne03 =*/ ne03,
|
1421
|
+
/*.nb00 =*/ nb00,
|
1422
|
+
/*.nb01 =*/ pnb1,
|
1423
|
+
/*.nb02 =*/ pnb2,
|
1424
|
+
/*.nb03 =*/ pnb3,
|
1425
|
+
/*.ne10 =*/ ne10,
|
1426
|
+
/*.ne11 =*/ ne11,
|
1427
|
+
/*.ne12 =*/ ne12,
|
1428
|
+
/*.ne13 =*/ ne13,
|
1429
|
+
/*.nb10 =*/ nb10,
|
1430
|
+
/*.nb11 =*/ nb11,
|
1431
|
+
/*.nb12 =*/ nb12,
|
1432
|
+
/*.nb13 =*/ nb13,
|
1433
|
+
/*.ne0 =*/ ne0,
|
1434
|
+
/*.ne1 =*/ ne1,
|
1435
|
+
/*.ne2 =*/ ne2,
|
1436
|
+
/*.ne3 =*/ ne3,
|
1437
|
+
/*.nb0 =*/ nb0,
|
1438
|
+
/*.nb1 =*/ pnb1,
|
1439
|
+
/*.nb2 =*/ pnb2,
|
1440
|
+
/*.nb3 =*/ pnb3,
|
1441
|
+
/*.offs =*/ offs,
|
1442
|
+
};
|
1443
|
+
|
1253
1444
|
[encoder setComputePipelineState:pipeline];
|
1254
|
-
[encoder
|
1255
|
-
[encoder setBuffer:
|
1256
|
-
[encoder setBuffer:
|
1257
|
-
[encoder
|
1258
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1259
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1260
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1261
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1262
|
-
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
1263
|
-
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
1264
|
-
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
1265
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1266
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1267
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1268
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1269
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1270
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1271
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1272
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1273
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1274
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1275
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1276
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1277
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1278
|
-
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
1279
|
-
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
1280
|
-
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
1281
|
-
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1445
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
1446
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1447
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
1448
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
1282
1449
|
|
1283
1450
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
1284
1451
|
|
@@ -1319,10 +1486,10 @@ static void lm_ggml_metal_encode_node(
|
|
1319
1486
|
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
1320
1487
|
|
1321
1488
|
[encoder setComputePipelineState:pipeline];
|
1322
|
-
[encoder setBuffer:id_src0
|
1323
|
-
[encoder setBuffer:id_dst
|
1324
|
-
[encoder setBytes:&min
|
1325
|
-
[encoder setBytes:&max
|
1489
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1490
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1491
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1492
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1326
1493
|
|
1327
1494
|
const int64_t n = lm_ggml_nelements(dst);
|
1328
1495
|
|
@@ -1426,6 +1593,18 @@ static void lm_ggml_metal_encode_node(
|
|
1426
1593
|
|
1427
1594
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1428
1595
|
} break;
|
1596
|
+
case LM_GGML_UNARY_OP_ELU:
|
1597
|
+
{
|
1598
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ELU].pipeline;
|
1599
|
+
|
1600
|
+
[encoder setComputePipelineState:pipeline];
|
1601
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1602
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1603
|
+
|
1604
|
+
const int64_t n = lm_ggml_nelements(dst);
|
1605
|
+
|
1606
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1607
|
+
} break;
|
1429
1608
|
default:
|
1430
1609
|
{
|
1431
1610
|
LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
|
@@ -1494,6 +1673,7 @@ static void lm_ggml_metal_encode_node(
|
|
1494
1673
|
|
1495
1674
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
1496
1675
|
|
1676
|
+
// TODO: add lm_ggml_metal_kargs struct
|
1497
1677
|
[encoder setComputePipelineState:pipeline];
|
1498
1678
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1499
1679
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -1569,6 +1749,8 @@ static void lm_ggml_metal_encode_node(
|
|
1569
1749
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1570
1750
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
1571
1751
|
|
1752
|
+
// TODO: add lm_ggml_metal_kargs struct
|
1753
|
+
// TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
|
1572
1754
|
[encoder setComputePipelineState:pipeline];
|
1573
1755
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1574
1756
|
if (id_src1) {
|
@@ -1585,6 +1767,7 @@ static void lm_ggml_metal_encode_node(
|
|
1585
1767
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1586
1768
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1587
1769
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1770
|
+
|
1588
1771
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1589
1772
|
|
1590
1773
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -1601,6 +1784,7 @@ static void lm_ggml_metal_encode_node(
|
|
1601
1784
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
1602
1785
|
}
|
1603
1786
|
|
1787
|
+
// TODO: add lm_ggml_metal_kargs struct
|
1604
1788
|
[encoder setComputePipelineState:pipeline];
|
1605
1789
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1606
1790
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -1625,6 +1809,7 @@ static void lm_ggml_metal_encode_node(
|
|
1625
1809
|
|
1626
1810
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
1627
1811
|
|
1812
|
+
// TODO: add lm_ggml_metal_kargs struct
|
1628
1813
|
[encoder setComputePipelineState:pipeline];
|
1629
1814
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1630
1815
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1695,6 +1880,7 @@ static void lm_ggml_metal_encode_node(
|
|
1695
1880
|
|
1696
1881
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
1697
1882
|
|
1883
|
+
// TODO: add lm_ggml_metal_kargs struct
|
1698
1884
|
[encoder setComputePipelineState:pipeline];
|
1699
1885
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1700
1886
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1742,7 +1928,7 @@ static void lm_ggml_metal_encode_node(
|
|
1742
1928
|
|
1743
1929
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1744
1930
|
// to the matrix-vector kernel
|
1745
|
-
int ne11_mm_min =
|
1931
|
+
int ne11_mm_min = 4;
|
1746
1932
|
|
1747
1933
|
#if 0
|
1748
1934
|
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
@@ -1766,286 +1952,316 @@ static void lm_ggml_metal_encode_node(
|
|
1766
1952
|
}
|
1767
1953
|
#endif
|
1768
1954
|
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
// some Metal matrix data types require aligned pointers
|
1780
|
-
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
1781
|
-
switch (src0->type) {
|
1782
|
-
case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
|
1783
|
-
case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
|
1784
|
-
default: break;
|
1785
|
-
}
|
1955
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1956
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1957
|
+
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
1958
|
+
!lm_ggml_is_transposed(src0) &&
|
1959
|
+
!lm_ggml_is_transposed(src1) &&
|
1960
|
+
src1t == LM_GGML_TYPE_F32 &&
|
1961
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
1962
|
+
(ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) {
|
1963
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1786
1964
|
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
1792
|
-
|
1793
|
-
|
1794
|
-
|
1795
|
-
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
1796
|
-
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
1797
|
-
case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
1798
|
-
case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
1799
|
-
case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
1800
|
-
case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
1801
|
-
case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
1802
|
-
case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1803
|
-
case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1804
|
-
case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
1805
|
-
case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
1806
|
-
case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
1807
|
-
case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
1808
|
-
case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
1809
|
-
case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
1810
|
-
case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
1811
|
-
default: LM_GGML_ABORT("MUL MAT-MAT not implemented");
|
1812
|
-
}
|
1965
|
+
// some Metal matrix data types require aligned pointers
|
1966
|
+
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
1967
|
+
switch (src0->type) {
|
1968
|
+
case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
|
1969
|
+
case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
|
1970
|
+
case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
|
1971
|
+
default: break;
|
1972
|
+
}
|
1813
1973
|
|
1814
|
-
|
1815
|
-
|
1816
|
-
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
1820
|
-
|
1821
|
-
|
1822
|
-
|
1823
|
-
|
1824
|
-
|
1825
|
-
|
1826
|
-
|
1827
|
-
|
1828
|
-
|
1829
|
-
|
1830
|
-
|
1831
|
-
|
1832
|
-
|
1833
|
-
|
1834
|
-
|
1835
|
-
|
1836
|
-
|
1837
|
-
|
1838
|
-
|
1839
|
-
|
1840
|
-
|
1841
|
-
|
1842
|
-
|
1843
|
-
|
1844
|
-
|
1845
|
-
|
1846
|
-
|
1847
|
-
|
1974
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1975
|
+
|
1976
|
+
switch (src0->type) {
|
1977
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
1978
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
1979
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
1980
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
1981
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
1982
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
1983
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
1984
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
1985
|
+
case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
1986
|
+
case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
1987
|
+
case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
1988
|
+
case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
1989
|
+
case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
1990
|
+
case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1991
|
+
case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1992
|
+
case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
1993
|
+
case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
1994
|
+
case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
1995
|
+
case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
1996
|
+
case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
1997
|
+
case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
1998
|
+
case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
1999
|
+
default: LM_GGML_ABORT("MUL MAT-MAT not implemented");
|
2000
|
+
}
|
2001
|
+
|
2002
|
+
lm_ggml_metal_kargs_mul_mm args = {
|
2003
|
+
/*.ne00 =*/ ne00,
|
2004
|
+
/*.ne02 =*/ ne02,
|
2005
|
+
/*.nb01 =*/ nb01,
|
2006
|
+
/*.nb02 =*/ nb02,
|
2007
|
+
/*.nb03 =*/ nb03,
|
2008
|
+
/*.ne12 =*/ ne12,
|
2009
|
+
/*.nb10 =*/ nb10,
|
2010
|
+
/*.nb11 =*/ nb11,
|
2011
|
+
/*.nb12 =*/ nb12,
|
2012
|
+
/*.nb13 =*/ nb13,
|
2013
|
+
/*.ne0 =*/ ne0,
|
2014
|
+
/*.ne1 =*/ ne1,
|
2015
|
+
/*.r2 =*/ r2,
|
2016
|
+
/*.r3 =*/ r3,
|
2017
|
+
};
|
2018
|
+
|
2019
|
+
[encoder setComputePipelineState:pipeline];
|
2020
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2021
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2022
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2023
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2024
|
+
|
2025
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
2026
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
2027
|
+
} else {
|
2028
|
+
int nth0 = 32;
|
2029
|
+
int nth1 = 1;
|
2030
|
+
int nrows = 1;
|
2031
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
2032
|
+
|
2033
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2034
|
+
|
2035
|
+
// use custom matrix x vector kernel
|
2036
|
+
switch (src0t) {
|
2037
|
+
case LM_GGML_TYPE_F32:
|
2038
|
+
{
|
2039
|
+
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2040
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
2041
|
+
nrows = 4;
|
2042
|
+
} break;
|
2043
|
+
case LM_GGML_TYPE_F16:
|
2044
|
+
{
|
2045
|
+
nth0 = 32;
|
2046
|
+
nth1 = 1;
|
2047
|
+
if (src1t == LM_GGML_TYPE_F32) {
|
2048
|
+
if (ne11 * ne12 < 4) {
|
2049
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
2050
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
2051
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
2052
|
+
nrows = ne11;
|
2053
|
+
} else {
|
2054
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
1848
2055
|
nrows = 4;
|
1849
|
-
} break;
|
1850
|
-
case LM_GGML_TYPE_F16:
|
1851
|
-
{
|
1852
|
-
nth0 = 32;
|
1853
|
-
nth1 = 1;
|
1854
|
-
if (src1t == LM_GGML_TYPE_F32) {
|
1855
|
-
if (ne11 * ne12 < 4) {
|
1856
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
1857
|
-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
1858
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
1859
|
-
nrows = ne11;
|
1860
|
-
} else {
|
1861
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
1862
|
-
nrows = 4;
|
1863
|
-
}
|
1864
|
-
} else {
|
1865
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
1866
|
-
nrows = 4;
|
1867
|
-
}
|
1868
|
-
} break;
|
1869
|
-
case LM_GGML_TYPE_Q4_0:
|
1870
|
-
{
|
1871
|
-
nth0 = 8;
|
1872
|
-
nth1 = 8;
|
1873
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
1874
|
-
} break;
|
1875
|
-
case LM_GGML_TYPE_Q4_1:
|
1876
|
-
{
|
1877
|
-
nth0 = 8;
|
1878
|
-
nth1 = 8;
|
1879
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
1880
|
-
} break;
|
1881
|
-
case LM_GGML_TYPE_Q5_0:
|
1882
|
-
{
|
1883
|
-
nth0 = 8;
|
1884
|
-
nth1 = 8;
|
1885
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
1886
|
-
} break;
|
1887
|
-
case LM_GGML_TYPE_Q5_1:
|
1888
|
-
{
|
1889
|
-
nth0 = 8;
|
1890
|
-
nth1 = 8;
|
1891
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
1892
|
-
} break;
|
1893
|
-
case LM_GGML_TYPE_Q8_0:
|
1894
|
-
{
|
1895
|
-
nth0 = 8;
|
1896
|
-
nth1 = 8;
|
1897
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
1898
|
-
} break;
|
1899
|
-
case LM_GGML_TYPE_Q2_K:
|
1900
|
-
{
|
1901
|
-
nth0 = 2;
|
1902
|
-
nth1 = 32;
|
1903
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
1904
|
-
} break;
|
1905
|
-
case LM_GGML_TYPE_Q3_K:
|
1906
|
-
{
|
1907
|
-
nth0 = 2;
|
1908
|
-
nth1 = 32;
|
1909
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
1910
|
-
} break;
|
1911
|
-
case LM_GGML_TYPE_Q4_K:
|
1912
|
-
{
|
1913
|
-
nth0 = 4; //1;
|
1914
|
-
nth1 = 8; //32;
|
1915
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
1916
|
-
} break;
|
1917
|
-
case LM_GGML_TYPE_Q5_K:
|
1918
|
-
{
|
1919
|
-
nth0 = 2;
|
1920
|
-
nth1 = 32;
|
1921
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
1922
|
-
} break;
|
1923
|
-
case LM_GGML_TYPE_Q6_K:
|
1924
|
-
{
|
1925
|
-
nth0 = 2;
|
1926
|
-
nth1 = 32;
|
1927
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
1928
|
-
} break;
|
1929
|
-
case LM_GGML_TYPE_IQ2_XXS:
|
1930
|
-
{
|
1931
|
-
nth0 = 4;
|
1932
|
-
nth1 = 16;
|
1933
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
1934
|
-
} break;
|
1935
|
-
case LM_GGML_TYPE_IQ2_XS:
|
1936
|
-
{
|
1937
|
-
nth0 = 4;
|
1938
|
-
nth1 = 16;
|
1939
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
1940
|
-
} break;
|
1941
|
-
case LM_GGML_TYPE_IQ3_XXS:
|
1942
|
-
{
|
1943
|
-
nth0 = 4;
|
1944
|
-
nth1 = 16;
|
1945
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
1946
|
-
} break;
|
1947
|
-
case LM_GGML_TYPE_IQ3_S:
|
1948
|
-
{
|
1949
|
-
nth0 = 4;
|
1950
|
-
nth1 = 16;
|
1951
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
1952
|
-
} break;
|
1953
|
-
case LM_GGML_TYPE_IQ2_S:
|
1954
|
-
{
|
1955
|
-
nth0 = 4;
|
1956
|
-
nth1 = 16;
|
1957
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
1958
|
-
} break;
|
1959
|
-
case LM_GGML_TYPE_IQ1_S:
|
1960
|
-
{
|
1961
|
-
nth0 = 4;
|
1962
|
-
nth1 = 16;
|
1963
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
1964
|
-
} break;
|
1965
|
-
case LM_GGML_TYPE_IQ1_M:
|
1966
|
-
{
|
1967
|
-
nth0 = 4;
|
1968
|
-
nth1 = 16;
|
1969
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
1970
|
-
} break;
|
1971
|
-
case LM_GGML_TYPE_IQ4_NL:
|
1972
|
-
{
|
1973
|
-
nth0 = 4;
|
1974
|
-
nth1 = 16;
|
1975
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
1976
|
-
} break;
|
1977
|
-
case LM_GGML_TYPE_IQ4_XS:
|
1978
|
-
{
|
1979
|
-
nth0 = 4;
|
1980
|
-
nth1 = 16;
|
1981
|
-
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
1982
|
-
} break;
|
1983
|
-
default:
|
1984
|
-
{
|
1985
|
-
LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1986
|
-
LM_GGML_ABORT("not implemented");
|
1987
2056
|
}
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
|
1995
|
-
|
1996
|
-
|
1997
|
-
|
1998
|
-
|
1999
|
-
|
2000
|
-
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2004
|
-
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2009
|
-
|
2010
|
-
|
2011
|
-
|
2012
|
-
|
2013
|
-
|
2014
|
-
|
2015
|
-
|
2016
|
-
|
2017
|
-
}
|
2018
|
-
|
2019
|
-
|
2020
|
-
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2024
|
-
|
2025
|
-
|
2026
|
-
|
2027
|
-
|
2028
|
-
|
2029
|
-
|
2030
|
-
|
2031
|
-
|
2032
|
-
|
2033
|
-
|
2034
|
-
|
2035
|
-
}
|
2036
|
-
|
2037
|
-
|
2038
|
-
|
2039
|
-
|
2040
|
-
|
2041
|
-
}
|
2042
|
-
|
2043
|
-
|
2044
|
-
|
2045
|
-
|
2046
|
-
|
2057
|
+
} else {
|
2058
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
2059
|
+
nrows = 4;
|
2060
|
+
}
|
2061
|
+
} break;
|
2062
|
+
case LM_GGML_TYPE_BF16:
|
2063
|
+
{
|
2064
|
+
nth0 = 32;
|
2065
|
+
nth1 = 1;
|
2066
|
+
if (src1t == LM_GGML_TYPE_F32) {
|
2067
|
+
if (ne11 * ne12 < 4) {
|
2068
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
2069
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
2070
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
2071
|
+
nrows = ne11;
|
2072
|
+
} else {
|
2073
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
2074
|
+
nrows = 4;
|
2075
|
+
}
|
2076
|
+
} else {
|
2077
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
2078
|
+
nrows = 4;
|
2079
|
+
}
|
2080
|
+
} break;
|
2081
|
+
case LM_GGML_TYPE_Q4_0:
|
2082
|
+
{
|
2083
|
+
nth0 = 8;
|
2084
|
+
nth1 = 8;
|
2085
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
2086
|
+
} break;
|
2087
|
+
case LM_GGML_TYPE_Q4_1:
|
2088
|
+
{
|
2089
|
+
nth0 = 8;
|
2090
|
+
nth1 = 8;
|
2091
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
2092
|
+
} break;
|
2093
|
+
case LM_GGML_TYPE_Q5_0:
|
2094
|
+
{
|
2095
|
+
nth0 = 8;
|
2096
|
+
nth1 = 8;
|
2097
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
2098
|
+
} break;
|
2099
|
+
case LM_GGML_TYPE_Q5_1:
|
2100
|
+
{
|
2101
|
+
nth0 = 8;
|
2102
|
+
nth1 = 8;
|
2103
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
2104
|
+
} break;
|
2105
|
+
case LM_GGML_TYPE_Q8_0:
|
2106
|
+
{
|
2107
|
+
nth0 = 8;
|
2108
|
+
nth1 = 8;
|
2109
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
2110
|
+
} break;
|
2111
|
+
case LM_GGML_TYPE_Q2_K:
|
2112
|
+
{
|
2113
|
+
nth0 = 2;
|
2114
|
+
nth1 = 32;
|
2115
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
2116
|
+
} break;
|
2117
|
+
case LM_GGML_TYPE_Q3_K:
|
2118
|
+
{
|
2119
|
+
nth0 = 2;
|
2120
|
+
nth1 = 32;
|
2121
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
2122
|
+
} break;
|
2123
|
+
case LM_GGML_TYPE_Q4_K:
|
2124
|
+
{
|
2125
|
+
nth0 = 4; //1;
|
2126
|
+
nth1 = 8; //32;
|
2127
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
2128
|
+
} break;
|
2129
|
+
case LM_GGML_TYPE_Q5_K:
|
2130
|
+
{
|
2131
|
+
nth0 = 2;
|
2132
|
+
nth1 = 32;
|
2133
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
2134
|
+
} break;
|
2135
|
+
case LM_GGML_TYPE_Q6_K:
|
2136
|
+
{
|
2137
|
+
nth0 = 2;
|
2138
|
+
nth1 = 32;
|
2139
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
2140
|
+
} break;
|
2141
|
+
case LM_GGML_TYPE_IQ2_XXS:
|
2142
|
+
{
|
2143
|
+
nth0 = 4;
|
2144
|
+
nth1 = 16;
|
2145
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
2146
|
+
} break;
|
2147
|
+
case LM_GGML_TYPE_IQ2_XS:
|
2148
|
+
{
|
2149
|
+
nth0 = 4;
|
2150
|
+
nth1 = 16;
|
2151
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
2152
|
+
} break;
|
2153
|
+
case LM_GGML_TYPE_IQ3_XXS:
|
2154
|
+
{
|
2155
|
+
nth0 = 4;
|
2156
|
+
nth1 = 16;
|
2157
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
2158
|
+
} break;
|
2159
|
+
case LM_GGML_TYPE_IQ3_S:
|
2160
|
+
{
|
2161
|
+
nth0 = 4;
|
2162
|
+
nth1 = 16;
|
2163
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
2164
|
+
} break;
|
2165
|
+
case LM_GGML_TYPE_IQ2_S:
|
2166
|
+
{
|
2167
|
+
nth0 = 4;
|
2168
|
+
nth1 = 16;
|
2169
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
2170
|
+
} break;
|
2171
|
+
case LM_GGML_TYPE_IQ1_S:
|
2172
|
+
{
|
2173
|
+
nth0 = 4;
|
2174
|
+
nth1 = 16;
|
2175
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
2176
|
+
} break;
|
2177
|
+
case LM_GGML_TYPE_IQ1_M:
|
2178
|
+
{
|
2179
|
+
nth0 = 4;
|
2180
|
+
nth1 = 16;
|
2181
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
2182
|
+
} break;
|
2183
|
+
case LM_GGML_TYPE_IQ4_NL:
|
2184
|
+
{
|
2185
|
+
nth0 = 4;
|
2186
|
+
nth1 = 16;
|
2187
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
2188
|
+
} break;
|
2189
|
+
case LM_GGML_TYPE_IQ4_XS:
|
2190
|
+
{
|
2191
|
+
nth0 = 4;
|
2192
|
+
nth1 = 16;
|
2193
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
2194
|
+
} break;
|
2195
|
+
default:
|
2196
|
+
{
|
2197
|
+
LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
2198
|
+
LM_GGML_ABORT("not implemented");
|
2047
2199
|
}
|
2048
|
-
|
2200
|
+
};
|
2201
|
+
|
2202
|
+
lm_ggml_metal_kargs_mul_mv args = {
|
2203
|
+
/*.ne00 =*/ ne00,
|
2204
|
+
/*.ne01 =*/ ne01,
|
2205
|
+
/*.ne02 =*/ ne02,
|
2206
|
+
/*.nb00 =*/ nb00,
|
2207
|
+
/*.nb01 =*/ nb01,
|
2208
|
+
/*.nb02 =*/ nb02,
|
2209
|
+
/*.nb03 =*/ nb03,
|
2210
|
+
/*.ne10 =*/ ne10,
|
2211
|
+
/*.ne11 =*/ ne11,
|
2212
|
+
/*.ne12 =*/ ne12,
|
2213
|
+
/*.nb10 =*/ nb10,
|
2214
|
+
/*.nb11 =*/ nb11,
|
2215
|
+
/*.nb12 =*/ nb12,
|
2216
|
+
/*.nb13 =*/ nb13,
|
2217
|
+
/*.ne0 =*/ ne0,
|
2218
|
+
/*.ne1 =*/ ne1,
|
2219
|
+
/*.r2 =*/ r2,
|
2220
|
+
/*.r3 =*/ r3,
|
2221
|
+
};
|
2222
|
+
|
2223
|
+
[encoder setComputePipelineState:pipeline];
|
2224
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2225
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2226
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2227
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2228
|
+
|
2229
|
+
if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
|
2230
|
+
src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
|
2231
|
+
src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
|
2232
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2233
|
+
}
|
2234
|
+
else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
|
2235
|
+
const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
2236
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2237
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2238
|
+
}
|
2239
|
+
else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
|
2240
|
+
const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
2241
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2242
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2243
|
+
}
|
2244
|
+
else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
|
2245
|
+
const int mem_size = 32*sizeof(float);
|
2246
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2247
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2248
|
+
}
|
2249
|
+
else if (src0t == LM_GGML_TYPE_Q4_K) {
|
2250
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2251
|
+
}
|
2252
|
+
else if (src0t == LM_GGML_TYPE_Q3_K) {
|
2253
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2254
|
+
}
|
2255
|
+
else if (src0t == LM_GGML_TYPE_Q5_K) {
|
2256
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2257
|
+
}
|
2258
|
+
else if (src0t == LM_GGML_TYPE_Q6_K) {
|
2259
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2260
|
+
} else {
|
2261
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
2262
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2263
|
+
}
|
2264
|
+
}
|
2049
2265
|
} break;
|
2050
2266
|
case LM_GGML_OP_MUL_MAT_ID:
|
2051
2267
|
{
|
@@ -2084,12 +2300,12 @@ static void lm_ggml_metal_encode_node(
|
|
2084
2300
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
2085
2301
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
2086
2302
|
dst_rows > dst_rows_min) {
|
2087
|
-
|
2088
2303
|
// some Metal matrix data types require aligned pointers
|
2089
2304
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
2090
2305
|
switch (src0->type) {
|
2091
|
-
case LM_GGML_TYPE_F32:
|
2092
|
-
case LM_GGML_TYPE_F16:
|
2306
|
+
case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
|
2307
|
+
case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
|
2308
|
+
case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
|
2093
2309
|
default: break;
|
2094
2310
|
}
|
2095
2311
|
|
@@ -2098,6 +2314,7 @@ static void lm_ggml_metal_encode_node(
|
|
2098
2314
|
switch (src0->type) {
|
2099
2315
|
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
2100
2316
|
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
2317
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
2101
2318
|
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
2102
2319
|
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
2103
2320
|
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
@@ -2120,27 +2337,30 @@ static void lm_ggml_metal_encode_node(
|
|
2120
2337
|
default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
|
2121
2338
|
}
|
2122
2339
|
|
2340
|
+
lm_ggml_metal_kargs_mul_mm_id args = {
|
2341
|
+
/*.nei0 =*/ ne20,
|
2342
|
+
/*.nei1 =*/ ne21,
|
2343
|
+
/*.nbi1 =*/ nb21,
|
2344
|
+
/*.ne00 =*/ ne00,
|
2345
|
+
/*.ne02 =*/ ne02,
|
2346
|
+
/*.nb01 =*/ nb01,
|
2347
|
+
/*.nb02 =*/ nb02,
|
2348
|
+
/*.ne11 =*/ ne11,
|
2349
|
+
/*.ne12 =*/ ne12,
|
2350
|
+
/*.ne13 =*/ ne13,
|
2351
|
+
/*.nb10 =*/ nb10,
|
2352
|
+
/*.nb11 =*/ nb11,
|
2353
|
+
/*.nb12 =*/ nb12,
|
2354
|
+
/*.ne0 =*/ ne0,
|
2355
|
+
/*.ne1 =*/ ne1,
|
2356
|
+
};
|
2357
|
+
|
2123
2358
|
[encoder setComputePipelineState:pipeline];
|
2124
|
-
[encoder
|
2125
|
-
[encoder setBuffer:
|
2126
|
-
[encoder setBuffer:
|
2127
|
-
[encoder setBuffer:
|
2128
|
-
[encoder
|
2129
|
-
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
2130
|
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
2131
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
2132
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
2133
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
2134
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
2135
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
2136
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
2137
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
2138
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
2139
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
2140
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
2141
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
2142
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
2143
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
2359
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2360
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2361
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2362
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2363
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
2144
2364
|
|
2145
2365
|
[encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
2146
2366
|
|
@@ -2167,6 +2387,13 @@ static void lm_ggml_metal_encode_node(
|
|
2167
2387
|
nth1 = 1;
|
2168
2388
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
2169
2389
|
} break;
|
2390
|
+
case LM_GGML_TYPE_BF16:
|
2391
|
+
{
|
2392
|
+
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2393
|
+
nth0 = 32;
|
2394
|
+
nth1 = 1;
|
2395
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
2396
|
+
} break;
|
2170
2397
|
case LM_GGML_TYPE_Q4_0:
|
2171
2398
|
{
|
2172
2399
|
nth0 = 8;
|
@@ -2292,30 +2519,34 @@ static void lm_ggml_metal_encode_node(
|
|
2292
2519
|
LM_GGML_ASSERT(ne00 >= nth0*nth1);
|
2293
2520
|
}
|
2294
2521
|
|
2522
|
+
lm_ggml_metal_kargs_mul_mv_id args = {
|
2523
|
+
/*.nei0 =*/ ne20,
|
2524
|
+
/*.nei1 =*/ ne21,
|
2525
|
+
/*.nbi1 =*/ nb21,
|
2526
|
+
/*.ne00 =*/ ne00,
|
2527
|
+
/*.ne01 =*/ ne01,
|
2528
|
+
/*.ne02 =*/ ne02,
|
2529
|
+
/*.nb00 =*/ nb00,
|
2530
|
+
/*.nb01 =*/ nb01,
|
2531
|
+
/*.nb02 =*/ nb02,
|
2532
|
+
/*.ne10 =*/ ne10,
|
2533
|
+
/*.ne11 =*/ ne11,
|
2534
|
+
/*.ne12 =*/ ne12,
|
2535
|
+
/*.ne13 =*/ ne13,
|
2536
|
+
/*.nb10 =*/ nb10,
|
2537
|
+
/*.nb11 =*/ nb11,
|
2538
|
+
/*.nb12 =*/ nb12,
|
2539
|
+
/*.ne0 =*/ ne0,
|
2540
|
+
/*.ne1 =*/ ne1,
|
2541
|
+
/*.nb1 =*/ nb1,
|
2542
|
+
};
|
2543
|
+
|
2295
2544
|
[encoder setComputePipelineState:pipeline];
|
2296
|
-
[encoder
|
2297
|
-
[encoder setBuffer:
|
2298
|
-
[encoder setBuffer:
|
2299
|
-
[encoder setBuffer:
|
2300
|
-
[encoder
|
2301
|
-
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
2302
|
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
2303
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
2304
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
2305
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
2306
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
2307
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
2308
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
2309
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
2310
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
2311
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
2312
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
2313
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
2314
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
2315
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
2316
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
2317
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
2318
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
2545
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2546
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2547
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2548
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2549
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
2319
2550
|
|
2320
2551
|
const int64_t _ne1 = 1;
|
2321
2552
|
const int tgz = dst_rows;
|
@@ -2364,6 +2595,7 @@ static void lm_ggml_metal_encode_node(
|
|
2364
2595
|
switch (src0->type) {
|
2365
2596
|
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
2366
2597
|
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
2598
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
|
2367
2599
|
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
2368
2600
|
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
2369
2601
|
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
@@ -2387,6 +2619,7 @@ static void lm_ggml_metal_encode_node(
|
|
2387
2619
|
default: LM_GGML_ABORT("not implemented");
|
2388
2620
|
}
|
2389
2621
|
|
2622
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2390
2623
|
[encoder setComputePipelineState:pipeline];
|
2391
2624
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2392
2625
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -2410,20 +2643,28 @@ static void lm_ggml_metal_encode_node(
|
|
2410
2643
|
float eps;
|
2411
2644
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2412
2645
|
|
2646
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
2647
|
+
|
2413
2648
|
int nth = 32; // SIMD width
|
2414
2649
|
|
2415
|
-
while (nth < ne00/4 && nth <
|
2650
|
+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
2416
2651
|
nth *= 2;
|
2417
2652
|
}
|
2418
2653
|
|
2419
|
-
|
2654
|
+
nth = MIN(nth, ne00/4);
|
2655
|
+
|
2656
|
+
lm_ggml_metal_kargs_rms_norm args = {
|
2657
|
+
/*.ne00 =*/ ne00,
|
2658
|
+
/*.ne00_4 =*/ ne00/4,
|
2659
|
+
/*.nb01 =*/ nb01,
|
2660
|
+
/*.eps =*/ eps,
|
2661
|
+
};
|
2420
2662
|
|
2421
2663
|
[encoder setComputePipelineState:pipeline];
|
2422
|
-
[encoder
|
2423
|
-
[encoder setBuffer:
|
2424
|
-
[encoder
|
2425
|
-
|
2426
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
2664
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2665
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2666
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2667
|
+
|
2427
2668
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2428
2669
|
|
2429
2670
|
const int64_t nrows = lm_ggml_nrows(src0);
|
@@ -2432,7 +2673,6 @@ static void lm_ggml_metal_encode_node(
|
|
2432
2673
|
} break;
|
2433
2674
|
case LM_GGML_OP_GROUP_NORM:
|
2434
2675
|
{
|
2435
|
-
LM_GGML_ASSERT(ne00 % 4 == 0);
|
2436
2676
|
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
|
2437
2677
|
|
2438
2678
|
float eps;
|
@@ -2448,6 +2688,7 @@ static void lm_ggml_metal_encode_node(
|
|
2448
2688
|
|
2449
2689
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
2450
2690
|
|
2691
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2451
2692
|
[encoder setComputePipelineState:pipeline];
|
2452
2693
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2453
2694
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2465,22 +2706,35 @@ static void lm_ggml_metal_encode_node(
|
|
2465
2706
|
} break;
|
2466
2707
|
case LM_GGML_OP_NORM:
|
2467
2708
|
{
|
2709
|
+
LM_GGML_ASSERT(ne00 % 4 == 0);
|
2468
2710
|
LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
|
2469
2711
|
|
2470
2712
|
float eps;
|
2471
2713
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2472
2714
|
|
2473
|
-
const int nth = MIN(256, ne00);
|
2474
|
-
|
2475
2715
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
2476
2716
|
|
2717
|
+
int nth = 32; // SIMD width
|
2718
|
+
|
2719
|
+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
2720
|
+
nth *= 2;
|
2721
|
+
}
|
2722
|
+
|
2723
|
+
nth = MIN(nth, ne00/4);
|
2724
|
+
|
2725
|
+
lm_ggml_metal_kargs_norm args = {
|
2726
|
+
/*.ne00 =*/ ne00,
|
2727
|
+
/*.ne00_4 =*/ ne00/4,
|
2728
|
+
/*.nb01 =*/ nb01,
|
2729
|
+
/*.eps =*/ eps,
|
2730
|
+
};
|
2731
|
+
|
2477
2732
|
[encoder setComputePipelineState:pipeline];
|
2478
|
-
[encoder
|
2479
|
-
[encoder setBuffer:
|
2480
|
-
[encoder
|
2481
|
-
|
2482
|
-
[encoder
|
2483
|
-
[encoder setThreadgroupMemoryLength:LM_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
2733
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2734
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2735
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2736
|
+
|
2737
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2484
2738
|
|
2485
2739
|
const int64_t nrows = lm_ggml_nrows(src0);
|
2486
2740
|
|
@@ -2530,40 +2784,44 @@ static void lm_ggml_metal_encode_node(
|
|
2530
2784
|
};
|
2531
2785
|
}
|
2532
2786
|
|
2787
|
+
lm_ggml_metal_kargs_rope args = {
|
2788
|
+
/*.ne00 =*/ ne00,
|
2789
|
+
/*.ne01 =*/ ne01,
|
2790
|
+
/*.ne02 =*/ ne02,
|
2791
|
+
/*.ne03 =*/ ne03,
|
2792
|
+
/*.nb00 =*/ nb00,
|
2793
|
+
/*.nb01 =*/ nb01,
|
2794
|
+
/*.nb02 =*/ nb02,
|
2795
|
+
/*.nb03 =*/ nb03,
|
2796
|
+
/*.ne0 =*/ ne0,
|
2797
|
+
/*.ne1 =*/ ne1,
|
2798
|
+
/*.ne2 =*/ ne2,
|
2799
|
+
/*.ne3 =*/ ne3,
|
2800
|
+
/*.nb0 =*/ nb0,
|
2801
|
+
/*.nb1 =*/ nb1,
|
2802
|
+
/*.nb2 =*/ nb2,
|
2803
|
+
/*.nb3 =*/ nb3,
|
2804
|
+
/*.n_past =*/ n_past,
|
2805
|
+
/*.n_dims =*/ n_dims,
|
2806
|
+
/*.n_ctx_orig =*/ n_ctx_orig,
|
2807
|
+
/*.freq_base =*/ freq_base,
|
2808
|
+
/*.freq_scale =*/ freq_scale,
|
2809
|
+
/*.ext_factor =*/ ext_factor,
|
2810
|
+
/*.attn_factor =*/ attn_factor,
|
2811
|
+
/*.beta_fast =*/ beta_fast,
|
2812
|
+
/*.beta_slow =*/ beta_slow,
|
2813
|
+
};
|
2814
|
+
|
2533
2815
|
[encoder setComputePipelineState:pipeline];
|
2534
|
-
[encoder
|
2535
|
-
[encoder setBuffer:
|
2816
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2817
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2818
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2536
2819
|
if (id_src2 != nil) {
|
2537
|
-
[encoder setBuffer:id_src2 offset:offs_src2
|
2820
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
2538
2821
|
} else {
|
2539
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
2822
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2540
2823
|
}
|
2541
|
-
[encoder setBuffer:id_dst
|
2542
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
2543
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2544
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2545
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2546
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
2547
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
2548
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
2549
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
2550
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
2551
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
2552
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
2553
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
2554
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
2555
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
2556
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
2557
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2558
|
-
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2559
|
-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2560
|
-
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
2561
|
-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
2562
|
-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
2563
|
-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
2564
|
-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
2565
|
-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
2566
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
2824
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2567
2825
|
|
2568
2826
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2569
2827
|
} break;
|
@@ -2620,6 +2878,7 @@ static void lm_ggml_metal_encode_node(
|
|
2620
2878
|
default: LM_GGML_ABORT("fatal error");
|
2621
2879
|
};
|
2622
2880
|
|
2881
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2623
2882
|
[encoder setComputePipelineState:pipeline];
|
2624
2883
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
2625
2884
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2660,6 +2919,7 @@ static void lm_ggml_metal_encode_node(
|
|
2660
2919
|
|
2661
2920
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2662
2921
|
|
2922
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2663
2923
|
[encoder setComputePipelineState:pipeline];
|
2664
2924
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2665
2925
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2694,6 +2954,7 @@ static void lm_ggml_metal_encode_node(
|
|
2694
2954
|
|
2695
2955
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
2696
2956
|
|
2957
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2697
2958
|
[encoder setComputePipelineState:pipeline];
|
2698
2959
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2699
2960
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2730,6 +2991,7 @@ static void lm_ggml_metal_encode_node(
|
|
2730
2991
|
|
2731
2992
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
2732
2993
|
|
2994
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2733
2995
|
[encoder setComputePipelineState:pipeline];
|
2734
2996
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
2735
2997
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
@@ -2751,6 +3013,7 @@ static void lm_ggml_metal_encode_node(
|
|
2751
3013
|
|
2752
3014
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
2753
3015
|
|
3016
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2754
3017
|
[encoder setComputePipelineState:pipeline];
|
2755
3018
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2756
3019
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2789,6 +3052,7 @@ static void lm_ggml_metal_encode_node(
|
|
2789
3052
|
default: LM_GGML_ABORT("fatal error");
|
2790
3053
|
};
|
2791
3054
|
|
3055
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2792
3056
|
[encoder setComputePipelineState:pipeline];
|
2793
3057
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2794
3058
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2807,6 +3071,7 @@ static void lm_ggml_metal_encode_node(
|
|
2807
3071
|
|
2808
3072
|
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
2809
3073
|
|
3074
|
+
// TODO: add lm_ggml_metal_kargs struct
|
2810
3075
|
[encoder setComputePipelineState:pipeline];
|
2811
3076
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2812
3077
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -2822,6 +3087,7 @@ static void lm_ggml_metal_encode_node(
|
|
2822
3087
|
LM_GGML_ASSERT(ne11 % 32 == 0);
|
2823
3088
|
|
2824
3089
|
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
|
3090
|
+
LM_GGML_ASSERT(src1->type == src2->type);
|
2825
3091
|
|
2826
3092
|
LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
|
2827
3093
|
|
@@ -2868,27 +3134,176 @@ static void lm_ggml_metal_encode_node(
|
|
2868
3134
|
|
2869
3135
|
bool use_vec_kernel = false;
|
2870
3136
|
|
3137
|
+
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
3138
|
+
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
2871
3139
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
2872
|
-
switch (
|
2873
|
-
case
|
2874
|
-
|
2875
|
-
|
2876
|
-
|
2877
|
-
|
2878
|
-
|
3140
|
+
switch (src1->type) {
|
3141
|
+
case LM_GGML_TYPE_F16:
|
3142
|
+
{
|
3143
|
+
switch (ne00) {
|
3144
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
3145
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
3146
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
3147
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
3148
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
3149
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
3150
|
+
default:
|
3151
|
+
{
|
3152
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3153
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3154
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3155
|
+
}
|
3156
|
+
}
|
3157
|
+
} break;
|
3158
|
+
case LM_GGML_TYPE_BF16:
|
3159
|
+
{
|
3160
|
+
switch (ne00) {
|
3161
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
3162
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
3163
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
3164
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
3165
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
3166
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
3167
|
+
default:
|
3168
|
+
{
|
3169
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3170
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3171
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3172
|
+
}
|
3173
|
+
}
|
3174
|
+
} break;
|
3175
|
+
case LM_GGML_TYPE_Q4_0:
|
3176
|
+
{
|
3177
|
+
switch (ne00) {
|
3178
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
3179
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
3180
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
3181
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
3182
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
3183
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
3184
|
+
default:
|
3185
|
+
{
|
3186
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3187
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3188
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3189
|
+
}
|
3190
|
+
}
|
3191
|
+
} break;
|
3192
|
+
case LM_GGML_TYPE_Q4_1:
|
3193
|
+
{
|
3194
|
+
switch (ne00) {
|
3195
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
3196
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
3197
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
3198
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
3199
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
3200
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
3201
|
+
default:
|
3202
|
+
{
|
3203
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3204
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3205
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3206
|
+
}
|
3207
|
+
}
|
3208
|
+
} break;
|
3209
|
+
case LM_GGML_TYPE_Q5_0:
|
3210
|
+
{
|
3211
|
+
switch (ne00) {
|
3212
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
3213
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
3214
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
3215
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
3216
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
3217
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
3218
|
+
default:
|
3219
|
+
{
|
3220
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3221
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3222
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3223
|
+
}
|
3224
|
+
}
|
3225
|
+
} break;
|
3226
|
+
case LM_GGML_TYPE_Q5_1:
|
3227
|
+
{
|
3228
|
+
switch (ne00) {
|
3229
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
3230
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
3231
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
3232
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
3233
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
3234
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
3235
|
+
default:
|
3236
|
+
{
|
3237
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3238
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3239
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3240
|
+
}
|
3241
|
+
}
|
3242
|
+
} break;
|
3243
|
+
case LM_GGML_TYPE_Q8_0:
|
3244
|
+
{
|
3245
|
+
switch (ne00) {
|
3246
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
3247
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
3248
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
3249
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
3250
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
3251
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
3252
|
+
default:
|
3253
|
+
{
|
3254
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3255
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3256
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3257
|
+
}
|
3258
|
+
}
|
3259
|
+
} break;
|
2879
3260
|
default:
|
2880
|
-
|
2881
|
-
|
2882
|
-
|
2883
|
-
|
2884
|
-
|
3261
|
+
{
|
3262
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
3263
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
3264
|
+
LM_GGML_ABORT("add template specialization for this type");
|
3265
|
+
}
|
2885
3266
|
}
|
2886
3267
|
} else {
|
2887
3268
|
use_vec_kernel = true;
|
2888
3269
|
|
2889
3270
|
switch (ne00) {
|
2890
|
-
case 128:
|
2891
|
-
|
3271
|
+
case 128:
|
3272
|
+
{
|
3273
|
+
switch (src1->type) {
|
3274
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
3275
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
|
3276
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
3277
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
3278
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
3279
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
|
3280
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
|
3281
|
+
default:
|
3282
|
+
{
|
3283
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
3284
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
3285
|
+
LM_GGML_ABORT("add template specialization for this type");
|
3286
|
+
}
|
3287
|
+
}
|
3288
|
+
} break;
|
3289
|
+
case 256:
|
3290
|
+
{
|
3291
|
+
switch (src1->type) {
|
3292
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
3293
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
|
3294
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
3295
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
3296
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
3297
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
|
3298
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
|
3299
|
+
default:
|
3300
|
+
{
|
3301
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
3302
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
3303
|
+
LM_GGML_ABORT("add template specialization for this type");
|
3304
|
+
}
|
3305
|
+
}
|
3306
|
+
} break;
|
2892
3307
|
default:
|
2893
3308
|
{
|
2894
3309
|
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -2898,40 +3313,41 @@ static void lm_ggml_metal_encode_node(
|
|
2898
3313
|
}
|
2899
3314
|
}
|
2900
3315
|
|
3316
|
+
lm_ggml_metal_kargs_flash_attn_ext args = {
|
3317
|
+
/*.ne01 =*/ ne01,
|
3318
|
+
/*.ne02 =*/ ne02,
|
3319
|
+
/*.ne03 =*/ ne03,
|
3320
|
+
/*.nb01 =*/ nb01,
|
3321
|
+
/*.nb02 =*/ nb02,
|
3322
|
+
/*.nb03 =*/ nb03,
|
3323
|
+
/*.ne11 =*/ ne11,
|
3324
|
+
/*.ne_12_2 =*/ ne12,
|
3325
|
+
/*.ne_12_3 =*/ ne13,
|
3326
|
+
/*.nb_12_1 =*/ nb11,
|
3327
|
+
/*.nb_12_2 =*/ nb12,
|
3328
|
+
/*.nb_12_3 =*/ nb13,
|
3329
|
+
/*.nb31 =*/ nb31,
|
3330
|
+
/*.ne1 =*/ ne1,
|
3331
|
+
/*.ne2 =*/ ne2,
|
3332
|
+
/*.scale =*/ scale,
|
3333
|
+
/*.max_bias =*/ max_bias,
|
3334
|
+
/*.m0 =*/ m0,
|
3335
|
+
/*.m1 =*/ m1,
|
3336
|
+
/*.n_head_log2 =*/ n_head_log2,
|
3337
|
+
/*.logit_softcap =*/ logit_softcap,
|
3338
|
+
};
|
3339
|
+
|
2901
3340
|
[encoder setComputePipelineState:pipeline];
|
2902
|
-
[encoder
|
2903
|
-
[encoder setBuffer:
|
2904
|
-
[encoder setBuffer:
|
3341
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3342
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3343
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
3344
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
2905
3345
|
if (id_src3) {
|
2906
|
-
[encoder setBuffer:id_src3
|
3346
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
|
2907
3347
|
} else {
|
2908
|
-
[encoder setBuffer:id_src0
|
3348
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
2909
3349
|
}
|
2910
|
-
[encoder setBuffer:id_dst
|
2911
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2912
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2913
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2914
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2915
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2916
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2917
|
-
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2918
|
-
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2919
|
-
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2920
|
-
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2921
|
-
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2922
|
-
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2923
|
-
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2924
|
-
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2925
|
-
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2926
|
-
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2927
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2928
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2929
|
-
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2930
|
-
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2931
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2932
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2933
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2934
|
-
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
3350
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
2935
3351
|
|
2936
3352
|
if (!use_vec_kernel) {
|
2937
3353
|
// half8x8 kernel
|
@@ -2942,10 +3358,19 @@ static void lm_ggml_metal_encode_node(
|
|
2942
3358
|
LM_GGML_ASSERT(nqptg % 8 == 0);
|
2943
3359
|
LM_GGML_ASSERT(ncpsg % 32 == 0);
|
2944
3360
|
|
3361
|
+
// 2*(2*ncpsg + nqptg)*(nsg)
|
3362
|
+
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
3363
|
+
//
|
3364
|
+
// 16*32*(nsg)
|
3365
|
+
// the shared memory needed for the simdgroups to load the KV cache
|
3366
|
+
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
3367
|
+
//
|
3368
|
+
#define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
3369
|
+
|
2945
3370
|
int64_t nsgmax = 2;
|
2946
3371
|
|
2947
3372
|
while (true) {
|
2948
|
-
const size_t smem =
|
3373
|
+
const size_t smem = FATTN_SMEM(nsgmax);
|
2949
3374
|
if (smem > device.maxThreadgroupMemoryLength) {
|
2950
3375
|
break;
|
2951
3376
|
}
|
@@ -2956,16 +3381,15 @@ static void lm_ggml_metal_encode_node(
|
|
2956
3381
|
// simdgroups per threadgroup (a.k.a. warps)
|
2957
3382
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
2958
3383
|
|
2959
|
-
const size_t smem =
|
3384
|
+
const size_t smem = FATTN_SMEM(nsg);
|
2960
3385
|
|
2961
|
-
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
3386
|
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
2962
3387
|
LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
2963
|
-
|
2964
|
-
|
2965
|
-
|
3388
|
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
3389
|
+
#undef FATTN_SMEM
|
2966
3390
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2967
3391
|
} else {
|
2968
|
-
//
|
3392
|
+
// half4x4 kernel
|
2969
3393
|
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
2970
3394
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
2971
3395
|
|
@@ -2973,8 +3397,28 @@ static void lm_ggml_metal_encode_node(
|
|
2973
3397
|
LM_GGML_ASSERT(nqptg % 1 == 0);
|
2974
3398
|
LM_GGML_ASSERT(ncpsg % 32 == 0);
|
2975
3399
|
|
3400
|
+
// ne00 + 2*ncpsg*(nsg)
|
3401
|
+
// for each query, we load it as f16 in shared memory (ne00)
|
3402
|
+
// and store the soft_max values and the mask
|
3403
|
+
//
|
3404
|
+
// ne00*(nsg)
|
3405
|
+
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
3406
|
+
//
|
3407
|
+
#define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
3408
|
+
|
3409
|
+
int64_t nsgmax = 2;
|
3410
|
+
|
3411
|
+
while (true) {
|
3412
|
+
const size_t smem = FATTN_SMEM(nsgmax);
|
3413
|
+
if (smem > device.maxThreadgroupMemoryLength) {
|
3414
|
+
break;
|
3415
|
+
}
|
3416
|
+
nsgmax *= 2;
|
3417
|
+
}
|
3418
|
+
nsgmax /= 2;
|
3419
|
+
|
2976
3420
|
// simdgroups per threadgroup (a.k.a. warps)
|
2977
|
-
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
3421
|
+
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
2978
3422
|
|
2979
3423
|
int64_t nsg = 1;
|
2980
3424
|
while (nsg <= nsgt) {
|
@@ -2982,12 +3426,12 @@ static void lm_ggml_metal_encode_node(
|
|
2982
3426
|
}
|
2983
3427
|
nsg /= 2;
|
2984
3428
|
|
2985
|
-
const size_t smem = (
|
3429
|
+
const size_t smem = FATTN_SMEM(nsg);
|
2986
3430
|
|
2987
|
-
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
3431
|
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
2988
3432
|
LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
2989
|
-
[encoder setThreadgroupMemoryLength:
|
2990
|
-
|
3433
|
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
3434
|
+
#undef FATTN_SMEM
|
2991
3435
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2992
3436
|
}
|
2993
3437
|
} break;
|
@@ -3009,6 +3453,7 @@ static void lm_ggml_metal_encode_node(
|
|
3009
3453
|
switch (dstt) {
|
3010
3454
|
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
3011
3455
|
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
3456
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
|
3012
3457
|
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
3013
3458
|
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
3014
3459
|
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
@@ -3026,28 +3471,40 @@ static void lm_ggml_metal_encode_node(
|
|
3026
3471
|
default: LM_GGML_ABORT("not implemented");
|
3027
3472
|
};
|
3028
3473
|
} break;
|
3474
|
+
case LM_GGML_TYPE_BF16:
|
3475
|
+
{
|
3476
|
+
switch (dstt) {
|
3477
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
3478
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
3479
|
+
default: LM_GGML_ASSERT(false && "not implemented");
|
3480
|
+
};
|
3481
|
+
} break;
|
3029
3482
|
default: LM_GGML_ABORT("not implemented");
|
3030
3483
|
}
|
3031
3484
|
|
3485
|
+
lm_ggml_metal_kargs_cpy args = {
|
3486
|
+
/*.ne00 =*/ ne00,
|
3487
|
+
/*.ne01 =*/ ne01,
|
3488
|
+
/*.ne02 =*/ ne02,
|
3489
|
+
/*.ne03 =*/ ne03,
|
3490
|
+
/*.nb00 =*/ nb00,
|
3491
|
+
/*.nb01 =*/ nb01,
|
3492
|
+
/*.nb02 =*/ nb02,
|
3493
|
+
/*.nb03 =*/ nb03,
|
3494
|
+
/*.ne0 =*/ ne0,
|
3495
|
+
/*.ne1 =*/ ne1,
|
3496
|
+
/*.ne2 =*/ ne2,
|
3497
|
+
/*.ne3 =*/ ne3,
|
3498
|
+
/*.nb0 =*/ nb0,
|
3499
|
+
/*.nb1 =*/ nb1,
|
3500
|
+
/*.nb2 =*/ nb2,
|
3501
|
+
/*.nb3 =*/ nb3,
|
3502
|
+
};
|
3503
|
+
|
3032
3504
|
[encoder setComputePipelineState:pipeline];
|
3033
|
-
[encoder
|
3034
|
-
[encoder setBuffer:
|
3035
|
-
[encoder
|
3036
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
3037
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
3038
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
3039
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
3040
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
3041
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
3042
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
3043
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
3044
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
3045
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
3046
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
3047
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
3048
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
3049
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
3050
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
3505
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3506
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3507
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
3051
3508
|
|
3052
3509
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3053
3510
|
} break;
|
@@ -3092,6 +3549,7 @@ static void lm_ggml_metal_encode_node(
|
|
3092
3549
|
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
3093
3550
|
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
3094
3551
|
|
3552
|
+
// TODO: add lm_ggml_metal_kargs struct
|
3095
3553
|
[encoder setComputePipelineState:pipeline];
|
3096
3554
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3097
3555
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -3279,6 +3737,12 @@ static void * lm_ggml_backend_metal_buffer_get_base(lm_ggml_backend_buffer_t buf
|
|
3279
3737
|
return ctx->all_data;
|
3280
3738
|
}
|
3281
3739
|
|
3740
|
+
static void lm_ggml_backend_metal_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
3741
|
+
memset((char *)tensor->data + offset, value, size);
|
3742
|
+
|
3743
|
+
UNUSED(buffer);
|
3744
|
+
}
|
3745
|
+
|
3282
3746
|
static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
3283
3747
|
memcpy((char *)tensor->data + offset, data, size);
|
3284
3748
|
|
@@ -3311,7 +3775,7 @@ static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = {
|
|
3311
3775
|
/* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer,
|
3312
3776
|
/* .get_base = */ lm_ggml_backend_metal_buffer_get_base,
|
3313
3777
|
/* .init_tensor = */ NULL,
|
3314
|
-
/* .memset_tensor = */
|
3778
|
+
/* .memset_tensor = */ lm_ggml_backend_metal_buffer_memset_tensor,
|
3315
3779
|
/* .set_tensor = */ lm_ggml_backend_metal_buffer_set_tensor,
|
3316
3780
|
/* .get_tensor = */ lm_ggml_backend_metal_buffer_get_tensor,
|
3317
3781
|
/* .cpy_tensor = */ lm_ggml_backend_metal_buffer_cpy_tensor,
|
@@ -3844,7 +4308,7 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_
|
|
3844
4308
|
}
|
3845
4309
|
}
|
3846
4310
|
|
3847
|
-
return lm_ggml_backend_buffer_init(
|
4311
|
+
return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
|
3848
4312
|
}
|
3849
4313
|
|
3850
4314
|
static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) {
|
@@ -3854,7 +4318,8 @@ static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev,
|
|
3854
4318
|
}
|
3855
4319
|
|
3856
4320
|
static bool lm_ggml_backend_metal_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) {
|
3857
|
-
return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name
|
4321
|
+
return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name ||
|
4322
|
+
buft->iface.get_name == lm_ggml_backend_metal_buffer_from_ptr_type_get_name;
|
3858
4323
|
|
3859
4324
|
UNUSED(dev);
|
3860
4325
|
}
|
@@ -3907,19 +4372,45 @@ static lm_ggml_backend_dev_t lm_ggml_backend_metal_reg_device_get(lm_ggml_backen
|
|
3907
4372
|
LM_GGML_UNUSED(index);
|
3908
4373
|
}
|
3909
4374
|
|
4375
|
+
static struct lm_ggml_backend_feature g_lm_ggml_backend_metal_features[] = {
|
4376
|
+
#if defined(LM_GGML_METAL_EMBED_LIBRARY)
|
4377
|
+
{ "EMBED_LIBRARY", "1" },
|
4378
|
+
#endif
|
4379
|
+
#if defined(LM_GGML_METAL_USE_BF16)
|
4380
|
+
{ "BF16", "1" },
|
4381
|
+
#endif
|
4382
|
+
{ nil, nil },
|
4383
|
+
};
|
4384
|
+
|
4385
|
+
static struct lm_ggml_backend_feature * lm_ggml_backend_metal_get_features(lm_ggml_backend_reg_t reg) {
|
4386
|
+
return g_lm_ggml_backend_metal_features;
|
4387
|
+
|
4388
|
+
LM_GGML_UNUSED(reg);
|
4389
|
+
}
|
4390
|
+
|
4391
|
+
static void * lm_ggml_backend_metal_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) {
|
4392
|
+
if (strcmp(name, "lm_ggml_backend_get_features") == 0) {
|
4393
|
+
return (void *)lm_ggml_backend_metal_get_features;
|
4394
|
+
}
|
4395
|
+
|
4396
|
+
return NULL;
|
4397
|
+
|
4398
|
+
LM_GGML_UNUSED(reg);
|
4399
|
+
}
|
3910
4400
|
static struct lm_ggml_backend_reg_i lm_ggml_backend_metal_reg_i = {
|
3911
4401
|
/* .get_name = */ lm_ggml_backend_metal_reg_get_name,
|
3912
4402
|
/* .device_count = */ lm_ggml_backend_metal_reg_device_count,
|
3913
4403
|
/* .device_get = */ lm_ggml_backend_metal_reg_device_get,
|
3914
|
-
/* .get_proc_address = */
|
4404
|
+
/* .get_proc_address = */ lm_ggml_backend_metal_get_proc_address,
|
3915
4405
|
};
|
3916
4406
|
|
3917
4407
|
lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) {
|
3918
4408
|
// TODO: make this thread-safe somehow?
|
3919
4409
|
{
|
3920
4410
|
g_lm_ggml_backend_metal_reg = (struct lm_ggml_backend_reg) {
|
3921
|
-
/* .
|
3922
|
-
/* .
|
4411
|
+
/* .api_version = */ LM_GGML_BACKEND_API_VERSION,
|
4412
|
+
/* .iface = */ lm_ggml_backend_metal_reg_i,
|
4413
|
+
/* .context = */ NULL,
|
3923
4414
|
};
|
3924
4415
|
|
3925
4416
|
g_lm_ggml_backend_metal_device = (struct lm_ggml_backend_device) {
|
@@ -3931,3 +4422,5 @@ lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) {
|
|
3931
4422
|
|
3932
4423
|
return &g_lm_ggml_backend_metal_reg;
|
3933
4424
|
}
|
4425
|
+
|
4426
|
+
LM_GGML_BACKEND_DL_IMPL(lm_ggml_backend_metal_reg)
|