whisper.rn 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/android/src/main/jni.cpp +12 -3
  4. package/cpp/ggml-alloc.c +292 -130
  5. package/cpp/ggml-backend-impl.h +4 -4
  6. package/cpp/ggml-backend-reg.cpp +13 -5
  7. package/cpp/ggml-backend.cpp +207 -17
  8. package/cpp/ggml-backend.h +19 -1
  9. package/cpp/ggml-cpu/amx/amx.cpp +5 -2
  10. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  11. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  12. package/cpp/ggml-cpu/common.h +14 -0
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
  14. package/cpp/ggml-cpu/ggml-cpu.c +65 -44
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  16. package/cpp/ggml-cpu/ops.cpp +542 -775
  17. package/cpp/ggml-cpu/ops.h +2 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  19. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  20. package/cpp/ggml-cpu/unary-ops.h +5 -0
  21. package/cpp/ggml-cpu/vec.cpp +227 -20
  22. package/cpp/ggml-cpu/vec.h +407 -56
  23. package/cpp/ggml-cpu.h +1 -1
  24. package/cpp/ggml-impl.h +94 -12
  25. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  26. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  27. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  28. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  29. package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
  30. package/cpp/ggml-metal/ggml-metal-device.h +244 -0
  31. package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
  32. package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
  33. package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
  34. package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
  35. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  36. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  37. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +40 -40
  39. package/cpp/ggml-metal.h +1 -6
  40. package/cpp/ggml-quants.c +1 -0
  41. package/cpp/ggml.c +341 -15
  42. package/cpp/ggml.h +150 -5
  43. package/cpp/jsi/RNWhisperJSI.cpp +9 -2
  44. package/cpp/jsi/ThreadPool.h +3 -3
  45. package/cpp/rn-whisper.h +1 -0
  46. package/cpp/whisper.cpp +89 -72
  47. package/cpp/whisper.h +1 -0
  48. package/ios/CMakeLists.txt +6 -1
  49. package/ios/RNWhisperContext.mm +3 -1
  50. package/ios/RNWhisperVadContext.mm +14 -13
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  61. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  70. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  72. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  74. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  80. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  81. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  82. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  83. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  86. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  87. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  92. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  93. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  94. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  95. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  96. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  97. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  98. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  99. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  101. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  102. package/lib/commonjs/version.json +1 -1
  103. package/lib/module/NativeRNWhisper.js.map +1 -1
  104. package/lib/module/version.json +1 -1
  105. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  106. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  107. package/package.json +1 -1
  108. package/src/NativeRNWhisper.ts +2 -0
  109. package/src/version.json +1 -1
  110. package/whisper-rn.podspec +8 -9
  111. package/cpp/ggml-metal.m +0 -6779
  112. package/cpp/ggml-whisper-sim.metallib +0 -0
  113. package/cpp/ggml-whisper.metallib +0 -0
@@ -0,0 +1,3583 @@
1
+ #include "ggml-metal-ops.h"
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-backend-impl.h"
6
+
7
+ #include "ggml-metal-impl.h"
8
+ #include "ggml-metal-common.h"
9
+ #include "ggml-metal-device.h"
10
+
11
+ #include <cassert>
12
+ #include <algorithm>
13
+
14
+ static wsp_ggml_metal_buffer_id wsp_ggml_metal_get_buffer_id(const wsp_ggml_tensor * t) {
15
+ if (!t) {
16
+ return { nullptr, 0 };
17
+ }
18
+
19
+ wsp_ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
20
+
21
+ wsp_ggml_metal_buffer_t ctx = (wsp_ggml_metal_buffer_t) buffer->context;
22
+
23
+ return wsp_ggml_metal_buffer_get_id(ctx, t);
24
+ }
25
+
26
+ struct wsp_ggml_metal_op {
27
+ wsp_ggml_metal_op(
28
+ wsp_ggml_metal_device_t dev,
29
+ wsp_ggml_metal_cmd_buf_t cmd_buf,
30
+ wsp_ggml_cgraph * gf,
31
+ int idx_start,
32
+ int idx_end,
33
+ bool use_fusion,
34
+ bool use_concurrency,
35
+ bool use_capture,
36
+ int debug_graph,
37
+ int debug_fusion) {
38
+ this->dev = dev;
39
+ this->lib = wsp_ggml_metal_device_get_library(dev);
40
+ this->enc = wsp_ggml_metal_encoder_init(cmd_buf, use_concurrency);
41
+ this->mem_ranges = wsp_ggml_mem_ranges_init(debug_graph);
42
+ this->idx_start = idx_start;
43
+ this->idx_end = idx_end;
44
+ this->use_fusion = use_fusion;
45
+ this->use_concurrency = use_concurrency;
46
+ this->use_capture = use_capture;
47
+ this->debug_graph = debug_graph;
48
+ this->debug_fusion = debug_fusion;
49
+ this->gf = gf;
50
+
51
+ idxs.reserve(gf->n_nodes);
52
+
53
+ // filter empty nodes
54
+ // TODO: this can be removed when the allocator starts filtering them earlier
55
+ // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
56
+ for (int i = idx_start; i < idx_end; i++) {
57
+ if (!wsp_ggml_op_is_empty(gf->nodes[i]->op) && !wsp_ggml_is_empty(gf->nodes[i])) {
58
+ idxs.push_back(i);
59
+ }
60
+ }
61
+ }
62
+
63
+ ~wsp_ggml_metal_op() {
64
+ wsp_ggml_metal_encoder_end_encoding(this->enc);
65
+ wsp_ggml_metal_encoder_free(this->enc);
66
+ wsp_ggml_mem_ranges_free(this->mem_ranges);
67
+ }
68
+
69
+ int n_nodes() const {
70
+ return idxs.size();
71
+ }
72
+
73
+ wsp_ggml_tensor * node(int i) const {
74
+ assert(i >= 0 && i < (int) idxs.size());
75
+ return wsp_ggml_graph_node(gf, idxs[i]);
76
+ }
77
+
78
+ bool can_fuse(int i0, const wsp_ggml_op * ops, int n_ops) const {
79
+ assert(use_fusion);
80
+ assert(i0 >= 0 && i0 < n_nodes());
81
+
82
+ if (i0 + n_ops > n_nodes()) {
83
+ return false;
84
+ }
85
+
86
+ return wsp_ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
87
+ }
88
+
89
+ wsp_ggml_metal_device_t dev;
90
+ wsp_ggml_metal_library_t lib;
91
+ wsp_ggml_metal_encoder_t enc;
92
+ wsp_ggml_mem_ranges_t mem_ranges;
93
+
94
+ bool use_fusion;
95
+ bool use_concurrency;
96
+ bool use_capture;
97
+
98
+ int debug_graph;
99
+ int debug_fusion;
100
+
101
+ private:
102
+ wsp_ggml_cgraph * gf;
103
+
104
+ int idx_start;
105
+ int idx_end;
106
+
107
+ // non-empty node indices
108
+ std::vector<int> idxs;
109
+ };
110
+
111
+ wsp_ggml_metal_op_t wsp_ggml_metal_op_init(
112
+ wsp_ggml_metal_device_t dev,
113
+ wsp_ggml_metal_cmd_buf_t cmd_buf,
114
+ wsp_ggml_cgraph * gf,
115
+ int idx_start,
116
+ int idx_end,
117
+ bool use_fusion,
118
+ bool use_concurrency,
119
+ bool use_capture,
120
+ int debug_graph,
121
+ int debug_fusion) {
122
+ wsp_ggml_metal_op_t res = new wsp_ggml_metal_op(
123
+ dev,
124
+ cmd_buf,
125
+ gf,
126
+ idx_start,
127
+ idx_end,
128
+ use_fusion,
129
+ use_concurrency,
130
+ use_capture,
131
+ debug_graph,
132
+ debug_fusion);
133
+
134
+ return res;
135
+ }
136
+
137
+ void wsp_ggml_metal_op_free(wsp_ggml_metal_op_t ctx) {
138
+ delete ctx;
139
+ }
140
+
141
+ int wsp_ggml_metal_op_n_nodes(wsp_ggml_metal_op_t ctx) {
142
+ return ctx->n_nodes();
143
+ }
144
+
145
+ static bool wsp_ggml_metal_op_concurrency_reset(wsp_ggml_metal_op_t ctx) {
146
+ if (!ctx->mem_ranges) {
147
+ return true;
148
+ }
149
+
150
+ wsp_ggml_metal_encoder_memory_barrier(ctx->enc);
151
+
152
+ wsp_ggml_mem_ranges_reset(ctx->mem_ranges);
153
+
154
+ return true;
155
+ }
156
+
157
+ static bool wsp_ggml_metal_op_concurrency_check(wsp_ggml_metal_op_t ctx, const wsp_ggml_tensor * node) {
158
+ if (!ctx->mem_ranges) {
159
+ return false;
160
+ }
161
+
162
+ return wsp_ggml_mem_ranges_check(ctx->mem_ranges, node);
163
+ }
164
+
165
+ static bool wsp_ggml_metal_op_concurrency_add(wsp_ggml_metal_op_t ctx, const wsp_ggml_tensor * node) {
166
+ if (!ctx->mem_ranges) {
167
+ return true;
168
+ }
169
+
170
+ return wsp_ggml_mem_ranges_add(ctx->mem_ranges, node);
171
+ }
172
+
173
+ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
174
+ struct wsp_ggml_tensor * node = ctx->node(idx);
175
+
176
+ //WSP_GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, wsp_ggml_op_name(node->op));
177
+
178
+ if (wsp_ggml_is_empty(node)) {
179
+ return 1;
180
+ }
181
+
182
+ switch (node->op) {
183
+ case WSP_GGML_OP_NONE:
184
+ case WSP_GGML_OP_RESHAPE:
185
+ case WSP_GGML_OP_VIEW:
186
+ case WSP_GGML_OP_TRANSPOSE:
187
+ case WSP_GGML_OP_PERMUTE:
188
+ {
189
+ // noop -> next node
190
+ if (ctx->debug_graph > 0) {
191
+ WSP_GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, wsp_ggml_op_name(node->op), "(noop)");
192
+ }
193
+ } return 1;
194
+ default:
195
+ {
196
+ } break;
197
+ }
198
+
199
+ if (!wsp_ggml_metal_device_supports_op(ctx->dev, node)) {
200
+ WSP_GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(node));
201
+ WSP_GGML_ABORT("unsupported op");
202
+ }
203
+
204
+ int n_fuse = 1;
205
+
206
+ // check if the current node can run concurrently with other nodes before it
207
+ // the condition is that:
208
+ // - the current node cannot write to any previous src or dst ranges
209
+ // - the current node cannot read from any previous dst ranges
210
+ //
211
+ // if the condition is not satisfied, we put a memory barrier and clear all ranges
212
+ // otherwise, we add the new ranges to the encoding context and process the node concurrently
213
+ //
214
+ {
215
+ const bool is_concurrent = wsp_ggml_metal_op_concurrency_check(ctx, node);
216
+
217
+ if (!is_concurrent) {
218
+ wsp_ggml_metal_op_concurrency_reset(ctx);
219
+ }
220
+
221
+ if (ctx->debug_graph > 0) {
222
+ WSP_GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, wsp_ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
223
+ }
224
+ if (ctx->debug_graph > 1) {
225
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
226
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
227
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
228
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
229
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
230
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
231
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
232
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
233
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
234
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
235
+
236
+ if (node->src[0]) {
237
+ WSP_GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
238
+ wsp_ggml_is_contiguous(node->src[0]), node->src[0]->name);
239
+ }
240
+ if (node->src[1]) {
241
+ WSP_GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
242
+ wsp_ggml_is_contiguous(node->src[1]), node->src[1]->name);
243
+ }
244
+ if (node->src[2]) {
245
+ WSP_GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
246
+ wsp_ggml_is_contiguous(node->src[2]), node->src[2]->name);
247
+ }
248
+ if (node->src[3]) {
249
+ WSP_GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
250
+ wsp_ggml_is_contiguous(node->src[3]), node->src[3]->name);
251
+ }
252
+ if (node) {
253
+ WSP_GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
254
+ node->name);
255
+ }
256
+ }
257
+ }
258
+
259
+ switch (node->op) {
260
+ case WSP_GGML_OP_CONCAT:
261
+ {
262
+ n_fuse = wsp_ggml_metal_op_concat(ctx, idx);
263
+ } break;
264
+ case WSP_GGML_OP_ADD:
265
+ case WSP_GGML_OP_SUB:
266
+ case WSP_GGML_OP_MUL:
267
+ case WSP_GGML_OP_DIV:
268
+ {
269
+ n_fuse = wsp_ggml_metal_op_bin(ctx, idx);
270
+ } break;
271
+ case WSP_GGML_OP_ADD_ID:
272
+ {
273
+ n_fuse = wsp_ggml_metal_op_add_id(ctx, idx);
274
+ } break;
275
+ case WSP_GGML_OP_REPEAT:
276
+ {
277
+ n_fuse = wsp_ggml_metal_op_repeat(ctx, idx);
278
+ } break;
279
+ case WSP_GGML_OP_ACC:
280
+ {
281
+ n_fuse = wsp_ggml_metal_op_acc(ctx, idx);
282
+ } break;
283
+ case WSP_GGML_OP_SCALE:
284
+ {
285
+ n_fuse = wsp_ggml_metal_op_scale(ctx, idx);
286
+ } break;
287
+ case WSP_GGML_OP_CLAMP:
288
+ {
289
+ n_fuse = wsp_ggml_metal_op_clamp(ctx, idx);
290
+ } break;
291
+ case WSP_GGML_OP_SQR:
292
+ case WSP_GGML_OP_SQRT:
293
+ case WSP_GGML_OP_SIN:
294
+ case WSP_GGML_OP_COS:
295
+ case WSP_GGML_OP_LOG:
296
+ case WSP_GGML_OP_UNARY:
297
+ {
298
+ n_fuse = wsp_ggml_metal_op_unary(ctx, idx);
299
+ } break;
300
+ case WSP_GGML_OP_GLU:
301
+ {
302
+ n_fuse = wsp_ggml_metal_op_glu(ctx, idx);
303
+ } break;
304
+ case WSP_GGML_OP_SUM:
305
+ {
306
+ n_fuse = wsp_ggml_metal_op_sum(ctx, idx);
307
+ } break;
308
+ case WSP_GGML_OP_SUM_ROWS:
309
+ case WSP_GGML_OP_MEAN:
310
+ {
311
+ n_fuse = wsp_ggml_metal_op_sum_rows(ctx, idx);
312
+ } break;
313
+ case WSP_GGML_OP_SOFT_MAX:
314
+ {
315
+ n_fuse = wsp_ggml_metal_op_soft_max(ctx, idx);
316
+ } break;
317
+ case WSP_GGML_OP_SSM_CONV:
318
+ {
319
+ n_fuse = wsp_ggml_metal_op_ssm_conv(ctx, idx);
320
+ } break;
321
+ case WSP_GGML_OP_SSM_SCAN:
322
+ {
323
+ n_fuse = wsp_ggml_metal_op_ssm_scan(ctx, idx);
324
+ } break;
325
+ case WSP_GGML_OP_RWKV_WKV6:
326
+ case WSP_GGML_OP_RWKV_WKV7:
327
+ {
328
+ n_fuse = wsp_ggml_metal_op_rwkv(ctx, idx);
329
+ } break;
330
+ case WSP_GGML_OP_MUL_MAT:
331
+ {
332
+ n_fuse = wsp_ggml_metal_op_mul_mat(ctx, idx);
333
+ } break;
334
+ case WSP_GGML_OP_MUL_MAT_ID:
335
+ {
336
+ n_fuse = wsp_ggml_metal_op_mul_mat_id(ctx, idx);
337
+ } break;
338
+ case WSP_GGML_OP_GET_ROWS:
339
+ {
340
+ n_fuse = wsp_ggml_metal_op_get_rows(ctx, idx);
341
+ } break;
342
+ case WSP_GGML_OP_SET_ROWS:
343
+ {
344
+ n_fuse = wsp_ggml_metal_op_set_rows(ctx, idx);
345
+ } break;
346
+ case WSP_GGML_OP_L2_NORM:
347
+ {
348
+ n_fuse = wsp_ggml_metal_op_l2_norm(ctx, idx);
349
+ } break;
350
+ case WSP_GGML_OP_GROUP_NORM:
351
+ {
352
+ n_fuse = wsp_ggml_metal_op_group_norm(ctx, idx);
353
+ } break;
354
+ case WSP_GGML_OP_NORM:
355
+ case WSP_GGML_OP_RMS_NORM:
356
+ {
357
+ n_fuse = wsp_ggml_metal_op_norm(ctx, idx);
358
+ } break;
359
+ case WSP_GGML_OP_ROPE:
360
+ {
361
+ n_fuse = wsp_ggml_metal_op_rope(ctx, idx);
362
+ } break;
363
+ case WSP_GGML_OP_IM2COL:
364
+ {
365
+ n_fuse = wsp_ggml_metal_op_im2col(ctx, idx);
366
+ } break;
367
+ case WSP_GGML_OP_CONV_TRANSPOSE_1D:
368
+ {
369
+ n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
370
+ } break;
371
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
372
+ {
373
+ n_fuse = wsp_ggml_metal_op_conv_transpose_2d(ctx, idx);
374
+ } break;
375
+ case WSP_GGML_OP_UPSCALE:
376
+ {
377
+ n_fuse = wsp_ggml_metal_op_upscale(ctx, idx);
378
+ } break;
379
+ case WSP_GGML_OP_PAD:
380
+ {
381
+ n_fuse = wsp_ggml_metal_op_pad(ctx, idx);
382
+ } break;
383
+ case WSP_GGML_OP_PAD_REFLECT_1D:
384
+ {
385
+ n_fuse = wsp_ggml_metal_op_pad_reflect_1d(ctx, idx);
386
+ } break;
387
+ case WSP_GGML_OP_ARANGE:
388
+ {
389
+ n_fuse = wsp_ggml_metal_op_arange(ctx, idx);
390
+ } break;
391
+ case WSP_GGML_OP_TIMESTEP_EMBEDDING:
392
+ {
393
+ n_fuse = wsp_ggml_metal_op_timestep_embedding(ctx, idx);
394
+ } break;
395
+ case WSP_GGML_OP_ARGSORT:
396
+ {
397
+ n_fuse = wsp_ggml_metal_op_argsort(ctx, idx);
398
+ } break;
399
+ case WSP_GGML_OP_LEAKY_RELU:
400
+ {
401
+ n_fuse = wsp_ggml_metal_op_leaky_relu(ctx, idx);
402
+ } break;
403
+ case WSP_GGML_OP_FLASH_ATTN_EXT:
404
+ {
405
+ n_fuse = wsp_ggml_metal_op_flash_attn_ext(ctx, idx);
406
+ } break;
407
+ case WSP_GGML_OP_DUP:
408
+ case WSP_GGML_OP_CPY:
409
+ case WSP_GGML_OP_CONT:
410
+ {
411
+ n_fuse = wsp_ggml_metal_op_cpy(ctx, idx);
412
+ } break;
413
+ case WSP_GGML_OP_POOL_2D:
414
+ {
415
+ n_fuse = wsp_ggml_metal_op_pool_2d(ctx, idx);
416
+ } break;
417
+ case WSP_GGML_OP_ARGMAX:
418
+ {
419
+ n_fuse = wsp_ggml_metal_op_argmax(ctx, idx);
420
+ } break;
421
+ case WSP_GGML_OP_OPT_STEP_ADAMW:
422
+ {
423
+ n_fuse = wsp_ggml_metal_op_opt_step_adamw(ctx, idx);
424
+ } break;
425
+ case WSP_GGML_OP_OPT_STEP_SGD:
426
+ {
427
+ n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
428
+ } break;
429
+ default:
430
+ {
431
+ WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
432
+ WSP_GGML_ABORT("fatal error");
433
+ }
434
+ }
435
+
436
+ if (ctx->debug_graph > 0) {
437
+ if (n_fuse > 1) {
438
+ WSP_GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
439
+ }
440
+ }
441
+
442
+ // update the mem ranges in the encoding context
443
+ for (int i = 0; i < n_fuse; ++i) {
444
+ if (!wsp_ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
445
+ wsp_ggml_metal_op_concurrency_reset(ctx);
446
+ }
447
+ }
448
+
449
+ return n_fuse;
450
+ }
451
+
452
+ int wsp_ggml_metal_op_encode(wsp_ggml_metal_op_t ctx, int idx) {
453
+ if (ctx->use_capture) {
454
+ wsp_ggml_metal_encoder_debug_group_push(ctx->enc, wsp_ggml_op_desc(ctx->node(idx)));
455
+ }
456
+
457
+ int res = wsp_ggml_metal_op_encode_impl(ctx, idx);
458
+ if (idx + res > ctx->n_nodes()) {
459
+ WSP_GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
460
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
461
+ }
462
+
463
+ if (ctx->use_capture) {
464
+ wsp_ggml_metal_encoder_debug_group_pop(ctx->enc);
465
+ }
466
+
467
+ return res;
468
+ }
469
+
470
+ int wsp_ggml_metal_op_concat(wsp_ggml_metal_op_t ctx, int idx) {
471
+ wsp_ggml_tensor * op = ctx->node(idx);
472
+
473
+ wsp_ggml_metal_library_t lib = ctx->lib;
474
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
475
+
476
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
477
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
478
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
479
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
480
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
481
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
482
+
483
+ const int32_t dim = ((const int32_t *) op->op_params)[0];
484
+
485
+ wsp_ggml_metal_kargs_concat args = {
486
+ /*.ne00 =*/ ne00,
487
+ /*.ne01 =*/ ne01,
488
+ /*.ne02 =*/ ne02,
489
+ /*.ne03 =*/ ne03,
490
+ /*.nb00 =*/ nb00,
491
+ /*.nb01 =*/ nb01,
492
+ /*.nb02 =*/ nb02,
493
+ /*.nb03 =*/ nb03,
494
+ /*.ne10 =*/ ne10,
495
+ /*.ne11 =*/ ne11,
496
+ /*.ne12 =*/ ne12,
497
+ /*.ne13 =*/ ne13,
498
+ /*.nb10 =*/ nb10,
499
+ /*.nb11 =*/ nb11,
500
+ /*.nb12 =*/ nb12,
501
+ /*.nb13 =*/ nb13,
502
+ /*.ne0 =*/ ne0,
503
+ /*.ne1 =*/ ne1,
504
+ /*.ne2 =*/ ne2,
505
+ /*.ne3 =*/ ne3,
506
+ /*.nb0 =*/ nb0,
507
+ /*.nb1 =*/ nb1,
508
+ /*.nb2 =*/ nb2,
509
+ /*.nb3 =*/ nb3,
510
+ /*.dim =*/ dim,
511
+ };
512
+
513
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_CONCAT);
514
+
515
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
516
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
517
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
518
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
519
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
520
+
521
+ const int nth = std::min(1024, ne0);
522
+
523
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
524
+
525
+ return 1;
526
+ }
527
+
528
+ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
529
+ wsp_ggml_tensor * op = ctx->node(idx);
530
+
531
+ wsp_ggml_metal_library_t lib = ctx->lib;
532
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
533
+
534
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
535
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
536
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
537
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
538
+
539
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
540
+
541
+ wsp_ggml_metal_kargs_repeat args = {
542
+ /*.ne00 =*/ ne00,
543
+ /*.ne01 =*/ ne01,
544
+ /*.ne02 =*/ ne02,
545
+ /*.ne03 =*/ ne03,
546
+ /*.nb00 =*/ nb00,
547
+ /*.nb01 =*/ nb01,
548
+ /*.nb02 =*/ nb02,
549
+ /*.nb03 =*/ nb03,
550
+ /*.ne0 =*/ ne0,
551
+ /*.ne1 =*/ ne1,
552
+ /*.ne2 =*/ ne2,
553
+ /*.ne3 =*/ ne3,
554
+ /*.nb0 =*/ nb0,
555
+ /*.nb1 =*/ nb1,
556
+ /*.nb2 =*/ nb2,
557
+ /*.nb3 =*/ nb3,
558
+ };
559
+
560
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
561
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
562
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
563
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
564
+
565
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
566
+
567
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
568
+
569
+ return 1;
570
+ }
571
+
572
+ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
573
+ wsp_ggml_tensor * op = ctx->node(idx);
574
+
575
+ wsp_ggml_metal_library_t lib = ctx->lib;
576
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
577
+
578
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
579
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
580
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
581
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
582
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
583
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
584
+
585
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
586
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
587
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
588
+
589
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
590
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
591
+
592
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
593
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
594
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
595
+ const size_t offs = ((const int32_t *) op->op_params)[3];
596
+
597
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
598
+
599
+ if (!inplace) {
600
+ // run a separete kernel to cpy src->dst
601
+ // not sure how to avoid this
602
+ // TODO: make a simpler cpy_bytes kernel
603
+
604
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[WSP_GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
605
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
606
+
607
+ wsp_ggml_metal_kargs_cpy args = {
608
+ /*.nk0 =*/ ne00,
609
+ /*.ne00 =*/ ne00,
610
+ /*.ne01 =*/ ne01,
611
+ /*.ne02 =*/ ne02,
612
+ /*.ne03 =*/ ne03,
613
+ /*.nb00 =*/ nb00,
614
+ /*.nb01 =*/ nb01,
615
+ /*.nb02 =*/ nb02,
616
+ /*.nb03 =*/ nb03,
617
+ /*.ne0 =*/ ne0,
618
+ /*.ne1 =*/ ne1,
619
+ /*.ne2 =*/ ne2,
620
+ /*.ne3 =*/ ne3,
621
+ /*.nb0 =*/ nb0,
622
+ /*.nb1 =*/ nb1,
623
+ /*.nb2 =*/ nb2,
624
+ /*.nb3 =*/ nb3,
625
+ };
626
+
627
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
628
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
629
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
630
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
631
+
632
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
633
+
634
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
635
+
636
+ wsp_ggml_metal_op_concurrency_reset(ctx);
637
+ }
638
+
639
+ wsp_ggml_metal_kargs_bin args = {
640
+ /*.ne00 =*/ ne00,
641
+ /*.ne01 =*/ ne01,
642
+ /*.ne02 =*/ ne02,
643
+ /*.ne03 =*/ ne03,
644
+ /*.nb00 =*/ nb00,
645
+ /*.nb01 =*/ pnb1,
646
+ /*.nb02 =*/ pnb2,
647
+ /*.nb03 =*/ pnb3,
648
+ /*.ne10 =*/ ne10,
649
+ /*.ne11 =*/ ne11,
650
+ /*.ne12 =*/ ne12,
651
+ /*.ne13 =*/ ne13,
652
+ /*.nb10 =*/ nb10,
653
+ /*.nb11 =*/ nb11,
654
+ /*.nb12 =*/ nb12,
655
+ /*.nb13 =*/ nb13,
656
+ /*.ne0 =*/ ne0,
657
+ /*.ne1 =*/ ne1,
658
+ /*.ne2 =*/ ne2,
659
+ /*.ne3 =*/ ne3,
660
+ /*.nb0 =*/ nb0,
661
+ /*.nb1 =*/ pnb1,
662
+ /*.nb2 =*/ pnb2,
663
+ /*.nb3 =*/ pnb3,
664
+ /*.offs =*/ offs,
665
+ /*.o1 =*/ { 0 },
666
+ };
667
+
668
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_bin(lib, WSP_GGML_OP_ADD, 1, false);
669
+
670
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
671
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
672
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
673
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
674
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
675
+
676
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
677
+
678
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
679
+
680
+ return 1;
681
+ }
682
+
683
+ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
684
+ wsp_ggml_tensor * op = ctx->node(idx);
685
+
686
+ wsp_ggml_metal_library_t lib = ctx->lib;
687
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
688
+
689
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
690
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
691
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
692
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
693
+
694
+ float scale;
695
+ float bias;
696
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
697
+ memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
698
+
699
+ wsp_ggml_metal_kargs_scale args = {
700
+ /*.scale =*/ scale,
701
+ /*.bias =*/ bias,
702
+ };
703
+
704
+ int64_t n = wsp_ggml_nelements(op);
705
+
706
+ if (n % 4 == 0) {
707
+ n /= 4;
708
+ }
709
+
710
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
711
+
712
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
713
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
714
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
715
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
716
+
717
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
718
+
719
+ return 1;
720
+ }
721
+
722
+ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
723
+ wsp_ggml_tensor * op = ctx->node(idx);
724
+
725
+ wsp_ggml_metal_library_t lib = ctx->lib;
726
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
727
+
728
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
729
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
730
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
731
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
732
+
733
+ float min;
734
+ float max;
735
+ memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
736
+ memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
737
+
738
+ wsp_ggml_metal_kargs_clamp args = {
739
+ /*.min =*/ min,
740
+ /*.max =*/ max,
741
+ };
742
+
743
+ int64_t n = wsp_ggml_nelements(op);
744
+
745
+ if (n % 4 == 0) {
746
+ n /= 4;
747
+ }
748
+
749
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
750
+
751
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
752
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
753
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
754
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
755
+
756
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
757
+
758
+ return 1;
759
+ }
760
+
761
+ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
762
+ wsp_ggml_tensor * op = ctx->node(idx);
763
+
764
+ wsp_ggml_metal_library_t lib = ctx->lib;
765
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
766
+
767
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
768
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
769
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
770
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
771
+
772
+ int64_t n = wsp_ggml_nelements(op);
773
+
774
+ if (n % 4 == 0) {
775
+ n /= 4;
776
+ }
777
+
778
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
779
+
780
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
781
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 0);
782
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 1);
783
+
784
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
785
+
786
+ return 1;
787
+ }
788
+
789
+ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
790
+ wsp_ggml_tensor * op = ctx->node(idx);
791
+
792
+ wsp_ggml_metal_library_t lib = ctx->lib;
793
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
794
+
795
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
796
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
797
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
798
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
799
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
800
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
801
+
802
+ if (op->src[1]) {
803
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
804
+ }
805
+
806
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_glu(lib, op);
807
+
808
+ const int32_t swp = wsp_ggml_get_op_params_i32(op, 1);
809
+ const float alpha = wsp_ggml_get_op_params_f32(op, 2);
810
+ const float limit = wsp_ggml_get_op_params_f32(op, 3);
811
+
812
+ const int32_t i00 = swp ? ne0 : 0;
813
+ const int32_t i10 = swp ? 0 : ne0;
814
+
815
+ wsp_ggml_metal_kargs_glu args = {
816
+ /*.ne00 =*/ ne00,
817
+ /*.nb01 =*/ nb01,
818
+ /*.ne10 =*/ op->src[1] ? ne10 : ne00,
819
+ /*.nb11 =*/ op->src[1] ? nb11 : nb01,
820
+ /*.ne0 =*/ ne0,
821
+ /*.nb1 =*/ nb1,
822
+ /*.i00 =*/ op->src[1] ? 0 : i00,
823
+ /*.i10 =*/ op->src[1] ? 0 : i10,
824
+ /*.alpha=*/ alpha,
825
+ /*.limit=*/ limit
826
+ };
827
+
828
+ const int64_t nrows = wsp_ggml_nrows(op->src[0]);
829
+
830
+ const int32_t nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
831
+
832
+ //[encoder setComputePipelineState:pipeline];
833
+ //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
834
+ //if (src1) {
835
+ // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
836
+ //} else {
837
+ // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
838
+ //}
839
+ //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
840
+ //[encoder setBytes:&args length:sizeof(args) atIndex:3];
841
+
842
+ //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
843
+
844
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
845
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
846
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
847
+ if (op->src[1]) {
848
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
849
+ } else {
850
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 2);
851
+ }
852
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
853
+
854
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
855
+
856
+ return 1;
857
+ }
858
+
859
+ int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
860
+ wsp_ggml_tensor * op = ctx->node(idx);
861
+
862
+ wsp_ggml_metal_library_t lib = ctx->lib;
863
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
864
+
865
+ const uint64_t n = (uint64_t) wsp_ggml_nelements(op->src[0]);
866
+
867
+ wsp_ggml_metal_kargs_sum args = {
868
+ /*.np =*/ n,
869
+ };
870
+
871
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
872
+
873
+ int nth = 32; // SIMD width
874
+
875
+ while (nth < (int) n && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
876
+ nth *= 2;
877
+ }
878
+
879
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
880
+ nth = std::min(nth, (int) n);
881
+
882
+ const int nsg = (nth + 31) / 32;
883
+
884
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
885
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
886
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
887
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
888
+
889
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
890
+
891
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
892
+
893
+ return 1;
894
+ }
895
+
896
+ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
897
+ wsp_ggml_tensor * op = ctx->node(idx);
898
+
899
+ wsp_ggml_metal_library_t lib = ctx->lib;
900
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
901
+
902
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
903
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
904
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
905
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
906
+
907
+ wsp_ggml_metal_kargs_sum_rows args = {
908
+ /*.ne00 =*/ ne00,
909
+ /*.ne01 =*/ ne01,
910
+ /*.ne02 =*/ ne02,
911
+ /*.ne03 =*/ ne03,
912
+ /*.nb00 =*/ nb00,
913
+ /*.nb01 =*/ nb01,
914
+ /*.nb02 =*/ nb02,
915
+ /*.nb03 =*/ nb03,
916
+ /*.ne0 =*/ ne0,
917
+ /*.ne1 =*/ ne1,
918
+ /*.ne2 =*/ ne2,
919
+ /*.ne3 =*/ ne3,
920
+ /*.nb0 =*/ nb0,
921
+ /*.nb1 =*/ nb1,
922
+ /*.nb2 =*/ nb2,
923
+ /*.nb3 =*/ nb3,
924
+ };
925
+
926
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum_rows(lib, op);
927
+
928
+ int nth = 32; // SIMD width
929
+
930
+ while (nth < ne00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
931
+ nth *= 2;
932
+ }
933
+
934
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
935
+ nth = std::min(nth, ne00);
936
+
937
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
938
+
939
+ //[encoder setComputePipelineState:pipeline];
940
+ //[encoder setBytes:&args length:sizeof(args) atIndex:0];
941
+ //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
942
+ //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
943
+ //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
944
+
945
+ //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
946
+
947
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
948
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
949
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
950
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
951
+
952
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
953
+
954
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
955
+
956
+ return 1;
957
+ }
958
+
959
+ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
960
+ wsp_ggml_tensor * op = ctx->node(idx);
961
+
962
+ wsp_ggml_metal_library_t lib = ctx->lib;
963
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
964
+
965
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
966
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
967
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
968
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
969
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
970
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
971
+
972
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
973
+
974
+ wsp_ggml_metal_kargs_get_rows args = {
975
+ /*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
976
+ /*.ne00 =*/ ne00,
977
+ /*.nb01 =*/ nb01,
978
+ /*.nb02 =*/ nb02,
979
+ /*.nb03 =*/ nb03,
980
+ /*.ne10 =*/ ne10,
981
+ /*.nb10 =*/ nb10,
982
+ /*.nb11 =*/ nb11,
983
+ /*.nb12 =*/ nb12,
984
+ /*.nb1 =*/ nb1,
985
+ /*.nb2 =*/ nb2,
986
+ /*.nb3 =*/ nb3,
987
+ };
988
+
989
+ const int nth = std::min(args.ne00t, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
990
+
991
+ const int nw0 = (args.ne00t + nth - 1)/nth;
992
+
993
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
994
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
995
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
996
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
997
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
998
+
999
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
1000
+
1001
+ return 1;
1002
+ }
1003
+
1004
+ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
1005
+ wsp_ggml_tensor * op = ctx->node(idx);
1006
+
1007
+ wsp_ggml_metal_library_t lib = ctx->lib;
1008
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1009
+
1010
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1011
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1012
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1013
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1014
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1015
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1016
+
1017
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1018
+
1019
+ const int32_t nk0 = ne0/wsp_ggml_blck_size(op->type);
1020
+
1021
+ int nth = 32; // SIMD width
1022
+
1023
+ while (nth < nk0 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1024
+ nth *= 2;
1025
+ }
1026
+
1027
+ int nrptg = 1;
1028
+ if (nth > nk0) {
1029
+ nrptg = (nth + nk0 - 1)/nk0;
1030
+ nth = nk0;
1031
+
1032
+ if (nrptg*nth > wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1033
+ nrptg--;
1034
+ }
1035
+ }
1036
+
1037
+ nth = std::min(nth, nk0);
1038
+
1039
+ wsp_ggml_metal_kargs_set_rows args = {
1040
+ /*.nk0 =*/ nk0,
1041
+ /*.ne01 =*/ ne01,
1042
+ /*.nb01 =*/ nb01,
1043
+ /*.nb02 =*/ nb02,
1044
+ /*.nb03 =*/ nb03,
1045
+ /*.ne11 =*/ ne11,
1046
+ /*.ne12 =*/ ne12,
1047
+ /*.nb10 =*/ nb10,
1048
+ /*.nb11 =*/ nb11,
1049
+ /*.nb12 =*/ nb12,
1050
+ /*.nb1 =*/ nb1,
1051
+ /*.nb2 =*/ nb2,
1052
+ /*.nb3 =*/ nb3,
1053
+ };
1054
+
1055
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1056
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1057
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1058
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1059
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
1060
+
1061
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1062
+
1063
+ return 1;
1064
+ }
1065
+
1066
+ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
1067
+ wsp_ggml_tensor * op = ctx->node(idx);
1068
+
1069
+ wsp_ggml_metal_library_t lib = ctx->lib;
1070
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1071
+
1072
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1073
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1074
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1075
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1076
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1077
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1078
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1079
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1080
+
1081
+ float scale;
1082
+ float max_bias;
1083
+
1084
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
1085
+ memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
1086
+
1087
+ const uint32_t n_head = op->src[0]->ne[2];
1088
+ const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1089
+
1090
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1091
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1092
+
1093
+ // softmax
1094
+
1095
+ wsp_ggml_metal_kargs_soft_max args = {
1096
+ /*.ne00 =*/ ne00,
1097
+ /*.ne01 =*/ ne01,
1098
+ /*.ne02 =*/ ne02,
1099
+ /*.nb01 =*/ nb01,
1100
+ /*.nb02 =*/ nb02,
1101
+ /*.nb03 =*/ nb03,
1102
+ /*.ne11 =*/ ne11,
1103
+ /*.ne12 =*/ ne12,
1104
+ /*.ne13 =*/ ne13,
1105
+ /*.nb11 =*/ nb11,
1106
+ /*.nb12 =*/ nb12,
1107
+ /*.nb13 =*/ nb13,
1108
+ /*.nb1 =*/ nb1,
1109
+ /*.nb2 =*/ nb2,
1110
+ /*.nb3 =*/ nb3,
1111
+ /*.scale =*/ scale,
1112
+ /*.max_bias =*/ max_bias,
1113
+ /*.m0 =*/ m0,
1114
+ /*.m1 =*/ m1,
1115
+ /*.n_head_log2 =*/ n_head_log2,
1116
+ };
1117
+
1118
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_soft_max(lib, op);
1119
+
1120
+ int nth = 32; // SIMD width
1121
+
1122
+ if (ne00%4 == 0) {
1123
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1124
+ nth *= 2;
1125
+ }
1126
+ } else {
1127
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1128
+ nth *= 2;
1129
+ }
1130
+ }
1131
+
1132
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1133
+
1134
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1135
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1136
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1137
+ if (op->src[1]) {
1138
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1139
+ } else {
1140
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 2);
1141
+ }
1142
+ if (op->src[2]) {
1143
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
1144
+ } else {
1145
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 3);
1146
+ }
1147
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 4);
1148
+
1149
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1150
+
1151
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1152
+
1153
+ return 1;
1154
+ }
1155
+
1156
+ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
1157
+ wsp_ggml_tensor * op = ctx->node(idx);
1158
+
1159
+ wsp_ggml_metal_library_t lib = ctx->lib;
1160
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1161
+
1162
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1163
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1164
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1165
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1166
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1167
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1168
+
1169
+ wsp_ggml_metal_kargs_ssm_conv args = {
1170
+ /*.ne00 =*/ ne00,
1171
+ /*.ne01 =*/ ne01,
1172
+ /*.ne02 =*/ ne02,
1173
+ /*.nb00 =*/ nb00,
1174
+ /*.nb01 =*/ nb01,
1175
+ /*.nb02 =*/ nb02,
1176
+ /*.ne10 =*/ ne10,
1177
+ /*.ne11 =*/ ne11,
1178
+ /*.nb10 =*/ nb10,
1179
+ /*.nb11 =*/ nb11,
1180
+ /*.ne0 =*/ ne0,
1181
+ /*.ne1 =*/ ne1,
1182
+ /*.ne2 =*/ ne2,
1183
+ /*.nb0 =*/ nb0,
1184
+ /*.nb1 =*/ nb1,
1185
+ /*.nb2 =*/ nb2,
1186
+ };
1187
+
1188
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1189
+
1190
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1191
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1192
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1193
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1194
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1195
+
1196
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1197
+
1198
+ return 1;
1199
+ }
1200
+
1201
+ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1202
+ wsp_ggml_tensor * op = ctx->node(idx);
1203
+
1204
+ wsp_ggml_metal_library_t lib = ctx->lib;
1205
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1206
+
1207
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1208
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1209
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1210
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1211
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1212
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1213
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1214
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1215
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);
1216
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);
1217
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);
1218
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);
1219
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1220
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1221
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1222
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1223
+
1224
+ const wsp_ggml_tensor * src3 = op->src[3];
1225
+ const wsp_ggml_tensor * src4 = op->src[4];
1226
+ const wsp_ggml_tensor * src5 = op->src[5];
1227
+ const wsp_ggml_tensor * src6 = op->src[6];
1228
+
1229
+ WSP_GGML_ASSERT(src3);
1230
+ WSP_GGML_ASSERT(src4);
1231
+ WSP_GGML_ASSERT(src5);
1232
+ WSP_GGML_ASSERT(src6);
1233
+
1234
+ const int64_t d_state = ne00;
1235
+ const int64_t d_inner = ne01;
1236
+ const int64_t n_head = ne02;
1237
+ const int64_t n_group = ne41;
1238
+ const int64_t n_seq_tokens = ne12;
1239
+ const int64_t n_seqs = ne13;
1240
+
1241
+ wsp_ggml_metal_kargs_ssm_scan args = {
1242
+ /*.d_state =*/ d_state,
1243
+ /*.d_inner =*/ d_inner,
1244
+ /*.n_head =*/ n_head,
1245
+ /*.n_group =*/ n_group,
1246
+ /*.n_seq_tokens =*/ n_seq_tokens,
1247
+ /*.n_seqs =*/ n_seqs,
1248
+ /*.s_off =*/ wsp_ggml_nelements(op->src[1]) * sizeof(float),
1249
+ /*.nb00 =*/ nb00,
1250
+ /*.nb01 =*/ nb01,
1251
+ /*.nb02 =*/ nb02,
1252
+ /*.nb03 =*/ nb03,
1253
+ /*.nb10 =*/ nb10,
1254
+ /*.nb11 =*/ nb11,
1255
+ /*.nb12 =*/ nb12,
1256
+ /*.ns12 =*/ nb12/nb10,
1257
+ /*.nb13 =*/ nb13,
1258
+ /*.nb20 =*/ nb20,
1259
+ /*.nb21 =*/ nb21,
1260
+ /*.ns21 =*/ nb21/nb20,
1261
+ /*.nb22 =*/ nb22,
1262
+ /*.ne30 =*/ ne30,
1263
+ /*.nb31 =*/ nb31,
1264
+ /*.nb41 =*/ nb41,
1265
+ /*.nb42 =*/ nb42,
1266
+ /*.ns42 =*/ nb42/nb40,
1267
+ /*.nb43 =*/ nb43,
1268
+ /*.nb51 =*/ nb51,
1269
+ /*.nb52 =*/ nb52,
1270
+ /*.ns52 =*/ nb52/nb50,
1271
+ /*.nb53 =*/ nb53,
1272
+ /*.nb0 =*/ nb0,
1273
+ };
1274
+
1275
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1276
+
1277
+ WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1278
+
1279
+ const size_t sms = wsp_ggml_metal_pipeline_get_smem(pipeline);
1280
+
1281
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1282
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1283
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1284
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1285
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
1286
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), 4);
1287
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
1288
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[5]), 6);
1289
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[6]), 7);
1290
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 8);
1291
+
1292
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
1293
+
1294
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1295
+
1296
+ return 1;
1297
+ }
1298
+
1299
+ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
1300
+ wsp_ggml_tensor * op = ctx->node(idx);
1301
+
1302
+ wsp_ggml_metal_library_t lib = ctx->lib;
1303
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1304
+
1305
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1306
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1307
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1308
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1309
+
1310
+ const int64_t B = op->op == WSP_GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1311
+ const int64_t T = op->src[0]->ne[2];
1312
+ const int64_t C = op->ne[0];
1313
+ const int64_t H = op->src[0]->ne[1];
1314
+
1315
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rwkv(lib, op);
1316
+
1317
+ int ida = 0;
1318
+
1319
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1320
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
1321
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
1322
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
1323
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), ida++);
1324
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), ida++);
1325
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[5]), ida++);
1326
+ if (op->op == WSP_GGML_OP_RWKV_WKV7) {
1327
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[6]), ida++);
1328
+ }
1329
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), ida++);
1330
+ wsp_ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1331
+ wsp_ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1332
+ wsp_ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1333
+ wsp_ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1334
+
1335
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1336
+
1337
+ return 1;
1338
+ }
1339
+
1340
+ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1341
+ wsp_ggml_tensor * op = ctx->node(idx);
1342
+
1343
+ wsp_ggml_metal_library_t lib = ctx->lib;
1344
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1345
+
1346
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1347
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1348
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1349
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1350
+
1351
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1352
+
1353
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
1354
+
1355
+ int64_t nk0 = ne00;
1356
+ if (wsp_ggml_is_quantized(op->src[0]->type)) {
1357
+ nk0 = ne00/16;
1358
+ } else if (wsp_ggml_is_quantized(op->type)) {
1359
+ nk0 = ne00/wsp_ggml_blck_size(op->type);
1360
+ }
1361
+
1362
+ int nth = std::min<int>(nk0, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1363
+
1364
+ // when rows are small, we can batch them together in a single threadgroup
1365
+ int nrptg = 1;
1366
+
1367
+ // TODO: relax this constraint in the future
1368
+ if (wsp_ggml_blck_size(op->src[0]->type) == 1 && wsp_ggml_blck_size(op->type) == 1) {
1369
+ if (nth > nk0) {
1370
+ nrptg = (nth + nk0 - 1)/nk0;
1371
+ nth = nk0;
1372
+
1373
+ if (nrptg*nth > wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1374
+ nrptg--;
1375
+ }
1376
+ }
1377
+ }
1378
+
1379
+ nth = std::min<int>(nth, nk0);
1380
+
1381
+ wsp_ggml_metal_kargs_cpy args = {
1382
+ /*.nk0 =*/ nk0,
1383
+ /*.ne00 =*/ ne00,
1384
+ /*.ne01 =*/ ne01,
1385
+ /*.ne02 =*/ ne02,
1386
+ /*.ne03 =*/ ne03,
1387
+ /*.nb00 =*/ nb00,
1388
+ /*.nb01 =*/ nb01,
1389
+ /*.nb02 =*/ nb02,
1390
+ /*.nb03 =*/ nb03,
1391
+ /*.ne0 =*/ ne0,
1392
+ /*.ne1 =*/ ne1,
1393
+ /*.ne2 =*/ ne2,
1394
+ /*.ne3 =*/ ne3,
1395
+ /*.nb0 =*/ nb0,
1396
+ /*.nb1 =*/ nb1,
1397
+ /*.nb2 =*/ nb2,
1398
+ /*.nb3 =*/ nb3,
1399
+ };
1400
+
1401
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1402
+
1403
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1404
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1405
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1406
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
1407
+
1408
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1409
+
1410
+ return 1;
1411
+ }
1412
+
1413
+ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
1414
+ wsp_ggml_tensor * op = ctx->node(idx);
1415
+
1416
+ wsp_ggml_metal_library_t lib = ctx->lib;
1417
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1418
+
1419
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1420
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1421
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1422
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1423
+
1424
+ const int32_t * opts = op->op_params;
1425
+ wsp_ggml_op_pool op_pool = (wsp_ggml_op_pool) opts[0];
1426
+
1427
+ const int32_t k0 = opts[1];
1428
+ const int32_t k1 = opts[2];
1429
+ const int32_t s0 = opts[3];
1430
+ const int32_t s1 = opts[4];
1431
+ const int32_t p0 = opts[5];
1432
+ const int32_t p1 = opts[6];
1433
+
1434
+ const int64_t IH = op->src[0]->ne[1];
1435
+ const int64_t IW = op->src[0]->ne[0];
1436
+
1437
+ const int64_t N = op->ne[3];
1438
+ const int64_t OC = op->ne[2];
1439
+ const int64_t OH = op->ne[1];
1440
+ const int64_t OW = op->ne[0];
1441
+
1442
+ const int64_t np = N * OC * OH * OW;
1443
+
1444
+ wsp_ggml_metal_kargs_pool_2d args_pool_2d = {
1445
+ /* .k0 = */ k0,
1446
+ /* .k1 = */ k1,
1447
+ /* .s0 = */ s0,
1448
+ /* .s1 = */ s1,
1449
+ /* .p0 = */ p0,
1450
+ /* .p1 = */ p1,
1451
+ /* .IH = */ IH,
1452
+ /* .IW = */ IW,
1453
+ /* .OH = */ OH,
1454
+ /* .OW = */ OW,
1455
+ /* .np = */ np
1456
+ };
1457
+
1458
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1459
+
1460
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1461
+ const int ntg = (np + nth - 1) / nth;
1462
+
1463
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1464
+ wsp_ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0);
1465
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1466
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
1467
+
1468
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1469
+
1470
+ return 1;
1471
+ }
1472
+
1473
+ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1474
+ wsp_ggml_tensor * op = ctx->node(idx);
1475
+
1476
+ wsp_ggml_metal_library_t lib = ctx->lib;
1477
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1478
+
1479
+ const wsp_ggml_metal_device_props * props_dev = wsp_ggml_metal_device_get_props(ctx->dev);
1480
+
1481
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1482
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1483
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1484
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1485
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1486
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1487
+
1488
+ WSP_GGML_ASSERT(ne00 == ne10);
1489
+
1490
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1491
+ WSP_GGML_ASSERT(ne13 % ne03 == 0);
1492
+
1493
+ const int16_t r2 = ne12/ne02;
1494
+ const int16_t r3 = ne13/ne03;
1495
+
1496
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1497
+ // to the matrix-vector kernel
1498
+ const int ne11_mm_min = 8;
1499
+
1500
+ // first try to use small-batch mat-mv kernels
1501
+ // these should be efficient for BS [2, ~8]
1502
+ if (op->src[1]->type == WSP_GGML_TYPE_F32 && (ne00%128 == 0) &&
1503
+ (
1504
+ (
1505
+ (
1506
+ op->src[0]->type == WSP_GGML_TYPE_F32 || // TODO: helper function
1507
+ op->src[0]->type == WSP_GGML_TYPE_F16 ||
1508
+ op->src[0]->type == WSP_GGML_TYPE_Q4_0 ||
1509
+ op->src[0]->type == WSP_GGML_TYPE_Q4_1 ||
1510
+ op->src[0]->type == WSP_GGML_TYPE_Q5_0 ||
1511
+ op->src[0]->type == WSP_GGML_TYPE_Q5_1 ||
1512
+ op->src[0]->type == WSP_GGML_TYPE_Q8_0 ||
1513
+ op->src[0]->type == WSP_GGML_TYPE_MXFP4 ||
1514
+ op->src[0]->type == WSP_GGML_TYPE_IQ4_NL ||
1515
+ false) && (ne11 >= 2 && ne11 <= 8)
1516
+ ) ||
1517
+ (
1518
+ (
1519
+ op->src[0]->type == WSP_GGML_TYPE_Q4_K ||
1520
+ op->src[0]->type == WSP_GGML_TYPE_Q5_K ||
1521
+ op->src[0]->type == WSP_GGML_TYPE_Q6_K ||
1522
+ false) && (ne11 >= 4 && ne11 <= 8)
1523
+ )
1524
+ )
1525
+ ) {
1526
+ // TODO: determine the optimal parameters based on grid utilization
1527
+ // I still don't know why we should not always use the maximum available threads:
1528
+ //
1529
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
1530
+ //
1531
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
1532
+ // values and there can be some tail effects when nsg is high. need to confirm this
1533
+ //
1534
+ const int nsg = 2; // num simdgroups per threadgroup
1535
+
1536
+ // num threads along row per simdgroup
1537
+ int16_t nxpsg = 0;
1538
+ if (ne00 % 256 == 0 && ne11 < 3) {
1539
+ nxpsg = 16;
1540
+ } else if (ne00 % 128 == 0) {
1541
+ nxpsg = 8;
1542
+ } else {
1543
+ nxpsg = 4;
1544
+ }
1545
+
1546
+ const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
1547
+ const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1548
+ int16_t r1ptg = 4; // num src1 rows per threadgroup
1549
+
1550
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
1551
+ switch (ne11) {
1552
+ case 2:
1553
+ r1ptg = 2; break;
1554
+ case 3:
1555
+ case 6:
1556
+ r1ptg = 3; break;
1557
+ case 4:
1558
+ case 7:
1559
+ case 8:
1560
+ r1ptg = 4; break;
1561
+ case 5:
1562
+ r1ptg = 5; break;
1563
+ default:
1564
+ WSP_GGML_ABORT("unsupported ne11");
1565
+ };
1566
+
1567
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1568
+
1569
+ wsp_ggml_metal_kargs_mul_mv_ext args = {
1570
+ /*.ne00 =*/ ne00,
1571
+ /*.ne01 =*/ ne01,
1572
+ /*.ne02 =*/ ne02,
1573
+ /*.nb00 =*/ nb00,
1574
+ /*.nb01 =*/ nb01,
1575
+ /*.nb02 =*/ nb02,
1576
+ /*.nb03 =*/ nb03,
1577
+ /*.ne10 =*/ ne10,
1578
+ /*.ne11 =*/ ne11,
1579
+ /*.ne12 =*/ ne12,
1580
+ /*.nb10 =*/ nb10,
1581
+ /*.nb11 =*/ nb11,
1582
+ /*.nb12 =*/ nb12,
1583
+ /*.nb13 =*/ nb13,
1584
+ /*.ne0 =*/ ne0,
1585
+ /*.ne1 =*/ ne1,
1586
+ /*.r2 =*/ r2,
1587
+ /*.r3 =*/ r3,
1588
+ };
1589
+
1590
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1591
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1592
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1593
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1594
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
1595
+
1596
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);
1597
+ } else if (
1598
+ !wsp_ggml_is_transposed(op->src[0]) &&
1599
+ !wsp_ggml_is_transposed(op->src[1]) &&
1600
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1601
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1602
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1603
+ //WSP_GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1604
+
1605
+ // some Metal matrix data types require aligned pointers
1606
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1607
+ //switch (op->src[0]->type) {
1608
+ // case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(nb01 % 16 == 0); break;
1609
+ // case WSP_GGML_TYPE_F16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
1610
+ // case WSP_GGML_TYPE_BF16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
1611
+ // default: break;
1612
+ //}
1613
+
1614
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm(lib, op);
1615
+
1616
+ wsp_ggml_metal_kargs_mul_mm args = {
1617
+ /*.ne00 =*/ ne00,
1618
+ /*.ne02 =*/ ne02,
1619
+ /*.nb01 =*/ nb01,
1620
+ /*.nb02 =*/ nb02,
1621
+ /*.nb03 =*/ nb03,
1622
+ /*.ne12 =*/ ne12,
1623
+ /*.nb10 =*/ nb10,
1624
+ /*.nb11 =*/ nb11,
1625
+ /*.nb12 =*/ nb12,
1626
+ /*.nb13 =*/ nb13,
1627
+ /*.ne0 =*/ ne0,
1628
+ /*.ne1 =*/ ne1,
1629
+ /*.r2 =*/ r2,
1630
+ /*.r3 =*/ r3,
1631
+ };
1632
+
1633
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1634
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1635
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1636
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1637
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
1638
+
1639
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1640
+
1641
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1642
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1643
+ } else {
1644
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv(lib, op);
1645
+
1646
+ const int nr0 = wsp_ggml_metal_pipeline_get_nr0(pipeline);
1647
+ const int nr1 = wsp_ggml_metal_pipeline_get_nr1(pipeline);
1648
+ const int nsg = wsp_ggml_metal_pipeline_get_nsg(pipeline);
1649
+
1650
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1651
+
1652
+ wsp_ggml_metal_kargs_mul_mv args = {
1653
+ /*.ne00 =*/ ne00,
1654
+ /*.ne01 =*/ ne01,
1655
+ /*.ne02 =*/ ne02,
1656
+ /*.nb00 =*/ nb00,
1657
+ /*.nb01 =*/ nb01,
1658
+ /*.nb02 =*/ nb02,
1659
+ /*.nb03 =*/ nb03,
1660
+ /*.ne10 =*/ ne10,
1661
+ /*.ne11 =*/ ne11,
1662
+ /*.ne12 =*/ ne12,
1663
+ /*.nb10 =*/ nb10,
1664
+ /*.nb11 =*/ nb11,
1665
+ /*.nb12 =*/ nb12,
1666
+ /*.nb13 =*/ nb13,
1667
+ /*.ne0 =*/ ne0,
1668
+ /*.ne1 =*/ ne1,
1669
+ /*.nr0 =*/ nr0,
1670
+ /*.r2 =*/ r2,
1671
+ /*.r3 =*/ r3,
1672
+ };
1673
+
1674
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1675
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1676
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1677
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1678
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
1679
+
1680
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1681
+
1682
+ if (op->src[0]->type == WSP_GGML_TYPE_F32 ||
1683
+ op->src[0]->type == WSP_GGML_TYPE_F16 ||
1684
+ op->src[0]->type == WSP_GGML_TYPE_BF16 ||
1685
+ op->src[0]->type == WSP_GGML_TYPE_Q8_0) {
1686
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1687
+ } else {
1688
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1689
+ }
1690
+ }
1691
+
1692
+ return 1;
1693
+ }
1694
+
1695
+ size_t wsp_ggml_metal_op_mul_mat_id_extra_tpe(const wsp_ggml_tensor * op) {
1696
+ assert(op->op == WSP_GGML_OP_MUL_MAT_ID);
1697
+
1698
+ const int64_t ne02 = op->src[0]->ne[2]; // n_expert
1699
+
1700
+ return wsp_ggml_type_size(WSP_GGML_TYPE_I32)*ne02;
1701
+ }
1702
+
1703
+ size_t wsp_ggml_metal_op_mul_mat_id_extra_ids(const wsp_ggml_tensor * op) {
1704
+ assert(op->op == WSP_GGML_OP_MUL_MAT_ID);
1705
+
1706
+ const int64_t ne02 = op->src[0]->ne[2]; // n_expert
1707
+ const int64_t ne21 = op->src[2]->ne[1]; // n_token
1708
+
1709
+ return wsp_ggml_type_size(WSP_GGML_TYPE_I32)*ne02*ne21;
1710
+ }
1711
+
1712
+ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
1713
+ wsp_ggml_tensor * op = ctx->node(idx);
1714
+
1715
+ wsp_ggml_metal_library_t lib = ctx->lib;
1716
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1717
+
1718
+ const wsp_ggml_metal_device_props * props_dev = wsp_ggml_metal_device_get_props(ctx->dev);
1719
+
1720
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1721
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1722
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1723
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1724
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1725
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1726
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1727
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1728
+
1729
+ // src2 = ids
1730
+ WSP_GGML_ASSERT(op->src[2]->type == WSP_GGML_TYPE_I32);
1731
+
1732
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(op->src[0]));
1733
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(op->src[1]));
1734
+
1735
+ WSP_GGML_ASSERT(ne03 == 1);
1736
+ WSP_GGML_ASSERT(ne13 == 1);
1737
+
1738
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
1739
+ wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
1740
+ wsp_ggml_metal_buffer_id bid_src2 = wsp_ggml_metal_get_buffer_id(op->src[2]);
1741
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
1742
+
1743
+ const uint32_t r2 = 1;
1744
+ const uint32_t r3 = 1;
1745
+
1746
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1747
+ // to the matrix-vector kernel
1748
+ // ne20 = n_used_experts
1749
+ // ne21 = n_rows (batch size)
1750
+ const int ne21_mm_id_min = 32;
1751
+
1752
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
1753
+ // some Metal matrix data types require aligned pointers
1754
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1755
+ //switch (op->src[0]->type) {
1756
+ // case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(nb01 % 16 == 0); break;
1757
+ // case WSP_GGML_TYPE_F16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
1758
+ // case WSP_GGML_TYPE_BF16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
1759
+ // default: break;
1760
+ //}
1761
+
1762
+ // extra buffers for intermediate id mapping
1763
+ wsp_ggml_metal_buffer_id bid_tpe = bid_dst;
1764
+ bid_tpe.offs += wsp_ggml_nbytes(op);
1765
+
1766
+ wsp_ggml_metal_buffer_id bid_ids = bid_tpe;
1767
+ bid_ids.offs += wsp_ggml_metal_op_mul_mat_id_extra_tpe(op);
1768
+
1769
+ {
1770
+ wsp_ggml_metal_kargs_mul_mm_id_map0 args = {
1771
+ ne02,
1772
+ ne10,
1773
+ ne11, // n_expert_used (bcast)
1774
+ nb11,
1775
+ nb12,
1776
+ ne21, // n_tokens
1777
+ ne20, // n_expert_used
1778
+ nb21,
1779
+ };
1780
+
1781
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1782
+
1783
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1784
+
1785
+ WSP_GGML_ASSERT(ne02 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1786
+
1787
+ WSP_GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
1788
+
1789
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1790
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1791
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 1);
1792
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tpe, 2);
1793
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_ids, 3);
1794
+
1795
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1796
+
1797
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);
1798
+ }
1799
+
1800
+ // this barrier is always needed because the next kernel has to wait for the id maps to be computed
1801
+ wsp_ggml_metal_op_concurrency_reset(ctx);
1802
+
1803
+ {
1804
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
1805
+
1806
+ wsp_ggml_metal_kargs_mul_mm_id args = {
1807
+ /*.ne00 =*/ ne00,
1808
+ /*.ne02 =*/ ne02,
1809
+ /*.nb01 =*/ nb01,
1810
+ /*.nb02 =*/ nb02,
1811
+ /*.nb03 =*/ nb03,
1812
+ /*.ne11 =*/ ne11, // n_expert_used (bcast)
1813
+ /*.nb10 =*/ nb10,
1814
+ /*.nb11 =*/ nb11,
1815
+ /*.nb12 =*/ nb12,
1816
+ /*.nb13 =*/ nb13,
1817
+ /*.ne20 =*/ ne20, // n_expert_used
1818
+ /*.ne21 =*/ ne21, // n_tokens
1819
+ /*.ne0 =*/ ne0,
1820
+ /*.ne1 =*/ ne1,
1821
+ /*.r2 =*/ r2,
1822
+ /*.r3 =*/ r3,
1823
+ };
1824
+
1825
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1826
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1827
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1828
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
1829
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tpe, 3);
1830
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
1831
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
1832
+
1833
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1834
+
1835
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1836
+
1837
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
1838
+ }
1839
+ } else {
1840
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
1841
+
1842
+ const int nr0 = wsp_ggml_metal_pipeline_get_nr0(pipeline);
1843
+ const int nr1 = wsp_ggml_metal_pipeline_get_nr1(pipeline);
1844
+ const int nsg = wsp_ggml_metal_pipeline_get_nsg(pipeline);
1845
+
1846
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1847
+
1848
+ wsp_ggml_metal_kargs_mul_mv_id args = {
1849
+ /*.nei0 =*/ ne20,
1850
+ /*.nei1 =*/ ne21,
1851
+ /*.nbi1 =*/ nb21,
1852
+ /*.ne00 =*/ ne00,
1853
+ /*.ne01 =*/ ne01,
1854
+ /*.ne02 =*/ ne02,
1855
+ /*.nb00 =*/ nb00,
1856
+ /*.nb01 =*/ nb01,
1857
+ /*.nb02 =*/ nb02,
1858
+ /*.ne10 =*/ ne10,
1859
+ /*.ne11 =*/ ne11,
1860
+ /*.ne12 =*/ ne12,
1861
+ /*.ne13 =*/ ne13,
1862
+ /*.nb10 =*/ nb10,
1863
+ /*.nb11 =*/ nb11,
1864
+ /*.nb12 =*/ nb12,
1865
+ /*.ne0 =*/ ne0,
1866
+ /*.ne1 =*/ ne1,
1867
+ /*.nb1 =*/ nb1,
1868
+ /*.nr0 =*/ nr0,
1869
+ };
1870
+
1871
+ if (wsp_ggml_is_quantized(op->src[0]->type)) {
1872
+ WSP_GGML_ASSERT(ne00 >= nsg*nr0);
1873
+ }
1874
+
1875
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1876
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1877
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_src0, 1);
1878
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_src1, 2);
1879
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_dst, 3);
1880
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_src2, 4);
1881
+
1882
+ const int64_t _ne1 = 1;
1883
+ const int64_t ne123 = ne20*ne21;
1884
+
1885
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1886
+
1887
+ if (op->src[0]->type == WSP_GGML_TYPE_F32 ||
1888
+ op->src[0]->type == WSP_GGML_TYPE_F16 ||
1889
+ op->src[0]->type == WSP_GGML_TYPE_BF16 ||
1890
+ op->src[0]->type == WSP_GGML_TYPE_Q8_0) {
1891
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
1892
+ } else {
1893
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
1894
+ }
1895
+ }
1896
+
1897
+ return 1;
1898
+ }
1899
+
1900
+ int wsp_ggml_metal_op_add_id(wsp_ggml_metal_op_t ctx, int idx) {
1901
+ wsp_ggml_tensor * op = ctx->node(idx);
1902
+
1903
+ wsp_ggml_metal_library_t lib = ctx->lib;
1904
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
1905
+
1906
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1907
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1908
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1909
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1910
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1911
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1912
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1913
+
1914
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
1915
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
1916
+ WSP_GGML_ASSERT(op->src[2]->type == WSP_GGML_TYPE_I32);
1917
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
1918
+
1919
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
1920
+
1921
+ wsp_ggml_metal_kargs_add_id args = {
1922
+ /*.ne0 =*/ ne0,
1923
+ /*.ne1 =*/ ne1,
1924
+ /*.nb01 =*/ nb01,
1925
+ /*.nb02 =*/ nb02,
1926
+ /*.nb11 =*/ nb11,
1927
+ /*.nb21 =*/ nb21,
1928
+ };
1929
+
1930
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_ADD_ID);
1931
+
1932
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1933
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1934
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1935
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1936
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
1937
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 4);
1938
+
1939
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
1940
+
1941
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);
1942
+
1943
+ return 1;
1944
+ }
1945
+
1946
+ bool wsp_ggml_metal_op_flash_attn_ext_use_vec(const wsp_ggml_tensor * op) {
1947
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1948
+
1949
+ const int64_t ne00 = op->src[0]->ne[0]; // head size
1950
+ const int64_t ne01 = op->src[0]->ne[1]; // batch size
1951
+
1952
+ // use vec kernel if the batch size is small and if the head size is supported
1953
+ return (ne01 < 20) && (ne00 % 32 == 0);
1954
+ }
1955
+
1956
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
1957
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1958
+
1959
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1960
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1961
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1962
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1963
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1964
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1965
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1966
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1967
+
1968
+ size_t res = 0;
1969
+
1970
+ const bool has_mask = op->src[3] != nullptr;
1971
+
1972
+ if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
1973
+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
1974
+
1975
+ if (has_kvpad) {
1976
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
1977
+ nb11*ne12*ne13 +
1978
+ nb21*ne22*ne23 +
1979
+ (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
1980
+ }
1981
+ } else {
1982
+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
1983
+
1984
+ if (has_kvpad) {
1985
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
1986
+ nb11*ne12*ne13 +
1987
+ nb21*ne22*ne23 +
1988
+ (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
1989
+ }
1990
+ }
1991
+
1992
+ return res;
1993
+ }
1994
+
1995
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
1996
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1997
+
1998
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1999
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2000
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2001
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2002
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2003
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2004
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2005
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2006
+
2007
+ size_t res = 0;
2008
+
2009
+ const bool has_mask = op->src[3] != nullptr;
2010
+
2011
+ if (!has_mask) {
2012
+ return res;
2013
+ }
2014
+
2015
+ const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
2016
+
2017
+ // this optimization is not useful for the vector kernels
2018
+ if (is_vec) {
2019
+ return res;
2020
+ }
2021
+
2022
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2023
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2024
+
2025
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2026
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2027
+
2028
+ res += WSP_GGML_PAD(wsp_ggml_type_size(WSP_GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2029
+
2030
+ return res;
2031
+ }
2032
+
2033
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
2034
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
2035
+
2036
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2037
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2038
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2039
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2040
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2041
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2042
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2043
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2044
+
2045
+ size_t res = 0;
2046
+
2047
+ if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2048
+ const int64_t nwg = 32;
2049
+
2050
+ // temp buffer for writing the results from each workgroup
2051
+ // - ne20: the size of the Value head
2052
+ // - + 2: the S and M values for each intermediate result
2053
+ res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2054
+ }
2055
+
2056
+ return res;
2057
+ }
2058
+
2059
+ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2060
+ wsp_ggml_tensor * op = ctx->node(idx);
2061
+
2062
+ wsp_ggml_metal_library_t lib = ctx->lib;
2063
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
2064
+
2065
+ const wsp_ggml_metal_device_props * props_dev = wsp_ggml_metal_device_get_props(ctx->dev);
2066
+
2067
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2068
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2069
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2070
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2071
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2072
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2073
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2074
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2075
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2076
+ WSP_GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
2077
+
2078
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
2079
+
2080
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
2081
+ WSP_GGML_ASSERT(op->src[1]->type == op->src[2]->type);
2082
+
2083
+ //WSP_GGML_ASSERT(wsp_ggml_are_same_shape (src1, src2));
2084
+ WSP_GGML_ASSERT(ne11 == ne21);
2085
+ WSP_GGML_ASSERT(ne12 == ne22);
2086
+
2087
+ WSP_GGML_ASSERT(!op->src[3] || op->src[3]->type == WSP_GGML_TYPE_F16);
2088
+ WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2089
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
2090
+
2091
+ float scale;
2092
+ float max_bias;
2093
+ float logit_softcap;
2094
+
2095
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
2096
+ memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
2097
+ memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
2098
+
2099
+ if (logit_softcap != 0.0f) {
2100
+ scale /= logit_softcap;
2101
+ }
2102
+
2103
+ const bool has_mask = op->src[3] != NULL;
2104
+ const bool has_sinks = op->src[4] != NULL;
2105
+ const bool has_bias = max_bias != 0.0f;
2106
+ const bool has_scap = logit_softcap != 0.0f;
2107
+
2108
+ const uint32_t n_head = op->src[0]->ne[2];
2109
+ const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2110
+
2111
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2112
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2113
+
2114
+ WSP_GGML_ASSERT(ne01 < 65536);
2115
+
2116
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
2117
+ wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
2118
+ wsp_ggml_metal_buffer_id bid_src2 = wsp_ggml_metal_get_buffer_id(op->src[2]);
2119
+ wsp_ggml_metal_buffer_id bid_src3 = has_mask ? wsp_ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2120
+ wsp_ggml_metal_buffer_id bid_src4 = has_sinks ? wsp_ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2121
+
2122
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2123
+
2124
+ wsp_ggml_metal_buffer_id bid_pad = bid_dst;
2125
+ bid_pad.offs += wsp_ggml_nbytes(op);
2126
+
2127
+ wsp_ggml_metal_buffer_id bid_blk = bid_pad;
2128
+ bid_blk.offs += wsp_ggml_metal_op_flash_attn_ext_extra_pad(op);
2129
+
2130
+ wsp_ggml_metal_buffer_id bid_tmp = bid_blk;
2131
+ bid_tmp.offs += wsp_ggml_metal_op_flash_attn_ext_extra_blk(op);
2132
+
2133
+ if (!wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2134
+ // half8x8 kernel
2135
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2136
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2137
+
2138
+ WSP_GGML_ASSERT(nqptg <= 32);
2139
+ WSP_GGML_ASSERT(nqptg % 8 == 0);
2140
+ WSP_GGML_ASSERT(ncpsg % 32 == 0);
2141
+
2142
+ bool need_sync = false;
2143
+
2144
+ const bool has_kvpad = ne11 % ncpsg != 0;
2145
+
2146
+ if (has_kvpad) {
2147
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2148
+
2149
+ wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
2150
+ /*.ne11 =*/ne11,
2151
+ /*.ne_12_2 =*/ne12,
2152
+ /*.ne_12_3 =*/ne13,
2153
+ /*.nb11 =*/nb11,
2154
+ /*.nb12 =*/nb12,
2155
+ /*.nb13 =*/nb13,
2156
+ /*.nb21 =*/nb21,
2157
+ /*.nb22 =*/nb22,
2158
+ /*.nb23 =*/nb23,
2159
+ /*.ne31 =*/ne31,
2160
+ /*.ne32 =*/ne32,
2161
+ /*.ne33 =*/ne33,
2162
+ /*.nb31 =*/nb31,
2163
+ /*.nb32 =*/nb32,
2164
+ /*.nb33 =*/nb33,
2165
+ };
2166
+
2167
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2168
+
2169
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2170
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2171
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2172
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2173
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2174
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2175
+
2176
+ assert(ne12 == ne22);
2177
+ assert(ne13 == ne23);
2178
+
2179
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2180
+
2181
+ need_sync = true;
2182
+ } else {
2183
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
2184
+ }
2185
+
2186
+ if (has_mask) {
2187
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2188
+
2189
+ wsp_ggml_metal_kargs_flash_attn_ext_blk args0 = {
2190
+ /*.ne01 =*/ ne01,
2191
+ /*.ne30 =*/ ne30,
2192
+ /*.ne31 =*/ ne31,
2193
+ /*.ne32 =*/ ne32,
2194
+ /*.ne33 =*/ ne33,
2195
+ /*.nb31 =*/ nb31,
2196
+ /*.nb32 =*/ nb32,
2197
+ /*.nb33 =*/ nb33,
2198
+ };
2199
+
2200
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2201
+
2202
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2203
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2204
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2205
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2206
+
2207
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2208
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2209
+
2210
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2211
+
2212
+ need_sync = true;
2213
+ } else {
2214
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
2215
+ }
2216
+
2217
+ if (need_sync) {
2218
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2219
+ }
2220
+
2221
+ const int is_q = wsp_ggml_is_quantized(op->src[1]->type) ? 1 : 0;
2222
+
2223
+ // 2*(2*ncpsg)
2224
+ // ncpsg soft_max values + ncpsg mask values
2225
+ //
2226
+ // 16*32*(nsg)
2227
+ // the shared memory needed for the simdgroups to load the KV cache
2228
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
2229
+ //
2230
+ #define FATTN_SMEM(nsg) (WSP_GGML_PAD((nqptg*(ne00 + 2*WSP_GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
2231
+
2232
+ //int64_t nsgmax = 4;
2233
+ //
2234
+ //if (is_q) {
2235
+ // nsgmax = 2;
2236
+ // while (true) {
2237
+ // const size_t smem = FATTN_SMEM(nsgmax);
2238
+ // if (smem > props_dev->max_theadgroup_memory_size) {
2239
+ // break;
2240
+ // }
2241
+ // nsgmax *= 2;
2242
+ // }
2243
+ // nsgmax /= 2;
2244
+ //}
2245
+
2246
+ // simdgroups per threadgroup (a.k.a. warps)
2247
+ //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2248
+ int32_t nsg = 4;
2249
+
2250
+ const size_t smem = FATTN_SMEM(nsg);
2251
+
2252
+ wsp_ggml_metal_kargs_flash_attn_ext args = {
2253
+ /*.ne01 =*/ ne01,
2254
+ /*.ne02 =*/ ne02,
2255
+ /*.ne03 =*/ ne03,
2256
+ /*.nb01 =*/ nb01,
2257
+ /*.nb02 =*/ nb02,
2258
+ /*.nb03 =*/ nb03,
2259
+ /*.ne11 =*/ ne11,
2260
+ /*.ne_12_2 =*/ ne12,
2261
+ /*.ne_12_3 =*/ ne13,
2262
+ /*.ns10 =*/ int32_t(nb11/nb10),
2263
+ /*.nb11 =*/ nb11,
2264
+ /*.nb12 =*/ nb12,
2265
+ /*.nb13 =*/ nb13,
2266
+ /*.ns20 =*/ int32_t(nb21/nb20),
2267
+ /*.nb21 =*/ nb21,
2268
+ /*.nb22 =*/ nb22,
2269
+ /*.nb23 =*/ nb23,
2270
+ /*.ne31 =*/ ne31,
2271
+ /*.ne32 =*/ ne32,
2272
+ /*.ne33 =*/ ne33,
2273
+ /*.nb31 =*/ nb31,
2274
+ /*.nb32 =*/ nb32,
2275
+ /*.nb33 =*/ nb33,
2276
+ /*.ne1 =*/ ne1,
2277
+ /*.ne2 =*/ ne2,
2278
+ /*.ne3 =*/ ne3,
2279
+ /*.scale =*/ scale,
2280
+ /*.max_bias =*/ max_bias,
2281
+ /*.m0 =*/ m0,
2282
+ /*.m1 =*/ m1,
2283
+ /*.n_head_log2 =*/ n_head_log2,
2284
+ /*.logit_softcap =*/ logit_softcap,
2285
+ };
2286
+
2287
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2288
+
2289
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2290
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2291
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2292
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2293
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2294
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2295
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2296
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2297
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2298
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2299
+
2300
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2301
+
2302
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
2303
+ #undef FATTN_SMEM
2304
+ } else {
2305
+ // half4x4 kernel
2306
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2307
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2308
+ const int nkpsg = 1*ncpsg;
2309
+
2310
+ WSP_GGML_ASSERT(nqptg <= 32);
2311
+ WSP_GGML_ASSERT(nqptg % 1 == 0);
2312
+ WSP_GGML_ASSERT(ncpsg % 32 == 0);
2313
+
2314
+ bool need_sync = false;
2315
+
2316
+ const bool has_kvpad = ne11 % ncpsg != 0;
2317
+
2318
+ if (has_kvpad) {
2319
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2320
+
2321
+ wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
2322
+ /*.ne11 =*/ne11,
2323
+ /*.ne_12_2 =*/ne12,
2324
+ /*.ne_12_3 =*/ne13,
2325
+ /*.nb11 =*/nb11,
2326
+ /*.nb12 =*/nb12,
2327
+ /*.nb13 =*/nb13,
2328
+ /*.nb21 =*/nb21,
2329
+ /*.nb22 =*/nb22,
2330
+ /*.nb23 =*/nb23,
2331
+ /*.ne31 =*/ne31,
2332
+ /*.ne32 =*/ne32,
2333
+ /*.ne33 =*/ne33,
2334
+ /*.nb31 =*/nb31,
2335
+ /*.nb32 =*/nb32,
2336
+ /*.nb33 =*/nb33,
2337
+ };
2338
+
2339
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2340
+
2341
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2342
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2343
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2344
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2345
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2346
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2347
+
2348
+ assert(ne12 == ne22);
2349
+ assert(ne13 == ne23);
2350
+
2351
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2352
+
2353
+ need_sync = true;
2354
+ } else {
2355
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
2356
+ }
2357
+
2358
+ if (need_sync) {
2359
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2360
+ }
2361
+
2362
+ // ne00 + 2*ncpsg*(nsg)
2363
+ // for each query, we load it as f16 in shared memory (ne00)
2364
+ // and store the soft_max values and the mask
2365
+ //
2366
+ // ne20*(nsg)
2367
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
2368
+ //
2369
+ #define FATTN_SMEM(nsg) (WSP_GGML_PAD((nqptg*(WSP_GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*WSP_GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
2370
+
2371
+ int64_t nsgmax = 2;
2372
+ while (true) {
2373
+ const size_t smem = FATTN_SMEM(nsgmax);
2374
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
2375
+ if (smem > props_dev->max_theadgroup_memory_size/2) {
2376
+ break;
2377
+ }
2378
+ nsgmax *= 2;
2379
+ }
2380
+ nsgmax /= 2;
2381
+
2382
+ // simdgroups per threadgroup (a.k.a. warps)
2383
+ //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
2384
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
2385
+
2386
+ int64_t nsg = 1;
2387
+ while (nsg <= nsgt) {
2388
+ nsg *= 2;
2389
+ }
2390
+ nsg /= 2;
2391
+
2392
+ // workgroups
2393
+ // each workgroup handles nsg*nkpsg cache values
2394
+ int32_t nwg = 1;
2395
+ if (false) {
2396
+ // for small KV caches, we could launch a single workgroup and write the results directly to dst/
2397
+ // however, this does not lead to significant improvement, so disabled
2398
+ nwg = 1;
2399
+ nsg = 4;
2400
+ } else {
2401
+ nwg = 32;
2402
+ nsg = 1;
2403
+ while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
2404
+ nsg *= 2;
2405
+ }
2406
+ }
2407
+
2408
+ wsp_ggml_metal_kargs_flash_attn_ext_vec args = {
2409
+ /*.ne01 =*/ ne01,
2410
+ /*.ne02 =*/ ne02,
2411
+ /*.ne03 =*/ ne03,
2412
+ /*.nb01 =*/ nb01,
2413
+ /*.nb02 =*/ nb02,
2414
+ /*.nb03 =*/ nb03,
2415
+ /*.ne11 =*/ ne11,
2416
+ /*.ne_12_2 =*/ ne12,
2417
+ /*.ne_12_3 =*/ ne13,
2418
+ /*.ns10 =*/ int32_t(nb11/nb10),
2419
+ /*.nb11 =*/ nb11,
2420
+ /*.nb12 =*/ nb12,
2421
+ /*.nb13 =*/ nb13,
2422
+ /*.ns20 =*/ int32_t(nb21/nb20),
2423
+ /*.nb21 =*/ nb21,
2424
+ /*.nb22 =*/ nb22,
2425
+ /*.nb23 =*/ nb23,
2426
+ /*.ne31 =*/ ne31,
2427
+ /*.ne32 =*/ ne32,
2428
+ /*.ne33 =*/ ne33,
2429
+ /*.nb31 =*/ nb31,
2430
+ /*.nb32 =*/ nb32,
2431
+ /*.nb33 =*/ nb33,
2432
+ /*.ne1 =*/ ne1,
2433
+ /*.ne2 =*/ ne2,
2434
+ /*.ne3 =*/ ne3,
2435
+ /*.scale =*/ scale,
2436
+ /*.max_bias =*/ max_bias,
2437
+ /*.m0 =*/ m0,
2438
+ /*.m1 =*/ m1,
2439
+ /*.n_head_log2 =*/ n_head_log2,
2440
+ /*.logit_softcap =*/ logit_softcap,
2441
+ };
2442
+
2443
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2444
+
2445
+ WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2446
+
2447
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2448
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2449
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2450
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2451
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2452
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2453
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2454
+
2455
+ const size_t smem = FATTN_SMEM(nsg);
2456
+
2457
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2458
+ WSP_GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2459
+
2460
+ if (nwg == 1) {
2461
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2462
+
2463
+ // using 1 workgroup -> write the result directly into dst
2464
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2465
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2466
+
2467
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2468
+
2469
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2470
+ } else {
2471
+ // sanity checks
2472
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2473
+
2474
+ WSP_GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2475
+ WSP_GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2476
+
2477
+ // write the results from each workgroup into a temp buffer
2478
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2479
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2480
+
2481
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2482
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2483
+
2484
+ // sync the 2 kernels
2485
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2486
+
2487
+ // reduce the results from the workgroups
2488
+ {
2489
+ const int32_t nrows = ne1*ne2*ne3;
2490
+
2491
+ wsp_ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
2492
+ nrows,
2493
+ };
2494
+
2495
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2496
+
2497
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2498
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2499
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
2500
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2501
+
2502
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);
2503
+ }
2504
+ }
2505
+ #undef FATTN_SMEM
2506
+ }
2507
+
2508
+ return 1;
2509
+ }
2510
+
2511
+ int wsp_ggml_metal_op_bin(wsp_ggml_metal_op_t ctx, int idx) {
2512
+ wsp_ggml_tensor * op = ctx->node(idx);
2513
+
2514
+ wsp_ggml_metal_library_t lib = ctx->lib;
2515
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
2516
+
2517
+ const bool use_fusion = ctx->use_fusion;
2518
+
2519
+ const int debug_fusion = ctx->debug_fusion;
2520
+
2521
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2522
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2523
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2524
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2525
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2526
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2527
+
2528
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
2529
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
2530
+
2531
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
2532
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[1]));
2533
+
2534
+ bool bcast_row = false;
2535
+
2536
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
2537
+ wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
2538
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2539
+
2540
+ wsp_ggml_metal_kargs_bin args = {
2541
+ /*.ne00 =*/ ne00,
2542
+ /*.ne01 =*/ ne01,
2543
+ /*.ne02 =*/ ne02,
2544
+ /*.ne03 =*/ ne03,
2545
+ /*.nb00 =*/ nb00,
2546
+ /*.nb01 =*/ nb01,
2547
+ /*.nb02 =*/ nb02,
2548
+ /*.nb03 =*/ nb03,
2549
+ /*.ne10 =*/ ne10,
2550
+ /*.ne11 =*/ ne11,
2551
+ /*.ne12 =*/ ne12,
2552
+ /*.ne13 =*/ ne13,
2553
+ /*.nb10 =*/ nb10,
2554
+ /*.nb11 =*/ nb11,
2555
+ /*.nb12 =*/ nb12,
2556
+ /*.nb13 =*/ nb13,
2557
+ /*.ne0 =*/ ne0,
2558
+ /*.ne1 =*/ ne1,
2559
+ /*.ne2 =*/ ne2,
2560
+ /*.ne3 =*/ ne3,
2561
+ /*.nb0 =*/ nb0,
2562
+ /*.nb1 =*/ nb1,
2563
+ /*.nb2 =*/ nb2,
2564
+ /*.nb3 =*/ nb3,
2565
+ /*.offs =*/ 0,
2566
+ /*.o1 =*/ { bid_src1.offs },
2567
+ };
2568
+
2569
+ wsp_ggml_op fops[8];
2570
+
2571
+ int n_fuse = 1;
2572
+
2573
+ // c[0] = add(a, b[0])
2574
+ // c[1] = add(c[0], b[1])
2575
+ // c[2] = add(c[1], b[2])
2576
+ // ...
2577
+ if (use_fusion) {
2578
+ fops[0] = WSP_GGML_OP_ADD;
2579
+ fops[1] = WSP_GGML_OP_ADD;
2580
+ fops[2] = WSP_GGML_OP_ADD;
2581
+ fops[3] = WSP_GGML_OP_ADD;
2582
+ fops[4] = WSP_GGML_OP_ADD;
2583
+ fops[5] = WSP_GGML_OP_ADD;
2584
+ fops[6] = WSP_GGML_OP_ADD;
2585
+ fops[7] = WSP_GGML_OP_ADD;
2586
+
2587
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
2588
+ // across splits. idx_end indicates the last node in the current split
2589
+ for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2590
+ if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
2591
+ break;
2592
+ }
2593
+
2594
+ wsp_ggml_tensor * f0 = ctx->node(idx + n_fuse);
2595
+ wsp_ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
2596
+
2597
+ if (f0 != f1->src[0]) {
2598
+ break;
2599
+ }
2600
+
2601
+ // b[0] === b[1] === ...
2602
+ if (!wsp_ggml_are_same_layout(f0->src[1], f1->src[1])) {
2603
+ break;
2604
+ }
2605
+
2606
+ // only fuse ops if src1 is in the same Metal buffer
2607
+ wsp_ggml_metal_buffer_id bid_fuse = wsp_ggml_metal_get_buffer_id(f1->src[1]);
2608
+ if (bid_fuse.metal != bid_src1.metal) {
2609
+ break;
2610
+ }
2611
+
2612
+ //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
2613
+
2614
+ args.o1[n_fuse + 1] = bid_fuse.offs;
2615
+ }
2616
+
2617
+ ++n_fuse;
2618
+
2619
+ if (debug_fusion > 1 && n_fuse > 1) {
2620
+ WSP_GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2621
+ }
2622
+ }
2623
+
2624
+ // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2625
+ bid_src1.offs = 0;
2626
+
2627
+ wsp_ggml_metal_pipeline_t pipeline = nullptr;
2628
+
2629
+ if (wsp_ggml_nelements(op->src[1]) == ne10 && wsp_ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2630
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
2631
+
2632
+ // src1 is a row
2633
+ WSP_GGML_ASSERT(ne11 == 1);
2634
+
2635
+ pipeline = wsp_ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
2636
+
2637
+ bcast_row = true;
2638
+ } else {
2639
+ pipeline = wsp_ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
2640
+ }
2641
+
2642
+ if (n_fuse > 1) {
2643
+ bid_dst = wsp_ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
2644
+
2645
+ for (int i = 1; i < n_fuse; ++i) {
2646
+ if (!wsp_ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
2647
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2648
+
2649
+ break;
2650
+ }
2651
+ }
2652
+ }
2653
+
2654
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2655
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2656
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2657
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2658
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2659
+
2660
+ if (bcast_row) {
2661
+ const int64_t n = wsp_ggml_nelements(op)/4;
2662
+
2663
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
2664
+ } else {
2665
+ int nth = 32;
2666
+
2667
+ while (16*nth < ne0 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2668
+ nth *= 2;
2669
+ }
2670
+
2671
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2672
+ }
2673
+
2674
+ return n_fuse;
2675
+ }
2676
+
2677
+ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
2678
+ wsp_ggml_tensor * op = ctx->node(idx);
2679
+
2680
+ wsp_ggml_metal_library_t lib = ctx->lib;
2681
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
2682
+
2683
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2684
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2685
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2686
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2687
+
2688
+ float eps;
2689
+ memcpy(&eps, op->op_params, sizeof(float));
2690
+
2691
+ int nth = 32; // SIMD width
2692
+
2693
+ wsp_ggml_metal_kargs_l2_norm args = {
2694
+ /*.ne00 =*/ ne00,
2695
+ /*.ne00_4 =*/ ne00/4,
2696
+ /*.nb01 =*/ nb01,
2697
+ /*.eps =*/ eps,
2698
+ };
2699
+
2700
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_l2_norm(lib, op);
2701
+
2702
+ while (nth < ne00/4 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2703
+ nth *= 2;
2704
+ }
2705
+
2706
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2707
+ nth = std::min(nth, ne00/4);
2708
+
2709
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2710
+
2711
+ const int64_t nrows = wsp_ggml_nrows(op->src[0]);
2712
+
2713
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2714
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2715
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2716
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
2717
+
2718
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2719
+
2720
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
2721
+
2722
+ return 1;
2723
+ }
2724
+
2725
+ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
2726
+ wsp_ggml_tensor * op = ctx->node(idx);
2727
+
2728
+ wsp_ggml_metal_library_t lib = ctx->lib;
2729
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
2730
+
2731
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2732
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2733
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2734
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2735
+
2736
+ const int32_t ngrp = ((const int32_t *) op->op_params)[0];
2737
+
2738
+ float eps;
2739
+ memcpy(&eps, op->op_params + 1, sizeof(float));
2740
+
2741
+ wsp_ggml_metal_kargs_group_norm args = {
2742
+ /*.ne00 =*/ ne00,
2743
+ /*.ne01 =*/ ne01,
2744
+ /*.ne02 =*/ ne02,
2745
+ /*.nb00 =*/ nb00,
2746
+ /*.nb01 =*/ nb01,
2747
+ /*.nb02 =*/ nb02,
2748
+ /*.ngrp =*/ ngrp,
2749
+ /*.eps =*/ eps,
2750
+ };
2751
+
2752
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_group_norm(lib, op);
2753
+
2754
+ int nth = 32; // SIMD width
2755
+ //while (nth < ne00/4 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2756
+ // nth *= 2;
2757
+ //}
2758
+
2759
+ //nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2760
+ //nth = std::min(nth, ne00/4);
2761
+
2762
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2763
+
2764
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2765
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2766
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2767
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
2768
+
2769
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2770
+
2771
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);
2772
+
2773
+ return 1;
2774
+ }
2775
+
2776
+ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
2777
+ wsp_ggml_tensor * op = ctx->node(idx);
2778
+
2779
+ wsp_ggml_metal_library_t lib = ctx->lib;
2780
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
2781
+
2782
+ const bool use_fusion = ctx->use_fusion;
2783
+
2784
+ const int debug_fusion = ctx->debug_fusion;
2785
+
2786
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2787
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2788
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2789
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2790
+
2791
+ float eps;
2792
+ memcpy(&eps, op->op_params, sizeof(float));
2793
+
2794
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
2795
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2796
+
2797
+ wsp_ggml_metal_kargs_norm args = {
2798
+ /*.ne00 =*/ ne00,
2799
+ /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
2800
+ /*.nb1 =*/ nb1,
2801
+ /*.nb2 =*/ nb2,
2802
+ /*.nb3 =*/ nb3,
2803
+ /*.eps =*/ eps,
2804
+ /*.nef1 =*/ { ne01 },
2805
+ /*.nef2 =*/ { ne02 },
2806
+ /*.nef3 =*/ { ne03 },
2807
+ /*.nbf1 =*/ { nb01 },
2808
+ /*.nbf2 =*/ { nb02 },
2809
+ /*.nbf3 =*/ { nb03 },
2810
+ };
2811
+
2812
+ wsp_ggml_op fops[8];
2813
+
2814
+ int n_fuse = 1;
2815
+
2816
+ wsp_ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
2817
+
2818
+ // d[0] = norm(a)
2819
+ // d[1] = mul(d[0], b)
2820
+ // d[2] = add(d[1], c)
2821
+ if (use_fusion) {
2822
+ fops[0] = op->op;
2823
+ fops[1] = WSP_GGML_OP_MUL;
2824
+ fops[2] = WSP_GGML_OP_ADD;
2825
+
2826
+ for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
2827
+ if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
2828
+ break;
2829
+ }
2830
+
2831
+ wsp_ggml_tensor * f0 = ctx->node(idx + n_fuse);
2832
+ wsp_ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
2833
+
2834
+ if (f0 != f1->src[0]) {
2835
+ break;
2836
+ }
2837
+
2838
+ if (f1->src[1]->ne[0] != op->ne[0]) {
2839
+ break;
2840
+ }
2841
+
2842
+ if (!wsp_ggml_is_contiguous_rows(f1->src[1])) {
2843
+ break;
2844
+ }
2845
+
2846
+ if (f1->type != WSP_GGML_TYPE_F32) {
2847
+ break;
2848
+ }
2849
+
2850
+ //ctx->fuse_cnt[f1->op]++;
2851
+
2852
+ bid_fuse[n_fuse] = wsp_ggml_metal_get_buffer_id(f1->src[1]);
2853
+
2854
+ args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
2855
+ args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
2856
+ args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
2857
+
2858
+ args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
2859
+ args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
2860
+ args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
2861
+ }
2862
+
2863
+ ++n_fuse;
2864
+
2865
+ if (debug_fusion > 1 && n_fuse > 1) {
2866
+ if (n_fuse == 2) {
2867
+ WSP_GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, wsp_ggml_op_name(op->op));
2868
+ }
2869
+ if (n_fuse == 3) {
2870
+ WSP_GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, wsp_ggml_op_name(op->op));
2871
+ }
2872
+ }
2873
+ }
2874
+
2875
+ if (n_fuse > 1) {
2876
+ bid_dst = wsp_ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
2877
+
2878
+ for (int i = 1; i < n_fuse; ++i) {
2879
+ if (!wsp_ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
2880
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2881
+
2882
+ break;
2883
+ }
2884
+ }
2885
+ }
2886
+
2887
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
2888
+
2889
+ int nth = 32; // SIMD width
2890
+
2891
+ while (nth < args.ne00_t && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2892
+ nth *= 2;
2893
+ }
2894
+
2895
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2896
+ nth = std::min(nth, args.ne00_t);
2897
+
2898
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2899
+
2900
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2901
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2902
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2903
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
2904
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
2905
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
2906
+
2907
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2908
+
2909
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2910
+
2911
+ return n_fuse;
2912
+ }
2913
+
2914
+ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
2915
+ wsp_ggml_tensor * op = ctx->node(idx);
2916
+
2917
+ wsp_ggml_metal_library_t lib = ctx->lib;
2918
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
2919
+
2920
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2921
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2922
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2923
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2924
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2925
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2926
+
2927
+ // make sure we have one or more position id(ne10) per token(ne02)
2928
+ WSP_GGML_ASSERT(ne10 % ne02 == 0);
2929
+ WSP_GGML_ASSERT(ne10 >= ne02);
2930
+
2931
+ const int nth = std::min(1024, ne00);
2932
+
2933
+ const int n_past = ((const int32_t *) op->op_params)[0];
2934
+ const int n_dims = ((const int32_t *) op->op_params)[1];
2935
+ //const int mode = ((const int32_t *) op->op_params)[2];
2936
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2937
+ const int n_ctx_orig = ((const int32_t *) op->op_params)[4];
2938
+
2939
+ float freq_base;
2940
+ float freq_scale;
2941
+ float ext_factor;
2942
+ float attn_factor;
2943
+ float beta_fast;
2944
+ float beta_slow;
2945
+
2946
+ memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float));
2947
+ memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float));
2948
+ memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float));
2949
+ memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float));
2950
+ memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float));
2951
+ memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float));
2952
+
2953
+ // mrope
2954
+ const int sect_0 = ((const int32_t *) op->op_params)[11];
2955
+ const int sect_1 = ((const int32_t *) op->op_params)[12];
2956
+ const int sect_2 = ((const int32_t *) op->op_params)[13];
2957
+ const int sect_3 = ((const int32_t *) op->op_params)[14];
2958
+
2959
+ wsp_ggml_metal_kargs_rope args = {
2960
+ /*.ne00 =*/ ne00,
2961
+ /*.ne01 =*/ ne01,
2962
+ /*.ne02 =*/ ne02,
2963
+ /*.ne03 =*/ ne03,
2964
+ /*.nb00 =*/ nb00,
2965
+ /*.nb01 =*/ nb01,
2966
+ /*.nb02 =*/ nb02,
2967
+ /*.nb03 =*/ nb03,
2968
+ /*.ne0 =*/ ne0,
2969
+ /*.ne1 =*/ ne1,
2970
+ /*.ne2 =*/ ne2,
2971
+ /*.ne3 =*/ ne3,
2972
+ /*.nb0 =*/ nb0,
2973
+ /*.nb1 =*/ nb1,
2974
+ /*.nb2 =*/ nb2,
2975
+ /*.nb3 =*/ nb3,
2976
+ /*.n_past =*/ n_past,
2977
+ /*.n_dims =*/ n_dims,
2978
+ /*.n_ctx_orig =*/ n_ctx_orig,
2979
+ /*.freq_base =*/ freq_base,
2980
+ /*.freq_scale =*/ freq_scale,
2981
+ /*.ext_factor =*/ ext_factor,
2982
+ /*.attn_factor =*/ attn_factor,
2983
+ /*.beta_fast =*/ beta_fast,
2984
+ /*.beta_slow =*/ beta_slow,
2985
+ /* sect_0 =*/ sect_0,
2986
+ /* sect_1 =*/ sect_1,
2987
+ /* sect_2 =*/ sect_2,
2988
+ /* sect_3 =*/ sect_3,
2989
+ /* src2 =*/ op->src[2] != nullptr,
2990
+ };
2991
+
2992
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
2993
+
2994
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2995
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2996
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2997
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
2998
+ if (op->src[2]) {
2999
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
3000
+ } else {
3001
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 3);
3002
+ }
3003
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 4);
3004
+
3005
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3006
+
3007
+ return 1;
3008
+ }
3009
+
3010
+ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
3011
+ wsp_ggml_tensor * op = ctx->node(idx);
3012
+
3013
+ wsp_ggml_metal_library_t lib = ctx->lib;
3014
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3015
+
3016
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3017
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3018
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3019
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3020
+
3021
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3022
+ const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3023
+ const int32_t p0 = ((const int32_t *)(op->op_params))[2];
3024
+ const int32_t p1 = ((const int32_t *)(op->op_params))[3];
3025
+ const int32_t d0 = ((const int32_t *)(op->op_params))[4];
3026
+ const int32_t d1 = ((const int32_t *)(op->op_params))[5];
3027
+
3028
+ const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
3029
+
3030
+ const int32_t N = op->src[1]->ne[is_2D ? 3 : 2];
3031
+ const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];
3032
+ const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;
3033
+ const int32_t IW = op->src[1]->ne[0];
3034
+
3035
+ const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;
3036
+ const int32_t KW = op->src[0]->ne[0];
3037
+
3038
+ const int32_t OH = is_2D ? op->ne[2] : 1;
3039
+ const int32_t OW = op->ne[1];
3040
+
3041
+ const int32_t CHW = IC * KH * KW;
3042
+
3043
+ const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
3044
+ const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
3045
+
3046
+ wsp_ggml_metal_kargs_im2col args = {
3047
+ /*.ofs0 =*/ ofs0,
3048
+ /*.ofs1 =*/ ofs1,
3049
+ /*.IW =*/ IW,
3050
+ /*.IH =*/ IH,
3051
+ /*.CHW =*/ CHW,
3052
+ /*.s0 =*/ s0,
3053
+ /*.s1 =*/ s1,
3054
+ /*.p0 =*/ p0,
3055
+ /*.p1 =*/ p1,
3056
+ /*.d0 =*/ d0,
3057
+ /*.d1 =*/ d1,
3058
+ /*.N =*/ N,
3059
+ /*.KH =*/ KH,
3060
+ /*.KW =*/ KW,
3061
+ /*.KHW =*/ KH * KW,
3062
+ };
3063
+
3064
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_im2col(lib, op);
3065
+
3066
+ WSP_GGML_ASSERT(KH*KW <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3067
+
3068
+ const uint64_t ntptg0 = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3069
+
3070
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3071
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3072
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 1);
3073
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3074
+
3075
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3076
+
3077
+ return 1;
3078
+ }
3079
+
3080
+ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
3081
+ wsp_ggml_tensor * op = ctx->node(idx);
3082
+
3083
+ wsp_ggml_metal_library_t lib = ctx->lib;
3084
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3085
+
3086
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3087
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3088
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3089
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3090
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3091
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3092
+
3093
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3094
+
3095
+ const int32_t IC = op->src[1]->ne[1];
3096
+ const int32_t IL = op->src[1]->ne[0];
3097
+
3098
+ const int32_t K = op->src[0]->ne[0];
3099
+
3100
+ const int32_t OL = op->ne[0];
3101
+ const int32_t OC = op->ne[1];
3102
+
3103
+ wsp_ggml_metal_kargs_conv_transpose_1d args = {
3104
+ /*.IC =*/ IC,
3105
+ /*.IL =*/ IL,
3106
+ /*.K =*/ K,
3107
+ /*.s0 =*/ s0,
3108
+ /*.nb0 =*/ nb0,
3109
+ /*.nb1 =*/ nb1,
3110
+ };
3111
+
3112
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3113
+
3114
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3115
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3116
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3117
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
3118
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
3119
+
3120
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
3121
+
3122
+ return 1;
3123
+ }
3124
+
3125
+ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
3126
+ wsp_ggml_tensor * op = ctx->node(idx);
3127
+
3128
+ wsp_ggml_metal_library_t lib = ctx->lib;
3129
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3130
+
3131
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3132
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3133
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3134
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3135
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3136
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3137
+
3138
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3139
+
3140
+ const int32_t IC = op->src[1]->ne[2];
3141
+ const int32_t IH = op->src[1]->ne[1];
3142
+ const int32_t IW = op->src[1]->ne[0];
3143
+
3144
+ const int32_t KH = op->src[0]->ne[1];
3145
+ const int32_t KW = op->src[0]->ne[0];
3146
+
3147
+ const int32_t OW = op->ne[0];
3148
+ const int32_t OH = op->ne[1];
3149
+ const int32_t OC = op->ne[2];
3150
+
3151
+ wsp_ggml_metal_kargs_conv_transpose_2d args = {
3152
+ /*.IC =*/ IC,
3153
+ /*.IH =*/ IH,
3154
+ /*.IW =*/ IW,
3155
+ /*.KH =*/ KH,
3156
+ /*.KW =*/ KW,
3157
+ /*.OC =*/ OC,
3158
+ /*.s0 =*/ s0,
3159
+ /*.nb0 =*/ nb0,
3160
+ /*.nb1 =*/ nb1,
3161
+ /*.nb2 =*/ nb2,
3162
+ };
3163
+
3164
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3165
+
3166
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3167
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3168
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3169
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
3170
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
3171
+
3172
+ // Metal requires buffer size to be multiple of 16 bytes
3173
+ const size_t smem = WSP_GGML_PAD(KW * KH * sizeof(float), 16);
3174
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3175
+
3176
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3177
+
3178
+ return 1;
3179
+ }
3180
+
3181
+ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
3182
+ wsp_ggml_tensor * op = ctx->node(idx);
3183
+
3184
+ wsp_ggml_metal_library_t lib = ctx->lib;
3185
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3186
+
3187
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3188
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3189
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3190
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3191
+
3192
+ const float sf0 = (float)ne0/op->src[0]->ne[0];
3193
+ const float sf1 = (float)ne1/op->src[0]->ne[1];
3194
+ const float sf2 = (float)ne2/op->src[0]->ne[2];
3195
+ const float sf3 = (float)ne3/op->src[0]->ne[3];
3196
+
3197
+ wsp_ggml_metal_kargs_upscale args = {
3198
+ /*.ne00 =*/ ne00,
3199
+ /*.ne01 =*/ ne01,
3200
+ /*.ne02 =*/ ne02,
3201
+ /*.ne03 =*/ ne03,
3202
+ /*.nb00 =*/ nb00,
3203
+ /*.nb01 =*/ nb01,
3204
+ /*.nb02 =*/ nb02,
3205
+ /*.nb03 =*/ nb03,
3206
+ /*.ne0 =*/ ne0,
3207
+ /*.ne1 =*/ ne1,
3208
+ /*.ne2 =*/ ne2,
3209
+ /*.ne3 =*/ ne3,
3210
+ /*.nb0 =*/ nb0,
3211
+ /*.nb1 =*/ nb1,
3212
+ /*.nb2 =*/ nb2,
3213
+ /*.nb3 =*/ nb3,
3214
+ /*.sf0 =*/ sf0,
3215
+ /*.sf1 =*/ sf1,
3216
+ /*.sf2 =*/ sf2,
3217
+ /*.sf3 =*/ sf3
3218
+ };
3219
+
3220
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_upscale(lib, op);
3221
+
3222
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3223
+
3224
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3225
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3226
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3227
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3228
+
3229
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3230
+
3231
+ return 1;
3232
+ }
3233
+
3234
+ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
3235
+ wsp_ggml_tensor * op = ctx->node(idx);
3236
+
3237
+ wsp_ggml_metal_library_t lib = ctx->lib;
3238
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3239
+
3240
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3241
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3242
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3243
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3244
+
3245
+ wsp_ggml_metal_kargs_pad args = {
3246
+ /*.ne00 =*/ ne00,
3247
+ /*.ne01 =*/ ne01,
3248
+ /*.ne02 =*/ ne02,
3249
+ /*.ne03 =*/ ne03,
3250
+ /*.nb00 =*/ nb00,
3251
+ /*.nb01 =*/ nb01,
3252
+ /*.nb02 =*/ nb02,
3253
+ /*.nb03 =*/ nb03,
3254
+ /*.ne0 =*/ ne0,
3255
+ /*.ne1 =*/ ne1,
3256
+ /*.ne2 =*/ ne2,
3257
+ /*.ne3 =*/ ne3,
3258
+ /*.nb0 =*/ nb0,
3259
+ /*.nb1 =*/ nb1,
3260
+ /*.nb2 =*/ nb2,
3261
+ /*.nb3 =*/ nb3
3262
+ };
3263
+
3264
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_pad(lib, op);
3265
+
3266
+ const int nth = std::min(1024, ne0);
3267
+
3268
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3269
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3270
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3271
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3272
+
3273
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3274
+
3275
+ return 1;
3276
+ }
3277
+
3278
+ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
3279
+ wsp_ggml_tensor * op = ctx->node(idx);
3280
+
3281
+ wsp_ggml_metal_library_t lib = ctx->lib;
3282
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3283
+
3284
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3285
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3286
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3287
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3288
+
3289
+ wsp_ggml_metal_kargs_pad_reflect_1d args = {
3290
+ /*.ne00 =*/ ne00,
3291
+ /*.ne01 =*/ ne01,
3292
+ /*.ne02 =*/ ne02,
3293
+ /*.ne03 =*/ ne03,
3294
+ /*.nb00 =*/ nb00,
3295
+ /*.nb01 =*/ nb01,
3296
+ /*.nb02 =*/ nb02,
3297
+ /*.nb03 =*/ nb03,
3298
+ /*.ne0 =*/ ne0,
3299
+ /*.ne1 =*/ ne1,
3300
+ /*.ne2 =*/ ne2,
3301
+ /*.ne3 =*/ ne3,
3302
+ /*.nb0 =*/ nb0,
3303
+ /*.nb1 =*/ nb1,
3304
+ /*.nb2 =*/ nb2,
3305
+ /*.nb3 =*/ nb3,
3306
+ /*.p0 =*/ ((const int32_t *)(op->op_params))[0],
3307
+ /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
3308
+ };
3309
+
3310
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3311
+
3312
+ const int nth = std::min(1024, ne0);
3313
+
3314
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3315
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3316
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3317
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3318
+
3319
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3320
+
3321
+ return 1;
3322
+ }
3323
+
3324
+ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
3325
+ wsp_ggml_tensor * op = ctx->node(idx);
3326
+
3327
+ wsp_ggml_metal_library_t lib = ctx->lib;
3328
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3329
+
3330
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3331
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3332
+
3333
+ float start;
3334
+ float step;
3335
+
3336
+ memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));
3337
+ memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float));
3338
+
3339
+ wsp_ggml_metal_kargs_arange args = {
3340
+ /*.ne0 =*/ ne0,
3341
+ /*.start =*/ start,
3342
+ /*.step =*/ step
3343
+ };
3344
+
3345
+ const int nth = std::min(1024, ne0);
3346
+
3347
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
3348
+
3349
+ //[encoder setComputePipelineState:pipeline];
3350
+ //[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3351
+ //[encoder setBytes:&args length:sizeof(args) atIndex:1];
3352
+
3353
+ //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3354
+
3355
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3356
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3357
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 1);
3358
+
3359
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
3360
+
3361
+ return 1;
3362
+ }
3363
+
3364
+ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
3365
+ wsp_ggml_tensor * op = ctx->node(idx);
3366
+
3367
+ wsp_ggml_metal_library_t lib = ctx->lib;
3368
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3369
+
3370
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3371
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3372
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3373
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3374
+
3375
+ const int dim = op->op_params[0];
3376
+ const int max_period = op->op_params[1];
3377
+
3378
+ wsp_ggml_metal_kargs_timestep_embedding args = {
3379
+ /*.nb1 =*/ nb1,
3380
+ /*.dim =*/ dim,
3381
+ /*.max_period =*/ max_period,
3382
+ };
3383
+
3384
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3385
+
3386
+ const int nth = std::max(1, std::min(1024, dim/2));
3387
+
3388
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3389
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3390
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3391
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3392
+
3393
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);
3394
+
3395
+ return 1;
3396
+ }
3397
+
3398
+ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
3399
+ wsp_ggml_tensor * op = ctx->node(idx);
3400
+
3401
+ wsp_ggml_metal_library_t lib = ctx->lib;
3402
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3403
+
3404
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3405
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3406
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3407
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3408
+
3409
+ wsp_ggml_metal_kargs_argmax args = {
3410
+ /*.ne00 = */ ne00,
3411
+ /*.nb01 = */ nb01,
3412
+ };
3413
+
3414
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argmax(lib, op);
3415
+
3416
+ const int64_t nrows = wsp_ggml_nrows(op->src[0]);
3417
+
3418
+ int nth = 32; // SIMD width
3419
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3420
+ nth *= 2;
3421
+ }
3422
+
3423
+ const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
3424
+
3425
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3426
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3427
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3428
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3429
+
3430
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3431
+
3432
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3433
+
3434
+ return 1;
3435
+ }
3436
+
3437
+ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3438
+ wsp_ggml_tensor * op = ctx->node(idx);
3439
+
3440
+ wsp_ggml_metal_library_t lib = ctx->lib;
3441
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3442
+
3443
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3444
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3445
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3446
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3447
+
3448
+ // bitonic sort requires the number of elements to be power of 2
3449
+ int64_t ne00_padded = 1;
3450
+ while (ne00_padded < ne00) {
3451
+ ne00_padded *= 2;
3452
+ }
3453
+
3454
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3455
+
3456
+ const int64_t nrows = wsp_ggml_nrows(op->src[0]);
3457
+
3458
+ // Metal kernels require the buffer size to be multiple of 16 bytes
3459
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3460
+ const size_t smem = WSP_GGML_PAD(ne00_padded*sizeof(int32_t), 16);
3461
+
3462
+ wsp_ggml_metal_kargs_argsort args = {
3463
+ /*.ncols =*/ ne00,
3464
+ /*.ncols_pad =*/ ne00_padded
3465
+ };
3466
+
3467
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3468
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3469
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3470
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3471
+
3472
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3473
+
3474
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
3475
+
3476
+ return 1;
3477
+ }
3478
+
3479
+ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3480
+ wsp_ggml_tensor * op = ctx->node(idx);
3481
+
3482
+ wsp_ggml_metal_library_t lib = ctx->lib;
3483
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3484
+
3485
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3486
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3487
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3488
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3489
+
3490
+ float slope;
3491
+ memcpy(&slope, op->op_params, sizeof(float));
3492
+
3493
+ wsp_ggml_metal_kargs_leaky_relu args = {
3494
+ /*.slope =*/ slope
3495
+ };
3496
+
3497
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
3498
+
3499
+ int64_t n = wsp_ggml_nelements(op);
3500
+
3501
+ if (n % 4 == 0) {
3502
+ n /= 4;
3503
+ }
3504
+
3505
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3506
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3507
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3508
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3509
+
3510
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3511
+
3512
+ return 1;
3513
+ }
3514
+
3515
+ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
3516
+ wsp_ggml_tensor * op = ctx->node(idx);
3517
+
3518
+ wsp_ggml_metal_library_t lib = ctx->lib;
3519
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3520
+
3521
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3522
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3523
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3524
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3525
+
3526
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3527
+
3528
+ const int64_t np = wsp_ggml_nelements(op->src[0]);
3529
+ wsp_ggml_metal_kargs_opt_step_adamw args = {
3530
+ /*.np =*/ np,
3531
+ };
3532
+
3533
+ int ida = 0;
3534
+
3535
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3536
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3537
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
3538
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
3539
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
3540
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), ida++);
3541
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), ida++);
3542
+
3543
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3544
+ const int64_t n = (np + nth - 1) / nth;
3545
+
3546
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3547
+
3548
+ return 1;
3549
+ }
3550
+
3551
+ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
3552
+ wsp_ggml_tensor * op = ctx->node(idx);
3553
+
3554
+ wsp_ggml_metal_library_t lib = ctx->lib;
3555
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3556
+
3557
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3558
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3559
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3560
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3561
+
3562
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3563
+
3564
+ const int64_t np = wsp_ggml_nelements(op->src[0]);
3565
+ wsp_ggml_metal_kargs_opt_step_sgd args = {
3566
+ /*.np =*/ np,
3567
+ };
3568
+
3569
+ int ida = 0;
3570
+
3571
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3572
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3573
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
3574
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
3575
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
3576
+
3577
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3578
+ const int64_t n = (np + nth - 1) / nth;
3579
+
3580
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3581
+
3582
+ return 1;
3583
+ }