whisper.rn 0.5.0-rc.8 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +1 -15
- package/cpp/ggml-backend-reg.cpp +17 -8
- package/cpp/ggml-backend.cpp +15 -22
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +34 -0
- package/cpp/ggml-cpu/ggml-cpu.c +22 -1
- package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/ggml-cpu/ops.cpp +870 -211
- package/cpp/ggml-cpu/ops.h +3 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +12 -9
- package/cpp/ggml-cpu/vec.h +107 -13
- package/cpp/ggml-impl.h +77 -0
- package/cpp/ggml-metal-impl.h +51 -12
- package/cpp/ggml-metal.m +610 -115
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +110 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +314 -88
- package/cpp/ggml.h +137 -11
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +23 -6
- package/cpp/whisper.cpp +15 -6
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +1 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +32 -0
- package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml-opt.cpp
CHANGED
|
@@ -64,9 +64,11 @@ struct wsp_ggml_opt_context {
|
|
|
64
64
|
int32_t opt_i = 0;
|
|
65
65
|
bool loss_per_datapoint = false;
|
|
66
66
|
|
|
67
|
-
wsp_ggml_opt_get_optimizer_params get_opt_pars
|
|
68
|
-
void *
|
|
69
|
-
struct wsp_ggml_tensor *
|
|
67
|
+
wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
|
68
|
+
void * get_opt_pars_ud = nullptr;
|
|
69
|
+
struct wsp_ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
|
70
|
+
|
|
71
|
+
enum wsp_ggml_opt_optimizer_type optimizer = WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
70
72
|
};
|
|
71
73
|
|
|
72
74
|
struct wsp_ggml_opt_result {
|
|
@@ -229,9 +231,13 @@ struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_default_optimizer_params(v
|
|
|
229
231
|
result.adamw.eps = 1e-8f;
|
|
230
232
|
result.adamw.wd = 0.0f;
|
|
231
233
|
|
|
234
|
+
result.sgd.alpha = 1e-3f;
|
|
235
|
+
result.sgd.wd = 0.0f;
|
|
236
|
+
|
|
232
237
|
return result;
|
|
233
238
|
}
|
|
234
239
|
|
|
240
|
+
|
|
235
241
|
struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_constant_optimizer_params(void * userdata) {
|
|
236
242
|
return *((struct wsp_ggml_opt_optimizer_params *) userdata);
|
|
237
243
|
}
|
|
@@ -249,6 +255,7 @@ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(
|
|
|
249
255
|
/*opt_period =*/ 1,
|
|
250
256
|
/*get_opt_pars =*/ wsp_ggml_opt_get_default_optimizer_params,
|
|
251
257
|
/*get_opt_pars_ud =*/ nullptr,
|
|
258
|
+
/*optimizer =*/ WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
252
259
|
};
|
|
253
260
|
}
|
|
254
261
|
|
|
@@ -316,9 +323,14 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
316
323
|
WSP_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with wsp_ggml_opt_prepare_alloc");
|
|
317
324
|
WSP_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
|
318
325
|
|
|
326
|
+
const enum wsp_ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
|
327
|
+
|
|
319
328
|
const bool accumulate = opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_GRAD &&
|
|
320
329
|
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
|
321
330
|
|
|
331
|
+
const bool need_momenta = opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT &&
|
|
332
|
+
opt_ctx->optimizer == WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
333
|
+
|
|
322
334
|
wsp_ggml_set_input(opt_ctx->inputs);
|
|
323
335
|
wsp_ggml_set_output(opt_ctx->outputs);
|
|
324
336
|
|
|
@@ -340,8 +352,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
340
352
|
// - pred (if using static graphs)
|
|
341
353
|
// - ncorrect (if using static graphs, 2 tensors).
|
|
342
354
|
constexpr size_t n_loss = 1;
|
|
343
|
-
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
|
344
|
-
(opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
|
355
|
+
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
|
345
356
|
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
|
346
357
|
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * wsp_ggml_tensor_overhead();
|
|
347
358
|
struct wsp_ggml_init_params params = {
|
|
@@ -458,7 +469,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
458
469
|
}
|
|
459
470
|
}
|
|
460
471
|
|
|
461
|
-
if (opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
|
|
472
|
+
if (need_momenta && opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
|
|
462
473
|
opt_ctx->grad_m.resize(n_nodes);
|
|
463
474
|
opt_ctx->grad_v.resize(n_nodes);
|
|
464
475
|
for (int i = 0; i < n_nodes; ++i) {
|
|
@@ -492,23 +503,36 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
|
|
|
492
503
|
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
|
493
504
|
opt_ctx->gb_opt = wsp_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
|
494
505
|
|
|
495
|
-
opt_ctx->
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
506
|
+
opt_ctx->opt_step_params = wsp_ggml_new_tensor_1d(opt_ctx->ctx_cpu, WSP_GGML_TYPE_F32, need_momenta ? 7 : 2);
|
|
507
|
+
wsp_ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
|
508
|
+
wsp_ggml_set_input(adamw_params);
|
|
509
|
+
const char * optimizer_name = wsp_ggml_opt_optimizer_name(opt_ctx->optimizer);
|
|
510
|
+
wsp_ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
|
499
511
|
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
|
500
512
|
struct wsp_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
|
501
513
|
struct wsp_ggml_tensor * grad = wsp_ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
|
502
514
|
|
|
503
515
|
if (grad && (node->flags & WSP_GGML_TENSOR_FLAG_PARAM)) {
|
|
504
|
-
struct wsp_ggml_tensor * m
|
|
505
|
-
struct wsp_ggml_tensor * v
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
516
|
+
struct wsp_ggml_tensor * m = nullptr;
|
|
517
|
+
struct wsp_ggml_tensor * v = nullptr;
|
|
518
|
+
if (need_momenta) {
|
|
519
|
+
m = opt_ctx->grad_m[i];
|
|
520
|
+
v = opt_ctx->grad_v[i];
|
|
521
|
+
wsp_ggml_format_name(m, "AdamW m for %s", node->name);
|
|
522
|
+
wsp_ggml_format_name(v, "AdamW v for %s", node->name);
|
|
523
|
+
}
|
|
524
|
+
struct wsp_ggml_tensor * opt_step;
|
|
525
|
+
switch (optimizer) {
|
|
526
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
|
527
|
+
opt_step = wsp_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
|
528
|
+
break;
|
|
529
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
|
|
530
|
+
opt_step = wsp_ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
|
531
|
+
break;
|
|
532
|
+
default:
|
|
533
|
+
WSP_GGML_ABORT("fatal error");
|
|
534
|
+
}
|
|
535
|
+
wsp_ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
|
512
536
|
wsp_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
|
513
537
|
}
|
|
514
538
|
}
|
|
@@ -534,6 +558,7 @@ wsp_ggml_opt_context_t wsp_ggml_opt_init(struct wsp_ggml_opt_params params) {
|
|
|
534
558
|
result->opt_period = params.opt_period;
|
|
535
559
|
result->get_opt_pars = params.get_opt_pars;
|
|
536
560
|
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
561
|
+
result->optimizer = params.optimizer;
|
|
537
562
|
|
|
538
563
|
WSP_GGML_ASSERT(result->opt_period >= 1);
|
|
539
564
|
|
|
@@ -756,29 +781,43 @@ void wsp_ggml_opt_alloc(wsp_ggml_opt_context_t opt_ctx, bool backward) {
|
|
|
756
781
|
void wsp_ggml_opt_eval(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result_t result) {
|
|
757
782
|
WSP_GGML_ASSERT(opt_ctx->eval_ready);
|
|
758
783
|
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
784
|
+
const wsp_ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
|
785
|
+
|
|
786
|
+
switch (opt_ctx->optimizer) {
|
|
787
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
|
788
|
+
WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
|
789
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
|
790
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
|
791
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
|
792
|
+
WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
|
793
|
+
WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
|
794
|
+
WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
|
795
|
+
WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
|
796
|
+
|
|
797
|
+
// beta1, beta2 after applying warmup
|
|
798
|
+
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
|
799
|
+
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
|
800
|
+
|
|
801
|
+
float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
|
|
802
|
+
adamw_par_data[0] = opt_pars.adamw.alpha;
|
|
803
|
+
adamw_par_data[1] = opt_pars.adamw.beta1;
|
|
804
|
+
adamw_par_data[2] = opt_pars.adamw.beta2;
|
|
805
|
+
adamw_par_data[3] = opt_pars.adamw.eps;
|
|
806
|
+
adamw_par_data[4] = opt_pars.adamw.wd;
|
|
807
|
+
adamw_par_data[5] = beta1h;
|
|
808
|
+
adamw_par_data[6] = beta2h;
|
|
809
|
+
} break;
|
|
810
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
|
811
|
+
WSP_GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
|
812
|
+
WSP_GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
|
813
|
+
WSP_GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
|
814
|
+
float * sgd = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
|
|
815
|
+
sgd[0] = opt_pars.sgd.alpha;
|
|
816
|
+
sgd[1] = opt_pars.sgd.wd;
|
|
817
|
+
} break;
|
|
818
|
+
default:
|
|
819
|
+
WSP_GGML_ABORT("fatal error");
|
|
820
|
+
}
|
|
782
821
|
}
|
|
783
822
|
|
|
784
823
|
wsp_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
@@ -963,6 +1002,7 @@ void wsp_ggml_opt_fit(
|
|
|
963
1002
|
wsp_ggml_tensor * outputs,
|
|
964
1003
|
wsp_ggml_opt_dataset_t dataset,
|
|
965
1004
|
enum wsp_ggml_opt_loss_type loss_type,
|
|
1005
|
+
enum wsp_ggml_opt_optimizer_type optimizer,
|
|
966
1006
|
wsp_ggml_opt_get_optimizer_params get_opt_pars,
|
|
967
1007
|
int64_t nepoch,
|
|
968
1008
|
int64_t nbatch_logical,
|
|
@@ -993,6 +1033,7 @@ void wsp_ggml_opt_fit(
|
|
|
993
1033
|
params.opt_period = opt_period;
|
|
994
1034
|
params.get_opt_pars = get_opt_pars;
|
|
995
1035
|
params.get_opt_pars_ud = &epoch;
|
|
1036
|
+
params.optimizer = optimizer;
|
|
996
1037
|
wsp_ggml_opt_context_t opt_ctx = wsp_ggml_opt_init(params);
|
|
997
1038
|
|
|
998
1039
|
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
|
@@ -1035,3 +1076,18 @@ void wsp_ggml_opt_fit(
|
|
|
1035
1076
|
wsp_ggml_opt_result_free(result_train);
|
|
1036
1077
|
wsp_ggml_opt_result_free(result_val);
|
|
1037
1078
|
}
|
|
1079
|
+
|
|
1080
|
+
enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t c) {
|
|
1081
|
+
return c->optimizer;
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type o) {
|
|
1085
|
+
switch (o) {
|
|
1086
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
|
1087
|
+
return "adamw";
|
|
1088
|
+
case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
|
|
1089
|
+
return "sgd";
|
|
1090
|
+
default:
|
|
1091
|
+
return "undefined";
|
|
1092
|
+
};
|
|
1093
|
+
}
|
package/cpp/ggml-opt.h
CHANGED
|
@@ -74,16 +74,26 @@ extern "C" {
|
|
|
74
74
|
WSP_GGML_OPT_BUILD_TYPE_OPT = 30,
|
|
75
75
|
};
|
|
76
76
|
|
|
77
|
+
enum wsp_ggml_opt_optimizer_type {
|
|
78
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
79
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_SGD,
|
|
80
|
+
|
|
81
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_COUNT
|
|
82
|
+
};
|
|
83
|
+
|
|
77
84
|
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
|
78
85
|
struct wsp_ggml_opt_optimizer_params {
|
|
79
|
-
// AdamW optimizer parameters
|
|
80
86
|
struct {
|
|
81
87
|
float alpha; // learning rate
|
|
82
|
-
float beta1;
|
|
83
|
-
float beta2;
|
|
88
|
+
float beta1; // first AdamW momentum
|
|
89
|
+
float beta2; // second AdamW momentum
|
|
84
90
|
float eps; // epsilon for numerical stability
|
|
85
|
-
float wd; // weight decay
|
|
91
|
+
float wd; // weight decay - 0.0f to disable
|
|
86
92
|
} adamw;
|
|
93
|
+
struct {
|
|
94
|
+
float alpha; // learning rate
|
|
95
|
+
float wd; // weight decay
|
|
96
|
+
} sgd;
|
|
87
97
|
};
|
|
88
98
|
|
|
89
99
|
// callback to calculate optimizer parameters prior to a backward pass
|
|
@@ -112,8 +122,11 @@ extern "C" {
|
|
|
112
122
|
|
|
113
123
|
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
|
114
124
|
|
|
115
|
-
wsp_ggml_opt_get_optimizer_params get_opt_pars;
|
|
116
|
-
void *
|
|
125
|
+
wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
|
126
|
+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
|
127
|
+
|
|
128
|
+
// only WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
|
129
|
+
enum wsp_ggml_opt_optimizer_type optimizer;
|
|
117
130
|
};
|
|
118
131
|
|
|
119
132
|
// get parameters for an optimization context with defaults set where possible
|
|
@@ -142,6 +155,10 @@ extern "C" {
|
|
|
142
155
|
// get the gradient accumulator for a node from the forward graph
|
|
143
156
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node);
|
|
144
157
|
|
|
158
|
+
WSP_GGML_API enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t); //TODO consistent naming scheme
|
|
159
|
+
|
|
160
|
+
WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type);
|
|
161
|
+
|
|
145
162
|
// ====== Optimization Result ======
|
|
146
163
|
|
|
147
164
|
WSP_GGML_API wsp_ggml_opt_result_t wsp_ggml_opt_result_init(void);
|
|
@@ -226,12 +243,14 @@ extern "C" {
|
|
|
226
243
|
struct wsp_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
|
227
244
|
wsp_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
|
228
245
|
enum wsp_ggml_opt_loss_type loss_type, // loss to minimize
|
|
246
|
+
enum wsp_ggml_opt_optimizer_type optimizer, // sgd or adamw
|
|
229
247
|
wsp_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
|
230
248
|
int64_t nepoch, // how many times the dataset should be iterated over
|
|
231
249
|
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
|
232
250
|
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
|
233
251
|
bool silent); // whether or not info prints to stderr should be suppressed
|
|
234
252
|
|
|
253
|
+
|
|
235
254
|
#ifdef __cplusplus
|
|
236
255
|
}
|
|
237
256
|
#endif
|
package/cpp/ggml-quants.c
CHANGED
|
@@ -21,6 +21,17 @@
|
|
|
21
21
|
|
|
22
22
|
#define UNUSED WSP_GGML_UNUSED
|
|
23
23
|
|
|
24
|
+
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
|
25
|
+
if (x <= val[0]) return 0;
|
|
26
|
+
if (x >= val[n-1]) return n-1;
|
|
27
|
+
int ml = 0, mu = n-1;
|
|
28
|
+
while (mu-ml > 1) {
|
|
29
|
+
int mav = (ml+mu)/2;
|
|
30
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
|
31
|
+
}
|
|
32
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
33
|
+
}
|
|
34
|
+
|
|
24
35
|
// reference implementation for deterministic creation of model files
|
|
25
36
|
void wsp_quantize_row_q4_0_ref(const float * WSP_GGML_RESTRICT x, block_q4_0 * WSP_GGML_RESTRICT y, int64_t k) {
|
|
26
37
|
static const int qk = QK4_0;
|
|
@@ -246,6 +257,53 @@ void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * W
|
|
|
246
257
|
}
|
|
247
258
|
}
|
|
248
259
|
|
|
260
|
+
static inline int best_index_mxfp4(float x, float e) {
|
|
261
|
+
int best_index = 0;
|
|
262
|
+
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
|
|
263
|
+
for (int i = 1; i < 16; i++) {
|
|
264
|
+
float err = fabsf(kvalues_mxfp4[i]*e - x);
|
|
265
|
+
if (err < best_err) {
|
|
266
|
+
best_index = i;
|
|
267
|
+
best_err = err;
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
return best_index;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k) {
|
|
274
|
+
static const int qk = QK_MXFP4;
|
|
275
|
+
|
|
276
|
+
assert(k % qk == 0);
|
|
277
|
+
|
|
278
|
+
const int nb = k / qk;
|
|
279
|
+
|
|
280
|
+
for (int i = 0; i < nb; i++) {
|
|
281
|
+
float amax = 0.0f; // absolute max
|
|
282
|
+
|
|
283
|
+
for (int j = 0; j < qk; j++) {
|
|
284
|
+
const float v = x[i*qk + j];
|
|
285
|
+
|
|
286
|
+
if (amax < fabsf(v)) {
|
|
287
|
+
amax = fabsf(v);
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
|
|
292
|
+
|
|
293
|
+
const float d = WSP_GGML_E8M0_TO_FP32_HALF(e);
|
|
294
|
+
|
|
295
|
+
y[i].e = e;
|
|
296
|
+
|
|
297
|
+
for (int j = 0; j < qk/2; ++j) {
|
|
298
|
+
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
|
|
299
|
+
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
|
|
300
|
+
|
|
301
|
+
y[i].qs[j] = x0;
|
|
302
|
+
y[i].qs[j] |= x1 << 4;
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
249
307
|
void wsp_dewsp_quantize_row_q4_0(const block_q4_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k) {
|
|
250
308
|
static const int qk = QK4_0;
|
|
251
309
|
|
|
@@ -356,6 +414,26 @@ void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float *
|
|
|
356
414
|
}
|
|
357
415
|
}
|
|
358
416
|
|
|
417
|
+
void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k) {
|
|
418
|
+
static const int qk = QK_MXFP4;
|
|
419
|
+
|
|
420
|
+
assert(k % qk == 0);
|
|
421
|
+
|
|
422
|
+
const int nb = k / qk;
|
|
423
|
+
|
|
424
|
+
for (int i = 0; i < nb; i++) {
|
|
425
|
+
const float d = WSP_GGML_E8M0_TO_FP32_HALF(x[i].e);
|
|
426
|
+
|
|
427
|
+
for (int j = 0; j < qk/2; ++j) {
|
|
428
|
+
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
|
|
429
|
+
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
|
|
430
|
+
|
|
431
|
+
y[i*qk + j + 0 ] = x0*d;
|
|
432
|
+
y[i*qk + j + qk/2] = x1*d;
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
|
|
359
437
|
//
|
|
360
438
|
// 2-6 bit quantization in super-blocks
|
|
361
439
|
//
|
|
@@ -488,7 +566,7 @@ static float make_q3_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x,
|
|
|
488
566
|
for (int i = 0; i < n; ++i) {
|
|
489
567
|
L[i] += nmax;
|
|
490
568
|
}
|
|
491
|
-
return sumlx / suml2;
|
|
569
|
+
return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
|
|
492
570
|
}
|
|
493
571
|
for (int i = 0; i < n; ++i) {
|
|
494
572
|
int l = nearest_int(iscale * x[i]);
|
|
@@ -823,7 +901,7 @@ static float make_qp_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x,
|
|
|
823
901
|
for (int i = 0; i < n; ++i) {
|
|
824
902
|
max = MAX(max, x[i]);
|
|
825
903
|
}
|
|
826
|
-
if (
|
|
904
|
+
if (max < GROUP_MAX_EPS) { // all zero
|
|
827
905
|
for (int i = 0; i < n; ++i) { L[i] = 0; }
|
|
828
906
|
return 0.f;
|
|
829
907
|
}
|
|
@@ -888,7 +966,7 @@ static float make_qp_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x,
|
|
|
888
966
|
break;
|
|
889
967
|
}
|
|
890
968
|
}
|
|
891
|
-
return sumlx/suml2;
|
|
969
|
+
return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
|
|
892
970
|
}
|
|
893
971
|
|
|
894
972
|
static void wsp_quantize_row_q2_K_impl(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int k, const float * WSP_GGML_RESTRICT quant_weights) {
|
|
@@ -2014,6 +2092,12 @@ size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RE
|
|
|
2014
2092
|
return nrow * row_size;
|
|
2015
2093
|
}
|
|
2016
2094
|
|
|
2095
|
+
size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
2096
|
+
WSP_GGML_UNUSED(quant_weights);
|
|
2097
|
+
wsp_quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
|
2098
|
+
return nrow * wsp_ggml_row_size(WSP_GGML_TYPE_MXFP4, n_per_row);
|
|
2099
|
+
}
|
|
2100
|
+
|
|
2017
2101
|
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
|
2018
2102
|
|
|
2019
2103
|
void wsp_quantize_row_tq1_0_ref(const float * WSP_GGML_RESTRICT x, block_tq1_0 * WSP_GGML_RESTRICT y, int64_t k) {
|
|
@@ -4182,7 +4266,7 @@ static void wsp_quantize_row_iq1_s_impl(const float * WSP_GGML_RESTRICT x, void
|
|
|
4182
4266
|
sumw[j+1] = sumw[j] + weight[i];
|
|
4183
4267
|
}
|
|
4184
4268
|
}
|
|
4185
|
-
float best_score = -
|
|
4269
|
+
float best_score = -FLT_MAX, scale = max;
|
|
4186
4270
|
int besti1 = -1, besti2 = -1, best_shift = 0;
|
|
4187
4271
|
for (int i1 = 0; i1 <= block_size; ++i1) {
|
|
4188
4272
|
for (int i2 = i1; i2 <= block_size; ++i2) {
|
|
@@ -4358,7 +4442,7 @@ static void wsp_quantize_row_iq1_m_impl(const float * WSP_GGML_RESTRICT x, void
|
|
|
4358
4442
|
idx[2*j] = j;
|
|
4359
4443
|
}
|
|
4360
4444
|
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
|
|
4361
|
-
float best_score = -
|
|
4445
|
+
float best_score = -FLT_MAX, scale = max;
|
|
4362
4446
|
int besti1 = -1, besti2 = -1, best_k = -1;
|
|
4363
4447
|
// 0: +, +
|
|
4364
4448
|
// 1: +, -
|
|
@@ -4551,17 +4635,6 @@ size_t wsp_quantize_iq1_m(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_R
|
|
|
4551
4635
|
|
|
4552
4636
|
// ============================ 4-bit non-linear quants
|
|
4553
4637
|
|
|
4554
|
-
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
|
4555
|
-
if (x <= val[0]) return 0;
|
|
4556
|
-
if (x >= val[n-1]) return n-1;
|
|
4557
|
-
int ml = 0, mu = n-1;
|
|
4558
|
-
while (mu-ml > 1) {
|
|
4559
|
-
int mav = (ml+mu)/2;
|
|
4560
|
-
if (x < val[mav]) mu = mav; else ml = mav;
|
|
4561
|
-
}
|
|
4562
|
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
4563
|
-
}
|
|
4564
|
-
|
|
4565
4638
|
static void wsp_quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * WSP_GGML_RESTRICT x,
|
|
4566
4639
|
wsp_ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
|
4567
4640
|
float * scales, float * weight, uint8_t * L,
|
|
@@ -4961,6 +5034,15 @@ static bool validate_fp16(wsp_ggml_fp16_t f, size_t i) {
|
|
|
4961
5034
|
return true;
|
|
4962
5035
|
}
|
|
4963
5036
|
|
|
5037
|
+
static bool validate_e_e8m0(uint8_t e, size_t i) {
|
|
5038
|
+
if (e == 0xff) {
|
|
5039
|
+
fprintf(stderr, "wsp_ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
|
|
5040
|
+
return false;
|
|
5041
|
+
}
|
|
5042
|
+
|
|
5043
|
+
return true;
|
|
5044
|
+
}
|
|
5045
|
+
|
|
4964
5046
|
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
|
4965
5047
|
const type * q = (const type *) (data); \
|
|
4966
5048
|
for (size_t i = 0; i < (nb); ++i) { \
|
|
@@ -4977,6 +5059,14 @@ static bool validate_fp16(wsp_ggml_fp16_t f, size_t i) {
|
|
|
4977
5059
|
} \
|
|
4978
5060
|
}
|
|
4979
5061
|
|
|
5062
|
+
#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
|
|
5063
|
+
const type * q = (const type *) (data); \
|
|
5064
|
+
for (size_t i = 0; i < (nb); ++i) { \
|
|
5065
|
+
if (!validate_e_e8m0(q[i].e, i)) { \
|
|
5066
|
+
return false; \
|
|
5067
|
+
} \
|
|
5068
|
+
}
|
|
5069
|
+
|
|
4980
5070
|
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
|
4981
5071
|
const type * q = (const type *) (data); \
|
|
4982
5072
|
for (size_t i = 0; i < (nb); ++i) { \
|
|
@@ -5130,6 +5220,10 @@ bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size
|
|
|
5130
5220
|
{
|
|
5131
5221
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
|
5132
5222
|
} break;
|
|
5223
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
5224
|
+
{
|
|
5225
|
+
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
|
|
5226
|
+
} break;
|
|
5133
5227
|
case WSP_GGML_TYPE_Q2_K:
|
|
5134
5228
|
{
|
|
5135
5229
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
package/cpp/ggml-quants.h
CHANGED
|
@@ -21,6 +21,8 @@ WSP_GGML_API void wsp_quantize_row_q5_1_ref(const float * WSP_GGML_RESTRICT x, b
|
|
|
21
21
|
WSP_GGML_API void wsp_quantize_row_q8_0_ref(const float * WSP_GGML_RESTRICT x, block_q8_0 * WSP_GGML_RESTRICT y, int64_t k);
|
|
22
22
|
WSP_GGML_API void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * WSP_GGML_RESTRICT y, int64_t k);
|
|
23
23
|
|
|
24
|
+
WSP_GGML_API void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k);
|
|
25
|
+
|
|
24
26
|
WSP_GGML_API void wsp_quantize_row_q2_K_ref(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
25
27
|
WSP_GGML_API void wsp_quantize_row_q3_K_ref(const float * WSP_GGML_RESTRICT x, block_q3_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
26
28
|
WSP_GGML_API void wsp_quantize_row_q4_K_ref(const float * WSP_GGML_RESTRICT x, block_q4_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
@@ -45,6 +47,8 @@ WSP_GGML_API void wsp_dewsp_quantize_row_q5_1(const block_q5_1 * WSP_GGML_RESTRI
|
|
|
45
47
|
WSP_GGML_API void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
46
48
|
//WSP_GGML_API void wsp_dewsp_quantize_row_q8_1(const block_q8_1 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
47
49
|
|
|
50
|
+
WSP_GGML_API void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
51
|
+
|
|
48
52
|
WSP_GGML_API void wsp_dewsp_quantize_row_q2_K(const block_q2_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
49
53
|
WSP_GGML_API void wsp_dewsp_quantize_row_q3_K(const block_q3_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
50
54
|
WSP_GGML_API void wsp_dewsp_quantize_row_q4_K(const block_q4_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
@@ -90,6 +94,8 @@ WSP_GGML_API size_t wsp_quantize_q5_0(const float * WSP_GGML_RESTRICT src, void
|
|
|
90
94
|
WSP_GGML_API size_t wsp_quantize_q5_1(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
91
95
|
WSP_GGML_API size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
92
96
|
|
|
97
|
+
WSP_GGML_API size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
98
|
+
|
|
93
99
|
WSP_GGML_API void wsp_iq2xs_init_impl(enum wsp_ggml_type type);
|
|
94
100
|
WSP_GGML_API void wsp_iq2xs_free_impl(enum wsp_ggml_type type);
|
|
95
101
|
WSP_GGML_API void wsp_iq3xs_init_impl(int grid_size);
|
|
Binary file
|
|
Binary file
|