cui-llama.rn 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/android/src/main/CMakeLists.txt +5 -7
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
  3. package/android/src/main/jni.cpp +9 -9
  4. package/cpp/common.cpp +21 -40
  5. package/cpp/common.h +21 -12
  6. package/cpp/ggml-backend-impl.h +38 -20
  7. package/cpp/ggml-backend-reg.cpp +216 -87
  8. package/cpp/ggml-backend.h +1 -0
  9. package/cpp/ggml-common.h +42 -48
  10. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +591 -152
  11. package/cpp/ggml-cpu-aarch64.h +2 -26
  12. package/cpp/ggml-cpu-traits.cpp +36 -0
  13. package/cpp/ggml-cpu-traits.h +38 -0
  14. package/cpp/ggml-cpu.c +14122 -13971
  15. package/cpp/ggml-cpu.cpp +618 -715
  16. package/cpp/ggml-cpu.h +0 -17
  17. package/cpp/ggml-impl.h +6 -6
  18. package/cpp/ggml-metal.m +482 -24
  19. package/cpp/ggml-quants.c +0 -9
  20. package/cpp/ggml-threading.h +4 -2
  21. package/cpp/ggml.c +132 -43
  22. package/cpp/ggml.h +44 -13
  23. package/cpp/llama-sampling.cpp +35 -90
  24. package/cpp/llama-vocab.cpp +2 -1
  25. package/cpp/llama.cpp +737 -233
  26. package/cpp/llama.h +20 -16
  27. package/cpp/sampling.cpp +11 -16
  28. package/cpp/speculative.cpp +4 -0
  29. package/cpp/unicode.cpp +51 -51
  30. package/cpp/unicode.h +9 -10
  31. package/lib/commonjs/index.js +38 -1
  32. package/lib/commonjs/index.js.map +1 -1
  33. package/lib/module/index.js +36 -0
  34. package/lib/module/index.js.map +1 -1
  35. package/lib/typescript/NativeRNLlama.d.ts +2 -3
  36. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  37. package/lib/typescript/index.d.ts +36 -2
  38. package/lib/typescript/index.d.ts.map +1 -1
  39. package/package.json +1 -1
  40. package/src/NativeRNLlama.ts +3 -3
  41. package/src/index.ts +46 -2
  42. package/cpp/amx/amx.cpp +0 -196
  43. package/cpp/amx/amx.h +0 -20
  44. package/cpp/amx/common.h +0 -101
  45. package/cpp/amx/mmq.cpp +0 -2524
  46. package/cpp/amx/mmq.h +0 -16
  47. package/cpp/ggml-aarch64.c +0 -129
  48. package/cpp/ggml-aarch64.h +0 -19
package/cpp/ggml-metal.m CHANGED
@@ -175,6 +175,46 @@ enum lm_ggml_metal_kernel_type {
175
175
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176
176
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177
177
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
178
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
179
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
180
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
181
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
182
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
183
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
184
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
185
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
186
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
187
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
188
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
189
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
190
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
191
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
192
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
193
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
194
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
195
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
196
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
197
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
198
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
199
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
202
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
203
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
204
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
205
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
206
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
207
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
208
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
209
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
210
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
211
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
212
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
213
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
214
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
215
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
216
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
217
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
178
218
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179
219
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180
220
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -266,8 +306,11 @@ enum lm_ggml_metal_kernel_type {
266
306
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
267
307
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
268
308
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
309
+ LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
310
+ LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
269
311
  LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
270
312
  LM_GGML_METAL_KERNEL_TYPE_PAD_F32,
313
+ LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
271
314
  LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
272
315
  LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
273
316
  LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -329,6 +372,8 @@ enum lm_ggml_metal_kernel_type {
329
372
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
330
373
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
331
374
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
375
+ LM_GGML_METAL_KERNEL_TYPE_SET_I32,
376
+ LM_GGML_METAL_KERNEL_TYPE_SET_F32,
332
377
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
333
378
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
334
379
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@@ -350,6 +395,7 @@ enum lm_ggml_metal_kernel_type {
350
395
  LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
351
396
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
352
397
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
398
+ LM_GGML_METAL_KERNEL_TYPE_ARGMAX,
353
399
 
354
400
  LM_GGML_METAL_KERNEL_TYPE_COUNT
355
401
  };
@@ -464,6 +510,35 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
464
510
  #endif
465
511
 
466
512
  NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
513
+ if (path_lib == nil) {
514
+ // Try to find the resource in the directory where the current binary located.
515
+ NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
516
+ NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
517
+ NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
518
+ if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
519
+ LM_GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
520
+ NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
521
+ if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
522
+ // Optionally, if this is a symlink, try to resolve it.
523
+ default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
524
+ if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
525
+ // It is a relative path, adding the binary directory as directory prefix.
526
+ default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
527
+ }
528
+ if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
529
+ // Link to the resource could not be resolved.
530
+ default_metallib_path = nil;
531
+ } else {
532
+ LM_GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
533
+ }
534
+ }
535
+ } else {
536
+ // The resource couldn't be found in the binary's directory.
537
+ default_metallib_path = nil;
538
+ }
539
+ path_lib = default_metallib_path;
540
+ }
541
+
467
542
  if (try_metallib && path_lib != nil) {
468
543
  // pre-compiled library found
469
544
  NSURL * libURL = [NSURL fileURLWithPath:path_lib];
@@ -699,6 +774,46 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
699
774
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700
775
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701
776
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
777
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
778
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
779
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
780
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
781
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
782
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
783
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
784
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
785
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
786
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
787
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
788
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
789
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
790
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
791
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
792
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
793
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
794
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
795
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
796
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
797
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
798
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
799
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
800
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
801
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
802
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
803
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
804
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
805
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
806
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
807
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
808
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
809
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
810
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
811
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
812
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
813
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
814
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
815
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
816
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
702
817
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703
818
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704
819
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -790,8 +905,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
790
905
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
791
906
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
792
907
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
908
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
909
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
793
910
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
794
911
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
912
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
795
913
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
796
914
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
797
915
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -853,6 +971,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
853
971
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
854
972
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
855
973
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
974
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
975
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
856
976
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
857
977
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
858
978
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
@@ -872,6 +992,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
872
992
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
873
993
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
874
994
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
995
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
875
996
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
876
997
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
877
998
  }
@@ -989,6 +1110,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
989
1110
  case LM_GGML_OP_REPEAT:
990
1111
  case LM_GGML_OP_SCALE:
991
1112
  case LM_GGML_OP_CLAMP:
1113
+ case LM_GGML_OP_CONV_TRANSPOSE_1D:
992
1114
  return true;
993
1115
  case LM_GGML_OP_SQR:
994
1116
  case LM_GGML_OP_SQRT:
@@ -1001,9 +1123,20 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1001
1123
  return has_simdgroup_reduction;
1002
1124
  case LM_GGML_OP_RMS_NORM:
1003
1125
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
1126
+ case LM_GGML_OP_ARGMAX:
1004
1127
  case LM_GGML_OP_NORM:
1005
- case LM_GGML_OP_ROPE:
1006
1128
  return true;
1129
+ case LM_GGML_OP_ROPE:
1130
+ {
1131
+ const int mode = ((const int32_t *) op->op_params)[2];
1132
+ if (mode & LM_GGML_ROPE_TYPE_MROPE) {
1133
+ return false;
1134
+ }
1135
+ if (mode & LM_GGML_ROPE_TYPE_VISION) {
1136
+ return false;
1137
+ }
1138
+ return true;
1139
+ }
1007
1140
  case LM_GGML_OP_IM2COL:
1008
1141
  return op->src[0]->type == LM_GGML_TYPE_F16;
1009
1142
  case LM_GGML_OP_POOL_1D:
@@ -1011,6 +1144,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1011
1144
  case LM_GGML_OP_POOL_2D:
1012
1145
  case LM_GGML_OP_UPSCALE:
1013
1146
  case LM_GGML_OP_PAD:
1147
+ case LM_GGML_OP_PAD_REFLECT_1D:
1014
1148
  case LM_GGML_OP_ARANGE:
1015
1149
  case LM_GGML_OP_TIMESTEP_EMBEDDING:
1016
1150
  case LM_GGML_OP_ARGSORT:
@@ -1068,6 +1202,16 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1068
1202
  return false;
1069
1203
  };
1070
1204
  }
1205
+ case LM_GGML_OP_SET:
1206
+ {
1207
+ switch (op->src[0]->type) {
1208
+ case LM_GGML_TYPE_F32:
1209
+ case LM_GGML_TYPE_I32:
1210
+ return true;
1211
+ default:
1212
+ return false;
1213
+ };
1214
+ }
1071
1215
  case LM_GGML_OP_DIAG_MASK_INF:
1072
1216
  case LM_GGML_OP_GET_ROWS:
1073
1217
  {
@@ -1928,30 +2072,180 @@ static void lm_ggml_metal_encode_node(
1928
2072
 
1929
2073
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1930
2074
  // to the matrix-vector kernel
1931
- int ne11_mm_min = 4;
2075
+ const int ne11_mm_min = 4;
2076
+
2077
+ // first try to use small-batch mat-mv kernels
2078
+ // these should be efficient for BS [2, ~8]
2079
+ if (src1t == LM_GGML_TYPE_F32 && (ne00%256 == 0) &&
2080
+ (
2081
+ (
2082
+ (
2083
+ src0t == LM_GGML_TYPE_F16 || // TODO: helper function
2084
+ src0t == LM_GGML_TYPE_Q4_0 ||
2085
+ src0t == LM_GGML_TYPE_Q4_1 ||
2086
+ src0t == LM_GGML_TYPE_Q5_0 ||
2087
+ src0t == LM_GGML_TYPE_Q5_1 ||
2088
+ src0t == LM_GGML_TYPE_Q8_0 ||
2089
+ src0t == LM_GGML_TYPE_IQ4_NL ||
2090
+ false) && (ne11 >= 2 && ne11 <= 8)
2091
+ ) ||
2092
+ (
2093
+ (
2094
+ src0t == LM_GGML_TYPE_Q4_K ||
2095
+ src0t == LM_GGML_TYPE_Q5_K ||
2096
+ src0t == LM_GGML_TYPE_Q6_K ||
2097
+ false) && (ne11 >= 4 && ne11 <= 8)
2098
+ )
2099
+ )
2100
+ ) {
2101
+ // TODO: determine the optimal parameters based on grid utilization
2102
+ // I still don't know why we should not always use the maximum available threads:
2103
+ //
2104
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
2105
+ //
2106
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
2107
+ // values and there can be some tail effects when nsg is high. need to confirm this
2108
+ //
2109
+ const int nsg = 2; // num simdgroups per threadgroup
2110
+ const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
2111
+ const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
2112
+ const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
2113
+ int r1ptg = 4; // num src1 rows per threadgroup
2114
+
2115
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
2116
+ switch (ne11) {
2117
+ case 2:
2118
+ r1ptg = 2; break;
2119
+ case 3:
2120
+ case 6:
2121
+ r1ptg = 3; break;
2122
+ case 4:
2123
+ case 7:
2124
+ case 8:
2125
+ r1ptg = 4; break;
2126
+ case 5:
2127
+ r1ptg = 5; break;
2128
+ };
1932
2129
 
1933
- #if 0
1934
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1935
- // these numbers do not translate to other devices or model sizes
1936
- // TODO: need to find a better approach
1937
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1938
- switch (src0t) {
1939
- case LM_GGML_TYPE_F16: ne11_mm_min = 2; break;
1940
- case LM_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1941
- case LM_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1942
- case LM_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1943
- case LM_GGML_TYPE_Q4_0:
1944
- case LM_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1945
- case LM_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1946
- case LM_GGML_TYPE_Q5_0: // not tested yet
1947
- case LM_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1948
- case LM_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1949
- case LM_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1950
- default: ne11_mm_min = 1; break;
1951
- }
1952
- }
1953
- #endif
2130
+ id<MTLComputePipelineState> pipeline = nil;
1954
2131
 
2132
+ switch (src0->type) {
2133
+ case LM_GGML_TYPE_F16:
2134
+ switch (r1ptg) {
2135
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
2136
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
2137
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
2138
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
2139
+ default: LM_GGML_ABORT("not implemented");
2140
+ } break;
2141
+ case LM_GGML_TYPE_Q4_0:
2142
+ switch (r1ptg) {
2143
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
2144
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
2145
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
2146
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
2147
+ default: LM_GGML_ABORT("not implemented");
2148
+ } break;
2149
+ case LM_GGML_TYPE_Q4_1:
2150
+ switch (r1ptg) {
2151
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
2152
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
2153
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
2154
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
2155
+ default: LM_GGML_ABORT("not implemented");
2156
+ } break;
2157
+ case LM_GGML_TYPE_Q5_0:
2158
+ switch (r1ptg) {
2159
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
2160
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
2161
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
2162
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
2163
+ default: LM_GGML_ABORT("not implemented");
2164
+ } break;
2165
+ case LM_GGML_TYPE_Q5_1:
2166
+ switch (r1ptg) {
2167
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
2168
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
2169
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
2170
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
2171
+ default: LM_GGML_ABORT("not implemented");
2172
+ } break;
2173
+ case LM_GGML_TYPE_Q8_0:
2174
+ switch (r1ptg) {
2175
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
2176
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
2177
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
2178
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
2179
+ default: LM_GGML_ABORT("not implemented");
2180
+ } break;
2181
+ case LM_GGML_TYPE_Q4_K:
2182
+ switch (r1ptg) {
2183
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
2184
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
2185
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
2186
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
2187
+ default: LM_GGML_ABORT("not implemented");
2188
+ } break;
2189
+ case LM_GGML_TYPE_Q5_K:
2190
+ switch (r1ptg) {
2191
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
2192
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
2193
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
2194
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
2195
+ default: LM_GGML_ABORT("not implemented");
2196
+ } break;
2197
+ case LM_GGML_TYPE_Q6_K:
2198
+ switch (r1ptg) {
2199
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
2200
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
2201
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
2202
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
2203
+ default: LM_GGML_ABORT("not implemented");
2204
+ } break;
2205
+ case LM_GGML_TYPE_IQ4_NL:
2206
+ switch (r1ptg) {
2207
+ case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
2208
+ case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
2209
+ case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
2210
+ case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
2211
+ default: LM_GGML_ABORT("not implemented");
2212
+ } break;
2213
+ default: LM_GGML_ABORT("not implemented");
2214
+ }
2215
+
2216
+ lm_ggml_metal_kargs_mul_mv_ext args = {
2217
+ /*.ne00 =*/ ne00,
2218
+ /*.ne01 =*/ ne01,
2219
+ /*.ne02 =*/ ne02,
2220
+ /*.nb00 =*/ nb00,
2221
+ /*.nb01 =*/ nb01,
2222
+ /*.nb02 =*/ nb02,
2223
+ /*.nb03 =*/ nb03,
2224
+ /*.ne10 =*/ ne10,
2225
+ /*.ne11 =*/ ne11,
2226
+ /*.ne12 =*/ ne12,
2227
+ /*.nb10 =*/ nb10,
2228
+ /*.nb11 =*/ nb11,
2229
+ /*.nb12 =*/ nb12,
2230
+ /*.nb13 =*/ nb13,
2231
+ /*.ne0 =*/ ne0,
2232
+ /*.ne1 =*/ ne1,
2233
+ /*.r2 =*/ r2,
2234
+ /*.r3 =*/ r3,
2235
+ /*.nsg =*/ nsg,
2236
+ /*.nxpsg =*/ nxpsg,
2237
+ /*.r1ptg =*/ r1ptg,
2238
+ };
2239
+
2240
+ [encoder setComputePipelineState:pipeline];
2241
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2242
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2243
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2244
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2245
+
2246
+ //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2247
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2248
+ } else
1955
2249
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1956
2250
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1957
2251
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
@@ -2742,7 +3036,9 @@ static void lm_ggml_metal_encode_node(
2742
3036
  } break;
2743
3037
  case LM_GGML_OP_ROPE:
2744
3038
  {
2745
- LM_GGML_ASSERT(ne10 == ne02);
3039
+ // make sure we have one or more position id(ne10) per token(ne02)
3040
+ LM_GGML_ASSERT(ne10 % ne02 == 0);
3041
+ LM_GGML_ASSERT(ne10 >= ne02);
2746
3042
 
2747
3043
  const int nth = MIN(1024, ne00);
2748
3044
 
@@ -2908,6 +3204,49 @@ static void lm_ggml_metal_encode_node(
2908
3204
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2909
3205
  }
2910
3206
  } break;
3207
+ case LM_GGML_OP_CONV_TRANSPOSE_1D:
3208
+ {
3209
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
3210
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
3211
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16 || src0->type == LM_GGML_TYPE_F32);
3212
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
3213
+ LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
3214
+
3215
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
3216
+
3217
+ const int32_t IC = src1->ne[1];
3218
+ const int32_t IL = src1->ne[0];
3219
+
3220
+ const int32_t K = src0->ne[0];
3221
+
3222
+ const int32_t OL = dst->ne[0];
3223
+ const int32_t OC = dst->ne[1];
3224
+
3225
+ id<MTLComputePipelineState> pipeline;
3226
+
3227
+ switch (src0->type) {
3228
+ case LM_GGML_TYPE_F32: {
3229
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
3230
+ } break;
3231
+ case LM_GGML_TYPE_F16: {
3232
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
3233
+ } break;
3234
+ default: LM_GGML_ABORT("fatal error");
3235
+ };
3236
+
3237
+ [encoder setComputePipelineState:pipeline];
3238
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3239
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3240
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3241
+ [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
3242
+ [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
3243
+ [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
3244
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
3245
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
3246
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
3247
+
3248
+ [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3249
+ } break;
2911
3250
  case LM_GGML_OP_UPSCALE:
2912
3251
  {
2913
3252
  LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
@@ -2977,6 +3316,38 @@ static void lm_ggml_metal_encode_node(
2977
3316
 
2978
3317
  const int nth = MIN(1024, ne0);
2979
3318
 
3319
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3320
+ } break;
3321
+ case LM_GGML_OP_PAD_REFLECT_1D:
3322
+ {
3323
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
3324
+
3325
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
3326
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
3327
+
3328
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3329
+
3330
+ [encoder setComputePipelineState:pipeline];
3331
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3332
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3333
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3334
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3335
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3336
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3337
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3338
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3339
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3340
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3341
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3342
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3343
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3344
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3345
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3346
+ [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3347
+ [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
3348
+
3349
+ const int nth = MIN(1024, ne0);
3350
+
2980
3351
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2981
3352
  } break;
2982
3353
  case LM_GGML_OP_ARANGE:
@@ -3508,6 +3879,68 @@ static void lm_ggml_metal_encode_node(
3508
3879
 
3509
3880
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3510
3881
  } break;
3882
+ case LM_GGML_OP_SET:
3883
+ {
3884
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
3885
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0));
3886
+
3887
+ // src0 and dst as viewed during set
3888
+ const size_t dst_nb0 = lm_ggml_element_size(src0);
3889
+
3890
+ const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
3891
+ const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
3892
+ const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
3893
+ const size_t offset = ((int32_t *) dst->op_params)[3];
3894
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
3895
+
3896
+ if (!inplace) {
3897
+ memcpy(((char *) dst->data), ((char *) src0->data), lm_ggml_nbytes(dst));
3898
+ }
3899
+
3900
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
3901
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
3902
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
3903
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
3904
+
3905
+ LM_GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= lm_ggml_nbytes(dst));
3906
+
3907
+ id<MTLComputePipelineState> pipeline = nil;
3908
+
3909
+ switch (src0t) {
3910
+ case LM_GGML_TYPE_F32:
3911
+ LM_GGML_ASSERT(nb10 == sizeof(float));
3912
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
3913
+ case LM_GGML_TYPE_I32:
3914
+ LM_GGML_ASSERT(nb10 == sizeof(int32_t));
3915
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
3916
+ default: LM_GGML_ABORT("fatal error");
3917
+ }
3918
+
3919
+ lm_ggml_metal_kargs_set args = {
3920
+ /*.ne10 =*/ ne10,
3921
+ /*.ne11 =*/ ne11,
3922
+ /*.ne12 =*/ ne12,
3923
+ /*.nb10 =*/ nb10,
3924
+ /*.nb11 =*/ nb11,
3925
+ /*.nb12 =*/ nb12,
3926
+ /*.nb13 =*/ nb13,
3927
+ /*.nb1 =*/ dst_nb1,
3928
+ /*.nb2 =*/ dst_nb2,
3929
+ /*.nb3 =*/ dst_nb3,
3930
+ /*.offs =*/ offset,
3931
+ /*.inplace =*/ inplace,
3932
+ };
3933
+
3934
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
3935
+
3936
+ [encoder setComputePipelineState:pipeline];
3937
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3938
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3939
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3940
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3941
+
3942
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3943
+ } break;
3511
3944
  case LM_GGML_OP_POOL_2D:
3512
3945
  {
3513
3946
  LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
@@ -3567,6 +4000,31 @@ static void lm_ggml_metal_encode_node(
3567
4000
 
3568
4001
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3569
4002
  } break;
4003
+ case LM_GGML_OP_ARGMAX:
4004
+ {
4005
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
4006
+ LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
4007
+ LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
4008
+
4009
+ const int64_t nrows = lm_ggml_nrows(src0);
4010
+
4011
+ int nth = 32; // SIMD width
4012
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
4013
+ nth *= 2;
4014
+ }
4015
+
4016
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
4017
+
4018
+ [encoder setComputePipelineState:pipeline];
4019
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4020
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4021
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
4022
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
4023
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
4024
+ [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
4025
+
4026
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4027
+ } break;
3570
4028
  default:
3571
4029
  {
3572
4030
  LM_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));