whisper.rn 0.4.0-rc.5 → 0.4.0-rc.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +5 -5
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +7 -2
- package/android/src/main/jni.cpp +3 -2
- package/cpp/ggml-alloc.h +1 -1
- package/cpp/ggml-metal-whisper.metal +1497 -169
- package/cpp/ggml-metal.m +530 -53
- package/cpp/ggml-quants.c +2 -2
- package/cpp/ggml.c +264 -99
- package/cpp/ggml.h +21 -7
- package/cpp/rn-whisper.cpp +3 -0
- package/cpp/rn-whisper.h +3 -2
- package/ios/RNWhisperContext.mm +10 -6
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/index.d.ts +5 -0
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/index.ts +5 -0
- package/src/version.json +1 -1
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml-metal.m
CHANGED
|
@@ -66,9 +66,11 @@ struct wsp_ggml_metal_context {
|
|
|
66
66
|
WSP_GGML_METAL_DECL_KERNEL(div_row);
|
|
67
67
|
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
68
68
|
WSP_GGML_METAL_DECL_KERNEL(scale_4);
|
|
69
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
69
|
+
WSP_GGML_METAL_DECL_KERNEL(tanh);
|
|
70
70
|
WSP_GGML_METAL_DECL_KERNEL(relu);
|
|
71
71
|
WSP_GGML_METAL_DECL_KERNEL(gelu);
|
|
72
|
+
WSP_GGML_METAL_DECL_KERNEL(gelu_quick);
|
|
73
|
+
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
72
74
|
WSP_GGML_METAL_DECL_KERNEL(soft_max);
|
|
73
75
|
WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
|
|
74
76
|
WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
@@ -86,6 +88,7 @@ struct wsp_ggml_metal_context {
|
|
|
86
88
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
|
87
89
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
88
90
|
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
91
|
+
WSP_GGML_METAL_DECL_KERNEL(group_norm);
|
|
89
92
|
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
90
93
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
91
94
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
@@ -102,6 +105,21 @@ struct wsp_ggml_metal_context {
|
|
|
102
105
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
|
103
106
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
|
104
107
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
108
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
|
109
|
+
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
|
110
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
111
|
+
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
112
|
+
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
113
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
|
114
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
|
115
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
|
116
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
|
117
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
|
118
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
|
119
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
|
120
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
|
121
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
|
122
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
|
105
123
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
106
124
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
107
125
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
@@ -130,8 +148,11 @@ struct wsp_ggml_metal_context {
|
|
|
130
148
|
WSP_GGML_METAL_DECL_KERNEL(rope_f16);
|
|
131
149
|
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
132
150
|
WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
151
|
+
WSP_GGML_METAL_DECL_KERNEL(upscale_f32);
|
|
152
|
+
WSP_GGML_METAL_DECL_KERNEL(pad_f32);
|
|
133
153
|
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
|
134
154
|
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
155
|
+
WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
|
135
156
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
136
157
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
137
158
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
@@ -140,6 +161,7 @@ struct wsp_ggml_metal_context {
|
|
|
140
161
|
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
|
141
162
|
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
|
142
163
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
164
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
|
143
165
|
WSP_GGML_METAL_DECL_KERNEL(concat);
|
|
144
166
|
WSP_GGML_METAL_DECL_KERNEL(sqr);
|
|
145
167
|
WSP_GGML_METAL_DECL_KERNEL(sum_rows);
|
|
@@ -318,9 +340,11 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
318
340
|
WSP_GGML_METAL_ADD_KERNEL(div_row);
|
|
319
341
|
WSP_GGML_METAL_ADD_KERNEL(scale);
|
|
320
342
|
WSP_GGML_METAL_ADD_KERNEL(scale_4);
|
|
321
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
343
|
+
WSP_GGML_METAL_ADD_KERNEL(tanh);
|
|
322
344
|
WSP_GGML_METAL_ADD_KERNEL(relu);
|
|
323
345
|
WSP_GGML_METAL_ADD_KERNEL(gelu);
|
|
346
|
+
WSP_GGML_METAL_ADD_KERNEL(gelu_quick);
|
|
347
|
+
WSP_GGML_METAL_ADD_KERNEL(silu);
|
|
324
348
|
WSP_GGML_METAL_ADD_KERNEL(soft_max);
|
|
325
349
|
WSP_GGML_METAL_ADD_KERNEL(soft_max_4);
|
|
326
350
|
WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
@@ -338,6 +362,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
338
362
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
|
339
363
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
|
340
364
|
WSP_GGML_METAL_ADD_KERNEL(rms_norm);
|
|
365
|
+
WSP_GGML_METAL_ADD_KERNEL(group_norm);
|
|
341
366
|
WSP_GGML_METAL_ADD_KERNEL(norm);
|
|
342
367
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
|
343
368
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
|
@@ -354,6 +379,21 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
354
379
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
|
355
380
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
|
356
381
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
|
382
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
|
383
|
+
//WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
|
384
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
|
385
|
+
//WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
|
386
|
+
//WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
|
387
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
|
388
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
|
389
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
|
390
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
|
391
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
|
392
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
|
393
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
|
394
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
|
395
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
|
396
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
|
357
397
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
358
398
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
|
359
399
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
@@ -384,8 +424,11 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
384
424
|
WSP_GGML_METAL_ADD_KERNEL(rope_f16);
|
|
385
425
|
WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
386
426
|
WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
|
|
427
|
+
WSP_GGML_METAL_ADD_KERNEL(upscale_f32);
|
|
428
|
+
WSP_GGML_METAL_ADD_KERNEL(pad_f32);
|
|
387
429
|
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
|
388
430
|
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
|
431
|
+
WSP_GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
|
389
432
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
390
433
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
391
434
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
|
@@ -394,6 +437,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
394
437
|
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
|
395
438
|
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
|
396
439
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
440
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
|
397
441
|
WSP_GGML_METAL_ADD_KERNEL(concat);
|
|
398
442
|
WSP_GGML_METAL_ADD_KERNEL(sqr);
|
|
399
443
|
WSP_GGML_METAL_ADD_KERNEL(sum_rows);
|
|
@@ -416,9 +460,11 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
416
460
|
WSP_GGML_METAL_DEL_KERNEL(div_row);
|
|
417
461
|
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
418
462
|
WSP_GGML_METAL_DEL_KERNEL(scale_4);
|
|
419
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
463
|
+
WSP_GGML_METAL_DEL_KERNEL(tanh);
|
|
420
464
|
WSP_GGML_METAL_DEL_KERNEL(relu);
|
|
421
465
|
WSP_GGML_METAL_DEL_KERNEL(gelu);
|
|
466
|
+
WSP_GGML_METAL_DEL_KERNEL(gelu_quick);
|
|
467
|
+
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
422
468
|
WSP_GGML_METAL_DEL_KERNEL(soft_max);
|
|
423
469
|
WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
|
|
424
470
|
WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
@@ -436,6 +482,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
436
482
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
|
437
483
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
438
484
|
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
485
|
+
WSP_GGML_METAL_DEL_KERNEL(group_norm);
|
|
439
486
|
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
440
487
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
441
488
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
@@ -452,6 +499,21 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
452
499
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
|
453
500
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
|
454
501
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
502
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
|
503
|
+
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
|
504
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
505
|
+
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
506
|
+
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
507
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
|
508
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
|
509
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
|
510
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
|
511
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
|
512
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
|
513
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
|
514
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
|
515
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
|
516
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
|
455
517
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
456
518
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
457
519
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
@@ -482,8 +544,11 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
482
544
|
WSP_GGML_METAL_DEL_KERNEL(rope_f16);
|
|
483
545
|
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
484
546
|
WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
547
|
+
WSP_GGML_METAL_DEL_KERNEL(upscale_f32);
|
|
548
|
+
WSP_GGML_METAL_DEL_KERNEL(pad_f32);
|
|
485
549
|
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
|
486
550
|
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
551
|
+
WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
|
487
552
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
488
553
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
489
554
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
@@ -492,6 +557,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
492
557
|
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
|
493
558
|
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
|
494
559
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
560
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
|
495
561
|
WSP_GGML_METAL_DEL_KERNEL(concat);
|
|
496
562
|
WSP_GGML_METAL_DEL_KERNEL(sqr);
|
|
497
563
|
WSP_GGML_METAL_DEL_KERNEL(sum_rows);
|
|
@@ -783,9 +849,11 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
|
783
849
|
switch (op->op) {
|
|
784
850
|
case WSP_GGML_OP_UNARY:
|
|
785
851
|
switch (wsp_ggml_get_unary_op(op)) {
|
|
786
|
-
case
|
|
852
|
+
case WSP_GGML_UNARY_OP_TANH:
|
|
787
853
|
case WSP_GGML_UNARY_OP_RELU:
|
|
788
854
|
case WSP_GGML_UNARY_OP_GELU:
|
|
855
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
856
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
789
857
|
return true;
|
|
790
858
|
default:
|
|
791
859
|
return false;
|
|
@@ -797,6 +865,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
|
797
865
|
case WSP_GGML_OP_PERMUTE:
|
|
798
866
|
case WSP_GGML_OP_CONCAT:
|
|
799
867
|
case WSP_GGML_OP_ADD:
|
|
868
|
+
case WSP_GGML_OP_ACC:
|
|
800
869
|
case WSP_GGML_OP_MUL:
|
|
801
870
|
case WSP_GGML_OP_DIV:
|
|
802
871
|
case WSP_GGML_OP_SCALE:
|
|
@@ -804,21 +873,50 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
|
804
873
|
case WSP_GGML_OP_SUM_ROWS:
|
|
805
874
|
case WSP_GGML_OP_SOFT_MAX:
|
|
806
875
|
case WSP_GGML_OP_RMS_NORM:
|
|
876
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
807
877
|
case WSP_GGML_OP_NORM:
|
|
808
878
|
case WSP_GGML_OP_ALIBI:
|
|
809
879
|
case WSP_GGML_OP_ROPE:
|
|
810
880
|
case WSP_GGML_OP_IM2COL:
|
|
881
|
+
case WSP_GGML_OP_UPSCALE:
|
|
882
|
+
case WSP_GGML_OP_PAD:
|
|
811
883
|
case WSP_GGML_OP_ARGSORT:
|
|
812
|
-
case
|
|
813
|
-
case WSP_GGML_OP_CPY:
|
|
814
|
-
case WSP_GGML_OP_CONT:
|
|
884
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
815
885
|
case WSP_GGML_OP_MUL_MAT:
|
|
816
886
|
case WSP_GGML_OP_MUL_MAT_ID:
|
|
817
887
|
return true;
|
|
888
|
+
case WSP_GGML_OP_CPY:
|
|
889
|
+
case WSP_GGML_OP_DUP:
|
|
890
|
+
case WSP_GGML_OP_CONT:
|
|
891
|
+
{
|
|
892
|
+
switch (op->src[0]->type) {
|
|
893
|
+
case WSP_GGML_TYPE_F32:
|
|
894
|
+
switch (op->type) {
|
|
895
|
+
case WSP_GGML_TYPE_F16:
|
|
896
|
+
case WSP_GGML_TYPE_F32:
|
|
897
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
898
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
899
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
900
|
+
return true;
|
|
901
|
+
default:
|
|
902
|
+
return false;
|
|
903
|
+
}
|
|
904
|
+
case WSP_GGML_TYPE_F16:
|
|
905
|
+
switch (op->type) {
|
|
906
|
+
case WSP_GGML_TYPE_F16:
|
|
907
|
+
case WSP_GGML_TYPE_F32:
|
|
908
|
+
return true;
|
|
909
|
+
default:
|
|
910
|
+
return false;
|
|
911
|
+
}
|
|
912
|
+
default:
|
|
913
|
+
return false;
|
|
914
|
+
};
|
|
915
|
+
}
|
|
818
916
|
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
819
917
|
case WSP_GGML_OP_GET_ROWS:
|
|
820
918
|
{
|
|
821
|
-
return op->ne[
|
|
919
|
+
return op->ne[3] == 1;
|
|
822
920
|
}
|
|
823
921
|
default:
|
|
824
922
|
return false;
|
|
@@ -894,7 +992,10 @@ void wsp_ggml_metal_graph_compute(
|
|
|
894
992
|
} break;
|
|
895
993
|
}
|
|
896
994
|
|
|
897
|
-
|
|
995
|
+
if (!wsp_ggml_metal_supports_op(dst)) {
|
|
996
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
|
|
997
|
+
WSP_GGML_ASSERT(!"unsupported op");
|
|
998
|
+
}
|
|
898
999
|
|
|
899
1000
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
900
1001
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
@@ -991,34 +1092,39 @@ void wsp_ggml_metal_graph_compute(
|
|
|
991
1092
|
case WSP_GGML_OP_MUL:
|
|
992
1093
|
case WSP_GGML_OP_DIV:
|
|
993
1094
|
{
|
|
994
|
-
|
|
995
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
1095
|
+
const size_t offs = 0;
|
|
996
1096
|
|
|
997
1097
|
bool bcast_row = false;
|
|
998
1098
|
|
|
999
1099
|
int64_t nb = ne00;
|
|
1000
1100
|
|
|
1001
|
-
|
|
1101
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1102
|
+
|
|
1103
|
+
if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
1104
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1105
|
+
|
|
1002
1106
|
// src1 is a row
|
|
1003
1107
|
WSP_GGML_ASSERT(ne11 == 1);
|
|
1004
1108
|
|
|
1005
1109
|
nb = ne00 / 4;
|
|
1006
1110
|
switch (dst->op) {
|
|
1007
|
-
case WSP_GGML_OP_ADD:
|
|
1008
|
-
case WSP_GGML_OP_MUL:
|
|
1009
|
-
case WSP_GGML_OP_DIV:
|
|
1111
|
+
case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
|
1112
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
|
1113
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
|
1010
1114
|
default: WSP_GGML_ASSERT(false);
|
|
1011
1115
|
}
|
|
1012
1116
|
|
|
1013
1117
|
bcast_row = true;
|
|
1014
1118
|
} else {
|
|
1015
1119
|
switch (dst->op) {
|
|
1016
|
-
case WSP_GGML_OP_ADD:
|
|
1017
|
-
case WSP_GGML_OP_MUL:
|
|
1018
|
-
case WSP_GGML_OP_DIV:
|
|
1120
|
+
case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
|
1121
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
|
1122
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
|
1019
1123
|
default: WSP_GGML_ASSERT(false);
|
|
1020
1124
|
}
|
|
1021
1125
|
}
|
|
1126
|
+
|
|
1127
|
+
[encoder setComputePipelineState:pipeline];
|
|
1022
1128
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1023
1129
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1024
1130
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
@@ -1046,18 +1152,99 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1046
1152
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1047
1153
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1048
1154
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1049
|
-
[encoder setBytes:&
|
|
1155
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1156
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
1050
1157
|
|
|
1051
1158
|
if (bcast_row) {
|
|
1052
1159
|
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
1053
1160
|
|
|
1054
1161
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1055
1162
|
} else {
|
|
1056
|
-
const int nth = MIN(
|
|
1163
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
1057
1164
|
|
|
1058
1165
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1059
1166
|
}
|
|
1060
1167
|
} break;
|
|
1168
|
+
case WSP_GGML_OP_ACC:
|
|
1169
|
+
{
|
|
1170
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
|
|
1171
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1172
|
+
WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
|
|
1173
|
+
|
|
1174
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1175
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
1176
|
+
|
|
1177
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
|
1178
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
|
1179
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
|
1180
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
|
1181
|
+
|
|
1182
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
1183
|
+
|
|
1184
|
+
if (!inplace) {
|
|
1185
|
+
// run a separete kernel to cpy src->dst
|
|
1186
|
+
// not sure how to avoid this
|
|
1187
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1188
|
+
|
|
1189
|
+
const int nth = MIN(1024, ne00);
|
|
1190
|
+
|
|
1191
|
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
|
1192
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1193
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1194
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1195
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1196
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1197
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1198
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1199
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1200
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1201
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1202
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1203
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1204
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1205
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1206
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1207
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1208
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1209
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1210
|
+
|
|
1211
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
1215
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1216
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1217
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1218
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1219
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1220
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1221
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1222
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1223
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
|
1224
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
|
1225
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
|
1226
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1227
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1228
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1229
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1230
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1231
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1232
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1233
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1234
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1235
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1236
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1237
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1238
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1239
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
|
1240
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
|
1241
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
|
1242
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1243
|
+
|
|
1244
|
+
const int nth = MIN(1024, ne0);
|
|
1245
|
+
|
|
1246
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1247
|
+
} break;
|
|
1061
1248
|
case WSP_GGML_OP_SCALE:
|
|
1062
1249
|
{
|
|
1063
1250
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
@@ -1081,16 +1268,15 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1081
1268
|
} break;
|
|
1082
1269
|
case WSP_GGML_OP_UNARY:
|
|
1083
1270
|
switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
|
|
1084
|
-
case
|
|
1271
|
+
case WSP_GGML_UNARY_OP_TANH:
|
|
1085
1272
|
{
|
|
1086
|
-
[encoder setComputePipelineState:ctx->
|
|
1273
|
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
|
1087
1274
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1088
1275
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1089
1276
|
|
|
1090
1277
|
const int64_t n = wsp_ggml_nelements(dst);
|
|
1091
|
-
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1092
1278
|
|
|
1093
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
|
1279
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1094
1280
|
} break;
|
|
1095
1281
|
case WSP_GGML_UNARY_OP_RELU:
|
|
1096
1282
|
{
|
|
@@ -1111,6 +1297,28 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1111
1297
|
const int64_t n = wsp_ggml_nelements(dst);
|
|
1112
1298
|
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1113
1299
|
|
|
1300
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1301
|
+
} break;
|
|
1302
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
1303
|
+
{
|
|
1304
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
|
1305
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1306
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1307
|
+
|
|
1308
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1309
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1310
|
+
|
|
1311
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1312
|
+
} break;
|
|
1313
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
1314
|
+
{
|
|
1315
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
|
1316
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1317
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1318
|
+
|
|
1319
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1320
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1321
|
+
|
|
1114
1322
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1115
1323
|
} break;
|
|
1116
1324
|
default:
|
|
@@ -1185,6 +1393,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1185
1393
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1186
1394
|
if (id_src1) {
|
|
1187
1395
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1396
|
+
} else {
|
|
1397
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
1188
1398
|
}
|
|
1189
1399
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1190
1400
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
@@ -1436,7 +1646,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1436
1646
|
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1437
1647
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1438
1648
|
} else {
|
|
1439
|
-
int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1649
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1440
1650
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1441
1651
|
}
|
|
1442
1652
|
}
|
|
@@ -1448,7 +1658,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1448
1658
|
|
|
1449
1659
|
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
|
|
1450
1660
|
|
|
1451
|
-
const int n_as =
|
|
1661
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
|
1452
1662
|
|
|
1453
1663
|
// TODO: make this more general
|
|
1454
1664
|
WSP_GGML_ASSERT(n_as <= 8);
|
|
@@ -1480,14 +1690,22 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1480
1690
|
|
|
1481
1691
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1482
1692
|
// to the matrix-vector kernel
|
|
1483
|
-
int ne11_mm_min =
|
|
1693
|
+
int ne11_mm_min = 1;
|
|
1484
1694
|
|
|
1485
1695
|
const int idx = ((int32_t *) dst->op_params)[0];
|
|
1486
1696
|
|
|
1697
|
+
// batch size
|
|
1698
|
+
WSP_GGML_ASSERT(ne01 == ne11);
|
|
1699
|
+
|
|
1700
|
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
|
1701
|
+
|
|
1487
1702
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1488
1703
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1489
|
-
|
|
1490
|
-
|
|
1704
|
+
// !!!
|
|
1705
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
1706
|
+
// indirect matrix multiplication
|
|
1707
|
+
// !!!
|
|
1708
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
|
1491
1709
|
switch (src2->type) {
|
|
1492
1710
|
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
|
1493
1711
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
@@ -1506,19 +1724,22 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1506
1724
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1507
1725
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1508
1726
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1509
|
-
[encoder setBytes:&
|
|
1510
|
-
[encoder setBytes:&
|
|
1511
|
-
[encoder setBytes:&
|
|
1512
|
-
[encoder setBytes:&
|
|
1513
|
-
[encoder setBytes:&
|
|
1514
|
-
[encoder setBytes:&
|
|
1515
|
-
[encoder setBytes:&
|
|
1516
|
-
[encoder setBytes:&
|
|
1517
|
-
[encoder setBytes:&
|
|
1518
|
-
[encoder setBytes:&
|
|
1519
|
-
[encoder setBytes:&
|
|
1520
|
-
[encoder setBytes:&
|
|
1521
|
-
[encoder setBytes:&
|
|
1727
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1728
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1729
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
|
1730
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
1731
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
|
1732
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
1733
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
|
1734
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
1735
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
1736
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
1737
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
1738
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
|
1739
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1740
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
|
1741
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
|
1742
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
|
1522
1743
|
// TODO: how to make this an array? read Metal docs
|
|
1523
1744
|
for (int j = 0; j < n_as; ++j) {
|
|
1524
1745
|
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
@@ -1526,11 +1747,157 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1526
1747
|
size_t offs_src_cur = 0;
|
|
1527
1748
|
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1528
1749
|
|
|
1529
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:
|
|
1750
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
1530
1751
|
}
|
|
1531
1752
|
|
|
1532
1753
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1533
|
-
|
|
1754
|
+
|
|
1755
|
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
|
1756
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1757
|
+
} else {
|
|
1758
|
+
int nth0 = 32;
|
|
1759
|
+
int nth1 = 1;
|
|
1760
|
+
int nrows = 1;
|
|
1761
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1762
|
+
|
|
1763
|
+
// use custom matrix x vector kernel
|
|
1764
|
+
switch (src2t) {
|
|
1765
|
+
case WSP_GGML_TYPE_F32:
|
|
1766
|
+
{
|
|
1767
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1768
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
|
1769
|
+
} break;
|
|
1770
|
+
case WSP_GGML_TYPE_F16:
|
|
1771
|
+
{
|
|
1772
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1773
|
+
nth0 = 32;
|
|
1774
|
+
nth1 = 1;
|
|
1775
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
|
1776
|
+
} break;
|
|
1777
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1778
|
+
{
|
|
1779
|
+
nth0 = 8;
|
|
1780
|
+
nth1 = 8;
|
|
1781
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
|
1782
|
+
} break;
|
|
1783
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
1784
|
+
{
|
|
1785
|
+
nth0 = 8;
|
|
1786
|
+
nth1 = 8;
|
|
1787
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
|
1788
|
+
} break;
|
|
1789
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1790
|
+
{
|
|
1791
|
+
nth0 = 8;
|
|
1792
|
+
nth1 = 8;
|
|
1793
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
|
1794
|
+
} break;
|
|
1795
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1796
|
+
{
|
|
1797
|
+
nth0 = 8;
|
|
1798
|
+
nth1 = 8;
|
|
1799
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
|
1800
|
+
} break;
|
|
1801
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
1802
|
+
{
|
|
1803
|
+
nth0 = 8;
|
|
1804
|
+
nth1 = 8;
|
|
1805
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
|
1806
|
+
} break;
|
|
1807
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
1808
|
+
{
|
|
1809
|
+
nth0 = 2;
|
|
1810
|
+
nth1 = 32;
|
|
1811
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
|
1812
|
+
} break;
|
|
1813
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
1814
|
+
{
|
|
1815
|
+
nth0 = 2;
|
|
1816
|
+
nth1 = 32;
|
|
1817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
|
1818
|
+
} break;
|
|
1819
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
1820
|
+
{
|
|
1821
|
+
nth0 = 4; //1;
|
|
1822
|
+
nth1 = 8; //32;
|
|
1823
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
|
1824
|
+
} break;
|
|
1825
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
1826
|
+
{
|
|
1827
|
+
nth0 = 2;
|
|
1828
|
+
nth1 = 32;
|
|
1829
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
|
1830
|
+
} break;
|
|
1831
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
1832
|
+
{
|
|
1833
|
+
nth0 = 2;
|
|
1834
|
+
nth1 = 32;
|
|
1835
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
|
1836
|
+
} break;
|
|
1837
|
+
default:
|
|
1838
|
+
{
|
|
1839
|
+
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1840
|
+
WSP_GGML_ASSERT(false && "not implemented");
|
|
1841
|
+
}
|
|
1842
|
+
};
|
|
1843
|
+
|
|
1844
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1845
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1846
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1847
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1848
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1849
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
1850
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
|
1851
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
|
1852
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
|
1853
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
|
1854
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1855
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
|
1856
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1857
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1858
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1859
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1860
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1861
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
1862
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
|
1863
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
1864
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
|
1865
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
|
1866
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
|
1867
|
+
// TODO: how to make this an array? read Metal docs
|
|
1868
|
+
for (int j = 0; j < n_as; ++j) {
|
|
1869
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
1870
|
+
|
|
1871
|
+
size_t offs_src_cur = 0;
|
|
1872
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1873
|
+
|
|
1874
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
|
1875
|
+
}
|
|
1876
|
+
|
|
1877
|
+
if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
|
|
1878
|
+
src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
|
|
1879
|
+
src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1880
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1881
|
+
}
|
|
1882
|
+
else if (src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1883
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1884
|
+
}
|
|
1885
|
+
else if (src2t == WSP_GGML_TYPE_Q3_K) {
|
|
1886
|
+
#ifdef WSP_GGML_QKK_64
|
|
1887
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1888
|
+
#else
|
|
1889
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1890
|
+
#endif
|
|
1891
|
+
}
|
|
1892
|
+
else if (src2t == WSP_GGML_TYPE_Q5_K) {
|
|
1893
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1894
|
+
}
|
|
1895
|
+
else if (src2t == WSP_GGML_TYPE_Q6_K) {
|
|
1896
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1897
|
+
} else {
|
|
1898
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
|
1899
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1900
|
+
}
|
|
1534
1901
|
}
|
|
1535
1902
|
} break;
|
|
1536
1903
|
case WSP_GGML_OP_GET_ROWS:
|
|
@@ -1551,16 +1918,19 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1551
1918
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1552
1919
|
}
|
|
1553
1920
|
|
|
1554
|
-
[encoder setBuffer:id_src0
|
|
1555
|
-
[encoder setBuffer:id_src1
|
|
1556
|
-
[encoder setBuffer:id_dst
|
|
1921
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1922
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1923
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1557
1924
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1558
1925
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1559
|
-
[encoder setBytes:&
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
[encoder
|
|
1926
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
1927
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
1928
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
1929
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
1930
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
1931
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
1932
|
+
|
|
1933
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
1564
1934
|
} break;
|
|
1565
1935
|
case WSP_GGML_OP_RMS_NORM:
|
|
1566
1936
|
{
|
|
@@ -1587,6 +1957,38 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1587
1957
|
|
|
1588
1958
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1589
1959
|
} break;
|
|
1960
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
1961
|
+
{
|
|
1962
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1963
|
+
|
|
1964
|
+
//float eps;
|
|
1965
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
1966
|
+
|
|
1967
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
|
1968
|
+
|
|
1969
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
|
1970
|
+
|
|
1971
|
+
int nth = 32; // SIMD width
|
|
1972
|
+
|
|
1973
|
+
//while (nth < ne00/4 && nth < 1024) {
|
|
1974
|
+
// nth *= 2;
|
|
1975
|
+
//}
|
|
1976
|
+
|
|
1977
|
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
|
1978
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1979
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1980
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1981
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1982
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1983
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
1984
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
1985
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
1986
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
1987
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
1988
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1989
|
+
|
|
1990
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1991
|
+
} break;
|
|
1590
1992
|
case WSP_GGML_OP_NORM:
|
|
1591
1993
|
{
|
|
1592
1994
|
float eps;
|
|
@@ -1756,6 +2158,65 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1756
2158
|
|
|
1757
2159
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
1758
2160
|
} break;
|
|
2161
|
+
case WSP_GGML_OP_UPSCALE:
|
|
2162
|
+
{
|
|
2163
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2164
|
+
|
|
2165
|
+
const int sf = dst->op_params[0];
|
|
2166
|
+
|
|
2167
|
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
|
2168
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2169
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2170
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2171
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2172
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2173
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2174
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2175
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2176
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2177
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2178
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2179
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2180
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2181
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2182
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2183
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2184
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2185
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2186
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
|
2187
|
+
|
|
2188
|
+
const int nth = MIN(1024, ne0);
|
|
2189
|
+
|
|
2190
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2191
|
+
} break;
|
|
2192
|
+
case WSP_GGML_OP_PAD:
|
|
2193
|
+
{
|
|
2194
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2195
|
+
|
|
2196
|
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
|
2197
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2198
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2199
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2200
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2201
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2202
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2203
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2204
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2205
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2206
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2207
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2208
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2209
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2210
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2211
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2212
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2213
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2214
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2215
|
+
|
|
2216
|
+
const int nth = MIN(1024, ne0);
|
|
2217
|
+
|
|
2218
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2219
|
+
} break;
|
|
1759
2220
|
case WSP_GGML_OP_ARGSORT:
|
|
1760
2221
|
{
|
|
1761
2222
|
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
@@ -1777,6 +2238,22 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1777
2238
|
|
|
1778
2239
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
|
1779
2240
|
} break;
|
|
2241
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
2242
|
+
{
|
|
2243
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2244
|
+
|
|
2245
|
+
float slope;
|
|
2246
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
|
2247
|
+
|
|
2248
|
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
|
2249
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2250
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2251
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
|
2252
|
+
|
|
2253
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2254
|
+
|
|
2255
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2256
|
+
} break;
|
|
1780
2257
|
case WSP_GGML_OP_DUP:
|
|
1781
2258
|
case WSP_GGML_OP_CPY:
|
|
1782
2259
|
case WSP_GGML_OP_CONT:
|
|
@@ -1805,7 +2282,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1805
2282
|
{
|
|
1806
2283
|
switch (dstt) {
|
|
1807
2284
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
|
1808
|
-
case WSP_GGML_TYPE_F32:
|
|
2285
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
|
1809
2286
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1810
2287
|
};
|
|
1811
2288
|
} break;
|