whisper.rn 0.5.0-rc.9 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/cpp/ggml-alloc.c +265 -141
- package/cpp/ggml-backend-impl.h +4 -1
- package/cpp/ggml-backend-reg.cpp +30 -13
- package/cpp/ggml-backend.cpp +221 -38
- package/cpp/ggml-backend.h +17 -1
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/amx/amx.cpp +4 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +32 -2
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
- package/cpp/ggml-cpu/ggml-cpu.c +70 -42
- package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
- package/cpp/ggml-cpu/ops.cpp +1587 -1177
- package/cpp/ggml-cpu/ops.h +5 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +89 -60
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +170 -26
- package/cpp/ggml-cpu/vec.h +506 -63
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +119 -9
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
- package/cpp/ggml-metal/ggml-metal-device.h +226 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +90 -51
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +111 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml.c +486 -98
- package/cpp/ggml.h +221 -16
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +25 -6
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/whisper.cpp +100 -76
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +16 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
- package/src/realtime-transcription/types.ts +6 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6284
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h
CHANGED
|
@@ -73,6 +73,35 @@ static inline int wsp_ggml_up(int n, int m) {
|
|
|
73
73
|
return (n + m - 1) & ~(m - 1);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
+
// TODO: move to ggml.h? (won't be able to inline)
|
|
77
|
+
static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
|
|
78
|
+
if (a->type != b->type) {
|
|
79
|
+
return false;
|
|
80
|
+
}
|
|
81
|
+
for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
|
|
82
|
+
if (a->ne[i] != b->ne[i]) {
|
|
83
|
+
return false;
|
|
84
|
+
}
|
|
85
|
+
if (a->nb[i] != b->nb[i]) {
|
|
86
|
+
return false;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
return true;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
|
|
93
|
+
switch (op) {
|
|
94
|
+
case WSP_GGML_OP_NONE:
|
|
95
|
+
case WSP_GGML_OP_RESHAPE:
|
|
96
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
97
|
+
case WSP_GGML_OP_VIEW:
|
|
98
|
+
case WSP_GGML_OP_PERMUTE:
|
|
99
|
+
return true;
|
|
100
|
+
default:
|
|
101
|
+
return false;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
76
105
|
//
|
|
77
106
|
// logging
|
|
78
107
|
//
|
|
@@ -313,6 +342,10 @@ struct wsp_ggml_cgraph {
|
|
|
313
342
|
// if you need the gradients, get them from the original graph
|
|
314
343
|
struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
|
|
315
344
|
|
|
345
|
+
// ggml-alloc.c: true if the operation can reuse memory from its sources
|
|
346
|
+
WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
|
|
347
|
+
|
|
348
|
+
|
|
316
349
|
// Memory allocation
|
|
317
350
|
|
|
318
351
|
WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
|
|
@@ -394,6 +427,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
|
|
|
394
427
|
#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
|
|
395
428
|
#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
|
|
396
429
|
|
|
430
|
+
static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
|
|
431
|
+
uint32_t bits; // Stores the raw bit representation of the float
|
|
432
|
+
|
|
433
|
+
// Handle special case for minimum exponent (denormalized float)
|
|
434
|
+
if (x == 0) {
|
|
435
|
+
// Bit pattern for 2^(-127):
|
|
436
|
+
// - Sign bit: 0 (positive)
|
|
437
|
+
// - Exponent: 0 (denormalized number)
|
|
438
|
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
|
439
|
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
|
440
|
+
bits = 0x00400000;
|
|
441
|
+
}
|
|
442
|
+
// note: disabled as we don't need to handle NaNs
|
|
443
|
+
//// Handle special case for NaN (all bits set)
|
|
444
|
+
//else if (x == 0xFF) {
|
|
445
|
+
// // Standard quiet NaN pattern:
|
|
446
|
+
// // - Sign bit: 0
|
|
447
|
+
// // - Exponent: all 1s (0xFF)
|
|
448
|
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
|
449
|
+
// bits = 0x7FC00000;
|
|
450
|
+
//}
|
|
451
|
+
// Normalized values (most common case)
|
|
452
|
+
else {
|
|
453
|
+
// Construct normalized float by shifting exponent into position:
|
|
454
|
+
// - Exponent field: 8 bits (positions 30-23)
|
|
455
|
+
// - Mantissa: 0 (implicit leading 1)
|
|
456
|
+
// Value = 2^(x - 127)
|
|
457
|
+
bits = (uint32_t) x << 23;
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
float result; // Final float value
|
|
461
|
+
// Safely reinterpret bit pattern as float without type-punning issues
|
|
462
|
+
memcpy(&result, &bits, sizeof(float));
|
|
463
|
+
return result;
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
// Equal to wsp_ggml_e8m0_to_fp32/2
|
|
467
|
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
|
468
|
+
static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
|
|
469
|
+
uint32_t bits;
|
|
470
|
+
|
|
471
|
+
// For x < 2: use precomputed denormal patterns
|
|
472
|
+
if (x < 2) {
|
|
473
|
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
|
474
|
+
bits = 0x00200000 << x;
|
|
475
|
+
}
|
|
476
|
+
// For x >= 2: normalized exponent adjustment
|
|
477
|
+
else {
|
|
478
|
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
|
479
|
+
bits = (uint32_t)(x - 1) << 23;
|
|
480
|
+
}
|
|
481
|
+
// Note: NaNs are not handled here
|
|
482
|
+
|
|
483
|
+
float result;
|
|
484
|
+
memcpy(&result, &bits, sizeof(float));
|
|
485
|
+
return result;
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
#define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
|
|
489
|
+
#define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
|
|
490
|
+
|
|
397
491
|
/**
|
|
398
492
|
* Converts brain16 to float32.
|
|
399
493
|
*
|
|
@@ -493,27 +587,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
|
|
|
493
587
|
return true;
|
|
494
588
|
}
|
|
495
589
|
|
|
496
|
-
// Returns true if nodes
|
|
590
|
+
// Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
|
|
497
591
|
// and are fusable. Nodes are considered fusable according to this function if:
|
|
498
592
|
// - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
|
|
499
593
|
// - all nodes except the last are a src of the following node.
|
|
500
594
|
// - all nodes are the same shape.
|
|
501
595
|
// TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
|
|
502
|
-
static inline bool
|
|
503
|
-
if (node_idx + num_ops > cgraph->n_nodes) {
|
|
504
|
-
return false;
|
|
505
|
-
}
|
|
506
|
-
|
|
596
|
+
static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
|
|
507
597
|
for (int i = 0; i < num_ops; ++i) {
|
|
508
|
-
|
|
598
|
+
if (node_idxs[i] >= cgraph->n_nodes) {
|
|
599
|
+
return false;
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
|
509
603
|
if (node->op != ops[i]) {
|
|
510
604
|
return false;
|
|
511
605
|
}
|
|
512
|
-
if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph,
|
|
606
|
+
if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
|
|
513
607
|
return false;
|
|
514
608
|
}
|
|
515
609
|
if (i > 0) {
|
|
516
|
-
struct wsp_ggml_tensor * prev = cgraph->nodes[
|
|
610
|
+
struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
|
|
517
611
|
if (node->src[0] != prev && node->src[1] != prev) {
|
|
518
612
|
return false;
|
|
519
613
|
}
|
|
@@ -525,6 +619,22 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
|
|
|
525
619
|
return true;
|
|
526
620
|
}
|
|
527
621
|
|
|
622
|
+
// same as above, for sequential indices starting at node_idx
|
|
623
|
+
static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
|
|
624
|
+
assert(num_ops < 32);
|
|
625
|
+
|
|
626
|
+
if (node_idx + num_ops > cgraph->n_nodes) {
|
|
627
|
+
return false;
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
int idxs[32];
|
|
631
|
+
for (int i = 0; i < num_ops; ++i) {
|
|
632
|
+
idxs[i] = node_idx + i;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
|
636
|
+
}
|
|
637
|
+
|
|
528
638
|
#ifdef __cplusplus
|
|
529
639
|
}
|
|
530
640
|
#endif
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
#ifndef
|
|
2
|
-
#define
|
|
1
|
+
#ifndef WSP_WSP_WSP_GGML_METAL_IMPL
|
|
2
|
+
#define WSP_WSP_WSP_GGML_METAL_IMPL
|
|
3
3
|
|
|
4
4
|
// kernel parameters for mat-vec threadgroups
|
|
5
5
|
//
|
|
@@ -23,6 +23,9 @@
|
|
|
23
23
|
#define N_R0_Q8_0 4
|
|
24
24
|
#define N_SG_Q8_0 2
|
|
25
25
|
|
|
26
|
+
#define N_R0_MXFP4 2
|
|
27
|
+
#define N_SG_MXFP4 2
|
|
28
|
+
|
|
26
29
|
#define N_R0_Q2_K 4
|
|
27
30
|
#define N_SG_Q2_K 2
|
|
28
31
|
|
|
@@ -98,7 +101,7 @@ typedef struct {
|
|
|
98
101
|
uint64_t nb2;
|
|
99
102
|
uint64_t nb3;
|
|
100
103
|
int32_t dim;
|
|
101
|
-
}
|
|
104
|
+
} wsp_wsp_wsp_ggml_metal_kargs_concat;
|
|
102
105
|
|
|
103
106
|
typedef struct {
|
|
104
107
|
int32_t ne00;
|
|
@@ -126,7 +129,17 @@ typedef struct {
|
|
|
126
129
|
uint64_t nb2;
|
|
127
130
|
uint64_t nb3;
|
|
128
131
|
uint64_t offs;
|
|
129
|
-
|
|
132
|
+
uint64_t o1[8];
|
|
133
|
+
} wsp_wsp_wsp_ggml_metal_kargs_bin;
|
|
134
|
+
|
|
135
|
+
typedef struct {
|
|
136
|
+
int64_t ne0;
|
|
137
|
+
int64_t ne1;
|
|
138
|
+
size_t nb01;
|
|
139
|
+
size_t nb02;
|
|
140
|
+
size_t nb11;
|
|
141
|
+
size_t nb21;
|
|
142
|
+
} wsp_wsp_wsp_ggml_metal_kargs_add_id;
|
|
130
143
|
|
|
131
144
|
typedef struct {
|
|
132
145
|
int32_t ne00;
|
|
@@ -145,7 +158,7 @@ typedef struct {
|
|
|
145
158
|
uint64_t nb1;
|
|
146
159
|
uint64_t nb2;
|
|
147
160
|
uint64_t nb3;
|
|
148
|
-
}
|
|
161
|
+
} wsp_wsp_wsp_ggml_metal_kargs_repeat;
|
|
149
162
|
|
|
150
163
|
typedef struct {
|
|
151
164
|
int64_t ne00;
|
|
@@ -164,7 +177,7 @@ typedef struct {
|
|
|
164
177
|
uint64_t nb1;
|
|
165
178
|
uint64_t nb2;
|
|
166
179
|
uint64_t nb3;
|
|
167
|
-
}
|
|
180
|
+
} wsp_wsp_wsp_ggml_metal_kargs_cpy;
|
|
168
181
|
|
|
169
182
|
typedef struct {
|
|
170
183
|
int64_t ne10;
|
|
@@ -179,7 +192,7 @@ typedef struct {
|
|
|
179
192
|
uint64_t nb3;
|
|
180
193
|
uint64_t offs;
|
|
181
194
|
bool inplace;
|
|
182
|
-
}
|
|
195
|
+
} wsp_wsp_wsp_ggml_metal_kargs_set;
|
|
183
196
|
|
|
184
197
|
typedef struct {
|
|
185
198
|
int32_t ne00;
|
|
@@ -211,7 +224,7 @@ typedef struct {
|
|
|
211
224
|
int32_t sect_1;
|
|
212
225
|
int32_t sect_2;
|
|
213
226
|
int32_t sect_3;
|
|
214
|
-
}
|
|
227
|
+
} wsp_wsp_wsp_ggml_metal_kargs_rope;
|
|
215
228
|
|
|
216
229
|
typedef struct {
|
|
217
230
|
int32_t ne01;
|
|
@@ -229,16 +242,20 @@ typedef struct {
|
|
|
229
242
|
uint64_t nb21;
|
|
230
243
|
uint64_t nb22;
|
|
231
244
|
uint64_t nb23;
|
|
245
|
+
int32_t ne32;
|
|
246
|
+
int32_t ne33;
|
|
232
247
|
uint64_t nb31;
|
|
248
|
+
uint64_t nb32;
|
|
249
|
+
uint64_t nb33;
|
|
233
250
|
int32_t ne1;
|
|
234
251
|
int32_t ne2;
|
|
235
252
|
float scale;
|
|
236
253
|
float max_bias;
|
|
237
254
|
float m0;
|
|
238
255
|
float m1;
|
|
239
|
-
|
|
256
|
+
int32_t n_head_log2;
|
|
240
257
|
float logit_softcap;
|
|
241
|
-
}
|
|
258
|
+
} wsp_wsp_wsp_ggml_metal_kargs_flash_attn_ext;
|
|
242
259
|
|
|
243
260
|
typedef struct {
|
|
244
261
|
int32_t ne00;
|
|
@@ -255,7 +272,7 @@ typedef struct {
|
|
|
255
272
|
int32_t ne1;
|
|
256
273
|
int16_t r2;
|
|
257
274
|
int16_t r3;
|
|
258
|
-
}
|
|
275
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm;
|
|
259
276
|
|
|
260
277
|
typedef struct {
|
|
261
278
|
int32_t ne00;
|
|
@@ -276,7 +293,7 @@ typedef struct {
|
|
|
276
293
|
int32_t ne1;
|
|
277
294
|
int16_t r2;
|
|
278
295
|
int16_t r3;
|
|
279
|
-
}
|
|
296
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mv;
|
|
280
297
|
|
|
281
298
|
typedef struct {
|
|
282
299
|
int32_t ne00;
|
|
@@ -300,7 +317,7 @@ typedef struct {
|
|
|
300
317
|
int16_t nsg;
|
|
301
318
|
int16_t nxpsg;
|
|
302
319
|
int16_t r1ptg;
|
|
303
|
-
}
|
|
320
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mv_ext;
|
|
304
321
|
|
|
305
322
|
typedef struct {
|
|
306
323
|
int32_t ne10;
|
|
@@ -311,7 +328,7 @@ typedef struct {
|
|
|
311
328
|
uint64_t nbh11;
|
|
312
329
|
int32_t ne20; // n_expert_used
|
|
313
330
|
uint64_t nb21;
|
|
314
|
-
}
|
|
331
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map0;
|
|
315
332
|
|
|
316
333
|
typedef struct {
|
|
317
334
|
int32_t ne20; // n_expert_used
|
|
@@ -322,7 +339,7 @@ typedef struct {
|
|
|
322
339
|
int32_t ne0;
|
|
323
340
|
uint64_t nb1;
|
|
324
341
|
uint64_t nb2;
|
|
325
|
-
}
|
|
342
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map1;
|
|
326
343
|
|
|
327
344
|
typedef struct {
|
|
328
345
|
int32_t ne00;
|
|
@@ -339,7 +356,7 @@ typedef struct {
|
|
|
339
356
|
int32_t neh1;
|
|
340
357
|
int16_t r2;
|
|
341
358
|
int16_t r3;
|
|
342
|
-
}
|
|
359
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id;
|
|
343
360
|
|
|
344
361
|
typedef struct {
|
|
345
362
|
int32_t nei0;
|
|
@@ -361,28 +378,36 @@ typedef struct {
|
|
|
361
378
|
int32_t ne0;
|
|
362
379
|
int32_t ne1;
|
|
363
380
|
uint64_t nb1;
|
|
364
|
-
}
|
|
381
|
+
} wsp_wsp_wsp_ggml_metal_kargs_mul_mv_id;
|
|
365
382
|
|
|
366
383
|
typedef struct {
|
|
367
384
|
int32_t ne00;
|
|
368
385
|
int32_t ne00_4;
|
|
369
386
|
uint64_t nb01;
|
|
370
387
|
float eps;
|
|
371
|
-
}
|
|
388
|
+
} wsp_wsp_wsp_ggml_metal_kargs_norm;
|
|
372
389
|
|
|
373
390
|
typedef struct {
|
|
374
391
|
int32_t ne00;
|
|
375
392
|
int32_t ne00_4;
|
|
376
|
-
uint64_t
|
|
393
|
+
uint64_t nb1;
|
|
394
|
+
uint64_t nb2;
|
|
395
|
+
uint64_t nb3;
|
|
377
396
|
float eps;
|
|
378
|
-
|
|
397
|
+
int32_t nef1[3];
|
|
398
|
+
int32_t nef2[3];
|
|
399
|
+
int32_t nef3[3];
|
|
400
|
+
uint64_t nbf1[3];
|
|
401
|
+
uint64_t nbf2[3];
|
|
402
|
+
uint64_t nbf3[3];
|
|
403
|
+
} wsp_wsp_wsp_ggml_metal_kargs_rms_norm;
|
|
379
404
|
|
|
380
405
|
typedef struct {
|
|
381
406
|
int32_t ne00;
|
|
382
407
|
int32_t ne00_4;
|
|
383
408
|
uint64_t nb01;
|
|
384
409
|
float eps;
|
|
385
|
-
}
|
|
410
|
+
} wsp_wsp_wsp_ggml_metal_kargs_l2_norm;
|
|
386
411
|
|
|
387
412
|
typedef struct {
|
|
388
413
|
int64_t ne00;
|
|
@@ -393,7 +418,7 @@ typedef struct {
|
|
|
393
418
|
uint64_t nb02;
|
|
394
419
|
int32_t n_groups;
|
|
395
420
|
float eps;
|
|
396
|
-
}
|
|
421
|
+
} wsp_wsp_wsp_ggml_metal_kargs_group_norm;
|
|
397
422
|
|
|
398
423
|
typedef struct {
|
|
399
424
|
int32_t IC;
|
|
@@ -402,7 +427,7 @@ typedef struct {
|
|
|
402
427
|
int32_t s0;
|
|
403
428
|
uint64_t nb0;
|
|
404
429
|
uint64_t nb1;
|
|
405
|
-
}
|
|
430
|
+
} wsp_wsp_wsp_ggml_metal_kargs_conv_transpose_1d;
|
|
406
431
|
|
|
407
432
|
typedef struct {
|
|
408
433
|
uint64_t ofs0;
|
|
@@ -420,7 +445,7 @@ typedef struct {
|
|
|
420
445
|
int32_t KH;
|
|
421
446
|
int32_t KW;
|
|
422
447
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
|
423
|
-
}
|
|
448
|
+
} wsp_wsp_wsp_ggml_metal_kargs_im2col;
|
|
424
449
|
|
|
425
450
|
typedef struct{
|
|
426
451
|
int32_t ne00;
|
|
@@ -431,7 +456,9 @@ typedef struct{
|
|
|
431
456
|
uint64_t nb1;
|
|
432
457
|
int32_t i00;
|
|
433
458
|
int32_t i10;
|
|
434
|
-
|
|
459
|
+
float alpha;
|
|
460
|
+
float limit;
|
|
461
|
+
} wsp_wsp_wsp_ggml_metal_kargs_glu;
|
|
435
462
|
|
|
436
463
|
typedef struct {
|
|
437
464
|
int64_t ne00;
|
|
@@ -458,24 +485,36 @@ typedef struct {
|
|
|
458
485
|
uint64_t nb1;
|
|
459
486
|
uint64_t nb2;
|
|
460
487
|
uint64_t nb3;
|
|
461
|
-
}
|
|
488
|
+
} wsp_wsp_wsp_ggml_metal_kargs_sum_rows;
|
|
462
489
|
|
|
463
490
|
typedef struct {
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
491
|
+
int32_t ne00;
|
|
492
|
+
int32_t ne01;
|
|
493
|
+
int32_t ne02;
|
|
494
|
+
uint64_t nb01;
|
|
495
|
+
uint64_t nb02;
|
|
496
|
+
uint64_t nb03;
|
|
497
|
+
int32_t ne11;
|
|
498
|
+
int32_t ne12;
|
|
499
|
+
int32_t ne13;
|
|
500
|
+
uint64_t nb11;
|
|
501
|
+
uint64_t nb12;
|
|
502
|
+
uint64_t nb13;
|
|
503
|
+
uint64_t nb1;
|
|
504
|
+
uint64_t nb2;
|
|
505
|
+
uint64_t nb3;
|
|
467
506
|
float scale;
|
|
468
507
|
float max_bias;
|
|
469
508
|
float m0;
|
|
470
509
|
float m1;
|
|
471
|
-
|
|
472
|
-
}
|
|
510
|
+
int32_t n_head_log2;
|
|
511
|
+
} wsp_wsp_wsp_ggml_metal_kargs_soft_max;
|
|
473
512
|
|
|
474
513
|
typedef struct {
|
|
475
514
|
int64_t ne00;
|
|
476
515
|
int64_t ne01;
|
|
477
516
|
int n_past;
|
|
478
|
-
}
|
|
517
|
+
} wsp_wsp_wsp_ggml_metal_kargs_diag_mask_inf;
|
|
479
518
|
|
|
480
519
|
typedef struct {
|
|
481
520
|
int64_t ne00;
|
|
@@ -494,32 +533,32 @@ typedef struct {
|
|
|
494
533
|
uint64_t nb0;
|
|
495
534
|
uint64_t nb1;
|
|
496
535
|
uint64_t nb2;
|
|
497
|
-
}
|
|
536
|
+
} wsp_wsp_wsp_ggml_metal_kargs_ssm_conv;
|
|
498
537
|
|
|
499
538
|
typedef struct {
|
|
500
539
|
int64_t d_state;
|
|
501
540
|
int64_t d_inner;
|
|
541
|
+
int64_t n_head;
|
|
542
|
+
int64_t n_group;
|
|
502
543
|
int64_t n_seq_tokens;
|
|
503
544
|
int64_t n_seqs;
|
|
504
|
-
|
|
545
|
+
int64_t s_off;
|
|
505
546
|
uint64_t nb01;
|
|
506
547
|
uint64_t nb02;
|
|
507
|
-
uint64_t
|
|
548
|
+
uint64_t nb03;
|
|
508
549
|
uint64_t nb11;
|
|
509
550
|
uint64_t nb12;
|
|
510
551
|
uint64_t nb13;
|
|
511
|
-
uint64_t nb20;
|
|
512
552
|
uint64_t nb21;
|
|
513
553
|
uint64_t nb22;
|
|
514
|
-
uint64_t nb30;
|
|
515
554
|
uint64_t nb31;
|
|
516
|
-
uint64_t nb40;
|
|
517
555
|
uint64_t nb41;
|
|
518
556
|
uint64_t nb42;
|
|
519
|
-
uint64_t
|
|
557
|
+
uint64_t nb43;
|
|
520
558
|
uint64_t nb51;
|
|
521
559
|
uint64_t nb52;
|
|
522
|
-
|
|
560
|
+
uint64_t nb53;
|
|
561
|
+
} wsp_wsp_wsp_ggml_metal_kargs_ssm_scan;
|
|
523
562
|
|
|
524
563
|
typedef struct {
|
|
525
564
|
int64_t ne00;
|
|
@@ -530,7 +569,7 @@ typedef struct {
|
|
|
530
569
|
uint64_t nb11;
|
|
531
570
|
uint64_t nb1;
|
|
532
571
|
uint64_t nb2;
|
|
533
|
-
}
|
|
572
|
+
} wsp_wsp_wsp_ggml_metal_kargs_get_rows;
|
|
534
573
|
|
|
535
574
|
typedef struct {
|
|
536
575
|
int32_t nk0;
|
|
@@ -546,7 +585,7 @@ typedef struct {
|
|
|
546
585
|
uint64_t nb1;
|
|
547
586
|
uint64_t nb2;
|
|
548
587
|
uint64_t nb3;
|
|
549
|
-
}
|
|
588
|
+
} wsp_wsp_wsp_ggml_metal_kargs_set_rows;
|
|
550
589
|
|
|
551
590
|
typedef struct {
|
|
552
591
|
int64_t ne00;
|
|
@@ -569,7 +608,7 @@ typedef struct {
|
|
|
569
608
|
float sf1;
|
|
570
609
|
float sf2;
|
|
571
610
|
float sf3;
|
|
572
|
-
}
|
|
611
|
+
} wsp_wsp_wsp_ggml_metal_kargs_upscale;
|
|
573
612
|
|
|
574
613
|
typedef struct {
|
|
575
614
|
int64_t ne00;
|
|
@@ -588,7 +627,7 @@ typedef struct {
|
|
|
588
627
|
uint64_t nb1;
|
|
589
628
|
uint64_t nb2;
|
|
590
629
|
uint64_t nb3;
|
|
591
|
-
}
|
|
630
|
+
} wsp_wsp_wsp_ggml_metal_kargs_pad;
|
|
592
631
|
|
|
593
632
|
typedef struct {
|
|
594
633
|
int64_t ne00;
|
|
@@ -609,28 +648,28 @@ typedef struct {
|
|
|
609
648
|
uint64_t nb3;
|
|
610
649
|
int32_t p0;
|
|
611
650
|
int32_t p1;
|
|
612
|
-
}
|
|
651
|
+
} wsp_wsp_wsp_ggml_metal_kargs_pad_reflect_1d;
|
|
613
652
|
|
|
614
653
|
typedef struct {
|
|
615
654
|
uint64_t nb1;
|
|
616
655
|
int dim;
|
|
617
656
|
int max_period;
|
|
618
|
-
}
|
|
657
|
+
} wsp_wsp_wsp_ggml_metal_kargs_timestep_embedding;
|
|
619
658
|
|
|
620
659
|
typedef struct {
|
|
621
660
|
float slope;
|
|
622
|
-
}
|
|
661
|
+
} wsp_wsp_wsp_ggml_metal_kargs_leaky_relu;
|
|
623
662
|
|
|
624
663
|
typedef struct {
|
|
625
664
|
int64_t ncols;
|
|
626
665
|
int64_t ncols_pad;
|
|
627
|
-
}
|
|
666
|
+
} wsp_wsp_wsp_ggml_metal_kargs_argsort;
|
|
628
667
|
|
|
629
668
|
typedef struct {
|
|
630
669
|
int64_t ne0;
|
|
631
670
|
float start;
|
|
632
671
|
float step;
|
|
633
|
-
}
|
|
672
|
+
} wsp_wsp_wsp_ggml_metal_kargs_arange;
|
|
634
673
|
|
|
635
674
|
typedef struct {
|
|
636
675
|
int32_t k0;
|
|
@@ -644,6 +683,6 @@ typedef struct {
|
|
|
644
683
|
int64_t OH;
|
|
645
684
|
int64_t OW;
|
|
646
685
|
int64_t parallel_elements;
|
|
647
|
-
}
|
|
686
|
+
} wsp_wsp_wsp_ggml_metal_kargs_pool_2d;
|
|
648
687
|
|
|
649
|
-
#endif //
|
|
688
|
+
#endif // WSP_WSP_WSP_GGML_METAL_IMPL
|
|
@@ -39,18 +39,13 @@ extern "C" {
|
|
|
39
39
|
// user-code should use only these functions
|
|
40
40
|
//
|
|
41
41
|
|
|
42
|
+
// TODO: remove in the future
|
|
42
43
|
WSP_GGML_BACKEND_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void);
|
|
43
44
|
|
|
44
45
|
WSP_GGML_BACKEND_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend);
|
|
45
46
|
|
|
46
|
-
WSP_GGML_DEPRECATED(
|
|
47
|
-
WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
|
|
48
|
-
"obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
|
|
49
|
-
|
|
50
47
|
WSP_GGML_BACKEND_API void wsp_ggml_backend_metal_set_abort_callback(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * user_data);
|
|
51
48
|
|
|
52
|
-
WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
|
|
53
|
-
|
|
54
49
|
// helper to check if the device supports a specific family
|
|
55
50
|
// ideally, the user code should be doing these checks
|
|
56
51
|
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h
CHANGED
|
@@ -74,16 +74,26 @@ extern "C" {
|
|
|
74
74
|
WSP_GGML_OPT_BUILD_TYPE_OPT = 30,
|
|
75
75
|
};
|
|
76
76
|
|
|
77
|
+
enum wsp_ggml_opt_optimizer_type {
|
|
78
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
79
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_SGD,
|
|
80
|
+
|
|
81
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_COUNT
|
|
82
|
+
};
|
|
83
|
+
|
|
77
84
|
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
|
78
85
|
struct wsp_ggml_opt_optimizer_params {
|
|
79
|
-
// AdamW optimizer parameters
|
|
80
86
|
struct {
|
|
81
87
|
float alpha; // learning rate
|
|
82
|
-
float beta1;
|
|
83
|
-
float beta2;
|
|
88
|
+
float beta1; // first AdamW momentum
|
|
89
|
+
float beta2; // second AdamW momentum
|
|
84
90
|
float eps; // epsilon for numerical stability
|
|
85
|
-
float wd; // weight decay
|
|
91
|
+
float wd; // weight decay - 0.0f to disable
|
|
86
92
|
} adamw;
|
|
93
|
+
struct {
|
|
94
|
+
float alpha; // learning rate
|
|
95
|
+
float wd; // weight decay
|
|
96
|
+
} sgd;
|
|
87
97
|
};
|
|
88
98
|
|
|
89
99
|
// callback to calculate optimizer parameters prior to a backward pass
|
|
@@ -112,8 +122,11 @@ extern "C" {
|
|
|
112
122
|
|
|
113
123
|
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
|
114
124
|
|
|
115
|
-
wsp_ggml_opt_get_optimizer_params get_opt_pars;
|
|
116
|
-
void *
|
|
125
|
+
wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
|
126
|
+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
|
127
|
+
|
|
128
|
+
// only WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
|
129
|
+
enum wsp_ggml_opt_optimizer_type optimizer;
|
|
117
130
|
};
|
|
118
131
|
|
|
119
132
|
// get parameters for an optimization context with defaults set where possible
|
|
@@ -142,6 +155,10 @@ extern "C" {
|
|
|
142
155
|
// get the gradient accumulator for a node from the forward graph
|
|
143
156
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node);
|
|
144
157
|
|
|
158
|
+
WSP_GGML_API enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t); //TODO consistent naming scheme
|
|
159
|
+
|
|
160
|
+
WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type);
|
|
161
|
+
|
|
145
162
|
// ====== Optimization Result ======
|
|
146
163
|
|
|
147
164
|
WSP_GGML_API wsp_ggml_opt_result_t wsp_ggml_opt_result_init(void);
|
|
@@ -226,12 +243,14 @@ extern "C" {
|
|
|
226
243
|
struct wsp_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
|
227
244
|
wsp_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
|
228
245
|
enum wsp_ggml_opt_loss_type loss_type, // loss to minimize
|
|
246
|
+
enum wsp_ggml_opt_optimizer_type optimizer, // sgd or adamw
|
|
229
247
|
wsp_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
|
230
248
|
int64_t nepoch, // how many times the dataset should be iterated over
|
|
231
249
|
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
|
232
250
|
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
|
233
251
|
bool silent); // whether or not info prints to stderr should be suppressed
|
|
234
252
|
|
|
253
|
+
|
|
235
254
|
#ifdef __cplusplus
|
|
236
255
|
}
|
|
237
256
|
#endif
|
|
@@ -21,6 +21,8 @@ WSP_GGML_API void wsp_quantize_row_q5_1_ref(const float * WSP_GGML_RESTRICT x, b
|
|
|
21
21
|
WSP_GGML_API void wsp_quantize_row_q8_0_ref(const float * WSP_GGML_RESTRICT x, block_q8_0 * WSP_GGML_RESTRICT y, int64_t k);
|
|
22
22
|
WSP_GGML_API void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * WSP_GGML_RESTRICT y, int64_t k);
|
|
23
23
|
|
|
24
|
+
WSP_GGML_API void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k);
|
|
25
|
+
|
|
24
26
|
WSP_GGML_API void wsp_quantize_row_q2_K_ref(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
25
27
|
WSP_GGML_API void wsp_quantize_row_q3_K_ref(const float * WSP_GGML_RESTRICT x, block_q3_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
26
28
|
WSP_GGML_API void wsp_quantize_row_q4_K_ref(const float * WSP_GGML_RESTRICT x, block_q4_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
@@ -45,6 +47,8 @@ WSP_GGML_API void wsp_dewsp_quantize_row_q5_1(const block_q5_1 * WSP_GGML_RESTRI
|
|
|
45
47
|
WSP_GGML_API void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
46
48
|
//WSP_GGML_API void wsp_dewsp_quantize_row_q8_1(const block_q8_1 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
47
49
|
|
|
50
|
+
WSP_GGML_API void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
51
|
+
|
|
48
52
|
WSP_GGML_API void wsp_dewsp_quantize_row_q2_K(const block_q2_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
49
53
|
WSP_GGML_API void wsp_dewsp_quantize_row_q3_K(const block_q3_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
50
54
|
WSP_GGML_API void wsp_dewsp_quantize_row_q4_K(const block_q4_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
@@ -90,6 +94,8 @@ WSP_GGML_API size_t wsp_quantize_q5_0(const float * WSP_GGML_RESTRICT src, void
|
|
|
90
94
|
WSP_GGML_API size_t wsp_quantize_q5_1(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
91
95
|
WSP_GGML_API size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
92
96
|
|
|
97
|
+
WSP_GGML_API size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
98
|
+
|
|
93
99
|
WSP_GGML_API void wsp_iq2xs_init_impl(enum wsp_ggml_type type);
|
|
94
100
|
WSP_GGML_API void wsp_iq2xs_free_impl(enum wsp_ggml_type type);
|
|
95
101
|
WSP_GGML_API void wsp_iq3xs_init_impl(int grid_size);
|