whisper.rn 0.5.0-rc.9 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/cpp/ggml-alloc.c +265 -141
- package/cpp/ggml-backend-impl.h +4 -1
- package/cpp/ggml-backend-reg.cpp +30 -13
- package/cpp/ggml-backend.cpp +221 -38
- package/cpp/ggml-backend.h +17 -1
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/amx/amx.cpp +4 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +32 -2
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
- package/cpp/ggml-cpu/ggml-cpu.c +70 -42
- package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
- package/cpp/ggml-cpu/ops.cpp +1587 -1177
- package/cpp/ggml-cpu/ops.h +5 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +89 -60
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +170 -26
- package/cpp/ggml-cpu/vec.h +506 -63
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +119 -9
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
- package/cpp/ggml-metal/ggml-metal-device.h +226 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +90 -51
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +111 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml.c +486 -98
- package/cpp/ggml.h +221 -16
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +25 -6
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/whisper.cpp +100 -76
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +16 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
- package/src/realtime-transcription/types.ts +6 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6284
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml-cpu.h
CHANGED
|
@@ -101,7 +101,6 @@ extern "C" {
|
|
|
101
101
|
WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_riscv_v (void);
|
|
102
102
|
WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vsx (void);
|
|
103
103
|
WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vxe (void);
|
|
104
|
-
WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_nnpa (void);
|
|
105
104
|
WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_wasm_simd (void);
|
|
106
105
|
WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_llamafile (void);
|
|
107
106
|
|
|
@@ -135,6 +134,7 @@ extern "C" {
|
|
|
135
134
|
WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_cpu_reg(void);
|
|
136
135
|
|
|
137
136
|
WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
|
|
137
|
+
WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
|
|
138
138
|
WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp16(const float *, wsp_ggml_fp16_t *, int64_t);
|
|
139
139
|
WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp16_to_fp32(const wsp_ggml_fp16_t *, float *, int64_t);
|
|
140
140
|
WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_bf16(const float *, wsp_ggml_bf16_t *, int64_t);
|
package/cpp/ggml-impl.h
CHANGED
|
@@ -73,6 +73,35 @@ static inline int wsp_ggml_up(int n, int m) {
|
|
|
73
73
|
return (n + m - 1) & ~(m - 1);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
+
// TODO: move to ggml.h? (won't be able to inline)
|
|
77
|
+
static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
|
|
78
|
+
if (a->type != b->type) {
|
|
79
|
+
return false;
|
|
80
|
+
}
|
|
81
|
+
for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
|
|
82
|
+
if (a->ne[i] != b->ne[i]) {
|
|
83
|
+
return false;
|
|
84
|
+
}
|
|
85
|
+
if (a->nb[i] != b->nb[i]) {
|
|
86
|
+
return false;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
return true;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
|
|
93
|
+
switch (op) {
|
|
94
|
+
case WSP_GGML_OP_NONE:
|
|
95
|
+
case WSP_GGML_OP_RESHAPE:
|
|
96
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
97
|
+
case WSP_GGML_OP_VIEW:
|
|
98
|
+
case WSP_GGML_OP_PERMUTE:
|
|
99
|
+
return true;
|
|
100
|
+
default:
|
|
101
|
+
return false;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
76
105
|
//
|
|
77
106
|
// logging
|
|
78
107
|
//
|
|
@@ -313,6 +342,10 @@ struct wsp_ggml_cgraph {
|
|
|
313
342
|
// if you need the gradients, get them from the original graph
|
|
314
343
|
struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
|
|
315
344
|
|
|
345
|
+
// ggml-alloc.c: true if the operation can reuse memory from its sources
|
|
346
|
+
WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
|
|
347
|
+
|
|
348
|
+
|
|
316
349
|
// Memory allocation
|
|
317
350
|
|
|
318
351
|
WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
|
|
@@ -394,6 +427,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
|
|
|
394
427
|
#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
|
|
395
428
|
#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
|
|
396
429
|
|
|
430
|
+
static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
|
|
431
|
+
uint32_t bits; // Stores the raw bit representation of the float
|
|
432
|
+
|
|
433
|
+
// Handle special case for minimum exponent (denormalized float)
|
|
434
|
+
if (x == 0) {
|
|
435
|
+
// Bit pattern for 2^(-127):
|
|
436
|
+
// - Sign bit: 0 (positive)
|
|
437
|
+
// - Exponent: 0 (denormalized number)
|
|
438
|
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
|
439
|
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
|
440
|
+
bits = 0x00400000;
|
|
441
|
+
}
|
|
442
|
+
// note: disabled as we don't need to handle NaNs
|
|
443
|
+
//// Handle special case for NaN (all bits set)
|
|
444
|
+
//else if (x == 0xFF) {
|
|
445
|
+
// // Standard quiet NaN pattern:
|
|
446
|
+
// // - Sign bit: 0
|
|
447
|
+
// // - Exponent: all 1s (0xFF)
|
|
448
|
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
|
449
|
+
// bits = 0x7FC00000;
|
|
450
|
+
//}
|
|
451
|
+
// Normalized values (most common case)
|
|
452
|
+
else {
|
|
453
|
+
// Construct normalized float by shifting exponent into position:
|
|
454
|
+
// - Exponent field: 8 bits (positions 30-23)
|
|
455
|
+
// - Mantissa: 0 (implicit leading 1)
|
|
456
|
+
// Value = 2^(x - 127)
|
|
457
|
+
bits = (uint32_t) x << 23;
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
float result; // Final float value
|
|
461
|
+
// Safely reinterpret bit pattern as float without type-punning issues
|
|
462
|
+
memcpy(&result, &bits, sizeof(float));
|
|
463
|
+
return result;
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
// Equal to wsp_ggml_e8m0_to_fp32/2
|
|
467
|
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
|
468
|
+
static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
|
|
469
|
+
uint32_t bits;
|
|
470
|
+
|
|
471
|
+
// For x < 2: use precomputed denormal patterns
|
|
472
|
+
if (x < 2) {
|
|
473
|
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
|
474
|
+
bits = 0x00200000 << x;
|
|
475
|
+
}
|
|
476
|
+
// For x >= 2: normalized exponent adjustment
|
|
477
|
+
else {
|
|
478
|
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
|
479
|
+
bits = (uint32_t)(x - 1) << 23;
|
|
480
|
+
}
|
|
481
|
+
// Note: NaNs are not handled here
|
|
482
|
+
|
|
483
|
+
float result;
|
|
484
|
+
memcpy(&result, &bits, sizeof(float));
|
|
485
|
+
return result;
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
#define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
|
|
489
|
+
#define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
|
|
490
|
+
|
|
397
491
|
/**
|
|
398
492
|
* Converts brain16 to float32.
|
|
399
493
|
*
|
|
@@ -493,27 +587,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
|
|
|
493
587
|
return true;
|
|
494
588
|
}
|
|
495
589
|
|
|
496
|
-
// Returns true if nodes
|
|
590
|
+
// Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
|
|
497
591
|
// and are fusable. Nodes are considered fusable according to this function if:
|
|
498
592
|
// - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
|
|
499
593
|
// - all nodes except the last are a src of the following node.
|
|
500
594
|
// - all nodes are the same shape.
|
|
501
595
|
// TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
|
|
502
|
-
static inline bool
|
|
503
|
-
if (node_idx + num_ops > cgraph->n_nodes) {
|
|
504
|
-
return false;
|
|
505
|
-
}
|
|
506
|
-
|
|
596
|
+
static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
|
|
507
597
|
for (int i = 0; i < num_ops; ++i) {
|
|
508
|
-
|
|
598
|
+
if (node_idxs[i] >= cgraph->n_nodes) {
|
|
599
|
+
return false;
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
|
509
603
|
if (node->op != ops[i]) {
|
|
510
604
|
return false;
|
|
511
605
|
}
|
|
512
|
-
if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph,
|
|
606
|
+
if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
|
|
513
607
|
return false;
|
|
514
608
|
}
|
|
515
609
|
if (i > 0) {
|
|
516
|
-
struct wsp_ggml_tensor * prev = cgraph->nodes[
|
|
610
|
+
struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
|
|
517
611
|
if (node->src[0] != prev && node->src[1] != prev) {
|
|
518
612
|
return false;
|
|
519
613
|
}
|
|
@@ -525,6 +619,22 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
|
|
|
525
619
|
return true;
|
|
526
620
|
}
|
|
527
621
|
|
|
622
|
+
// same as above, for sequential indices starting at node_idx
|
|
623
|
+
static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
|
|
624
|
+
assert(num_ops < 32);
|
|
625
|
+
|
|
626
|
+
if (node_idx + num_ops > cgraph->n_nodes) {
|
|
627
|
+
return false;
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
int idxs[32];
|
|
631
|
+
for (int i = 0; i < num_ops; ++i) {
|
|
632
|
+
idxs[i] = node_idx + i;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
|
636
|
+
}
|
|
637
|
+
|
|
528
638
|
#ifdef __cplusplus
|
|
529
639
|
}
|
|
530
640
|
#endif
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
#include "ggml-metal-common.h"
|
|
2
|
+
|
|
3
|
+
#include "ggml-impl.h"
|
|
4
|
+
#include "ggml-backend-impl.h"
|
|
5
|
+
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
|
|
9
|
+
// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
|
|
10
|
+
struct wsp_ggml_mem_range {
|
|
11
|
+
uint64_t pb; // buffer id
|
|
12
|
+
|
|
13
|
+
uint64_t p0; // begin
|
|
14
|
+
uint64_t p1; // end
|
|
15
|
+
|
|
16
|
+
wsp_ggml_mem_range_type pt;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
struct wsp_ggml_mem_ranges {
|
|
20
|
+
std::vector<wsp_ggml_mem_range> ranges;
|
|
21
|
+
|
|
22
|
+
int debug = 0;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug) {
|
|
26
|
+
auto * res = new wsp_ggml_mem_ranges;
|
|
27
|
+
|
|
28
|
+
res->ranges.reserve(256);
|
|
29
|
+
res->debug = debug;
|
|
30
|
+
|
|
31
|
+
return res;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs) {
|
|
35
|
+
delete mrs;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs) {
|
|
39
|
+
mrs->ranges.clear();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
static bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
|
|
43
|
+
mrs->ranges.push_back(mr);
|
|
44
|
+
|
|
45
|
+
return true;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor(const wsp_ggml_tensor * tensor, wsp_ggml_mem_range_type pt) {
|
|
49
|
+
// always use the base tensor
|
|
50
|
+
tensor = tensor->view_src ? tensor->view_src : tensor;
|
|
51
|
+
|
|
52
|
+
WSP_GGML_ASSERT(!tensor->view_src);
|
|
53
|
+
|
|
54
|
+
wsp_ggml_mem_range mr;
|
|
55
|
+
|
|
56
|
+
if (tensor->buffer) {
|
|
57
|
+
// when the tensor is allocated, use the actual memory address range in the buffer
|
|
58
|
+
//
|
|
59
|
+
// take the actual allocated size with wsp_ggml_backend_buft_get_alloc_size()
|
|
60
|
+
// this can be larger than the tensor size if the buffer type allocates extra memory
|
|
61
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
|
62
|
+
mr = {
|
|
63
|
+
/*.pb =*/ (uint64_t) tensor->buffer,
|
|
64
|
+
/*.p0 =*/ (uint64_t) tensor->data,
|
|
65
|
+
/*.p1 =*/ (uint64_t) tensor->data + wsp_ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
|
|
66
|
+
/*.pt =*/ pt,
|
|
67
|
+
};
|
|
68
|
+
} else {
|
|
69
|
+
// otherwise, the pointer address is used as an unique id of the memory ranges
|
|
70
|
+
// that the tensor will be using when it is allocated
|
|
71
|
+
mr = {
|
|
72
|
+
/*.pb =*/ (uint64_t) tensor,
|
|
73
|
+
/*.p0 =*/ 0, //
|
|
74
|
+
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
|
|
75
|
+
/*.pt =*/ pt,
|
|
76
|
+
};
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
return mr;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_src(const wsp_ggml_tensor * tensor) {
|
|
83
|
+
return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_dst(const wsp_ggml_tensor * tensor) {
|
|
87
|
+
return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
static bool wsp_ggml_mem_ranges_add_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
91
|
+
WSP_GGML_ASSERT(tensor);
|
|
92
|
+
|
|
93
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
|
|
94
|
+
|
|
95
|
+
if (mrs->debug > 2) {
|
|
96
|
+
WSP_GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return wsp_ggml_mem_ranges_add(mrs, mr);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
static bool wsp_ggml_mem_ranges_add_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
103
|
+
WSP_GGML_ASSERT(tensor);
|
|
104
|
+
|
|
105
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
|
|
106
|
+
|
|
107
|
+
if (mrs->debug > 2) {
|
|
108
|
+
WSP_GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
return wsp_ggml_mem_ranges_add(mrs, mr);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
115
|
+
for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
|
|
116
|
+
if (tensor->src[i]) {
|
|
117
|
+
wsp_ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return wsp_ggml_mem_ranges_add_dst(mrs, tensor);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
static bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
|
|
125
|
+
for (size_t i = 0; i < mrs->ranges.size(); i++) {
|
|
126
|
+
const auto & cmp = mrs->ranges[i];
|
|
127
|
+
|
|
128
|
+
// two memory ranges cannot intersect if they are in different buffers
|
|
129
|
+
if (mr.pb != cmp.pb) {
|
|
130
|
+
continue;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// intersecting source ranges are allowed
|
|
134
|
+
if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
|
|
135
|
+
continue;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
|
|
139
|
+
if (mrs->debug > 2) {
|
|
140
|
+
WSP_GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
|
|
141
|
+
__func__,
|
|
142
|
+
mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
|
143
|
+
mr.pb, mr.p0, mr.p1,
|
|
144
|
+
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
|
145
|
+
cmp.pb, cmp.p0, cmp.p1);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
return false;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
return true;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
static bool wsp_ggml_mem_ranges_check_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
156
|
+
WSP_GGML_ASSERT(tensor);
|
|
157
|
+
|
|
158
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
|
|
159
|
+
|
|
160
|
+
const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
|
|
161
|
+
|
|
162
|
+
return res;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
static bool wsp_ggml_mem_ranges_check_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
166
|
+
WSP_GGML_ASSERT(tensor);
|
|
167
|
+
|
|
168
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
|
|
169
|
+
|
|
170
|
+
const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
|
|
171
|
+
|
|
172
|
+
return res;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
176
|
+
for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
|
|
177
|
+
if (tensor->src[i]) {
|
|
178
|
+
if (!wsp_ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
|
179
|
+
return false;
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
return wsp_ggml_mem_ranges_check_dst(mrs, tensor);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
struct node_info {
|
|
188
|
+
wsp_ggml_tensor * node;
|
|
189
|
+
|
|
190
|
+
std::vector<wsp_ggml_tensor *> fused;
|
|
191
|
+
|
|
192
|
+
wsp_ggml_op op() const {
|
|
193
|
+
return node->op;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
const wsp_ggml_tensor * dst() const {
|
|
197
|
+
return fused.empty() ? node : fused.back();
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
bool is_empty() const {
|
|
201
|
+
return wsp_ggml_op_is_empty(node->op);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
void add_fused(wsp_ggml_tensor * t) {
|
|
205
|
+
fused.push_back(t);
|
|
206
|
+
}
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
static std::vector<int> wsp_ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
|
|
210
|
+
// helper to add node src and dst ranges
|
|
211
|
+
const auto & h_add = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
|
|
212
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
213
|
+
if (node.node->src[i]) {
|
|
214
|
+
if (!wsp_ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
|
|
215
|
+
return false;
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
// keep track of the sources of the fused nodes as well
|
|
221
|
+
for (const auto * fused : node.fused) {
|
|
222
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
223
|
+
if (fused->src[i]) {
|
|
224
|
+
if (!wsp_ggml_mem_ranges_add_src(mrs, fused->src[i])) {
|
|
225
|
+
return false;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
return wsp_ggml_mem_ranges_add_dst(mrs, node.dst());
|
|
232
|
+
};
|
|
233
|
+
|
|
234
|
+
// helper to check if a node can run concurrently with the existing set of nodes
|
|
235
|
+
const auto & h_check = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
|
|
236
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
237
|
+
if (node.node->src[i]) {
|
|
238
|
+
if (!wsp_ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
|
|
239
|
+
return false;
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
for (const auto * fused : node.fused) {
|
|
245
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
246
|
+
if (fused->src[i]) {
|
|
247
|
+
if (!wsp_ggml_mem_ranges_check_src(mrs, fused->src[i])) {
|
|
248
|
+
return false;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
return wsp_ggml_mem_ranges_check_dst(mrs, node.dst());
|
|
255
|
+
};
|
|
256
|
+
|
|
257
|
+
// perform reorders only across these types of ops
|
|
258
|
+
// can be expanded when needed
|
|
259
|
+
const auto & h_safe = [](wsp_ggml_op op) {
|
|
260
|
+
switch (op) {
|
|
261
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
262
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
263
|
+
case WSP_GGML_OP_ROPE:
|
|
264
|
+
case WSP_GGML_OP_NORM:
|
|
265
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
266
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
267
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
268
|
+
case WSP_GGML_OP_MUL:
|
|
269
|
+
case WSP_GGML_OP_ADD:
|
|
270
|
+
case WSP_GGML_OP_DIV:
|
|
271
|
+
case WSP_GGML_OP_GLU:
|
|
272
|
+
case WSP_GGML_OP_SCALE:
|
|
273
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
274
|
+
case WSP_GGML_OP_CPY:
|
|
275
|
+
case WSP_GGML_OP_SET_ROWS:
|
|
276
|
+
return true;
|
|
277
|
+
default:
|
|
278
|
+
return wsp_ggml_op_is_empty(op);
|
|
279
|
+
}
|
|
280
|
+
};
|
|
281
|
+
|
|
282
|
+
const int n = nodes.size();
|
|
283
|
+
|
|
284
|
+
std::vector<int> res;
|
|
285
|
+
res.reserve(n);
|
|
286
|
+
|
|
287
|
+
std::vector<bool> used(n, false);
|
|
288
|
+
|
|
289
|
+
// the memory ranges for the set of currently concurrent nodes
|
|
290
|
+
wsp_ggml_mem_ranges_t mrs0 = wsp_ggml_mem_ranges_init(0);
|
|
291
|
+
|
|
292
|
+
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
|
|
293
|
+
wsp_ggml_mem_ranges_t mrs1 = wsp_ggml_mem_ranges_init(0);
|
|
294
|
+
|
|
295
|
+
for (int i0 = 0; i0 < n; i0++) {
|
|
296
|
+
if (used[i0]) {
|
|
297
|
+
continue;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
const auto & node0 = nodes[i0];
|
|
301
|
+
|
|
302
|
+
// the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
|
|
303
|
+
// but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
|
|
304
|
+
//
|
|
305
|
+
// note: we can always add empty nodes to the concurrent set as they don't read nor write anything
|
|
306
|
+
if (!node0.is_empty() && !h_check(mrs0, node0)) {
|
|
307
|
+
// this will hold the set of memory ranges from the nodes that haven't been processed yet
|
|
308
|
+
// if a node is not concurrent with this set, we cannot reorder it
|
|
309
|
+
wsp_ggml_mem_ranges_reset(mrs1);
|
|
310
|
+
|
|
311
|
+
// initialize it with the current node
|
|
312
|
+
h_add(mrs1, node0);
|
|
313
|
+
|
|
314
|
+
// that many nodes forward to search for a concurrent node
|
|
315
|
+
constexpr int N_FORWARD = 8;
|
|
316
|
+
|
|
317
|
+
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
|
318
|
+
if (used[i1]) {
|
|
319
|
+
continue;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
const auto & node1 = nodes[i1];
|
|
323
|
+
|
|
324
|
+
// disallow reordering of certain ops
|
|
325
|
+
if (!h_safe(node1.op())) {
|
|
326
|
+
break;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
const bool is_empty = node1.is_empty();
|
|
330
|
+
|
|
331
|
+
// to reorder a node and add it to the concurrent set, it has to be:
|
|
332
|
+
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
|
|
333
|
+
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
|
|
334
|
+
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
|
|
335
|
+
// add the node to the existing concurrent set (i.e. reorder it for early execution)
|
|
336
|
+
h_add(mrs0, node1);
|
|
337
|
+
res.push_back(i1);
|
|
338
|
+
|
|
339
|
+
// mark as used, so we skip re-processing it later
|
|
340
|
+
used[i1] = true;
|
|
341
|
+
} else {
|
|
342
|
+
// expand the set of nodes that haven't been processed yet
|
|
343
|
+
h_add(mrs1, node1);
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
// finalize the concurrent set and begin a new one
|
|
348
|
+
wsp_ggml_mem_ranges_reset(mrs0);
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
// expand the concurrent set with the current node
|
|
352
|
+
{
|
|
353
|
+
h_add(mrs0, node0);
|
|
354
|
+
res.push_back(i0);
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
wsp_ggml_mem_ranges_free(mrs0);
|
|
359
|
+
wsp_ggml_mem_ranges_free(mrs1);
|
|
360
|
+
|
|
361
|
+
return res;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
void wsp_ggml_graph_optimize(wsp_ggml_cgraph * gf) {
|
|
365
|
+
constexpr int MAX_FUSE = 16;
|
|
366
|
+
|
|
367
|
+
const int n = gf->n_nodes;
|
|
368
|
+
|
|
369
|
+
enum wsp_ggml_op ops[MAX_FUSE];
|
|
370
|
+
|
|
371
|
+
std::vector<node_info> nodes;
|
|
372
|
+
nodes.reserve(gf->n_nodes);
|
|
373
|
+
|
|
374
|
+
// fuse nodes:
|
|
375
|
+
// we don't want to make reorders that break fusing, so we first pack all fusable tensors
|
|
376
|
+
// and perform the reorder over the fused nodes. after the reorder is done, we unfuse
|
|
377
|
+
for (int i = 0; i < n; i++) {
|
|
378
|
+
node_info node = {
|
|
379
|
+
/*.node =*/ gf->nodes[i],
|
|
380
|
+
/*.fused =*/ {},
|
|
381
|
+
};
|
|
382
|
+
|
|
383
|
+
// fuse only ops that start with these operations
|
|
384
|
+
// can be expanded when needed
|
|
385
|
+
if (node.op() == WSP_GGML_OP_ADD ||
|
|
386
|
+
node.op() == WSP_GGML_OP_NORM ||
|
|
387
|
+
node.op() == WSP_GGML_OP_RMS_NORM) {
|
|
388
|
+
ops[0] = node.op();
|
|
389
|
+
|
|
390
|
+
int f = i + 1;
|
|
391
|
+
while (f < n && f < i + MAX_FUSE) {
|
|
392
|
+
// conservatively allow fusing only these ops
|
|
393
|
+
// can be expanded when needed
|
|
394
|
+
if (gf->nodes[f]->op != WSP_GGML_OP_ADD &&
|
|
395
|
+
gf->nodes[f]->op != WSP_GGML_OP_MUL &&
|
|
396
|
+
gf->nodes[f]->op != WSP_GGML_OP_NORM &&
|
|
397
|
+
gf->nodes[f]->op != WSP_GGML_OP_RMS_NORM) {
|
|
398
|
+
break;
|
|
399
|
+
}
|
|
400
|
+
ops[f - i] = gf->nodes[f]->op;
|
|
401
|
+
f++;
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
f -= i;
|
|
405
|
+
for (; f > 1; f--) {
|
|
406
|
+
if (wsp_ggml_can_fuse(gf, i, ops, f)) {
|
|
407
|
+
break;
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
// add the fused tensors into the node info so we can unfuse them later
|
|
412
|
+
for (int k = 1; k < f; k++) {
|
|
413
|
+
++i;
|
|
414
|
+
|
|
415
|
+
// the .dst() becomes the last fused tensor
|
|
416
|
+
node.add_fused(gf->nodes[i]);
|
|
417
|
+
}
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
nodes.push_back(std::move(node));
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
#if 1
|
|
424
|
+
// reorder to improve concurrency
|
|
425
|
+
const auto order = wsp_ggml_metal_graph_optimize_reorder(nodes);
|
|
426
|
+
#else
|
|
427
|
+
std::vector<int> order(nodes.size());
|
|
428
|
+
for (size_t i = 0; i < nodes.size(); i++) {
|
|
429
|
+
order[i] = i;
|
|
430
|
+
}
|
|
431
|
+
#endif
|
|
432
|
+
|
|
433
|
+
// unfuse
|
|
434
|
+
{
|
|
435
|
+
int j = 0;
|
|
436
|
+
for (const auto i : order) {
|
|
437
|
+
const auto & node = nodes[i];
|
|
438
|
+
|
|
439
|
+
gf->nodes[j++] = node.node;
|
|
440
|
+
|
|
441
|
+
for (auto * fused : node.fused) {
|
|
442
|
+
gf->nodes[j++] = fused;
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
// helper functions for ggml-metal that are too difficult to implement in Objective-C
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <stdbool.h>
|
|
6
|
+
|
|
7
|
+
#ifdef __cplusplus
|
|
8
|
+
extern "C" {
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
struct wsp_ggml_tensor;
|
|
12
|
+
struct wsp_ggml_cgraph;
|
|
13
|
+
|
|
14
|
+
enum wsp_ggml_mem_range_type {
|
|
15
|
+
MEM_RANGE_TYPE_SRC = 0,
|
|
16
|
+
MEM_RANGE_TYPE_DST = 1,
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
// a helper object that can be used for reordering operations to improve concurrency
|
|
20
|
+
//
|
|
21
|
+
// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
|
|
22
|
+
// don't write to a memory that is being read by another task or written to by another task in the set
|
|
23
|
+
//
|
|
24
|
+
// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
|
|
25
|
+
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
|
|
26
|
+
// tasks already in the set)
|
|
27
|
+
//
|
|
28
|
+
typedef struct wsp_ggml_mem_ranges * wsp_ggml_mem_ranges_t;
|
|
29
|
+
|
|
30
|
+
wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug);
|
|
31
|
+
void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs);
|
|
32
|
+
|
|
33
|
+
// remove all ranges from the set
|
|
34
|
+
void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs);
|
|
35
|
+
|
|
36
|
+
// add src or dst ranges to track
|
|
37
|
+
bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
|
|
38
|
+
|
|
39
|
+
// return false if:
|
|
40
|
+
// - new src range overlaps with any existing dst range
|
|
41
|
+
// - new dst range overlaps with any existing range (src or dst)
|
|
42
|
+
bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
|
|
43
|
+
|
|
44
|
+
// reorder the nodes in the graph to improve concurrency, while respecting fusion
|
|
45
|
+
//
|
|
46
|
+
// note: this implementation is generic and not specific to metal
|
|
47
|
+
// if it proves to work well, we can start using it for other backends in the future
|
|
48
|
+
void wsp_ggml_graph_optimize(struct wsp_ggml_cgraph * gf);
|
|
49
|
+
|
|
50
|
+
#ifdef __cplusplus
|
|
51
|
+
}
|
|
52
|
+
#endif
|