whisper.rn 0.5.4 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
  2. package/android/src/main/jni.cpp +13 -0
  3. package/cpp/ggml-alloc.c +78 -26
  4. package/cpp/ggml-alloc.h +9 -0
  5. package/cpp/ggml-backend-impl.h +1 -1
  6. package/cpp/ggml-backend-reg.cpp +19 -3
  7. package/cpp/ggml-backend.cpp +72 -20
  8. package/cpp/ggml-backend.h +2 -1
  9. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  10. package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
  11. package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
  12. package/cpp/ggml-cpu/arch-fallback.h +50 -2
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  14. package/cpp/ggml-cpu/ggml-cpu.c +139 -58
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  16. package/cpp/ggml-cpu/ops.cpp +170 -18
  17. package/cpp/ggml-cpu/ops.h +1 -0
  18. package/cpp/ggml-cpu/repack.cpp +531 -5
  19. package/cpp/ggml-cpu/repack.h +14 -0
  20. package/cpp/ggml-cpu/simd-mappings.h +16 -18
  21. package/cpp/ggml-cpu/vec.cpp +41 -1
  22. package/cpp/ggml-cpu/vec.h +241 -138
  23. package/cpp/ggml-cpu.h +1 -0
  24. package/cpp/ggml-impl.h +0 -4
  25. package/cpp/ggml-metal/ggml-metal-context.m +26 -16
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
  27. package/cpp/ggml-metal/ggml-metal-device.h +87 -65
  28. package/cpp/ggml-metal/ggml-metal-device.m +263 -104
  29. package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
  31. package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +6 -5
  33. package/cpp/ggml-metal/ggml-metal.metal +404 -34
  34. package/cpp/ggml.c +110 -31
  35. package/cpp/ggml.h +51 -12
  36. package/cpp/jsi/RNWhisperJSI.cpp +1 -0
  37. package/cpp/whisper.cpp +16 -3
  38. package/ios/CMakeLists.txt +21 -1
  39. package/ios/RNWhisperContext.mm +5 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  74. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  75. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  78. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  79. package/lib/commonjs/jest-mock.js +2 -0
  80. package/lib/commonjs/jest-mock.js.map +1 -1
  81. package/lib/commonjs/version.json +1 -1
  82. package/lib/module/NativeRNWhisper.js.map +1 -1
  83. package/lib/module/jest-mock.js +2 -0
  84. package/lib/module/jest-mock.js.map +1 -1
  85. package/lib/module/version.json +1 -1
  86. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  87. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  88. package/package.json +1 -1
  89. package/src/NativeRNWhisper.ts +1 -0
  90. package/src/jest-mock.ts +2 -0
  91. package/src/version.json +1 -1
@@ -50,14 +50,14 @@ void wsp_ggml_metal_pipelines_add(wsp_ggml_metal_pipelines_t ppls, const char *
50
50
  }
51
51
 
52
52
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_pipelines_get(wsp_ggml_metal_pipelines_t ppls, const char * name) {
53
- if (ppls->data.find(name) == ppls->data.end()) {
53
+ if (ppls->data.find(name) == ppls->data.end()) {
54
54
  return nullptr;
55
55
  }
56
56
 
57
57
  return ppls->data[name];
58
58
  }
59
59
 
60
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_base(wsp_ggml_metal_library_t lib, wsp_ggml_op op) {
60
+ struct wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_base(wsp_ggml_metal_library_t lib, wsp_ggml_op op) {
61
61
  char base[256];
62
62
  char name[256];
63
63
 
@@ -71,34 +71,30 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_base(wsp_ggml_meta
71
71
  snprintf(base, 256, "kernel_%s", op_str);
72
72
  snprintf(name, 256, "%s", base);
73
73
 
74
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
75
- if (res) {
76
- return res;
74
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
75
+ if (!res.pipeline) {
76
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
77
77
  }
78
78
 
79
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
80
-
81
79
  return res;
82
80
  }
83
81
 
84
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cpy(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc, wsp_ggml_type tdst) {
82
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_cpy(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc, wsp_ggml_type tdst) {
85
83
  char base[256];
86
84
  char name[256];
87
85
 
88
86
  snprintf(base, 256, "kernel_cpy_%s_%s", wsp_ggml_type_name(tsrc), wsp_ggml_type_name(tdst));
89
87
  snprintf(name, 256, "%s", base);
90
88
 
91
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
92
- if (res) {
93
- return res;
89
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
90
+ if (!res.pipeline) {
91
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
94
92
  }
95
93
 
96
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
97
-
98
94
  return res;
99
95
  }
100
96
 
101
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pool_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, wsp_ggml_op_pool op_pool) {
97
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_pool_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, wsp_ggml_op_pool op_pool) {
102
98
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
103
99
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32 && op->src[0]->type == op->type);
104
100
 
@@ -115,68 +111,60 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pool_2d(wsp_ggml_m
115
111
  snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, wsp_ggml_type_name(op->src[0]->type));
116
112
  snprintf(name, 256, "%s", base);
117
113
 
118
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
119
- if (res) {
120
- return res;
114
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
115
+ if (!res.pipeline) {
116
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
121
117
  }
122
118
 
123
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
124
-
125
119
  return res;
126
120
  }
127
121
 
128
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_get_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
122
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_get_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
129
123
  char base[256];
130
124
  char name[256];
131
125
 
132
126
  snprintf(base, 256, "kernel_get_rows_%s", wsp_ggml_type_name(tsrc));
133
127
  snprintf(name, 256, "%s", base);
134
128
 
135
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
136
- if (res) {
137
- return res;
129
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
130
+ if (!res.pipeline) {
131
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
138
132
  }
139
133
 
140
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
141
-
142
134
  return res;
143
135
  }
144
136
 
145
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_set_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tidx, wsp_ggml_type tdst) {
137
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_set_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tidx, wsp_ggml_type tdst) {
146
138
  char base[256];
147
139
  char name[256];
148
140
 
149
141
  snprintf(base, 256, "kernel_set_rows_%s_%s", wsp_ggml_type_name(tdst), wsp_ggml_type_name(tidx));
150
142
  snprintf(name, 256, "%s", base);
151
143
 
152
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
153
- if (res) {
154
- return res;
144
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
145
+ if (!res.pipeline) {
146
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
155
147
  }
156
148
 
157
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
158
-
159
149
  return res;
160
150
  }
161
151
 
162
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_repeat(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
152
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_repeat(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
163
153
  char base[256];
164
154
  char name[256];
165
155
 
166
156
  snprintf(base, 256, "kernel_repeat_%s", wsp_ggml_type_name(tsrc));
167
157
  snprintf(name, 256, "%s", base);
168
158
 
169
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
170
- if (res) {
171
- return res;
159
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
160
+ if (!res.pipeline) {
161
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
172
162
  }
173
163
 
174
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
175
-
176
164
  return res;
177
165
  }
178
166
 
179
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
167
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
180
168
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
181
169
 
182
170
  char base[256];
@@ -187,6 +175,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_met
187
175
  const char * op_str = "undefined";
188
176
  switch (op->op) {
189
177
  case WSP_GGML_OP_SCALE: op_str = "scale"; break;
178
+ case WSP_GGML_OP_FILL: op_str = "fill"; break;
190
179
  case WSP_GGML_OP_CLAMP: op_str = "clamp"; break;
191
180
  case WSP_GGML_OP_SQR: op_str = "sqr"; break;
192
181
  case WSP_GGML_OP_SQRT: op_str = "sqrt"; break;
@@ -211,6 +200,8 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_met
211
200
  case WSP_GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
212
201
  case WSP_GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
213
202
  case WSP_GGML_UNARY_OP_EXP: op_str = "exp"; break;
203
+ case WSP_GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
204
+ case WSP_GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
214
205
  default: WSP_GGML_ABORT("fatal error");
215
206
  } break;
216
207
  default: WSP_GGML_ABORT("fatal error");
@@ -224,17 +215,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_met
224
215
  snprintf(base, 256, "kernel_%s_%s%s", op_str, wsp_ggml_type_name(op->src[0]->type), suffix);
225
216
  snprintf(name, 256, "%s", base);
226
217
 
227
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
228
- if (res) {
229
- return res;
218
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
219
+ if (!res.pipeline) {
220
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
230
221
  }
231
222
 
232
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
233
-
234
223
  return res;
235
224
  }
236
225
 
237
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
226
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
238
227
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
239
228
 
240
229
  char base[256];
@@ -258,17 +247,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal
258
247
  snprintf(base, 256, "kernel_%s_%s", op_str, wsp_ggml_type_name(op->src[0]->type));
259
248
  snprintf(name, 256, "%s", base);
260
249
 
261
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
262
- if (res) {
263
- return res;
250
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
251
+ if (!res.pipeline) {
252
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
264
253
  }
265
254
 
266
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
267
-
268
255
  return res;
269
256
  }
270
257
 
271
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
258
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
272
259
  assert(op->op == WSP_GGML_OP_SUM);
273
260
 
274
261
  char base[256];
@@ -277,17 +264,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal
277
264
  snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
278
265
  snprintf(name, 256, "%s", base);
279
266
 
280
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
281
- if (res) {
282
- return res;
267
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
268
+ if (!res.pipeline) {
269
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
283
270
  }
284
271
 
285
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
286
-
287
272
  return res;
288
273
  }
289
274
 
290
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
275
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
291
276
  WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
292
277
 
293
278
  char base[256];
@@ -306,19 +291,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_
306
291
 
307
292
  snprintf(name, 256, "%s", base);
308
293
 
309
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
310
- if (res) {
311
- return res;
294
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
295
+ if (!res.pipeline) {
296
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
312
297
  }
313
298
 
314
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
315
-
316
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
299
+ res.smem = 32*sizeof(float);
317
300
 
318
301
  return res;
319
302
  }
320
303
 
321
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
304
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
322
305
  WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
323
306
 
324
307
  char base[256];
@@ -327,17 +310,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggm
327
310
  snprintf(base, 256, "kernel_cumsum_blk_%s", wsp_ggml_type_name(op->src[0]->type));
328
311
  snprintf(name, 256, "%s", base);
329
312
 
330
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
331
- if (res) {
332
- return res;
313
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
314
+ if (!res.pipeline) {
315
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
333
316
  }
334
317
 
335
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
336
-
337
318
  return res;
338
319
  }
339
320
 
340
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
321
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
341
322
  WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
342
323
 
343
324
  char base[256];
@@ -346,17 +327,37 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggm
346
327
  snprintf(base, 256, "kernel_cumsum_add_%s", wsp_ggml_type_name(op->src[0]->type));
347
328
  snprintf(name, 256, "%s", base);
348
329
 
349
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
350
- if (res) {
351
- return res;
330
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
331
+ if (!res.pipeline) {
332
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
352
333
  }
353
334
 
354
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
335
+ return res;
336
+ }
337
+
338
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_tri(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
339
+ WSP_GGML_ASSERT(op->op == WSP_GGML_OP_TRI);
340
+ WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
341
+
342
+ char base[256];
343
+ char name[256];
344
+
345
+ const char * op_str = "tri";
346
+ const int ttype = op->op_params[0];
347
+
348
+ snprintf(base, 256, "kernel_%s_%s_%d", op_str, wsp_ggml_type_name(op->src[0]->type), ttype);
349
+
350
+ snprintf(name, 256, "%s", base);
351
+
352
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
353
+ if (!res.pipeline) {
354
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
355
+ }
355
356
 
356
357
  return res;
357
358
  }
358
359
 
359
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
360
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
360
361
  WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
361
362
 
362
363
  char base[256];
@@ -373,19 +374,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_
373
374
  snprintf(base, 256, "kernel_soft_max_%s%s", wsp_ggml_type_name(tsrc1), suffix);
374
375
  snprintf(name, 256, "%s", base);
375
376
 
376
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
377
- if (res) {
378
- return res;
377
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
378
+ if (!res.pipeline) {
379
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
379
380
  }
380
381
 
381
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
382
-
383
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
382
+ res.smem = 32*sizeof(float);
384
383
 
385
384
  return res;
386
385
  }
387
386
 
388
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
387
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
389
388
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
390
389
  WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
391
390
 
@@ -404,17 +403,47 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
404
403
  snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
405
404
  snprintf(name, 256, "%s", base);
406
405
 
407
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
408
- if (res) {
409
- return res;
406
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
407
+ if (!res.pipeline) {
408
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
410
409
  }
411
410
 
412
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
411
+ return res;
412
+ }
413
+
414
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_ssm_conv_batched(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, int ssm_conv_bs) {
415
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
416
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
417
+
418
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
419
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
420
+
421
+ char base[256];
422
+ char name[256];
423
+
424
+ const char * suffix = "";
425
+ if (op->src[1]->ne[0] % 4 == 0) {
426
+ suffix = "_4";
427
+ }
428
+
429
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
430
+ snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
431
+
432
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
433
+ if (!res.pipeline) {
434
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
435
+
436
+ wsp_ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
437
+
438
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
439
+
440
+ wsp_ggml_metal_cv_free(cv);
441
+ }
413
442
 
414
443
  return res;
415
444
  }
416
445
 
417
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
446
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
418
447
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
419
448
 
420
449
  char base[256];
@@ -425,19 +454,22 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_
425
454
  snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
426
455
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
427
456
 
428
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
429
- if (res) {
430
- return res;
457
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
458
+ if (!res.pipeline) {
459
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
431
460
  }
432
461
 
433
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
434
-
435
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
462
+ // Shared memory layout:
463
+ // - sgptg * NW floats for partial sums (nsg * 32)
464
+ // - sgptg floats for shared_x_dt (nsg)
465
+ // - sgptg floats for shared_dA (nsg)
466
+ // Total: nsg * (32 + 2) floats
467
+ res.smem = (32 + 2)*sizeof(float)*nsg;
436
468
 
437
469
  return res;
438
470
  }
439
471
 
440
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rwkv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
472
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_rwkv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
441
473
  char base[256];
442
474
  char name[256];
443
475
 
@@ -467,41 +499,37 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rwkv(wsp_ggml_meta
467
499
 
468
500
  snprintf(name, 256, "%s", base);
469
501
 
470
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
471
- if (res) {
472
- return res;
502
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
503
+ if (!res.pipeline) {
504
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
473
505
  }
474
506
 
475
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
476
-
477
507
  return res;
478
508
  }
479
509
 
480
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_ext(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc0, wsp_ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
510
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mv_ext(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc0, wsp_ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
481
511
  char base[256];
482
512
  char name[256];
483
513
 
484
514
  snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), r1ptg);
485
515
  snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
486
516
 
487
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
488
- if (res) {
489
- return res;
490
- }
491
-
492
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
517
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
518
+ if (!res.pipeline) {
519
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
493
520
 
494
- wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
495
- wsp_ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
521
+ wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
522
+ wsp_ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
496
523
 
497
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
524
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
498
525
 
499
- wsp_ggml_metal_cv_free(cv);
526
+ wsp_ggml_metal_cv_free(cv);
527
+ }
500
528
 
501
529
  return res;
502
530
  }
503
531
 
504
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
532
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
505
533
  char base[256];
506
534
  char name[256];
507
535
 
@@ -514,27 +542,25 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm(wsp_ggml_me
514
542
  snprintf(base, 256, "kernel_mul_mm_%s_%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1));
515
543
  snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
516
544
 
517
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
518
- if (res) {
519
- return res;
520
- }
521
-
522
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
545
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
546
+ if (!res.pipeline) {
547
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
523
548
 
524
- wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
525
- wsp_ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
549
+ wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
550
+ wsp_ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
526
551
 
527
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
552
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
528
553
 
529
- wsp_ggml_metal_cv_free(cv);
554
+ wsp_ggml_metal_cv_free(cv);
555
+ }
530
556
 
531
557
  // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
532
- wsp_ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
558
+ res.smem = bc_out ? 8192 : 4096 + 2048;
533
559
 
534
560
  return res;
535
561
  }
536
562
 
537
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
563
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
538
564
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
539
565
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
540
566
 
@@ -689,49 +715,43 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv(wsp_ggml_me
689
715
  snprintf(base, 256, "kernel_mul_mv_%s_%s%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), suffix);
690
716
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
691
717
 
692
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
693
- if (res) {
694
- return res;
695
- }
696
-
697
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
718
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
719
+ if (!res.pipeline) {
720
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
698
721
 
699
- wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
722
+ wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
700
723
 
701
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
724
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
702
725
 
703
- wsp_ggml_metal_cv_free(cv);
726
+ wsp_ggml_metal_cv_free(cv);
727
+ }
704
728
 
705
- wsp_ggml_metal_pipeline_set_nr0 (res, nr0);
706
- wsp_ggml_metal_pipeline_set_nr1 (res, nr1);
707
- wsp_ggml_metal_pipeline_set_nsg (res, nsg);
708
- wsp_ggml_metal_pipeline_set_smem(res, smem);
729
+ res.nr0 = nr0;
730
+ res.nr1 = nr1;
731
+ res.nsg = nsg;
732
+ res.smem = smem;
709
733
 
710
734
  return res;
711
735
  }
712
736
 
713
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp_ggml_metal_library_t lib, int ne02, int ne20) {
737
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp_ggml_metal_library_t lib, int ne02, int ne20) {
714
738
  char base[256];
715
739
  char name[256];
716
740
 
717
741
  snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
718
742
  snprintf(name, 256, "%s_ne02=%d", base, ne02);
719
743
 
720
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
721
- if (res) {
722
- return res;
744
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
745
+ if (!res.pipeline) {
746
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
723
747
  }
724
748
 
725
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
726
-
727
- const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
728
-
729
- wsp_ggml_metal_pipeline_set_smem(res, smem);
749
+ res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
730
750
 
731
751
  return res;
732
752
  }
733
753
 
734
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
754
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mm_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
735
755
  char base[256];
736
756
  char name[256];
737
757
 
@@ -743,25 +763,23 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id(wsp_ggml
743
763
  snprintf(base, 256, "kernel_mul_mm_id_%s_%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1));
744
764
  snprintf(name, 256, "%s_bci=%d", base, bc_inp);
745
765
 
746
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
747
- if (res) {
748
- return res;
749
- }
766
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
767
+ if (!res.pipeline) {
768
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
750
769
 
751
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
770
+ wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
752
771
 
753
- wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
772
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
754
773
 
755
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
756
-
757
- wsp_ggml_metal_cv_free(cv);
774
+ wsp_ggml_metal_cv_free(cv);
775
+ }
758
776
 
759
- wsp_ggml_metal_pipeline_set_smem(res, 8192);
777
+ res.smem = 8192;
760
778
 
761
779
  return res;
762
780
  }
763
781
 
764
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
782
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mv_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
765
783
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
766
784
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
767
785
 
@@ -909,28 +927,26 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id(wsp_ggml
909
927
  snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), suffix);
910
928
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
911
929
 
912
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
913
- if (res) {
914
- return res;
915
- }
930
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
931
+ if (!res.pipeline) {
932
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
916
933
 
917
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
934
+ wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
918
935
 
919
- wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
936
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
920
937
 
921
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
922
-
923
- wsp_ggml_metal_cv_free(cv);
938
+ wsp_ggml_metal_cv_free(cv);
939
+ }
924
940
 
925
- wsp_ggml_metal_pipeline_set_nr0 (res, nr0);
926
- wsp_ggml_metal_pipeline_set_nr1 (res, nr1);
927
- wsp_ggml_metal_pipeline_set_nsg (res, nsg);
928
- wsp_ggml_metal_pipeline_set_smem(res, smem);
941
+ res.nr0 = nr0;
942
+ res.nr1 = nr1;
943
+ res.nsg = nsg;
944
+ res.smem = smem;
929
945
 
930
946
  return res;
931
947
  }
932
948
 
933
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
949
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_argmax(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
934
950
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
935
951
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
936
952
  WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
@@ -941,19 +957,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax(wsp_ggml_me
941
957
  snprintf(base, 256, "kernel_argmax_%s", wsp_ggml_type_name(op->src[0]->type));
942
958
  snprintf(name, 256, "%s", base);
943
959
 
944
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
945
- if (res) {
946
- return res;
960
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
961
+ if (!res.pipeline) {
962
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
947
963
  }
948
964
 
949
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
950
-
951
- wsp_ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t)));
965
+ res.smem = 32*(sizeof(float) + sizeof(int32_t));
952
966
 
953
967
  return res;
954
968
  }
955
969
 
956
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
970
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
957
971
  assert(op->op == WSP_GGML_OP_ARGSORT);
958
972
 
959
973
  char base[256];
@@ -971,17 +985,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
971
985
  snprintf(base, 256, "kernel_argsort_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
972
986
  snprintf(name, 256, "%s", base);
973
987
 
974
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
975
- if (res) {
976
- return res;
988
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
989
+ if (!res.pipeline) {
990
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
977
991
  }
978
992
 
979
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
980
-
981
993
  return res;
982
994
  }
983
995
 
984
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
996
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
985
997
  assert(op->op == WSP_GGML_OP_ARGSORT);
986
998
 
987
999
  char base[256];
@@ -999,17 +1011,69 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_
999
1011
  snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
1000
1012
  snprintf(name, 256, "%s", base);
1001
1013
 
1002
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1003
- if (res) {
1004
- return res;
1014
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1015
+ if (!res.pipeline) {
1016
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1005
1017
  }
1006
1018
 
1007
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1019
+ return res;
1020
+ }
1021
+
1022
+ // note: reuse the argsort kernel for top_k
1023
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_top_k(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1024
+ assert(op->op == WSP_GGML_OP_TOP_K);
1025
+
1026
+ char base[256];
1027
+ char name[256];
1028
+
1029
+ // note: the top_k kernel is always descending order
1030
+ wsp_ggml_sort_order order = WSP_GGML_SORT_ORDER_DESC;
1031
+
1032
+ const char * order_str = "undefined";
1033
+ switch (order) {
1034
+ case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1035
+ case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1036
+ default: WSP_GGML_ABORT("fatal error");
1037
+ };
1038
+
1039
+ snprintf(base, 256, "kernel_argsort_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
1040
+ snprintf(name, 256, "%s", base);
1041
+
1042
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1043
+ if (!res.pipeline) {
1044
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1045
+ }
1008
1046
 
1009
1047
  return res;
1010
1048
  }
1011
1049
 
1012
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1050
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_top_k_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1051
+ assert(op->op == WSP_GGML_OP_TOP_K);
1052
+
1053
+ char base[256];
1054
+ char name[256];
1055
+
1056
+ wsp_ggml_sort_order order = WSP_GGML_SORT_ORDER_DESC;
1057
+
1058
+ const char * order_str = "undefined";
1059
+ switch (order) {
1060
+ case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1061
+ case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1062
+ default: WSP_GGML_ABORT("fatal error");
1063
+ };
1064
+
1065
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
1066
+ snprintf(name, 256, "%s", base);
1067
+
1068
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1069
+ if (!res.pipeline) {
1070
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1071
+ }
1072
+
1073
+ return res;
1074
+ }
1075
+
1076
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1013
1077
  wsp_ggml_metal_library_t lib,
1014
1078
  const struct wsp_ggml_tensor * op,
1015
1079
  bool has_mask,
@@ -1028,33 +1092,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad
1028
1092
  has_mask,
1029
1093
  ncpsg);
1030
1094
 
1031
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1032
- if (res) {
1033
- return res;
1034
- }
1035
-
1036
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1095
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1096
+ if (!res.pipeline) {
1097
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1037
1098
 
1038
- wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1039
- //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1040
- //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1041
- //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1099
+ wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1100
+ //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1101
+ //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1102
+ //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1042
1103
 
1043
- //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1044
- //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1045
- //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1046
- //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1047
- //wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1048
- wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1104
+ //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1105
+ //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1106
+ //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1107
+ //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1108
+ //wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1109
+ wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1049
1110
 
1050
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1111
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1051
1112
 
1052
- wsp_ggml_metal_cv_free(cv);
1113
+ wsp_ggml_metal_cv_free(cv);
1114
+ }
1053
1115
 
1054
1116
  return res;
1055
1117
  }
1056
1118
 
1057
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1119
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1058
1120
  wsp_ggml_metal_library_t lib,
1059
1121
  const struct wsp_ggml_tensor * op,
1060
1122
  int32_t nqptg,
@@ -1073,33 +1135,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk
1073
1135
  nqptg,
1074
1136
  ncpsg);
1075
1137
 
1076
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1077
- if (res) {
1078
- return res;
1079
- }
1080
-
1081
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1138
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1139
+ if (!res.pipeline) {
1140
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1082
1141
 
1083
- //wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1084
- //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1085
- //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1086
- //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1142
+ //wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1143
+ //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1144
+ //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1145
+ //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1087
1146
 
1088
- //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1089
- //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1090
- //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1091
- //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1092
- wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1093
- wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1147
+ //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1148
+ //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1149
+ //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1150
+ //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1151
+ wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1152
+ wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1094
1153
 
1095
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1154
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1096
1155
 
1097
- wsp_ggml_metal_cv_free(cv);
1156
+ wsp_ggml_metal_cv_free(cv);
1157
+ }
1098
1158
 
1099
1159
  return res;
1100
1160
  }
1101
1161
 
1102
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
1162
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
1103
1163
  wsp_ggml_metal_library_t lib,
1104
1164
  const wsp_ggml_tensor * op,
1105
1165
  bool has_mask,
@@ -1140,33 +1200,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
1140
1200
  ns20,
1141
1201
  nsg);
1142
1202
 
1143
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1144
- if (res) {
1145
- return res;
1146
- }
1147
-
1148
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1203
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1204
+ if (!res.pipeline) {
1205
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1149
1206
 
1150
- wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
1151
- wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1152
- wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
1153
- wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1154
- wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1207
+ wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
1208
+ wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1209
+ wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
1210
+ wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1211
+ wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1155
1212
 
1156
- wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
1213
+ wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
1157
1214
 
1158
- wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1159
- wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1160
- wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
1215
+ wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1216
+ wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1217
+ wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
1161
1218
 
1162
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1219
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1163
1220
 
1164
- wsp_ggml_metal_cv_free(cv);
1221
+ wsp_ggml_metal_cv_free(cv);
1222
+ }
1165
1223
 
1166
1224
  return res;
1167
1225
  }
1168
1226
 
1169
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1227
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1170
1228
  wsp_ggml_metal_library_t lib,
1171
1229
  const wsp_ggml_tensor * op,
1172
1230
  bool has_mask,
@@ -1204,32 +1262,30 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
1204
1262
  ns20,
1205
1263
  nsg, nwg);
1206
1264
 
1207
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1208
- if (res) {
1209
- return res;
1210
- }
1211
-
1212
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1265
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1266
+ if (!res.pipeline) {
1267
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1213
1268
 
1214
- wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1215
- wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1216
- wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1217
- wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1218
- wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1269
+ wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1270
+ wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1271
+ wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1272
+ wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1273
+ wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1219
1274
 
1220
- wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1221
- wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1222
- wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1223
- wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1275
+ wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1276
+ wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1277
+ wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1278
+ wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1224
1279
 
1225
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1280
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1226
1281
 
1227
- wsp_ggml_metal_cv_free(cv);
1282
+ wsp_ggml_metal_cv_free(cv);
1283
+ }
1228
1284
 
1229
1285
  return res;
1230
1286
  }
1231
1287
 
1232
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1288
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1233
1289
  wsp_ggml_metal_library_t lib,
1234
1290
  const wsp_ggml_tensor * op,
1235
1291
  int32_t dv,
@@ -1242,26 +1298,24 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
1242
1298
  snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1243
1299
  snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
1244
1300
 
1245
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1246
- if (res) {
1247
- return res;
1248
- }
1249
-
1250
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1301
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1302
+ if (!res.pipeline) {
1303
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1251
1304
 
1252
- wsp_ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1253
- wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1305
+ wsp_ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1306
+ wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1254
1307
 
1255
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1308
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1256
1309
 
1257
- wsp_ggml_metal_cv_free(cv);
1310
+ wsp_ggml_metal_cv_free(cv);
1311
+ }
1258
1312
 
1259
1313
  return res;
1260
1314
 
1261
1315
  WSP_GGML_UNUSED(op);
1262
1316
  }
1263
1317
 
1264
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin(
1318
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_bin(
1265
1319
  wsp_ggml_metal_library_t lib,
1266
1320
  wsp_ggml_op op,
1267
1321
  int32_t n_fuse,
@@ -1286,17 +1340,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin(
1286
1340
 
1287
1341
  snprintf(name, 256, "%s", base);
1288
1342
 
1289
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1290
- if (res) {
1291
- return res;
1343
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1344
+ if (!res.pipeline) {
1345
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1292
1346
  }
1293
1347
 
1294
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1295
-
1296
1348
  return res;
1297
1349
  }
1298
1350
 
1299
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1351
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_l2_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1300
1352
  assert(op->op == WSP_GGML_OP_L2_NORM);
1301
1353
 
1302
1354
  WSP_GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
@@ -1308,19 +1360,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm(wsp_ggml_m
1308
1360
  snprintf(base, 256, "kernel_l2_norm_f32");
1309
1361
  snprintf(name, 256, "%s", base);
1310
1362
 
1311
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1312
- if (res) {
1313
- return res;
1363
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1364
+ if (!res.pipeline) {
1365
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1314
1366
  }
1315
1367
 
1316
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1317
-
1318
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1368
+ res.smem = 32*sizeof(float);
1319
1369
 
1320
1370
  return res;
1321
1371
  }
1322
1372
 
1323
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1373
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_group_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1324
1374
  assert(op->op == WSP_GGML_OP_GROUP_NORM);
1325
1375
 
1326
1376
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
@@ -1331,19 +1381,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm(wsp_ggm
1331
1381
  snprintf(base, 256, "kernel_group_norm_f32");
1332
1382
  snprintf(name, 256, "%s", base);
1333
1383
 
1334
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1335
- if (res) {
1336
- return res;
1384
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1385
+ if (!res.pipeline) {
1386
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1337
1387
  }
1338
1388
 
1339
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1340
-
1341
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1389
+ res.smem = 32*sizeof(float);
1342
1390
 
1343
1391
  return res;
1344
1392
  }
1345
1393
 
1346
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, int n_fuse) {
1394
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, int n_fuse) {
1347
1395
  assert(op->op == WSP_GGML_OP_NORM || op->op == WSP_GGML_OP_RMS_NORM);
1348
1396
 
1349
1397
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
@@ -1376,19 +1424,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm(wsp_ggml_meta
1376
1424
 
1377
1425
  snprintf(name, 256, "%s", base);
1378
1426
 
1379
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1380
- if (res) {
1381
- return res;
1427
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1428
+ if (!res.pipeline) {
1429
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1382
1430
  }
1383
1431
 
1384
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1385
-
1386
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1432
+ res.smem = 32*sizeof(float);
1387
1433
 
1388
1434
  return res;
1389
1435
  }
1390
1436
 
1391
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1437
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1392
1438
  assert(op->op == WSP_GGML_OP_ROPE);
1393
1439
 
1394
1440
  char base[256];
@@ -1415,23 +1461,21 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
1415
1461
 
1416
1462
  snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1417
1463
 
1418
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1419
- if (res) {
1420
- return res;
1421
- }
1464
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1465
+ if (!res.pipeline) {
1466
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1422
1467
 
1423
- wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1468
+ wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1424
1469
 
1425
- wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1470
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1426
1471
 
1427
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1428
-
1429
- wsp_ggml_metal_cv_free(cv);
1472
+ wsp_ggml_metal_cv_free(cv);
1473
+ }
1430
1474
 
1431
1475
  return res;
1432
1476
  }
1433
1477
 
1434
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1478
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_im2col(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1435
1479
  assert(op->op == WSP_GGML_OP_IM2COL);
1436
1480
 
1437
1481
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
@@ -1444,17 +1488,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col(wsp_ggml_me
1444
1488
  snprintf(base, 256, "kernel_im2col_%s", wsp_ggml_type_name(op->type));
1445
1489
  snprintf(name, 256, "%s", base);
1446
1490
 
1447
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1448
- if (res) {
1449
- return res;
1491
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1492
+ if (!res.pipeline) {
1493
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1450
1494
  }
1451
1495
 
1452
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1453
-
1454
1496
  return res;
1455
1497
  }
1456
1498
 
1457
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1499
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1458
1500
  assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_1D);
1459
1501
 
1460
1502
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
@@ -1469,17 +1511,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(
1469
1511
  snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1470
1512
  snprintf(name, 256, "%s", base);
1471
1513
 
1472
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1473
- if (res) {
1474
- return res;
1514
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1515
+ if (!res.pipeline) {
1516
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1475
1517
  }
1476
1518
 
1477
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1478
-
1479
1519
  return res;
1480
1520
  }
1481
1521
 
1482
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1522
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1483
1523
  assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
1484
1524
 
1485
1525
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
@@ -1494,17 +1534,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(
1494
1534
  snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1495
1535
  snprintf(name, 256, "%s", base);
1496
1536
 
1497
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1498
- if (res) {
1499
- return res;
1537
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1538
+ if (!res.pipeline) {
1539
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1500
1540
  }
1501
1541
 
1502
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1503
-
1504
1542
  return res;
1505
1543
  }
1506
1544
 
1507
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1545
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1508
1546
  assert(op->op == WSP_GGML_OP_CONV_2D);
1509
1547
 
1510
1548
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
@@ -1518,17 +1556,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_m
1518
1556
  snprintf(base, 256, "kernel_conv_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1519
1557
  snprintf(name, 256, "%s", base);
1520
1558
 
1521
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1522
- if (res) {
1523
- return res;
1559
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1560
+ if (!res.pipeline) {
1561
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1524
1562
  }
1525
1563
 
1526
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1527
-
1528
1564
  return res;
1529
1565
  }
1530
1566
 
1531
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1567
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1532
1568
  assert(op->op == WSP_GGML_OP_UPSCALE);
1533
1569
 
1534
1570
  char base[256];
@@ -1537,17 +1573,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_m
1537
1573
  snprintf(base, 256, "kernel_upscale_%s", wsp_ggml_type_name(op->src[0]->type));
1538
1574
  snprintf(name, 256, "%s", base);
1539
1575
 
1540
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1541
- if (res) {
1542
- return res;
1576
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1577
+ if (!res.pipeline) {
1578
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1543
1579
  }
1544
1580
 
1545
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1546
-
1547
1581
  return res;
1548
1582
  }
1549
1583
 
1550
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1584
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1551
1585
  assert(op->op == WSP_GGML_OP_PAD);
1552
1586
 
1553
1587
  char base[256];
@@ -1556,8 +1590,8 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal
1556
1590
  snprintf(base, 256, "kernel_pad_%s", wsp_ggml_type_name(op->src[0]->type));
1557
1591
  snprintf(name, 256, "%s", base);
1558
1592
 
1559
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1560
- if (res) {
1593
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1594
+ if (res.pipeline) {
1561
1595
  return res;
1562
1596
  }
1563
1597
 
@@ -1566,7 +1600,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal
1566
1600
  return res;
1567
1601
  }
1568
1602
 
1569
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1603
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1570
1604
  assert(op->op == WSP_GGML_OP_PAD_REFLECT_1D);
1571
1605
 
1572
1606
  char base[256];
@@ -1575,17 +1609,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(wsp
1575
1609
  snprintf(base, 256, "kernel_pad_reflect_1d_%s", wsp_ggml_type_name(op->src[0]->type));
1576
1610
  snprintf(name, 256, "%s", base);
1577
1611
 
1578
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1579
- if (res) {
1580
- return res;
1612
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1613
+ if (!res.pipeline) {
1614
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1581
1615
  }
1582
1616
 
1583
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1584
-
1585
1617
  return res;
1586
1618
  }
1587
1619
 
1588
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1620
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_arange(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1589
1621
  assert(op->op == WSP_GGML_OP_ARANGE);
1590
1622
 
1591
1623
  char base[256];
@@ -1594,17 +1626,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange(wsp_ggml_me
1594
1626
  snprintf(base, 256, "kernel_arange_%s", wsp_ggml_type_name(op->type));
1595
1627
  snprintf(name, 256, "%s", base);
1596
1628
 
1597
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1598
- if (res) {
1599
- return res;
1629
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1630
+ if (!res.pipeline) {
1631
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1600
1632
  }
1601
1633
 
1602
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1603
-
1604
1634
  return res;
1605
1635
  }
1606
1636
 
1607
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1637
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1608
1638
  assert(op->op == WSP_GGML_OP_TIMESTEP_EMBEDDING);
1609
1639
 
1610
1640
  char base[256];
@@ -1613,17 +1643,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding
1613
1643
  snprintf(base, 256, "kernel_timestep_embedding_%s", wsp_ggml_type_name(op->src[0]->type));
1614
1644
  snprintf(name, 256, "%s", base);
1615
1645
 
1616
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1617
- if (res) {
1618
- return res;
1646
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1647
+ if (!res.pipeline) {
1648
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1619
1649
  }
1620
1650
 
1621
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1622
-
1623
1651
  return res;
1624
1652
  }
1625
1653
 
1626
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1654
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1627
1655
  assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
1628
1656
 
1629
1657
  char base[256];
@@ -1632,17 +1660,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp
1632
1660
  snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
1633
1661
  snprintf(name, 256, "%s", base);
1634
1662
 
1635
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1636
- if (res) {
1637
- return res;
1663
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1664
+ if (!res.pipeline) {
1665
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1638
1666
  }
1639
1667
 
1640
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1641
-
1642
1668
  return res;
1643
1669
  }
1644
1670
 
1645
- wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1671
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1646
1672
  assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
1647
1673
 
1648
1674
  char base[256];
@@ -1651,12 +1677,67 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_g
1651
1677
  snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
1652
1678
  snprintf(name, 256, "%s", base);
1653
1679
 
1654
- wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1655
- if (res) {
1656
- return res;
1680
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1681
+ if (!res.pipeline) {
1682
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1657
1683
  }
1658
1684
 
1659
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1685
+ return res;
1686
+ }
1687
+
1688
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_memset(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1689
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_I64);
1690
+
1691
+ char base[256];
1692
+ char name[256];
1693
+
1694
+ snprintf(base, 256, "kernel_memset_%s", wsp_ggml_type_name(op->type));
1695
+ snprintf(name, 256, "%s", base);
1696
+
1697
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1698
+ if (!res.pipeline) {
1699
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1700
+ }
1701
+
1702
+ return res;
1703
+ }
1704
+
1705
+ wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_count_equal(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1706
+ assert(op->op == WSP_GGML_OP_COUNT_EQUAL);
1707
+
1708
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1709
+
1710
+ WSP_GGML_ASSERT(op->src[0]->type == op->src[1]->type);
1711
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_I32);
1712
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_I64);
1713
+
1714
+ // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
1715
+ WSP_GGML_ASSERT(wsp_ggml_nelements(op->src[0]) < (1LL << 31));
1716
+
1717
+ char base[256];
1718
+ char name[256];
1719
+
1720
+ int nsg = 1;
1721
+ while (32*nsg < ne00 && nsg < 32) {
1722
+ nsg *= 2;
1723
+ }
1724
+
1725
+ snprintf(base, 256, "kernel_count_equal_%s", wsp_ggml_type_name(op->src[0]->type));
1726
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
1727
+
1728
+ wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
1729
+ if (!res.pipeline) {
1730
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1731
+
1732
+ wsp_ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
1733
+
1734
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1735
+
1736
+ wsp_ggml_metal_cv_free(cv);
1737
+ }
1738
+
1739
+ res.smem = 32 * sizeof(int32_t);
1740
+ res.nsg = nsg;
1660
1741
 
1661
1742
  return res;
1662
1743
  }