whisper.rn 0.5.0-rc.9 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/cpp/ggml-alloc.c +265 -141
- package/cpp/ggml-backend-impl.h +4 -1
- package/cpp/ggml-backend-reg.cpp +30 -13
- package/cpp/ggml-backend.cpp +221 -38
- package/cpp/ggml-backend.h +17 -1
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/amx/amx.cpp +4 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +32 -2
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
- package/cpp/ggml-cpu/ggml-cpu.c +70 -42
- package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
- package/cpp/ggml-cpu/ops.cpp +1587 -1177
- package/cpp/ggml-cpu/ops.h +5 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +89 -60
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +170 -26
- package/cpp/ggml-cpu/vec.h +506 -63
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +119 -9
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
- package/cpp/ggml-metal/ggml-metal-device.h +226 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +90 -51
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +111 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml.c +486 -98
- package/cpp/ggml.h +221 -16
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +25 -6
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/whisper.cpp +100 -76
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +16 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
- package/src/realtime-transcription/types.ts +6 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6284
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
|
Binary file
|
|
Binary file
|
package/cpp/ggml-metal-impl.h
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
#ifndef
|
|
2
|
-
#define
|
|
1
|
+
#ifndef WSP_WSP_WSP_GGML_METAL_IMPL
|
|
2
|
+
#define WSP_WSP_WSP_GGML_METAL_IMPL
|
|
3
3
|
|
|
4
4
|
// kernel parameters for mat-vec threadgroups
|
|
5
5
|
//
|
|
@@ -23,6 +23,9 @@
|
|
|
23
23
|
#define N_R0_Q8_0 4
|
|
24
24
|
#define N_SG_Q8_0 2
|
|
25
25
|
|
|
26
|
+
#define N_R0_MXFP4 2
|
|
27
|
+
#define N_SG_MXFP4 2
|
|
28
|
+
|
|
26
29
|
#define N_R0_Q2_K 4
|
|
27
30
|
#define N_SG_Q2_K 2
|
|
28
31
|
|
|
@@ -98,7 +101,7 @@ typedef struct {
|
|
|
98
101
|
uint64_t nb2;
|
|
99
102
|
uint64_t nb3;
|
|
100
103
|
int32_t dim;
|
|
101
|
-
}
|
|
104
|
+
} wsp_wsp_wsp_ggml_metal_kargs_concat;
|
|
102
105
|
|
|
103
106
|
typedef struct {
|
|
104
107
|
int32_t ne00;
|
|
@@ -126,7 +129,17 @@ typedef struct {
|
|
|
126
129
|
uint64_t nb2;
|
|
127
130
|
uint64_t nb3;
|
|
128
131
|
uint64_t offs;
|
|
129
|
-
|
|
132
|
+
uint64_t o1[8];
|
|
133
|
+
} wsp_wsp_wsp_ggml_metal_kargs_bin;
|
|
134
|
+
|
|
135
|
+
typedef struct {
|
|
136
|
+
int64_t ne0;
|
|
137
|
+
int64_t ne1;
|
|
138
|
+
size_t nb01;
|
|
139
|
+
size_t nb02;
|
|
140
|
+
size_t nb11;
|
|
141
|
+
size_t nb21;
|
|
142
|
+
} wsp_wsp_wsp_ggml_metal_kargs_add_id;
|
|
130
143
|
|
|
131
144
|
typedef struct {
|
|
132
145
|
int32_t ne00;
|
|
@@ -145,7 +158,7 @@ typedef struct {
|
|
|
145
158
|
uint64_t nb1;
|
|
146
159
|
uint64_t nb2;
|
|
147
160
|
uint64_t nb3;
|
|
148
|
-
}
|
|
161
|
+
} wsp_wsp_wsp_ggml_metal_kargs_repeat;
|
|
149
162
|
|
|
150
163
|
typedef struct {
|
|
151
164
|
int64_t ne00;
|
|
@@ -164,7 +177,7 @@ typedef struct {
|
|
|
164
177
|
uint64_t nb1;
|
|
165
178
|
uint64_t nb2;
|
|
166
179
|
uint64_t nb3;
|
|
167
|
-
}
|
|
180
|
+
} wsp_wsp_wsp_ggml_metal_kargs_cpy;
|
|
168
181
|
|
|
169
182
|
typedef struct {
|
|
170
183
|
int64_t ne10;
|
|
@@ -179,7 +192,7 @@ typedef struct {
|
|
|
179
192
|
uint64_t nb3;
|
|
180
193
|
uint64_t offs;
|
|
181
194
|
bool inplace;
|
|
182
|
-
}
|
|
195
|
+
} wsp_wsp_wsp_ggml_metal_kargs_set;
|
|
183
196
|
|
|
184
197
|
typedef struct {
|
|
185
198
|
int32_t ne00;
|
|
@@ -211,7 +224,7 @@ typedef struct {
|
|
|
211
224
|
int32_t sect_1;
|
|
212
225
|
int32_t sect_2;
|
|
213
226
|
int32_t sect_3;
|
|
214
|
-
}
|
|
227
|
+
} wsp_wsp_wsp_ggml_metal_kargs_rope;
|
|
215
228
|
|
|
216
229
|
typedef struct {
|
|
217
230
|
int32_t ne01;
|
|
@@ -229,16 +242,20 @@ typedef struct {
|
|
|
229
242
|
uint64_t nb21;
|
|
230
243
|
uint64_t nb22;
|
|
231
244
|
uint64_t nb23;
|
|
245
|
+
int32_t ne32;
|
|
246
|
+
int32_t ne33;
|
|
232
247
|
uint64_t nb31;
|
|
248
|
+
uint64_t nb32;
|
|
249
|
+
uint64_t nb33;
|
|
233
250
|
int32_t ne1;
|
|
234
251
|
int32_t ne2;
|
|
235
252
|
float scale;
|
|
236
253
|
float max_bias;
|
|
237
254
|
float m0;
|
|
238
255
|
float m1;
|
|
239
|
-
|
|
256
|
+
int32_t n_head_log2;
|
|
240
257
|
float logit_softcap;
|
|
241
|
-
}
|
|
258
|
+
} wsp_wsp_wsp_ggml_metal_kargs_flash_attn_ext;
|
|
242
259
|
|
|
243
260
|
typedef struct {
|
|
244
261
|
int32_t ne00;
|
|
@@ -255,7 +272,7 @@ typedef struct {
|
|
|
255
272
|
int32_t ne1;
|
|
256
273
|
int16_t r2;
|
|
257
274
|
int16_t r3;
|
|
258
|
-
}
|
|
275
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm;
|
|
259
276
|
|
|
260
277
|
typedef struct {
|
|
261
278
|
int32_t ne00;
|
|
@@ -276,7 +293,7 @@ typedef struct {
|
|
|
276
293
|
int32_t ne1;
|
|
277
294
|
int16_t r2;
|
|
278
295
|
int16_t r3;
|
|
279
|
-
}
|
|
296
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mv;
|
|
280
297
|
|
|
281
298
|
typedef struct {
|
|
282
299
|
int32_t ne00;
|
|
@@ -300,7 +317,7 @@ typedef struct {
|
|
|
300
317
|
int16_t nsg;
|
|
301
318
|
int16_t nxpsg;
|
|
302
319
|
int16_t r1ptg;
|
|
303
|
-
}
|
|
320
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mv_ext;
|
|
304
321
|
|
|
305
322
|
typedef struct {
|
|
306
323
|
int32_t ne10;
|
|
@@ -311,7 +328,7 @@ typedef struct {
|
|
|
311
328
|
uint64_t nbh11;
|
|
312
329
|
int32_t ne20; // n_expert_used
|
|
313
330
|
uint64_t nb21;
|
|
314
|
-
}
|
|
331
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map0;
|
|
315
332
|
|
|
316
333
|
typedef struct {
|
|
317
334
|
int32_t ne20; // n_expert_used
|
|
@@ -322,7 +339,7 @@ typedef struct {
|
|
|
322
339
|
int32_t ne0;
|
|
323
340
|
uint64_t nb1;
|
|
324
341
|
uint64_t nb2;
|
|
325
|
-
}
|
|
342
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map1;
|
|
326
343
|
|
|
327
344
|
typedef struct {
|
|
328
345
|
int32_t ne00;
|
|
@@ -339,7 +356,7 @@ typedef struct {
|
|
|
339
356
|
int32_t neh1;
|
|
340
357
|
int16_t r2;
|
|
341
358
|
int16_t r3;
|
|
342
|
-
}
|
|
359
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id;
|
|
343
360
|
|
|
344
361
|
typedef struct {
|
|
345
362
|
int32_t nei0;
|
|
@@ -361,28 +378,36 @@ typedef struct {
|
|
|
361
378
|
int32_t ne0;
|
|
362
379
|
int32_t ne1;
|
|
363
380
|
uint64_t nb1;
|
|
364
|
-
}
|
|
381
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mv_id;
|
|
365
382
|
|
|
366
383
|
typedef struct {
|
|
367
384
|
int32_t ne00;
|
|
368
385
|
int32_t ne00_4;
|
|
369
386
|
uint64_t nb01;
|
|
370
387
|
float eps;
|
|
371
|
-
}
|
|
388
|
+
} wsp_wsp_wsp_ggml_metal_kargs_norm;
|
|
372
389
|
|
|
373
390
|
typedef struct {
|
|
374
391
|
int32_t ne00;
|
|
375
392
|
int32_t ne00_4;
|
|
376
|
-
uint64_t
|
|
393
|
+
uint64_t nb1;
|
|
394
|
+
uint64_t nb2;
|
|
395
|
+
uint64_t nb3;
|
|
377
396
|
float eps;
|
|
378
|
-
|
|
397
|
+
int32_t nef1[3];
|
|
398
|
+
int32_t nef2[3];
|
|
399
|
+
int32_t nef3[3];
|
|
400
|
+
uint64_t nbf1[3];
|
|
401
|
+
uint64_t nbf2[3];
|
|
402
|
+
uint64_t nbf3[3];
|
|
403
|
+
} wsp_wsp_wsp_ggml_metal_kargs_rms_norm;
|
|
379
404
|
|
|
380
405
|
typedef struct {
|
|
381
406
|
int32_t ne00;
|
|
382
407
|
int32_t ne00_4;
|
|
383
408
|
uint64_t nb01;
|
|
384
409
|
float eps;
|
|
385
|
-
}
|
|
410
|
+
} wsp_wsp_wsp_ggml_metal_kargs_l2_norm;
|
|
386
411
|
|
|
387
412
|
typedef struct {
|
|
388
413
|
int64_t ne00;
|
|
@@ -393,7 +418,7 @@ typedef struct {
|
|
|
393
418
|
uint64_t nb02;
|
|
394
419
|
int32_t n_groups;
|
|
395
420
|
float eps;
|
|
396
|
-
}
|
|
421
|
+
} wsp_wsp_wsp_ggml_metal_kargs_group_norm;
|
|
397
422
|
|
|
398
423
|
typedef struct {
|
|
399
424
|
int32_t IC;
|
|
@@ -402,7 +427,7 @@ typedef struct {
|
|
|
402
427
|
int32_t s0;
|
|
403
428
|
uint64_t nb0;
|
|
404
429
|
uint64_t nb1;
|
|
405
|
-
}
|
|
430
|
+
} wsp_wsp_wsp_ggml_metal_kargs_conv_transpose_1d;
|
|
406
431
|
|
|
407
432
|
typedef struct {
|
|
408
433
|
uint64_t ofs0;
|
|
@@ -420,7 +445,7 @@ typedef struct {
|
|
|
420
445
|
int32_t KH;
|
|
421
446
|
int32_t KW;
|
|
422
447
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
|
423
|
-
}
|
|
448
|
+
} wsp_wsp_wsp_ggml_metal_kargs_im2col;
|
|
424
449
|
|
|
425
450
|
typedef struct{
|
|
426
451
|
int32_t ne00;
|
|
@@ -431,7 +456,9 @@ typedef struct{
|
|
|
431
456
|
uint64_t nb1;
|
|
432
457
|
int32_t i00;
|
|
433
458
|
int32_t i10;
|
|
434
|
-
|
|
459
|
+
float alpha;
|
|
460
|
+
float limit;
|
|
461
|
+
} wsp_wsp_wsp_ggml_metal_kargs_glu;
|
|
435
462
|
|
|
436
463
|
typedef struct {
|
|
437
464
|
int64_t ne00;
|
|
@@ -458,24 +485,36 @@ typedef struct {
|
|
|
458
485
|
uint64_t nb1;
|
|
459
486
|
uint64_t nb2;
|
|
460
487
|
uint64_t nb3;
|
|
461
|
-
}
|
|
488
|
+
} wsp_wsp_wsp_ggml_metal_kargs_sum_rows;
|
|
462
489
|
|
|
463
490
|
typedef struct {
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
491
|
+
int32_t ne00;
|
|
492
|
+
int32_t ne01;
|
|
493
|
+
int32_t ne02;
|
|
494
|
+
uint64_t nb01;
|
|
495
|
+
uint64_t nb02;
|
|
496
|
+
uint64_t nb03;
|
|
497
|
+
int32_t ne11;
|
|
498
|
+
int32_t ne12;
|
|
499
|
+
int32_t ne13;
|
|
500
|
+
uint64_t nb11;
|
|
501
|
+
uint64_t nb12;
|
|
502
|
+
uint64_t nb13;
|
|
503
|
+
uint64_t nb1;
|
|
504
|
+
uint64_t nb2;
|
|
505
|
+
uint64_t nb3;
|
|
467
506
|
float scale;
|
|
468
507
|
float max_bias;
|
|
469
508
|
float m0;
|
|
470
509
|
float m1;
|
|
471
|
-
|
|
472
|
-
}
|
|
510
|
+
int32_t n_head_log2;
|
|
511
|
+
} wsp_wsp_wsp_ggml_metal_kargs_soft_max;
|
|
473
512
|
|
|
474
513
|
typedef struct {
|
|
475
514
|
int64_t ne00;
|
|
476
515
|
int64_t ne01;
|
|
477
516
|
int n_past;
|
|
478
|
-
}
|
|
517
|
+
} wsp_wsp_wsp_ggml_metal_kargs_diag_mask_inf;
|
|
479
518
|
|
|
480
519
|
typedef struct {
|
|
481
520
|
int64_t ne00;
|
|
@@ -494,32 +533,32 @@ typedef struct {
|
|
|
494
533
|
uint64_t nb0;
|
|
495
534
|
uint64_t nb1;
|
|
496
535
|
uint64_t nb2;
|
|
497
|
-
}
|
|
536
|
+
} wsp_wsp_wsp_ggml_metal_kargs_ssm_conv;
|
|
498
537
|
|
|
499
538
|
typedef struct {
|
|
500
539
|
int64_t d_state;
|
|
501
540
|
int64_t d_inner;
|
|
541
|
+
int64_t n_head;
|
|
542
|
+
int64_t n_group;
|
|
502
543
|
int64_t n_seq_tokens;
|
|
503
544
|
int64_t n_seqs;
|
|
504
|
-
|
|
545
|
+
int64_t s_off;
|
|
505
546
|
uint64_t nb01;
|
|
506
547
|
uint64_t nb02;
|
|
507
|
-
uint64_t
|
|
548
|
+
uint64_t nb03;
|
|
508
549
|
uint64_t nb11;
|
|
509
550
|
uint64_t nb12;
|
|
510
551
|
uint64_t nb13;
|
|
511
|
-
uint64_t nb20;
|
|
512
552
|
uint64_t nb21;
|
|
513
553
|
uint64_t nb22;
|
|
514
|
-
uint64_t nb30;
|
|
515
554
|
uint64_t nb31;
|
|
516
|
-
uint64_t nb40;
|
|
517
555
|
uint64_t nb41;
|
|
518
556
|
uint64_t nb42;
|
|
519
|
-
uint64_t
|
|
557
|
+
uint64_t nb43;
|
|
520
558
|
uint64_t nb51;
|
|
521
559
|
uint64_t nb52;
|
|
522
|
-
|
|
560
|
+
uint64_t nb53;
|
|
561
|
+
} wsp_wsp_wsp_ggml_metal_kargs_ssm_scan;
|
|
523
562
|
|
|
524
563
|
typedef struct {
|
|
525
564
|
int64_t ne00;
|
|
@@ -530,7 +569,7 @@ typedef struct {
|
|
|
530
569
|
uint64_t nb11;
|
|
531
570
|
uint64_t nb1;
|
|
532
571
|
uint64_t nb2;
|
|
533
|
-
}
|
|
572
|
+
} wsp_wsp_wsp_ggml_metal_kargs_get_rows;
|
|
534
573
|
|
|
535
574
|
typedef struct {
|
|
536
575
|
int32_t nk0;
|
|
@@ -546,7 +585,7 @@ typedef struct {
|
|
|
546
585
|
uint64_t nb1;
|
|
547
586
|
uint64_t nb2;
|
|
548
587
|
uint64_t nb3;
|
|
549
|
-
}
|
|
588
|
+
} wsp_wsp_wsp_ggml_metal_kargs_set_rows;
|
|
550
589
|
|
|
551
590
|
typedef struct {
|
|
552
591
|
int64_t ne00;
|
|
@@ -569,7 +608,7 @@ typedef struct {
|
|
|
569
608
|
float sf1;
|
|
570
609
|
float sf2;
|
|
571
610
|
float sf3;
|
|
572
|
-
}
|
|
611
|
+
} wsp_wsp_wsp_ggml_metal_kargs_upscale;
|
|
573
612
|
|
|
574
613
|
typedef struct {
|
|
575
614
|
int64_t ne00;
|
|
@@ -588,7 +627,7 @@ typedef struct {
|
|
|
588
627
|
uint64_t nb1;
|
|
589
628
|
uint64_t nb2;
|
|
590
629
|
uint64_t nb3;
|
|
591
|
-
}
|
|
630
|
+
} wsp_wsp_wsp_ggml_metal_kargs_pad;
|
|
592
631
|
|
|
593
632
|
typedef struct {
|
|
594
633
|
int64_t ne00;
|
|
@@ -609,28 +648,28 @@ typedef struct {
|
|
|
609
648
|
uint64_t nb3;
|
|
610
649
|
int32_t p0;
|
|
611
650
|
int32_t p1;
|
|
612
|
-
}
|
|
651
|
+
} wsp_wsp_wsp_ggml_metal_kargs_pad_reflect_1d;
|
|
613
652
|
|
|
614
653
|
typedef struct {
|
|
615
654
|
uint64_t nb1;
|
|
616
655
|
int dim;
|
|
617
656
|
int max_period;
|
|
618
|
-
}
|
|
657
|
+
} wsp_wsp_wsp_ggml_metal_kargs_timestep_embedding;
|
|
619
658
|
|
|
620
659
|
typedef struct {
|
|
621
660
|
float slope;
|
|
622
|
-
}
|
|
661
|
+
} wsp_wsp_wsp_ggml_metal_kargs_leaky_relu;
|
|
623
662
|
|
|
624
663
|
typedef struct {
|
|
625
664
|
int64_t ncols;
|
|
626
665
|
int64_t ncols_pad;
|
|
627
|
-
}
|
|
666
|
+
} wsp_wsp_wsp_ggml_metal_kargs_argsort;
|
|
628
667
|
|
|
629
668
|
typedef struct {
|
|
630
669
|
int64_t ne0;
|
|
631
670
|
float start;
|
|
632
671
|
float step;
|
|
633
|
-
}
|
|
672
|
+
} wsp_wsp_wsp_ggml_metal_kargs_arange;
|
|
634
673
|
|
|
635
674
|
typedef struct {
|
|
636
675
|
int32_t k0;
|
|
@@ -644,6 +683,6 @@ typedef struct {
|
|
|
644
683
|
int64_t OH;
|
|
645
684
|
int64_t OW;
|
|
646
685
|
int64_t parallel_elements;
|
|
647
|
-
}
|
|
686
|
+
} wsp_wsp_wsp_ggml_metal_kargs_pool_2d;
|
|
648
687
|
|
|
649
|
-
#endif //
|
|
688
|
+
#endif // WSP_WSP_WSP_GGML_METAL_IMPL
|
package/cpp/ggml-metal.h
CHANGED
|
@@ -39,18 +39,13 @@ extern "C" {
|
|
|
39
39
|
// user-code should use only these functions
|
|
40
40
|
//
|
|
41
41
|
|
|
42
|
+
// TODO: remove in the future
|
|
42
43
|
WSP_GGML_BACKEND_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void);
|
|
43
44
|
|
|
44
45
|
WSP_GGML_BACKEND_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend);
|
|
45
46
|
|
|
46
|
-
WSP_GGML_DEPRECATED(
|
|
47
|
-
WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
|
|
48
|
-
"obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
|
|
49
|
-
|
|
50
47
|
WSP_GGML_BACKEND_API void wsp_ggml_backend_metal_set_abort_callback(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * user_data);
|
|
51
48
|
|
|
52
|
-
WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
|
|
53
|
-
|
|
54
49
|
// helper to check if the device supports a specific family
|
|
55
50
|
// ideally, the user code should be doing these checks
|
|
56
51
|
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
package/cpp/ggml-opt.cpp
CHANGED
|
@@ -64,9 +64,11 @@ struct wsp_ggml_opt_context {
|
|
|
64
64
|
int32_t opt_i = 0;
|
|
65
65
|
bool loss_per_datapoint = false;
|
|
66
66
|
|
|
67
|
-
wsp_ggml_opt_get_optimizer_params get_opt_pars
|
|
68
|
-
void *
|
|
69
|
-
struct wsp_ggml_tensor *
|
|
67
|
+
wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
|
68
|
+
void * get_opt_pars_ud = nullptr;
|
|
69
|
+
struct wsp_ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
|
70
|
+
|
|
71
|
+
enum wsp_ggml_opt_optimizer_type optimizer = WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
70
72
|
};
|
|
71
73
|
|
|
72
74
|
struct wsp_ggml_opt_result {
|
|
@@ -229,9 +231,13 @@ struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_default_optimizer_params(v
|
|
|
229
231
|
result.adamw.eps = 1e-8f;
|
|
230
232
|
result.adamw.wd = 0.0f;
|
|
231
233
|
|
|
234
|
+
result.sgd.alpha = 1e-3f;
|
|
235
|
+
result.sgd.wd = 0.0f;
|
|
236
|
+
|
|
232
237
|
return result;
|
|
233
238
|
}
|
|
234
239
|
|
|
240
|
+
|
|
235
241
|
struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_constant_optimizer_params(void * userdata) {
|
|
236
242
|
return *((struct wsp_ggml_opt_optimizer_params *) userdata);
|
|
237
243
|
}
|
|
@@ -249,6 +255,7 @@ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(
|
|
|
249
255
|
/*opt_period =*/ 1,
|
|
250
256
|
/*get_opt_pars =*/ wsp_ggml_opt_get_default_optimizer_params,
|
|
251
257
|
/*get_opt_pars_ud =*/ nullptr,
|
|
258
|
+
/*optimizer =*/ WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
252
259
|
};
|
|
253
260
|
}
|
|
254
261
|
|
|
@@ -316,9 +323,14 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
316
323
|
WSP_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with wsp_ggml_opt_prepare_alloc");
|
|
317
324
|
WSP_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
|
318
325
|
|
|
326
|
+
const enum wsp_ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
|
327
|
+
|
|
319
328
|
const bool accumulate = opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_GRAD &&
|
|
320
329
|
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
|
321
330
|
|
|
331
|
+
const bool need_momenta = opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT &&
|
|
332
|
+
opt_ctx->optimizer == WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
333
|
+
|
|
322
334
|
wsp_ggml_set_input(opt_ctx->inputs);
|
|
323
335
|
wsp_ggml_set_output(opt_ctx->outputs);
|
|
324
336
|
|
|
@@ -340,8 +352,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
340
352
|
// - pred (if using static graphs)
|
|
341
353
|
// - ncorrect (if using static graphs, 2 tensors).
|
|
342
354
|
constexpr size_t n_loss = 1;
|
|
343
|
-
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
|
344
|
-
(opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
|
355
|
+
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
|
345
356
|
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
|
346
357
|
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * wsp_ggml_tensor_overhead();
|
|
347
358
|
struct wsp_ggml_init_params params = {
|
|
@@ -458,7 +469,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
458
469
|
}
|
|
459
470
|
}
|
|
460
471
|
|
|
461
|
-
if (opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
|
|
472
|
+
if (need_momenta && opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
|
|
462
473
|
opt_ctx->grad_m.resize(n_nodes);
|
|
463
474
|
opt_ctx->grad_v.resize(n_nodes);
|
|
464
475
|
for (int i = 0; i < n_nodes; ++i) {
|
|
@@ -492,23 +503,36 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
492
503
|
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
|
493
504
|
opt_ctx->gb_opt = wsp_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
|
494
505
|
|
|
495
|
-
opt_ctx->
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
506
|
+
opt_ctx->opt_step_params = wsp_ggml_new_tensor_1d(opt_ctx->ctx_cpu, WSP_GGML_TYPE_F32, need_momenta ? 7 : 2);
|
|
507
|
+
wsp_ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
|
508
|
+
wsp_ggml_set_input(adamw_params);
|
|
509
|
+
const char * optimizer_name = wsp_ggml_opt_optimizer_name(opt_ctx->optimizer);
|
|
510
|
+
wsp_ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
|
499
511
|
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
|
500
512
|
struct wsp_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
|
501
513
|
struct wsp_ggml_tensor * grad = wsp_ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
|
502
514
|
|
|
503
515
|
if (grad && (node->flags & WSP_GGML_TENSOR_FLAG_PARAM)) {
|
|
504
|
-
struct wsp_ggml_tensor * m
|
|
505
|
-
struct wsp_ggml_tensor * v
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
516
|
+
struct wsp_ggml_tensor * m = nullptr;
|
|
517
|
+
struct wsp_ggml_tensor * v = nullptr;
|
|
518
|
+
if (need_momenta) {
|
|
519
|
+
m = opt_ctx->grad_m[i];
|
|
520
|
+
v = opt_ctx->grad_v[i];
|
|
521
|
+
wsp_ggml_format_name(m, "AdamW m for %s", node->name);
|
|
522
|
+
wsp_ggml_format_name(v, "AdamW v for %s", node->name);
|
|
523
|
+
}
|
|
524
|
+
struct wsp_ggml_tensor * opt_step;
|
|
525
|
+
switch (optimizer) {
|
|
526
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
|
527
|
+
opt_step = wsp_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
|
528
|
+
break;
|
|
529
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
|
|
530
|
+
opt_step = wsp_ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
|
531
|
+
break;
|
|
532
|
+
default:
|
|
533
|
+
WSP_GGML_ABORT("fatal error");
|
|
534
|
+
}
|
|
535
|
+
wsp_ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
|
512
536
|
wsp_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
|
513
537
|
}
|
|
514
538
|
}
|
|
@@ -534,6 +558,7 @@ wsp_ggml_opt_context_t wsp_ggml_opt_init(struct wsp_ggml_opt_params params) {
|
|
|
534
558
|
result->opt_period = params.opt_period;
|
|
535
559
|
result->get_opt_pars = params.get_opt_pars;
|
|
536
560
|
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
561
|
+
result->optimizer = params.optimizer;
|
|
537
562
|
|
|
538
563
|
WSP_GGML_ASSERT(result->opt_period >= 1);
|
|
539
564
|
|
|
@@ -756,29 +781,43 @@ void wsp_ggml_opt_alloc(wsp_ggml_opt_context_t opt_ctx, bool backward) {
|
|
|
756
781
|
void wsp_ggml_opt_eval(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result_t result) {
|
|
757
782
|
WSP_GGML_ASSERT(opt_ctx->eval_ready);
|
|
758
783
|
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
784
|
+
const wsp_ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
|
785
|
+
|
|
786
|
+
switch (opt_ctx->optimizer) {
|
|
787
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
|
788
|
+
WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
|
789
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
|
790
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
|
791
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
|
792
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
|
793
|
+
WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
|
794
|
+
WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
|
795
|
+
WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
|
796
|
+
|
|
797
|
+
// beta1, beta2 after applying warmup
|
|
798
|
+
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
|
799
|
+
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
|
800
|
+
|
|
801
|
+
float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
|
|
802
|
+
adamw_par_data[0] = opt_pars.adamw.alpha;
|
|
803
|
+
adamw_par_data[1] = opt_pars.adamw.beta1;
|
|
804
|
+
adamw_par_data[2] = opt_pars.adamw.beta2;
|
|
805
|
+
adamw_par_data[3] = opt_pars.adamw.eps;
|
|
806
|
+
adamw_par_data[4] = opt_pars.adamw.wd;
|
|
807
|
+
adamw_par_data[5] = beta1h;
|
|
808
|
+
adamw_par_data[6] = beta2h;
|
|
809
|
+
} break;
|
|
810
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
|
811
|
+
WSP_GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
|
812
|
+
WSP_GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
|
813
|
+
WSP_GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
|
814
|
+
float * sgd = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
|
|
815
|
+
sgd[0] = opt_pars.sgd.alpha;
|
|
816
|
+
sgd[1] = opt_pars.sgd.wd;
|
|
817
|
+
} break;
|
|
818
|
+
default:
|
|
819
|
+
WSP_GGML_ABORT("fatal error");
|
|
820
|
+
}
|
|
782
821
|
}
|
|
783
822
|
|
|
784
823
|
wsp_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
@@ -963,6 +1002,7 @@ void wsp_ggml_opt_fit(
|
|
|
963
1002
|
wsp_ggml_tensor * outputs,
|
|
964
1003
|
wsp_ggml_opt_dataset_t dataset,
|
|
965
1004
|
enum wsp_ggml_opt_loss_type loss_type,
|
|
1005
|
+
enum wsp_ggml_opt_optimizer_type optimizer,
|
|
966
1006
|
wsp_ggml_opt_get_optimizer_params get_opt_pars,
|
|
967
1007
|
int64_t nepoch,
|
|
968
1008
|
int64_t nbatch_logical,
|
|
@@ -993,6 +1033,7 @@ void wsp_ggml_opt_fit(
|
|
|
993
1033
|
params.opt_period = opt_period;
|
|
994
1034
|
params.get_opt_pars = get_opt_pars;
|
|
995
1035
|
params.get_opt_pars_ud = &epoch;
|
|
1036
|
+
params.optimizer = optimizer;
|
|
996
1037
|
wsp_ggml_opt_context_t opt_ctx = wsp_ggml_opt_init(params);
|
|
997
1038
|
|
|
998
1039
|
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
|
@@ -1035,3 +1076,18 @@ void wsp_ggml_opt_fit(
|
|
|
1035
1076
|
wsp_ggml_opt_result_free(result_train);
|
|
1036
1077
|
wsp_ggml_opt_result_free(result_val);
|
|
1037
1078
|
}
|
|
1079
|
+
|
|
1080
|
+
enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t c) {
|
|
1081
|
+
return c->optimizer;
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type o) {
|
|
1085
|
+
switch (o) {
|
|
1086
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
|
1087
|
+
return "adamw";
|
|
1088
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
|
|
1089
|
+
return "sgd";
|
|
1090
|
+
default:
|
|
1091
|
+
return "undefined";
|
|
1092
|
+
};
|
|
1093
|
+
}
|