cui-llama.rn 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +5 -7
- package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
- package/android/src/main/jni.cpp +9 -9
- package/cpp/common.cpp +28 -44
- package/cpp/common.h +35 -14
- package/cpp/ggml-alloc.c +0 -1
- package/cpp/ggml-backend-impl.h +38 -20
- package/cpp/ggml-backend-reg.cpp +246 -92
- package/cpp/ggml-backend.h +1 -0
- package/cpp/ggml-common.h +42 -48
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +642 -223
- package/cpp/ggml-cpu-aarch64.h +2 -26
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +14122 -13971
- package/cpp/ggml-cpu.cpp +627 -715
- package/cpp/ggml-cpu.h +0 -17
- package/cpp/ggml-impl.h +22 -6
- package/cpp/ggml-metal.m +482 -24
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml.c +284 -178
- package/cpp/ggml.h +73 -25
- package/cpp/llama-grammar.cpp +15 -15
- package/cpp/llama-grammar.h +2 -5
- package/cpp/llama-sampling.cpp +35 -90
- package/cpp/llama-vocab.cpp +7 -2
- package/cpp/llama-vocab.h +1 -1
- package/cpp/llama.cpp +1782 -586
- package/cpp/llama.h +20 -19
- package/cpp/sampling.cpp +11 -16
- package/cpp/sgemm.cpp +265 -258
- package/cpp/sgemm.h +2 -2
- package/cpp/speculative.cpp +4 -0
- package/cpp/unicode.cpp +51 -51
- package/cpp/unicode.h +9 -10
- package/lib/commonjs/index.js +38 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/index.js +36 -0
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +2 -3
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +36 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +3 -3
- package/src/index.ts +46 -2
- package/cpp/amx/amx.cpp +0 -196
- package/cpp/amx/amx.h +0 -20
- package/cpp/amx/common.h +0 -101
- package/cpp/amx/mmq.cpp +0 -2524
- package/cpp/amx/mmq.h +0 -16
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
package/cpp/ggml-metal.m
CHANGED
@@ -175,6 +175,46 @@ enum lm_ggml_metal_kernel_type {
|
|
175
175
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
176
176
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
177
177
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
178
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
179
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
180
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
181
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
|
182
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
|
183
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
|
184
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
|
185
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
|
186
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
|
187
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
|
188
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
|
189
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
|
190
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
|
191
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
|
192
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
|
193
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
|
194
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
|
195
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
|
196
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
|
197
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
|
198
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
|
199
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
200
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
201
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
202
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
203
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
204
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
205
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
|
206
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
|
207
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
|
208
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
|
209
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
|
210
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
|
211
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
|
212
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
|
213
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
|
214
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
|
215
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
|
216
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
|
217
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
|
178
218
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
179
219
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
180
220
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
@@ -266,8 +306,11 @@ enum lm_ggml_metal_kernel_type {
|
|
266
306
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
267
307
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
268
308
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
309
|
+
LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
|
310
|
+
LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
269
311
|
LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
270
312
|
LM_GGML_METAL_KERNEL_TYPE_PAD_F32,
|
313
|
+
LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
|
271
314
|
LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
272
315
|
LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
273
316
|
LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
@@ -329,6 +372,8 @@ enum lm_ggml_metal_kernel_type {
|
|
329
372
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
330
373
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
331
374
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
375
|
+
LM_GGML_METAL_KERNEL_TYPE_SET_I32,
|
376
|
+
LM_GGML_METAL_KERNEL_TYPE_SET_F32,
|
332
377
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
333
378
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
334
379
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
@@ -350,6 +395,7 @@ enum lm_ggml_metal_kernel_type {
|
|
350
395
|
LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
351
396
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
352
397
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
398
|
+
LM_GGML_METAL_KERNEL_TYPE_ARGMAX,
|
353
399
|
|
354
400
|
LM_GGML_METAL_KERNEL_TYPE_COUNT
|
355
401
|
};
|
@@ -464,6 +510,35 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
464
510
|
#endif
|
465
511
|
|
466
512
|
NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
|
513
|
+
if (path_lib == nil) {
|
514
|
+
// Try to find the resource in the directory where the current binary located.
|
515
|
+
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
516
|
+
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
|
517
|
+
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
518
|
+
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
519
|
+
LM_GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
|
520
|
+
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
|
521
|
+
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
522
|
+
// Optionally, if this is a symlink, try to resolve it.
|
523
|
+
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
|
524
|
+
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
|
525
|
+
// It is a relative path, adding the binary directory as directory prefix.
|
526
|
+
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
|
527
|
+
}
|
528
|
+
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
529
|
+
// Link to the resource could not be resolved.
|
530
|
+
default_metallib_path = nil;
|
531
|
+
} else {
|
532
|
+
LM_GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
|
533
|
+
}
|
534
|
+
}
|
535
|
+
} else {
|
536
|
+
// The resource couldn't be found in the binary's directory.
|
537
|
+
default_metallib_path = nil;
|
538
|
+
}
|
539
|
+
path_lib = default_metallib_path;
|
540
|
+
}
|
541
|
+
|
467
542
|
if (try_metallib && path_lib != nil) {
|
468
543
|
// pre-compiled library found
|
469
544
|
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
@@ -699,6 +774,46 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
699
774
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
700
775
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
701
776
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
777
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
778
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
779
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
780
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
781
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
782
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
783
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
784
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
785
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
786
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
787
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
788
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
789
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
790
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
791
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
792
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
793
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
794
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
795
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
796
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
797
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
798
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
799
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
800
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
801
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
802
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
803
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
804
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
|
805
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
|
806
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
|
807
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
|
808
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
|
809
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
|
810
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
|
811
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
|
812
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
|
813
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
|
814
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
|
815
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
|
816
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
|
702
817
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
703
818
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
704
819
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
@@ -790,8 +905,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
790
905
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
791
906
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
792
907
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
908
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
|
909
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
793
910
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
794
911
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
912
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
|
795
913
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
796
914
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
797
915
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
@@ -853,6 +971,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
853
971
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
854
972
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
855
973
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
974
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
975
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
856
976
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
857
977
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
858
978
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
@@ -872,6 +992,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
872
992
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
873
993
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
874
994
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
995
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
875
996
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
876
997
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
877
998
|
}
|
@@ -989,6 +1110,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
989
1110
|
case LM_GGML_OP_REPEAT:
|
990
1111
|
case LM_GGML_OP_SCALE:
|
991
1112
|
case LM_GGML_OP_CLAMP:
|
1113
|
+
case LM_GGML_OP_CONV_TRANSPOSE_1D:
|
992
1114
|
return true;
|
993
1115
|
case LM_GGML_OP_SQR:
|
994
1116
|
case LM_GGML_OP_SQRT:
|
@@ -1001,9 +1123,20 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1001
1123
|
return has_simdgroup_reduction;
|
1002
1124
|
case LM_GGML_OP_RMS_NORM:
|
1003
1125
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
1126
|
+
case LM_GGML_OP_ARGMAX:
|
1004
1127
|
case LM_GGML_OP_NORM:
|
1005
|
-
case LM_GGML_OP_ROPE:
|
1006
1128
|
return true;
|
1129
|
+
case LM_GGML_OP_ROPE:
|
1130
|
+
{
|
1131
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
1132
|
+
if (mode & LM_GGML_ROPE_TYPE_MROPE) {
|
1133
|
+
return false;
|
1134
|
+
}
|
1135
|
+
if (mode & LM_GGML_ROPE_TYPE_VISION) {
|
1136
|
+
return false;
|
1137
|
+
}
|
1138
|
+
return true;
|
1139
|
+
}
|
1007
1140
|
case LM_GGML_OP_IM2COL:
|
1008
1141
|
return op->src[0]->type == LM_GGML_TYPE_F16;
|
1009
1142
|
case LM_GGML_OP_POOL_1D:
|
@@ -1011,6 +1144,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1011
1144
|
case LM_GGML_OP_POOL_2D:
|
1012
1145
|
case LM_GGML_OP_UPSCALE:
|
1013
1146
|
case LM_GGML_OP_PAD:
|
1147
|
+
case LM_GGML_OP_PAD_REFLECT_1D:
|
1014
1148
|
case LM_GGML_OP_ARANGE:
|
1015
1149
|
case LM_GGML_OP_TIMESTEP_EMBEDDING:
|
1016
1150
|
case LM_GGML_OP_ARGSORT:
|
@@ -1068,6 +1202,16 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1068
1202
|
return false;
|
1069
1203
|
};
|
1070
1204
|
}
|
1205
|
+
case LM_GGML_OP_SET:
|
1206
|
+
{
|
1207
|
+
switch (op->src[0]->type) {
|
1208
|
+
case LM_GGML_TYPE_F32:
|
1209
|
+
case LM_GGML_TYPE_I32:
|
1210
|
+
return true;
|
1211
|
+
default:
|
1212
|
+
return false;
|
1213
|
+
};
|
1214
|
+
}
|
1071
1215
|
case LM_GGML_OP_DIAG_MASK_INF:
|
1072
1216
|
case LM_GGML_OP_GET_ROWS:
|
1073
1217
|
{
|
@@ -1928,30 +2072,180 @@ static void lm_ggml_metal_encode_node(
|
|
1928
2072
|
|
1929
2073
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1930
2074
|
// to the matrix-vector kernel
|
1931
|
-
int ne11_mm_min = 4;
|
2075
|
+
const int ne11_mm_min = 4;
|
2076
|
+
|
2077
|
+
// first try to use small-batch mat-mv kernels
|
2078
|
+
// these should be efficient for BS [2, ~8]
|
2079
|
+
if (src1t == LM_GGML_TYPE_F32 && (ne00%256 == 0) &&
|
2080
|
+
(
|
2081
|
+
(
|
2082
|
+
(
|
2083
|
+
src0t == LM_GGML_TYPE_F16 || // TODO: helper function
|
2084
|
+
src0t == LM_GGML_TYPE_Q4_0 ||
|
2085
|
+
src0t == LM_GGML_TYPE_Q4_1 ||
|
2086
|
+
src0t == LM_GGML_TYPE_Q5_0 ||
|
2087
|
+
src0t == LM_GGML_TYPE_Q5_1 ||
|
2088
|
+
src0t == LM_GGML_TYPE_Q8_0 ||
|
2089
|
+
src0t == LM_GGML_TYPE_IQ4_NL ||
|
2090
|
+
false) && (ne11 >= 2 && ne11 <= 8)
|
2091
|
+
) ||
|
2092
|
+
(
|
2093
|
+
(
|
2094
|
+
src0t == LM_GGML_TYPE_Q4_K ||
|
2095
|
+
src0t == LM_GGML_TYPE_Q5_K ||
|
2096
|
+
src0t == LM_GGML_TYPE_Q6_K ||
|
2097
|
+
false) && (ne11 >= 4 && ne11 <= 8)
|
2098
|
+
)
|
2099
|
+
)
|
2100
|
+
) {
|
2101
|
+
// TODO: determine the optimal parameters based on grid utilization
|
2102
|
+
// I still don't know why we should not always use the maximum available threads:
|
2103
|
+
//
|
2104
|
+
// nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
|
2105
|
+
//
|
2106
|
+
// my current hypothesis is that the work grid is not evenly divisible for different nsg
|
2107
|
+
// values and there can be some tail effects when nsg is high. need to confirm this
|
2108
|
+
//
|
2109
|
+
const int nsg = 2; // num simdgroups per threadgroup
|
2110
|
+
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
2111
|
+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
2112
|
+
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
2113
|
+
int r1ptg = 4; // num src1 rows per threadgroup
|
2114
|
+
|
2115
|
+
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
|
2116
|
+
switch (ne11) {
|
2117
|
+
case 2:
|
2118
|
+
r1ptg = 2; break;
|
2119
|
+
case 3:
|
2120
|
+
case 6:
|
2121
|
+
r1ptg = 3; break;
|
2122
|
+
case 4:
|
2123
|
+
case 7:
|
2124
|
+
case 8:
|
2125
|
+
r1ptg = 4; break;
|
2126
|
+
case 5:
|
2127
|
+
r1ptg = 5; break;
|
2128
|
+
};
|
1932
2129
|
|
1933
|
-
|
1934
|
-
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
1935
|
-
// these numbers do not translate to other devices or model sizes
|
1936
|
-
// TODO: need to find a better approach
|
1937
|
-
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
1938
|
-
switch (src0t) {
|
1939
|
-
case LM_GGML_TYPE_F16: ne11_mm_min = 2; break;
|
1940
|
-
case LM_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
1941
|
-
case LM_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
1942
|
-
case LM_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
1943
|
-
case LM_GGML_TYPE_Q4_0:
|
1944
|
-
case LM_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
1945
|
-
case LM_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
1946
|
-
case LM_GGML_TYPE_Q5_0: // not tested yet
|
1947
|
-
case LM_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
1948
|
-
case LM_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
1949
|
-
case LM_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
1950
|
-
default: ne11_mm_min = 1; break;
|
1951
|
-
}
|
1952
|
-
}
|
1953
|
-
#endif
|
2130
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1954
2131
|
|
2132
|
+
switch (src0->type) {
|
2133
|
+
case LM_GGML_TYPE_F16:
|
2134
|
+
switch (r1ptg) {
|
2135
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
2136
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
|
2137
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
|
2138
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
|
2139
|
+
default: LM_GGML_ABORT("not implemented");
|
2140
|
+
} break;
|
2141
|
+
case LM_GGML_TYPE_Q4_0:
|
2142
|
+
switch (r1ptg) {
|
2143
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
|
2144
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
|
2145
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
|
2146
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
|
2147
|
+
default: LM_GGML_ABORT("not implemented");
|
2148
|
+
} break;
|
2149
|
+
case LM_GGML_TYPE_Q4_1:
|
2150
|
+
switch (r1ptg) {
|
2151
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
|
2152
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
|
2153
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
|
2154
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
|
2155
|
+
default: LM_GGML_ABORT("not implemented");
|
2156
|
+
} break;
|
2157
|
+
case LM_GGML_TYPE_Q5_0:
|
2158
|
+
switch (r1ptg) {
|
2159
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
|
2160
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
|
2161
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
|
2162
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
|
2163
|
+
default: LM_GGML_ABORT("not implemented");
|
2164
|
+
} break;
|
2165
|
+
case LM_GGML_TYPE_Q5_1:
|
2166
|
+
switch (r1ptg) {
|
2167
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
|
2168
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
|
2169
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
|
2170
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
|
2171
|
+
default: LM_GGML_ABORT("not implemented");
|
2172
|
+
} break;
|
2173
|
+
case LM_GGML_TYPE_Q8_0:
|
2174
|
+
switch (r1ptg) {
|
2175
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
|
2176
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
|
2177
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
|
2178
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
2179
|
+
default: LM_GGML_ABORT("not implemented");
|
2180
|
+
} break;
|
2181
|
+
case LM_GGML_TYPE_Q4_K:
|
2182
|
+
switch (r1ptg) {
|
2183
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
2184
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
|
2185
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
|
2186
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
|
2187
|
+
default: LM_GGML_ABORT("not implemented");
|
2188
|
+
} break;
|
2189
|
+
case LM_GGML_TYPE_Q5_K:
|
2190
|
+
switch (r1ptg) {
|
2191
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
|
2192
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
|
2193
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
|
2194
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
|
2195
|
+
default: LM_GGML_ABORT("not implemented");
|
2196
|
+
} break;
|
2197
|
+
case LM_GGML_TYPE_Q6_K:
|
2198
|
+
switch (r1ptg) {
|
2199
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
|
2200
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
|
2201
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
|
2202
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
|
2203
|
+
default: LM_GGML_ABORT("not implemented");
|
2204
|
+
} break;
|
2205
|
+
case LM_GGML_TYPE_IQ4_NL:
|
2206
|
+
switch (r1ptg) {
|
2207
|
+
case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
|
2208
|
+
case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
|
2209
|
+
case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
|
2210
|
+
case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
|
2211
|
+
default: LM_GGML_ABORT("not implemented");
|
2212
|
+
} break;
|
2213
|
+
default: LM_GGML_ABORT("not implemented");
|
2214
|
+
}
|
2215
|
+
|
2216
|
+
lm_ggml_metal_kargs_mul_mv_ext args = {
|
2217
|
+
/*.ne00 =*/ ne00,
|
2218
|
+
/*.ne01 =*/ ne01,
|
2219
|
+
/*.ne02 =*/ ne02,
|
2220
|
+
/*.nb00 =*/ nb00,
|
2221
|
+
/*.nb01 =*/ nb01,
|
2222
|
+
/*.nb02 =*/ nb02,
|
2223
|
+
/*.nb03 =*/ nb03,
|
2224
|
+
/*.ne10 =*/ ne10,
|
2225
|
+
/*.ne11 =*/ ne11,
|
2226
|
+
/*.ne12 =*/ ne12,
|
2227
|
+
/*.nb10 =*/ nb10,
|
2228
|
+
/*.nb11 =*/ nb11,
|
2229
|
+
/*.nb12 =*/ nb12,
|
2230
|
+
/*.nb13 =*/ nb13,
|
2231
|
+
/*.ne0 =*/ ne0,
|
2232
|
+
/*.ne1 =*/ ne1,
|
2233
|
+
/*.r2 =*/ r2,
|
2234
|
+
/*.r3 =*/ r3,
|
2235
|
+
/*.nsg =*/ nsg,
|
2236
|
+
/*.nxpsg =*/ nxpsg,
|
2237
|
+
/*.r1ptg =*/ r1ptg,
|
2238
|
+
};
|
2239
|
+
|
2240
|
+
[encoder setComputePipelineState:pipeline];
|
2241
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2242
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2243
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2244
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2245
|
+
|
2246
|
+
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
|
2247
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2248
|
+
} else
|
1955
2249
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1956
2250
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1957
2251
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
@@ -2742,7 +3036,9 @@ static void lm_ggml_metal_encode_node(
|
|
2742
3036
|
} break;
|
2743
3037
|
case LM_GGML_OP_ROPE:
|
2744
3038
|
{
|
2745
|
-
|
3039
|
+
// make sure we have one or more position id(ne10) per token(ne02)
|
3040
|
+
LM_GGML_ASSERT(ne10 % ne02 == 0);
|
3041
|
+
LM_GGML_ASSERT(ne10 >= ne02);
|
2746
3042
|
|
2747
3043
|
const int nth = MIN(1024, ne00);
|
2748
3044
|
|
@@ -2908,6 +3204,49 @@ static void lm_ggml_metal_encode_node(
|
|
2908
3204
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
2909
3205
|
}
|
2910
3206
|
} break;
|
3207
|
+
case LM_GGML_OP_CONV_TRANSPOSE_1D:
|
3208
|
+
{
|
3209
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
|
3210
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
|
3211
|
+
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16 || src0->type == LM_GGML_TYPE_F32);
|
3212
|
+
LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
|
3213
|
+
LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
|
3214
|
+
|
3215
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
3216
|
+
|
3217
|
+
const int32_t IC = src1->ne[1];
|
3218
|
+
const int32_t IL = src1->ne[0];
|
3219
|
+
|
3220
|
+
const int32_t K = src0->ne[0];
|
3221
|
+
|
3222
|
+
const int32_t OL = dst->ne[0];
|
3223
|
+
const int32_t OC = dst->ne[1];
|
3224
|
+
|
3225
|
+
id<MTLComputePipelineState> pipeline;
|
3226
|
+
|
3227
|
+
switch (src0->type) {
|
3228
|
+
case LM_GGML_TYPE_F32: {
|
3229
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
|
3230
|
+
} break;
|
3231
|
+
case LM_GGML_TYPE_F16: {
|
3232
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
|
3233
|
+
} break;
|
3234
|
+
default: LM_GGML_ABORT("fatal error");
|
3235
|
+
};
|
3236
|
+
|
3237
|
+
[encoder setComputePipelineState:pipeline];
|
3238
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3239
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
3240
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
3241
|
+
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
|
3242
|
+
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
3243
|
+
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
3244
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
3245
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
3246
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
3247
|
+
|
3248
|
+
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
3249
|
+
} break;
|
2911
3250
|
case LM_GGML_OP_UPSCALE:
|
2912
3251
|
{
|
2913
3252
|
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
|
@@ -2977,6 +3316,38 @@ static void lm_ggml_metal_encode_node(
|
|
2977
3316
|
|
2978
3317
|
const int nth = MIN(1024, ne0);
|
2979
3318
|
|
3319
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3320
|
+
} break;
|
3321
|
+
case LM_GGML_OP_PAD_REFLECT_1D:
|
3322
|
+
{
|
3323
|
+
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
|
3324
|
+
|
3325
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
|
3326
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
|
3327
|
+
|
3328
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
3329
|
+
|
3330
|
+
[encoder setComputePipelineState:pipeline];
|
3331
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3332
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3333
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
3334
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
3335
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
3336
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
3337
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
3338
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
3339
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
3340
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
3341
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
3342
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
3343
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
3344
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
3345
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
3346
|
+
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
3347
|
+
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
3348
|
+
|
3349
|
+
const int nth = MIN(1024, ne0);
|
3350
|
+
|
2980
3351
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2981
3352
|
} break;
|
2982
3353
|
case LM_GGML_OP_ARANGE:
|
@@ -3508,6 +3879,68 @@ static void lm_ggml_metal_encode_node(
|
|
3508
3879
|
|
3509
3880
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3510
3881
|
} break;
|
3882
|
+
case LM_GGML_OP_SET:
|
3883
|
+
{
|
3884
|
+
LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
|
3885
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0));
|
3886
|
+
|
3887
|
+
// src0 and dst as viewed during set
|
3888
|
+
const size_t dst_nb0 = lm_ggml_element_size(src0);
|
3889
|
+
|
3890
|
+
const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
|
3891
|
+
const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
|
3892
|
+
const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
|
3893
|
+
const size_t offset = ((int32_t *) dst->op_params)[3];
|
3894
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
3895
|
+
|
3896
|
+
if (!inplace) {
|
3897
|
+
memcpy(((char *) dst->data), ((char *) src0->data), lm_ggml_nbytes(dst));
|
3898
|
+
}
|
3899
|
+
|
3900
|
+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
3901
|
+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
3902
|
+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
3903
|
+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
3904
|
+
|
3905
|
+
LM_GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= lm_ggml_nbytes(dst));
|
3906
|
+
|
3907
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3908
|
+
|
3909
|
+
switch (src0t) {
|
3910
|
+
case LM_GGML_TYPE_F32:
|
3911
|
+
LM_GGML_ASSERT(nb10 == sizeof(float));
|
3912
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
|
3913
|
+
case LM_GGML_TYPE_I32:
|
3914
|
+
LM_GGML_ASSERT(nb10 == sizeof(int32_t));
|
3915
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
|
3916
|
+
default: LM_GGML_ABORT("fatal error");
|
3917
|
+
}
|
3918
|
+
|
3919
|
+
lm_ggml_metal_kargs_set args = {
|
3920
|
+
/*.ne10 =*/ ne10,
|
3921
|
+
/*.ne11 =*/ ne11,
|
3922
|
+
/*.ne12 =*/ ne12,
|
3923
|
+
/*.nb10 =*/ nb10,
|
3924
|
+
/*.nb11 =*/ nb11,
|
3925
|
+
/*.nb12 =*/ nb12,
|
3926
|
+
/*.nb13 =*/ nb13,
|
3927
|
+
/*.nb1 =*/ dst_nb1,
|
3928
|
+
/*.nb2 =*/ dst_nb2,
|
3929
|
+
/*.nb3 =*/ dst_nb3,
|
3930
|
+
/*.offs =*/ offset,
|
3931
|
+
/*.inplace =*/ inplace,
|
3932
|
+
};
|
3933
|
+
|
3934
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
|
3935
|
+
|
3936
|
+
[encoder setComputePipelineState:pipeline];
|
3937
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3938
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3939
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
3940
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3941
|
+
|
3942
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3943
|
+
} break;
|
3511
3944
|
case LM_GGML_OP_POOL_2D:
|
3512
3945
|
{
|
3513
3946
|
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
|
@@ -3567,6 +4000,31 @@ static void lm_ggml_metal_encode_node(
|
|
3567
4000
|
|
3568
4001
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
3569
4002
|
} break;
|
4003
|
+
case LM_GGML_OP_ARGMAX:
|
4004
|
+
{
|
4005
|
+
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
|
4006
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
|
4007
|
+
LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
|
4008
|
+
|
4009
|
+
const int64_t nrows = lm_ggml_nrows(src0);
|
4010
|
+
|
4011
|
+
int nth = 32; // SIMD width
|
4012
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
4013
|
+
nth *= 2;
|
4014
|
+
}
|
4015
|
+
|
4016
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
|
4017
|
+
|
4018
|
+
[encoder setComputePipelineState:pipeline];
|
4019
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
4020
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
4021
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
4022
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
4023
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
4024
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
|
4025
|
+
|
4026
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
4027
|
+
} break;
|
3570
4028
|
default:
|
3571
4029
|
{
|
3572
4030
|
LM_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
|