whisper.rn 0.5.0-rc.8 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. package/cpp/ggml-alloc.c +1 -15
  2. package/cpp/ggml-backend-reg.cpp +17 -8
  3. package/cpp/ggml-backend.cpp +15 -22
  4. package/cpp/ggml-common.h +17 -0
  5. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  6. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  7. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  8. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  9. package/cpp/ggml-cpu/arch-fallback.h +34 -0
  10. package/cpp/ggml-cpu/ggml-cpu.c +22 -1
  11. package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
  12. package/cpp/ggml-cpu/ops.cpp +870 -211
  13. package/cpp/ggml-cpu/ops.h +3 -8
  14. package/cpp/ggml-cpu/quants.c +35 -0
  15. package/cpp/ggml-cpu/quants.h +8 -0
  16. package/cpp/ggml-cpu/repack.cpp +458 -47
  17. package/cpp/ggml-cpu/repack.h +22 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +1 -1
  19. package/cpp/ggml-cpu/traits.cpp +2 -2
  20. package/cpp/ggml-cpu/traits.h +1 -1
  21. package/cpp/ggml-cpu/vec.cpp +12 -9
  22. package/cpp/ggml-cpu/vec.h +107 -13
  23. package/cpp/ggml-impl.h +77 -0
  24. package/cpp/ggml-metal-impl.h +51 -12
  25. package/cpp/ggml-metal.m +610 -115
  26. package/cpp/ggml-opt.cpp +97 -41
  27. package/cpp/ggml-opt.h +25 -6
  28. package/cpp/ggml-quants.c +110 -16
  29. package/cpp/ggml-quants.h +6 -0
  30. package/cpp/ggml-whisper-sim.metallib +0 -0
  31. package/cpp/ggml-whisper.metallib +0 -0
  32. package/cpp/ggml.c +314 -88
  33. package/cpp/ggml.h +137 -11
  34. package/cpp/gguf.cpp +8 -1
  35. package/cpp/jsi/RNWhisperJSI.cpp +23 -6
  36. package/cpp/whisper.cpp +15 -6
  37. package/ios/RNWhisper.mm +6 -6
  38. package/ios/RNWhisperContext.mm +2 -0
  39. package/ios/RNWhisperVadContext.mm +2 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  72. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +28 -2
  73. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  74. package/lib/module/realtime-transcription/RealtimeTranscriber.js +28 -2
  75. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  76. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +1 -0
  77. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  78. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  79. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  80. package/package.json +1 -1
  81. package/src/realtime-transcription/RealtimeTranscriber.ts +32 -0
  82. package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml-opt.cpp CHANGED
@@ -64,9 +64,11 @@ struct wsp_ggml_opt_context {
64
64
  int32_t opt_i = 0;
65
65
  bool loss_per_datapoint = false;
66
66
 
67
- wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
- void * get_opt_pars_ud = nullptr;
69
- struct wsp_ggml_tensor * adamw_params = nullptr;
67
+ wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
+ void * get_opt_pars_ud = nullptr;
69
+ struct wsp_ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
70
+
71
+ enum wsp_ggml_opt_optimizer_type optimizer = WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
70
72
  };
71
73
 
72
74
  struct wsp_ggml_opt_result {
@@ -229,9 +231,13 @@ struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_default_optimizer_params(v
229
231
  result.adamw.eps = 1e-8f;
230
232
  result.adamw.wd = 0.0f;
231
233
 
234
+ result.sgd.alpha = 1e-3f;
235
+ result.sgd.wd = 0.0f;
236
+
232
237
  return result;
233
238
  }
234
239
 
240
+
235
241
  struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_constant_optimizer_params(void * userdata) {
236
242
  return *((struct wsp_ggml_opt_optimizer_params *) userdata);
237
243
  }
@@ -249,6 +255,7 @@ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(
249
255
  /*opt_period =*/ 1,
250
256
  /*get_opt_pars =*/ wsp_ggml_opt_get_default_optimizer_params,
251
257
  /*get_opt_pars_ud =*/ nullptr,
258
+ /*optimizer =*/ WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
252
259
  };
253
260
  }
254
261
 
@@ -316,9 +323,14 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
316
323
  WSP_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with wsp_ggml_opt_prepare_alloc");
317
324
  WSP_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
318
325
 
326
+ const enum wsp_ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
327
+
319
328
  const bool accumulate = opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_GRAD &&
320
329
  !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
321
330
 
331
+ const bool need_momenta = opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT &&
332
+ opt_ctx->optimizer == WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
333
+
322
334
  wsp_ggml_set_input(opt_ctx->inputs);
323
335
  wsp_ggml_set_output(opt_ctx->outputs);
324
336
 
@@ -340,8 +352,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
340
352
  // - pred (if using static graphs)
341
353
  // - ncorrect (if using static graphs, 2 tensors).
342
354
  constexpr size_t n_loss = 1;
343
- const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
- (opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
355
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
345
356
  const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
357
  const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * wsp_ggml_tensor_overhead();
347
358
  struct wsp_ggml_init_params params = {
@@ -458,7 +469,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
458
469
  }
459
470
  }
460
471
 
461
- if (opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
472
+ if (need_momenta && opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
462
473
  opt_ctx->grad_m.resize(n_nodes);
463
474
  opt_ctx->grad_v.resize(n_nodes);
464
475
  for (int i = 0; i < n_nodes; ++i) {
@@ -492,23 +503,36 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
492
503
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
493
504
  opt_ctx->gb_opt = wsp_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
494
505
 
495
- opt_ctx->adamw_params = wsp_ggml_new_tensor_1d(opt_ctx->ctx_cpu, WSP_GGML_TYPE_F32, 7);
496
- wsp_ggml_set_input(opt_ctx->adamw_params);
497
- wsp_ggml_set_name(opt_ctx->adamw_params, "adamw_params");
498
-
506
+ opt_ctx->opt_step_params = wsp_ggml_new_tensor_1d(opt_ctx->ctx_cpu, WSP_GGML_TYPE_F32, need_momenta ? 7 : 2);
507
+ wsp_ggml_tensor * adamw_params = opt_ctx->opt_step_params;
508
+ wsp_ggml_set_input(adamw_params);
509
+ const char * optimizer_name = wsp_ggml_opt_optimizer_name(opt_ctx->optimizer);
510
+ wsp_ggml_format_name(adamw_params, "%s_params", optimizer_name);
499
511
  for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
512
  struct wsp_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
513
  struct wsp_ggml_tensor * grad = wsp_ggml_graph_get_grad(opt_ctx->gb_opt, node);
502
514
 
503
515
  if (grad && (node->flags & WSP_GGML_TENSOR_FLAG_PARAM)) {
504
- struct wsp_ggml_tensor * m = opt_ctx->grad_m[i];
505
- struct wsp_ggml_tensor * v = opt_ctx->grad_v[i];
506
- struct wsp_ggml_tensor * opt_step = wsp_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
-
508
- wsp_ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
- wsp_ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
- wsp_ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
-
516
+ struct wsp_ggml_tensor * m = nullptr;
517
+ struct wsp_ggml_tensor * v = nullptr;
518
+ if (need_momenta) {
519
+ m = opt_ctx->grad_m[i];
520
+ v = opt_ctx->grad_v[i];
521
+ wsp_ggml_format_name(m, "AdamW m for %s", node->name);
522
+ wsp_ggml_format_name(v, "AdamW v for %s", node->name);
523
+ }
524
+ struct wsp_ggml_tensor * opt_step;
525
+ switch (optimizer) {
526
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
527
+ opt_step = wsp_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
528
+ break;
529
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
530
+ opt_step = wsp_ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
531
+ break;
532
+ default:
533
+ WSP_GGML_ABORT("fatal error");
534
+ }
535
+ wsp_ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
512
536
  wsp_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
513
537
  }
514
538
  }
@@ -534,6 +558,7 @@ wsp_ggml_opt_context_t wsp_ggml_opt_init(struct wsp_ggml_opt_params params) {
534
558
  result->opt_period = params.opt_period;
535
559
  result->get_opt_pars = params.get_opt_pars;
536
560
  result->get_opt_pars_ud = params.get_opt_pars_ud;
561
+ result->optimizer = params.optimizer;
537
562
 
538
563
  WSP_GGML_ASSERT(result->opt_period >= 1);
539
564
 
@@ -756,29 +781,43 @@ void wsp_ggml_opt_alloc(wsp_ggml_opt_context_t opt_ctx, bool backward) {
756
781
  void wsp_ggml_opt_eval(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result_t result) {
757
782
  WSP_GGML_ASSERT(opt_ctx->eval_ready);
758
783
  if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
759
- struct wsp_ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
760
-
761
- WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
762
- WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
763
- WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
764
- WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
765
- WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
766
- WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
767
- WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
768
- WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
769
-
770
- // beta1, beta2 after applying warmup
771
- const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
772
- const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
773
-
774
- float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->adamw_params);
775
- adamw_par_data[0] = opt_pars.adamw.alpha;
776
- adamw_par_data[1] = opt_pars.adamw.beta1;
777
- adamw_par_data[2] = opt_pars.adamw.beta2;
778
- adamw_par_data[3] = opt_pars.adamw.eps;
779
- adamw_par_data[4] = opt_pars.adamw.wd;
780
- adamw_par_data[5] = beta1h;
781
- adamw_par_data[6] = beta2h;
784
+ const wsp_ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
785
+
786
+ switch (opt_ctx->optimizer) {
787
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
788
+ WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
789
+ WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
790
+ WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
791
+ WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
792
+ WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
793
+ WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
794
+ WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
795
+ WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
796
+
797
+ // beta1, beta2 after applying warmup
798
+ const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
799
+ const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
800
+
801
+ float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
802
+ adamw_par_data[0] = opt_pars.adamw.alpha;
803
+ adamw_par_data[1] = opt_pars.adamw.beta1;
804
+ adamw_par_data[2] = opt_pars.adamw.beta2;
805
+ adamw_par_data[3] = opt_pars.adamw.eps;
806
+ adamw_par_data[4] = opt_pars.adamw.wd;
807
+ adamw_par_data[5] = beta1h;
808
+ adamw_par_data[6] = beta2h;
809
+ } break;
810
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD: {
811
+ WSP_GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
812
+ WSP_GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
813
+ WSP_GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
814
+ float * sgd = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
815
+ sgd[0] = opt_pars.sgd.alpha;
816
+ sgd[1] = opt_pars.sgd.wd;
817
+ } break;
818
+ default:
819
+ WSP_GGML_ABORT("fatal error");
820
+ }
782
821
  }
783
822
 
784
823
  wsp_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@@ -963,6 +1002,7 @@ void wsp_ggml_opt_fit(
963
1002
  wsp_ggml_tensor * outputs,
964
1003
  wsp_ggml_opt_dataset_t dataset,
965
1004
  enum wsp_ggml_opt_loss_type loss_type,
1005
+ enum wsp_ggml_opt_optimizer_type optimizer,
966
1006
  wsp_ggml_opt_get_optimizer_params get_opt_pars,
967
1007
  int64_t nepoch,
968
1008
  int64_t nbatch_logical,
@@ -993,6 +1033,7 @@ void wsp_ggml_opt_fit(
993
1033
  params.opt_period = opt_period;
994
1034
  params.get_opt_pars = get_opt_pars;
995
1035
  params.get_opt_pars_ud = &epoch;
1036
+ params.optimizer = optimizer;
996
1037
  wsp_ggml_opt_context_t opt_ctx = wsp_ggml_opt_init(params);
997
1038
 
998
1039
  // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@@ -1035,3 +1076,18 @@ void wsp_ggml_opt_fit(
1035
1076
  wsp_ggml_opt_result_free(result_train);
1036
1077
  wsp_ggml_opt_result_free(result_val);
1037
1078
  }
1079
+
1080
+ enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t c) {
1081
+ return c->optimizer;
1082
+ }
1083
+
1084
+ WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type o) {
1085
+ switch (o) {
1086
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1087
+ return "adamw";
1088
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
1089
+ return "sgd";
1090
+ default:
1091
+ return "undefined";
1092
+ };
1093
+ }
package/cpp/ggml-opt.h CHANGED
@@ -74,16 +74,26 @@ extern "C" {
74
74
  WSP_GGML_OPT_BUILD_TYPE_OPT = 30,
75
75
  };
76
76
 
77
+ enum wsp_ggml_opt_optimizer_type {
78
+ WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79
+ WSP_GGML_OPT_OPTIMIZER_TYPE_SGD,
80
+
81
+ WSP_GGML_OPT_OPTIMIZER_TYPE_COUNT
82
+ };
83
+
77
84
  // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
78
85
  struct wsp_ggml_opt_optimizer_params {
79
- // AdamW optimizer parameters
80
86
  struct {
81
87
  float alpha; // learning rate
82
- float beta1;
83
- float beta2;
88
+ float beta1; // first AdamW momentum
89
+ float beta2; // second AdamW momentum
84
90
  float eps; // epsilon for numerical stability
85
- float wd; // weight decay for AdamW, use 0.0f to disable
91
+ float wd; // weight decay - 0.0f to disable
86
92
  } adamw;
93
+ struct {
94
+ float alpha; // learning rate
95
+ float wd; // weight decay
96
+ } sgd;
87
97
  };
88
98
 
89
99
  // callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
112
122
 
113
123
  int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114
124
 
115
- wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116
- void * get_opt_pars_ud; // userdata for calculating optimizer parameters
125
+ wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
126
+ void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127
+
128
+ // only WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
129
+ enum wsp_ggml_opt_optimizer_type optimizer;
117
130
  };
118
131
 
119
132
  // get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142
155
  // get the gradient accumulator for a node from the forward graph
143
156
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node);
144
157
 
158
+ WSP_GGML_API enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t); //TODO consistent naming scheme
159
+
160
+ WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type);
161
+
145
162
  // ====== Optimization Result ======
146
163
 
147
164
  WSP_GGML_API wsp_ggml_opt_result_t wsp_ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226
243
  struct wsp_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227
244
  wsp_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228
245
  enum wsp_ggml_opt_loss_type loss_type, // loss to minimize
246
+ enum wsp_ggml_opt_optimizer_type optimizer, // sgd or adamw
229
247
  wsp_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230
248
  int64_t nepoch, // how many times the dataset should be iterated over
231
249
  int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232
250
  float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233
251
  bool silent); // whether or not info prints to stderr should be suppressed
234
252
 
253
+
235
254
  #ifdef __cplusplus
236
255
  }
237
256
  #endif
package/cpp/ggml-quants.c CHANGED
@@ -21,6 +21,17 @@
21
21
 
22
22
  #define UNUSED WSP_GGML_UNUSED
23
23
 
24
+ static inline int best_index_int8(int n, const int8_t * val, float x) {
25
+ if (x <= val[0]) return 0;
26
+ if (x >= val[n-1]) return n-1;
27
+ int ml = 0, mu = n-1;
28
+ while (mu-ml > 1) {
29
+ int mav = (ml+mu)/2;
30
+ if (x < val[mav]) mu = mav; else ml = mav;
31
+ }
32
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
33
+ }
34
+
24
35
  // reference implementation for deterministic creation of model files
25
36
  void wsp_quantize_row_q4_0_ref(const float * WSP_GGML_RESTRICT x, block_q4_0 * WSP_GGML_RESTRICT y, int64_t k) {
26
37
  static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * W
246
257
  }
247
258
  }
248
259
 
260
+ static inline int best_index_mxfp4(float x, float e) {
261
+ int best_index = 0;
262
+ float best_err = fabsf(kvalues_mxfp4[0]*e - x);
263
+ for (int i = 1; i < 16; i++) {
264
+ float err = fabsf(kvalues_mxfp4[i]*e - x);
265
+ if (err < best_err) {
266
+ best_index = i;
267
+ best_err = err;
268
+ }
269
+ }
270
+ return best_index;
271
+ }
272
+
273
+ void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k) {
274
+ static const int qk = QK_MXFP4;
275
+
276
+ assert(k % qk == 0);
277
+
278
+ const int nb = k / qk;
279
+
280
+ for (int i = 0; i < nb; i++) {
281
+ float amax = 0.0f; // absolute max
282
+
283
+ for (int j = 0; j < qk; j++) {
284
+ const float v = x[i*qk + j];
285
+
286
+ if (amax < fabsf(v)) {
287
+ amax = fabsf(v);
288
+ }
289
+ }
290
+
291
+ const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
292
+
293
+ const float d = WSP_GGML_E8M0_TO_FP32_HALF(e);
294
+
295
+ y[i].e = e;
296
+
297
+ for (int j = 0; j < qk/2; ++j) {
298
+ const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
299
+ const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
300
+
301
+ y[i].qs[j] = x0;
302
+ y[i].qs[j] |= x1 << 4;
303
+ }
304
+ }
305
+ }
306
+
249
307
  void wsp_dewsp_quantize_row_q4_0(const block_q4_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k) {
250
308
  static const int qk = QK4_0;
251
309
 
@@ -356,6 +414,26 @@ void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float *
356
414
  }
357
415
  }
358
416
 
417
+ void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k) {
418
+ static const int qk = QK_MXFP4;
419
+
420
+ assert(k % qk == 0);
421
+
422
+ const int nb = k / qk;
423
+
424
+ for (int i = 0; i < nb; i++) {
425
+ const float d = WSP_GGML_E8M0_TO_FP32_HALF(x[i].e);
426
+
427
+ for (int j = 0; j < qk/2; ++j) {
428
+ const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
429
+ const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
430
+
431
+ y[i*qk + j + 0 ] = x0*d;
432
+ y[i*qk + j + qk/2] = x1*d;
433
+ }
434
+ }
435
+ }
436
+
359
437
  //
360
438
  // 2-6 bit quantization in super-blocks
361
439
  //
@@ -488,7 +566,7 @@ static float make_q3_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x,
488
566
  for (int i = 0; i < n; ++i) {
489
567
  L[i] += nmax;
490
568
  }
491
- return sumlx / suml2;
569
+ return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
492
570
  }
493
571
  for (int i = 0; i < n; ++i) {
494
572
  int l = nearest_int(iscale * x[i]);
@@ -823,7 +901,7 @@ static float make_qp_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x,
823
901
  for (int i = 0; i < n; ++i) {
824
902
  max = MAX(max, x[i]);
825
903
  }
826
- if (!max) { // all zero
904
+ if (max < GROUP_MAX_EPS) { // all zero
827
905
  for (int i = 0; i < n; ++i) { L[i] = 0; }
828
906
  return 0.f;
829
907
  }
@@ -888,7 +966,7 @@ static float make_qp_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x,
888
966
  break;
889
967
  }
890
968
  }
891
- return sumlx/suml2;
969
+ return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
892
970
  }
893
971
 
894
972
  static void wsp_quantize_row_q2_K_impl(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int k, const float * WSP_GGML_RESTRICT quant_weights) {
@@ -2014,6 +2092,12 @@ size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RE
2014
2092
  return nrow * row_size;
2015
2093
  }
2016
2094
 
2095
+ size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2096
+ WSP_GGML_UNUSED(quant_weights);
2097
+ wsp_quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
2098
+ return nrow * wsp_ggml_row_size(WSP_GGML_TYPE_MXFP4, n_per_row);
2099
+ }
2100
+
2017
2101
  // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
2018
2102
 
2019
2103
  void wsp_quantize_row_tq1_0_ref(const float * WSP_GGML_RESTRICT x, block_tq1_0 * WSP_GGML_RESTRICT y, int64_t k) {
@@ -4182,7 +4266,7 @@ static void wsp_quantize_row_iq1_s_impl(const float * WSP_GGML_RESTRICT x, void
4182
4266
  sumw[j+1] = sumw[j] + weight[i];
4183
4267
  }
4184
4268
  }
4185
- float best_score = -FLT_MIN, scale = max;
4269
+ float best_score = -FLT_MAX, scale = max;
4186
4270
  int besti1 = -1, besti2 = -1, best_shift = 0;
4187
4271
  for (int i1 = 0; i1 <= block_size; ++i1) {
4188
4272
  for (int i2 = i1; i2 <= block_size; ++i2) {
@@ -4358,7 +4442,7 @@ static void wsp_quantize_row_iq1_m_impl(const float * WSP_GGML_RESTRICT x, void
4358
4442
  idx[2*j] = j;
4359
4443
  }
4360
4444
  qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4361
- float best_score = -FLT_MIN, scale = max;
4445
+ float best_score = -FLT_MAX, scale = max;
4362
4446
  int besti1 = -1, besti2 = -1, best_k = -1;
4363
4447
  // 0: +, +
4364
4448
  // 1: +, -
@@ -4551,17 +4635,6 @@ size_t wsp_quantize_iq1_m(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_R
4551
4635
 
4552
4636
  // ============================ 4-bit non-linear quants
4553
4637
 
4554
- static inline int best_index_int8(int n, const int8_t * val, float x) {
4555
- if (x <= val[0]) return 0;
4556
- if (x >= val[n-1]) return n-1;
4557
- int ml = 0, mu = n-1;
4558
- while (mu-ml > 1) {
4559
- int mav = (ml+mu)/2;
4560
- if (x < val[mav]) mu = mav; else ml = mav;
4561
- }
4562
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4563
- }
4564
-
4565
4638
  static void wsp_quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * WSP_GGML_RESTRICT x,
4566
4639
  wsp_ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
4567
4640
  float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(wsp_ggml_fp16_t f, size_t i) {
4961
5034
  return true;
4962
5035
  }
4963
5036
 
5037
+ static bool validate_e_e8m0(uint8_t e, size_t i) {
5038
+ if (e == 0xff) {
5039
+ fprintf(stderr, "wsp_ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
5040
+ return false;
5041
+ }
5042
+
5043
+ return true;
5044
+ }
5045
+
4964
5046
  #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
4965
5047
  const type * q = (const type *) (data); \
4966
5048
  for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(wsp_ggml_fp16_t f, size_t i) {
4977
5059
  } \
4978
5060
  }
4979
5061
 
5062
+ #define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
5063
+ const type * q = (const type *) (data); \
5064
+ for (size_t i = 0; i < (nb); ++i) { \
5065
+ if (!validate_e_e8m0(q[i].e, i)) { \
5066
+ return false; \
5067
+ } \
5068
+ }
5069
+
4980
5070
  #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
4981
5071
  const type * q = (const type *) (data); \
4982
5072
  for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size
5130
5220
  {
5131
5221
  VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
5132
5222
  } break;
5223
+ case WSP_GGML_TYPE_MXFP4:
5224
+ {
5225
+ VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
5226
+ } break;
5133
5227
  case WSP_GGML_TYPE_Q2_K:
5134
5228
  {
5135
5229
  VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
package/cpp/ggml-quants.h CHANGED
@@ -21,6 +21,8 @@ WSP_GGML_API void wsp_quantize_row_q5_1_ref(const float * WSP_GGML_RESTRICT x, b
21
21
  WSP_GGML_API void wsp_quantize_row_q8_0_ref(const float * WSP_GGML_RESTRICT x, block_q8_0 * WSP_GGML_RESTRICT y, int64_t k);
22
22
  WSP_GGML_API void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * WSP_GGML_RESTRICT y, int64_t k);
23
23
 
24
+ WSP_GGML_API void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k);
25
+
24
26
  WSP_GGML_API void wsp_quantize_row_q2_K_ref(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int64_t k);
25
27
  WSP_GGML_API void wsp_quantize_row_q3_K_ref(const float * WSP_GGML_RESTRICT x, block_q3_K * WSP_GGML_RESTRICT y, int64_t k);
26
28
  WSP_GGML_API void wsp_quantize_row_q4_K_ref(const float * WSP_GGML_RESTRICT x, block_q4_K * WSP_GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ WSP_GGML_API void wsp_dewsp_quantize_row_q5_1(const block_q5_1 * WSP_GGML_RESTRI
45
47
  WSP_GGML_API void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
46
48
  //WSP_GGML_API void wsp_dewsp_quantize_row_q8_1(const block_q8_1 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
47
49
 
50
+ WSP_GGML_API void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
51
+
48
52
  WSP_GGML_API void wsp_dewsp_quantize_row_q2_K(const block_q2_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
49
53
  WSP_GGML_API void wsp_dewsp_quantize_row_q3_K(const block_q3_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
50
54
  WSP_GGML_API void wsp_dewsp_quantize_row_q4_K(const block_q4_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ WSP_GGML_API size_t wsp_quantize_q5_0(const float * WSP_GGML_RESTRICT src, void
90
94
  WSP_GGML_API size_t wsp_quantize_q5_1(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
91
95
  WSP_GGML_API size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
92
96
 
97
+ WSP_GGML_API size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
98
+
93
99
  WSP_GGML_API void wsp_iq2xs_init_impl(enum wsp_ggml_type type);
94
100
  WSP_GGML_API void wsp_iq2xs_free_impl(enum wsp_ggml_type type);
95
101
  WSP_GGML_API void wsp_iq3xs_init_impl(int grid_size);
Binary file
Binary file