whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
|
|
48
48
|
int mtl_device_ref_count;
|
49
49
|
id<MTLLibrary> mtl_library;
|
50
50
|
|
51
|
+
NSLock * mtl_lock;
|
52
|
+
|
51
53
|
bool has_simdgroup_reduction;
|
52
54
|
bool has_simdgroup_mm;
|
53
55
|
bool has_residency_sets;
|
54
56
|
bool has_bfloat;
|
55
57
|
bool use_bfloat;
|
56
58
|
|
59
|
+
size_t max_size;
|
60
|
+
|
57
61
|
char name[128];
|
58
62
|
} g_ggml_ctx_dev_main = {
|
59
63
|
/*.mtl_device =*/ nil,
|
60
64
|
/*.mtl_device_ref_count =*/ 0,
|
61
65
|
/*.mtl_library =*/ nil,
|
66
|
+
/*.mtl_lock =*/ nil,
|
62
67
|
/*.has_simdgroup_reduction =*/ false,
|
63
68
|
/*.has_simdgroup_mm =*/ false,
|
64
69
|
/*.has_residency_sets =*/ false,
|
65
70
|
/*.has_bfloat =*/ false,
|
66
71
|
/*.use_bfloat =*/ false,
|
72
|
+
/*.max_size =*/ 0,
|
67
73
|
/*.name =*/ "",
|
68
74
|
};
|
69
75
|
|
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
|
|
71
77
|
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
72
78
|
assert(ctx != NULL);
|
73
79
|
|
80
|
+
if (ctx->mtl_lock == nil) {
|
81
|
+
ctx->mtl_lock = [[NSLock alloc] init];
|
82
|
+
}
|
83
|
+
|
74
84
|
if (ctx->mtl_device == nil) {
|
75
85
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
76
86
|
}
|
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
94
104
|
ctx->use_bfloat = false;
|
95
105
|
#endif
|
96
106
|
|
107
|
+
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
108
|
+
|
97
109
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
98
110
|
}
|
99
111
|
|
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
110
122
|
ctx->mtl_device_ref_count--;
|
111
123
|
|
112
124
|
if (ctx->mtl_device_ref_count == 0) {
|
125
|
+
if (ctx->mtl_lock) {
|
126
|
+
[ctx->mtl_lock release];
|
127
|
+
ctx->mtl_lock = nil;
|
128
|
+
}
|
129
|
+
|
113
130
|
if (ctx->mtl_library) {
|
114
131
|
[ctx->mtl_library release];
|
115
132
|
ctx->mtl_library = nil;
|
@@ -185,6 +202,15 @@ enum ggml_metal_kernel_type {
|
|
185
202
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
186
203
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
187
204
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
205
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
206
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
207
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
208
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
209
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
210
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
211
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
212
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
213
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
188
214
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
189
215
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
190
216
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
@@ -194,11 +220,14 @@ enum ggml_metal_kernel_type {
|
|
194
220
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
195
221
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
196
222
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
223
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
|
197
224
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
225
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
|
198
226
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
199
227
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
200
228
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
201
229
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
230
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
|
202
231
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
203
232
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
204
233
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
@@ -497,7 +526,11 @@ enum ggml_metal_kernel_type {
|
|
497
526
|
GGML_METAL_KERNEL_TYPE_SIN,
|
498
527
|
GGML_METAL_KERNEL_TYPE_COS,
|
499
528
|
GGML_METAL_KERNEL_TYPE_NEG,
|
529
|
+
GGML_METAL_KERNEL_TYPE_REGLU,
|
530
|
+
GGML_METAL_KERNEL_TYPE_GEGLU,
|
531
|
+
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
500
532
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
533
|
+
GGML_METAL_KERNEL_TYPE_MEAN,
|
501
534
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
502
535
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
503
536
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
@@ -976,7 +1009,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
976
1009
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
977
1010
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
978
1011
|
|
979
|
-
id<MTLDevice> device =
|
1012
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
980
1013
|
|
981
1014
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
982
1015
|
|
@@ -990,9 +1023,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
990
1023
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
991
1024
|
|
992
1025
|
// load library
|
993
|
-
|
994
|
-
ctx_dev->
|
1026
|
+
{
|
1027
|
+
[ctx_dev->mtl_lock lock];
|
1028
|
+
|
1029
|
+
if (ctx_dev->mtl_library == nil) {
|
1030
|
+
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
1031
|
+
}
|
1032
|
+
|
1033
|
+
[ctx_dev->mtl_lock unlock];
|
995
1034
|
}
|
1035
|
+
|
996
1036
|
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
997
1037
|
if (metal_library == nil) {
|
998
1038
|
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
@@ -1141,6 +1181,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
1141
1181
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
1142
1182
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
1143
1183
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
1184
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
1185
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
1186
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
1190
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
1191
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
1192
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
1144
1193
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
1145
1194
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
1146
1195
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
@@ -1150,11 +1199,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
1150
1199
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
1151
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
1152
1201
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
1202
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
|
1153
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
1204
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
|
1154
1205
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
1155
1206
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
1156
1207
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
1157
1208
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
1209
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
|
1158
1210
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
1159
1211
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
1160
1212
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
@@ -1453,7 +1505,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
1453
1505
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
1454
1506
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
1455
1507
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
1508
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
1509
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
1510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
1456
1511
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
1512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
1457
1513
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
1458
1514
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
1459
1515
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
@@ -1603,6 +1659,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1603
1659
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
1604
1660
|
|
1605
1661
|
if (!use_bfloat) {
|
1662
|
+
if (op->type == GGML_TYPE_BF16) {
|
1663
|
+
return false;
|
1664
|
+
}
|
1665
|
+
|
1606
1666
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
1607
1667
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
1608
1668
|
return false;
|
@@ -1626,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1626
1686
|
default:
|
1627
1687
|
return false;
|
1628
1688
|
}
|
1689
|
+
case GGML_OP_GLU:
|
1690
|
+
switch (ggml_get_glu_op(op)) {
|
1691
|
+
case GGML_GLU_OP_REGLU:
|
1692
|
+
case GGML_GLU_OP_GEGLU:
|
1693
|
+
case GGML_GLU_OP_SWIGLU:
|
1694
|
+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
1695
|
+
default:
|
1696
|
+
return false;
|
1697
|
+
}
|
1629
1698
|
case GGML_OP_NONE:
|
1630
1699
|
case GGML_OP_RESHAPE:
|
1631
1700
|
case GGML_OP_VIEW:
|
@@ -1653,6 +1722,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1653
1722
|
case GGML_OP_LOG:
|
1654
1723
|
return false; // TODO: implement
|
1655
1724
|
case GGML_OP_SUM_ROWS:
|
1725
|
+
case GGML_OP_MEAN:
|
1656
1726
|
case GGML_OP_SOFT_MAX:
|
1657
1727
|
case GGML_OP_GROUP_NORM:
|
1658
1728
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
@@ -1771,6 +1841,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1771
1841
|
{
|
1772
1842
|
return op->ne[3] == 1;
|
1773
1843
|
}
|
1844
|
+
case GGML_OP_SET_ROWS:
|
1845
|
+
{
|
1846
|
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
1847
|
+
return false;
|
1848
|
+
}
|
1849
|
+
|
1850
|
+
switch (op->type) {
|
1851
|
+
case GGML_TYPE_F32:
|
1852
|
+
case GGML_TYPE_F16:
|
1853
|
+
case GGML_TYPE_BF16:
|
1854
|
+
case GGML_TYPE_Q8_0:
|
1855
|
+
case GGML_TYPE_Q4_0:
|
1856
|
+
case GGML_TYPE_Q4_1:
|
1857
|
+
case GGML_TYPE_Q5_0:
|
1858
|
+
case GGML_TYPE_Q5_1:
|
1859
|
+
case GGML_TYPE_IQ4_NL:
|
1860
|
+
return true;
|
1861
|
+
default:
|
1862
|
+
return false;
|
1863
|
+
};
|
1864
|
+
}
|
1774
1865
|
default:
|
1775
1866
|
return false;
|
1776
1867
|
}
|
@@ -2343,6 +2434,62 @@ static bool ggml_metal_encode_node(
|
|
2343
2434
|
GGML_ABORT("fatal error");
|
2344
2435
|
}
|
2345
2436
|
} break;
|
2437
|
+
case GGML_OP_GLU:
|
2438
|
+
{
|
2439
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2440
|
+
|
2441
|
+
if (src1) {
|
2442
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
2443
|
+
}
|
2444
|
+
|
2445
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2446
|
+
|
2447
|
+
switch (ggml_get_glu_op(node)) {
|
2448
|
+
case GGML_GLU_OP_REGLU:
|
2449
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
2450
|
+
break;
|
2451
|
+
case GGML_GLU_OP_GEGLU:
|
2452
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
2453
|
+
break;
|
2454
|
+
case GGML_GLU_OP_SWIGLU:
|
2455
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
2456
|
+
break;
|
2457
|
+
default:
|
2458
|
+
GGML_ABORT("fatal error");
|
2459
|
+
}
|
2460
|
+
|
2461
|
+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
2462
|
+
|
2463
|
+
const int32_t i00 = swp ? ne0 : 0;
|
2464
|
+
const int32_t i10 = swp ? 0 : ne0;
|
2465
|
+
|
2466
|
+
ggml_metal_kargs_glu args = {
|
2467
|
+
/*.ne00 =*/ ne00,
|
2468
|
+
/*.nb01 =*/ nb01,
|
2469
|
+
/*.ne10 =*/ src1 ? ne10 : ne00,
|
2470
|
+
/*.nb11 =*/ src1 ? nb11 : nb01,
|
2471
|
+
/*.ne0 =*/ ne0,
|
2472
|
+
/*.nb1 =*/ nb1,
|
2473
|
+
/*.i00 =*/ src1 ? 0 : i00,
|
2474
|
+
/*.i10 =*/ src1 ? 0 : i10,
|
2475
|
+
};
|
2476
|
+
|
2477
|
+
[encoder setComputePipelineState:pipeline];
|
2478
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2479
|
+
if (src1) {
|
2480
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2481
|
+
} else {
|
2482
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2483
|
+
}
|
2484
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2485
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
2486
|
+
|
2487
|
+
const int64_t nrows = ggml_nrows(src0);
|
2488
|
+
|
2489
|
+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
2490
|
+
|
2491
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2492
|
+
} break;
|
2346
2493
|
case GGML_OP_SQR:
|
2347
2494
|
{
|
2348
2495
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
@@ -2400,11 +2547,31 @@ static bool ggml_metal_encode_node(
|
|
2400
2547
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2401
2548
|
} break;
|
2402
2549
|
case GGML_OP_SUM_ROWS:
|
2550
|
+
case GGML_OP_MEAN:
|
2403
2551
|
{
|
2404
2552
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
2405
2553
|
|
2406
|
-
id<MTLComputePipelineState> pipeline =
|
2554
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2555
|
+
|
2556
|
+
switch (dst->op) {
|
2557
|
+
case GGML_OP_SUM_ROWS:
|
2558
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
2559
|
+
break;
|
2560
|
+
case GGML_OP_MEAN:
|
2561
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
2562
|
+
break;
|
2563
|
+
default:
|
2564
|
+
GGML_ABORT("fatal error");
|
2565
|
+
}
|
2407
2566
|
|
2567
|
+
int nth = 32; // SIMD width
|
2568
|
+
|
2569
|
+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
2570
|
+
nth *= 2;
|
2571
|
+
}
|
2572
|
+
|
2573
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
2574
|
+
nth = MIN(nth, ne00);
|
2408
2575
|
|
2409
2576
|
ggml_metal_kargs_sum_rows args = {
|
2410
2577
|
/*.ne00 =*/ ne00,
|
@@ -2434,11 +2601,12 @@ static bool ggml_metal_encode_node(
|
|
2434
2601
|
};
|
2435
2602
|
|
2436
2603
|
[encoder setComputePipelineState:pipeline];
|
2437
|
-
[encoder
|
2438
|
-
[encoder setBuffer:
|
2439
|
-
[encoder
|
2604
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
2605
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2606
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2607
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2440
2608
|
|
2441
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
2609
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2442
2610
|
} break;
|
2443
2611
|
case GGML_OP_SOFT_MAX:
|
2444
2612
|
{
|
@@ -3063,14 +3231,23 @@ static bool ggml_metal_encode_node(
|
|
3063
3231
|
nsg = 1;
|
3064
3232
|
nr0 = 1;
|
3065
3233
|
nr1 = 4;
|
3066
|
-
|
3234
|
+
if (ne00 == 4) {
|
3235
|
+
nr0 = 32;
|
3236
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
|
3237
|
+
} else {
|
3238
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
3239
|
+
}
|
3067
3240
|
} break;
|
3068
3241
|
case GGML_TYPE_F16:
|
3069
3242
|
{
|
3070
3243
|
nsg = 1;
|
3071
3244
|
nr0 = 1;
|
3072
3245
|
if (src1t == GGML_TYPE_F32) {
|
3073
|
-
if (
|
3246
|
+
if (ne00 == 4) {
|
3247
|
+
nr0 = 32;
|
3248
|
+
nr1 = 4;
|
3249
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
|
3250
|
+
} else if (ne11 * ne12 < 4) {
|
3074
3251
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
3075
3252
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
3076
3253
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
@@ -3089,7 +3266,11 @@ static bool ggml_metal_encode_node(
|
|
3089
3266
|
nsg = 1;
|
3090
3267
|
nr0 = 1;
|
3091
3268
|
if (src1t == GGML_TYPE_F32) {
|
3092
|
-
if (
|
3269
|
+
if (ne00 == 4) {
|
3270
|
+
nr0 = 32;
|
3271
|
+
nr1 = 4;
|
3272
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
|
3273
|
+
} else if (ne11 * ne12 < 4) {
|
3093
3274
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
3094
3275
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
3095
3276
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
@@ -3710,13 +3891,74 @@ static bool ggml_metal_encode_node(
|
|
3710
3891
|
};
|
3711
3892
|
|
3712
3893
|
[encoder setComputePipelineState:pipeline];
|
3713
|
-
[encoder
|
3714
|
-
[encoder setBuffer:
|
3715
|
-
[encoder setBuffer:
|
3716
|
-
[encoder
|
3894
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3895
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3896
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
3897
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3717
3898
|
|
3718
3899
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
3719
3900
|
} break;
|
3901
|
+
case GGML_OP_SET_ROWS:
|
3902
|
+
{
|
3903
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3904
|
+
|
3905
|
+
switch (dst->type) {
|
3906
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
3907
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
3908
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
3909
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
3910
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
3911
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
3912
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
3913
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
3914
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
3915
|
+
default: GGML_ABORT("not implemented");
|
3916
|
+
}
|
3917
|
+
|
3918
|
+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
|
3919
|
+
|
3920
|
+
int nth = 32; // SIMD width
|
3921
|
+
|
3922
|
+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
3923
|
+
nth *= 2;
|
3924
|
+
}
|
3925
|
+
|
3926
|
+
int nrptg = 1;
|
3927
|
+
if (nth > nk0) {
|
3928
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
3929
|
+
nth = nk0;
|
3930
|
+
|
3931
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
3932
|
+
nrptg--;
|
3933
|
+
}
|
3934
|
+
}
|
3935
|
+
|
3936
|
+
nth = MIN(nth, nk0);
|
3937
|
+
|
3938
|
+
ggml_metal_kargs_set_rows args = {
|
3939
|
+
/*.nk0 =*/ nk0,
|
3940
|
+
/*.ne01 =*/ ne01,
|
3941
|
+
/*.nb01 =*/ nb01,
|
3942
|
+
/*.nb02 =*/ nb02,
|
3943
|
+
/*.nb03 =*/ nb03,
|
3944
|
+
/*.ne11 =*/ ne11,
|
3945
|
+
/*.ne12 =*/ ne12,
|
3946
|
+
/*.nb10 =*/ nb10,
|
3947
|
+
/*.nb11 =*/ nb11,
|
3948
|
+
/*.nb12 =*/ nb12,
|
3949
|
+
/*.nb1 =*/ nb1,
|
3950
|
+
/*.nb2 =*/ nb2,
|
3951
|
+
/*.nb3 =*/ nb3,
|
3952
|
+
};
|
3953
|
+
|
3954
|
+
[encoder setComputePipelineState:pipeline];
|
3955
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3956
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3957
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
3958
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3959
|
+
|
3960
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
3961
|
+
} break;
|
3720
3962
|
case GGML_OP_RMS_NORM:
|
3721
3963
|
{
|
3722
3964
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -3733,6 +3975,7 @@ static bool ggml_metal_encode_node(
|
|
3733
3975
|
nth *= 2;
|
3734
3976
|
}
|
3735
3977
|
|
3978
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
3736
3979
|
nth = MIN(nth, ne00/4);
|
3737
3980
|
|
3738
3981
|
ggml_metal_kargs_rms_norm args = {
|
@@ -3769,6 +4012,7 @@ static bool ggml_metal_encode_node(
|
|
3769
4012
|
nth *= 2;
|
3770
4013
|
}
|
3771
4014
|
|
4015
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
3772
4016
|
nth = MIN(nth, ne00/4);
|
3773
4017
|
|
3774
4018
|
ggml_metal_kargs_l2_norm args = {
|
@@ -3841,6 +4085,7 @@ static bool ggml_metal_encode_node(
|
|
3841
4085
|
nth *= 2;
|
3842
4086
|
}
|
3843
4087
|
|
4088
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
3844
4089
|
nth = MIN(nth, ne00/4);
|
3845
4090
|
|
3846
4091
|
ggml_metal_kargs_norm args = {
|
@@ -4766,6 +5011,8 @@ static bool ggml_metal_encode_node(
|
|
4766
5011
|
GGML_ASSERT(nqptg % 8 == 0);
|
4767
5012
|
GGML_ASSERT(ncpsg % 32 == 0);
|
4768
5013
|
|
5014
|
+
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
5015
|
+
|
4769
5016
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
4770
5017
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
4771
5018
|
//
|
@@ -4773,7 +5020,7 @@ static bool ggml_metal_encode_node(
|
|
4773
5020
|
// the shared memory needed for the simdgroups to load the KV cache
|
4774
5021
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
4775
5022
|
//
|
4776
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
5023
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
4777
5024
|
|
4778
5025
|
int64_t nsgmax = 2;
|
4779
5026
|
|
@@ -4810,9 +5057,9 @@ static bool ggml_metal_encode_node(
|
|
4810
5057
|
// and store the soft_max values and the mask
|
4811
5058
|
//
|
4812
5059
|
// ne00*(nsg)
|
4813
|
-
// each simdgroup has a full
|
5060
|
+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
4814
5061
|
//
|
4815
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
5062
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
4816
5063
|
|
4817
5064
|
int64_t nsgmax = 2;
|
4818
5065
|
while (true) {
|
@@ -4925,8 +5172,39 @@ static bool ggml_metal_encode_node(
|
|
4925
5172
|
default: GGML_ABORT("not implemented");
|
4926
5173
|
}
|
4927
5174
|
|
5175
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
5176
|
+
|
5177
|
+
// TODO: support
|
5178
|
+
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
|
5179
|
+
const int32_t nk00 = ne00;
|
5180
|
+
|
5181
|
+
int nth = 32; // SIMD width
|
5182
|
+
|
5183
|
+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
5184
|
+
nth *= 2;
|
5185
|
+
}
|
5186
|
+
|
5187
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
5188
|
+
|
5189
|
+
// when rows are small, we can batch them together in a single threadgroup
|
5190
|
+
int nrptg = 1;
|
5191
|
+
|
5192
|
+
// TODO: relax this constraint in the future
|
5193
|
+
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
|
5194
|
+
if (nth > nk00) {
|
5195
|
+
nrptg = (nth + nk00 - 1)/nk00;
|
5196
|
+
nth = nk00;
|
5197
|
+
|
5198
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
5199
|
+
nrptg--;
|
5200
|
+
}
|
5201
|
+
}
|
5202
|
+
}
|
5203
|
+
|
5204
|
+
nth = MIN(nth, nk00);
|
5205
|
+
|
4928
5206
|
ggml_metal_kargs_cpy args = {
|
4929
|
-
/*.ne00 =*/
|
5207
|
+
/*.ne00 =*/ nk00,
|
4930
5208
|
/*.ne01 =*/ ne01,
|
4931
5209
|
/*.ne02 =*/ ne02,
|
4932
5210
|
/*.ne03 =*/ ne03,
|
@@ -4949,11 +5227,7 @@ static bool ggml_metal_encode_node(
|
|
4949
5227
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
4950
5228
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
4951
5229
|
|
4952
|
-
|
4953
|
-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
4954
|
-
|
4955
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
4956
|
-
|
5230
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
4957
5231
|
} break;
|
4958
5232
|
case GGML_OP_SET:
|
4959
5233
|
{
|
@@ -5259,7 +5533,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
5259
5533
|
}
|
5260
5534
|
|
5261
5535
|
ggml_backend_metal_buffer_rset_free(ctx);
|
5262
|
-
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
5263
5536
|
|
5264
5537
|
if (ctx->owned) {
|
5265
5538
|
#if TARGET_OS_OSX
|
@@ -5368,7 +5641,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
5368
5641
|
}
|
5369
5642
|
|
5370
5643
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
5371
|
-
|
5644
|
+
|
5645
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
5646
|
+
|
5647
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
5372
5648
|
|
5373
5649
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
5374
5650
|
ctx->all_size = size_aligned;
|
@@ -5391,14 +5667,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
5391
5667
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
5392
5668
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
5393
5669
|
free(ctx);
|
5394
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5395
5670
|
return NULL;
|
5396
5671
|
}
|
5397
5672
|
|
5398
5673
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
5399
5674
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
5400
5675
|
free(ctx);
|
5401
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5402
5676
|
return NULL;
|
5403
5677
|
}
|
5404
5678
|
|
@@ -5409,17 +5683,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
5409
5683
|
|
5410
5684
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
5411
5685
|
return 32;
|
5686
|
+
|
5412
5687
|
GGML_UNUSED(buft);
|
5413
5688
|
}
|
5414
5689
|
|
5415
5690
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
5416
|
-
|
5417
|
-
const size_t max_size = device.maxBufferLength;
|
5418
|
-
ggml_backend_metal_device_rel(buft->device->context);
|
5691
|
+
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
5419
5692
|
|
5420
5693
|
return max_size;
|
5421
|
-
|
5422
|
-
GGML_UNUSED(buft);
|
5423
5694
|
}
|
5424
5695
|
|
5425
5696
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
@@ -5492,7 +5763,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
5492
5763
|
}
|
5493
5764
|
|
5494
5765
|
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
5495
|
-
|
5766
|
+
|
5767
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
5768
|
+
|
5769
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
5496
5770
|
|
5497
5771
|
// the buffer fits into the max buffer size allowed by the device
|
5498
5772
|
if (size_aligned <= device.maxBufferLength) {
|
@@ -5548,7 +5822,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
5548
5822
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
5549
5823
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
5550
5824
|
free(ctx);
|
5551
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5552
5825
|
return NULL;
|
5553
5826
|
}
|
5554
5827
|
|
@@ -5564,10 +5837,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
5564
5837
|
}
|
5565
5838
|
|
5566
5839
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
5567
|
-
struct ggml_backend_metal_context
|
5568
|
-
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
5840
|
+
struct ggml_backend_metal_context * ctx = backend->context;
|
5569
5841
|
|
5570
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5571
5842
|
ggml_metal_free(ctx);
|
5572
5843
|
|
5573
5844
|
free(backend);
|
@@ -5707,6 +5978,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
5707
5978
|
|
5708
5979
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
5709
5980
|
|
5981
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
5982
|
+
|
5710
5983
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
5711
5984
|
}
|
5712
5985
|
|
@@ -5726,10 +5999,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
|
5726
5999
|
}
|
5727
6000
|
|
5728
6001
|
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
5729
|
-
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
5730
6002
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
5731
|
-
ggml_backend_metal_device_acq(ctx_dev);
|
5732
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5733
6003
|
|
5734
6004
|
return ctx_dev->name;
|
5735
6005
|
}
|
@@ -5737,12 +6007,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
|
|
5737
6007
|
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
5738
6008
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
5739
6009
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
5740
|
-
id<MTLDevice> device =
|
6010
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
5741
6011
|
|
5742
6012
|
*total = device.recommendedMaxWorkingSetSize;
|
5743
6013
|
*free = *total - device.currentAllocatedSize;
|
5744
|
-
|
5745
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5746
6014
|
} else {
|
5747
6015
|
*free = 1;
|
5748
6016
|
*total = 1;
|
@@ -5820,7 +6088,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
5820
6088
|
}
|
5821
6089
|
|
5822
6090
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
5823
|
-
|
6091
|
+
|
6092
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
6093
|
+
|
6094
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
5824
6095
|
|
5825
6096
|
// the buffer fits into the max buffer size allowed by the device
|
5826
6097
|
if (size_aligned <= device.maxBufferLength) {
|
@@ -5876,7 +6147,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
5876
6147
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
5877
6148
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
5878
6149
|
free(ctx);
|
5879
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
5880
6150
|
return NULL;
|
5881
6151
|
}
|
5882
6152
|
|
@@ -5890,8 +6160,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
|
5890
6160
|
}
|
5891
6161
|
|
5892
6162
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
5893
|
-
return
|
5894
|
-
|
6163
|
+
return
|
6164
|
+
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
6165
|
+
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
5895
6166
|
|
5896
6167
|
GGML_UNUSED(dev);
|
5897
6168
|
}
|
@@ -5976,8 +6247,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
|
5976
6247
|
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
5977
6248
|
};
|
5978
6249
|
|
6250
|
+
// called upon program exit
|
6251
|
+
static void ggml_metal_cleanup(void) {
|
6252
|
+
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
|
6253
|
+
}
|
6254
|
+
|
6255
|
+
// TODO: make thread-safe
|
5979
6256
|
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
5980
|
-
|
6257
|
+
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
6258
|
+
|
6259
|
+
// register cleanup callback
|
6260
|
+
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
6261
|
+
atexit(ggml_metal_cleanup);
|
6262
|
+
|
5981
6263
|
{
|
5982
6264
|
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
5983
6265
|
/* .api_version = */ GGML_BACKEND_API_VERSION,
|