@novastera-oss/llamarn 0.2.7 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/common/arg.cpp +7 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +1 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
- package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -3
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
- package/cpp/llama.cpp/src/llama-batch.h +98 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
- package/cpp/llama.cpp/src/llama-graph.h +44 -32
- package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-hparams.h +8 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
- package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.h +18 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
- package/cpp/llama.cpp/src/llama-model.h +22 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/common.h +1 -0
- package/ios/include/llama.h +8 -3
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
|
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
|
|
|
48
48
|
int mtl_device_ref_count;
|
|
49
49
|
id<MTLLibrary> mtl_library;
|
|
50
50
|
|
|
51
|
+
NSLock * mtl_lock;
|
|
52
|
+
|
|
51
53
|
bool has_simdgroup_reduction;
|
|
52
54
|
bool has_simdgroup_mm;
|
|
53
55
|
bool has_residency_sets;
|
|
54
56
|
bool has_bfloat;
|
|
55
57
|
bool use_bfloat;
|
|
56
58
|
|
|
59
|
+
size_t max_size;
|
|
60
|
+
|
|
57
61
|
char name[128];
|
|
58
62
|
} g_ggml_ctx_dev_main = {
|
|
59
63
|
/*.mtl_device =*/ nil,
|
|
60
64
|
/*.mtl_device_ref_count =*/ 0,
|
|
61
65
|
/*.mtl_library =*/ nil,
|
|
66
|
+
/*.mtl_lock =*/ nil,
|
|
62
67
|
/*.has_simdgroup_reduction =*/ false,
|
|
63
68
|
/*.has_simdgroup_mm =*/ false,
|
|
64
69
|
/*.has_residency_sets =*/ false,
|
|
65
70
|
/*.has_bfloat =*/ false,
|
|
66
71
|
/*.use_bfloat =*/ false,
|
|
72
|
+
/*.max_size =*/ 0,
|
|
67
73
|
/*.name =*/ "",
|
|
68
74
|
};
|
|
69
75
|
|
|
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
|
|
|
71
77
|
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
|
72
78
|
assert(ctx != NULL);
|
|
73
79
|
|
|
80
|
+
if (ctx->mtl_lock == nil) {
|
|
81
|
+
ctx->mtl_lock = [[NSLock alloc] init];
|
|
82
|
+
}
|
|
83
|
+
|
|
74
84
|
if (ctx->mtl_device == nil) {
|
|
75
85
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
76
86
|
}
|
|
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
94
104
|
ctx->use_bfloat = false;
|
|
95
105
|
#endif
|
|
96
106
|
|
|
107
|
+
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
|
+
|
|
97
109
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
98
110
|
}
|
|
99
111
|
|
|
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
|
110
122
|
ctx->mtl_device_ref_count--;
|
|
111
123
|
|
|
112
124
|
if (ctx->mtl_device_ref_count == 0) {
|
|
125
|
+
if (ctx->mtl_lock) {
|
|
126
|
+
[ctx->mtl_lock release];
|
|
127
|
+
ctx->mtl_lock = nil;
|
|
128
|
+
}
|
|
129
|
+
|
|
113
130
|
if (ctx->mtl_library) {
|
|
114
131
|
[ctx->mtl_library release];
|
|
115
132
|
ctx->mtl_library = nil;
|
|
@@ -185,6 +202,15 @@ enum ggml_metal_kernel_type {
|
|
|
185
202
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
186
203
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
|
187
204
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
205
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
|
206
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
|
207
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
|
208
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
|
209
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
|
210
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
|
211
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
|
212
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
188
214
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
189
215
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
190
216
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -194,11 +220,14 @@ enum ggml_metal_kernel_type {
|
|
|
194
220
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
195
221
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
196
222
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
223
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
|
|
197
224
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
225
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
|
|
198
226
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
199
227
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
200
228
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
201
229
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
|
230
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
|
|
202
231
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
|
203
232
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
|
204
233
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
|
@@ -977,7 +1006,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
977
1006
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
|
978
1007
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
979
1008
|
|
|
980
|
-
id<MTLDevice> device =
|
|
1009
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
981
1010
|
|
|
982
1011
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
983
1012
|
|
|
@@ -991,9 +1020,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
991
1020
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
992
1021
|
|
|
993
1022
|
// load library
|
|
994
|
-
|
|
995
|
-
ctx_dev->
|
|
1023
|
+
{
|
|
1024
|
+
[ctx_dev->mtl_lock lock];
|
|
1025
|
+
|
|
1026
|
+
if (ctx_dev->mtl_library == nil) {
|
|
1027
|
+
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
|
1028
|
+
}
|
|
1029
|
+
|
|
1030
|
+
[ctx_dev->mtl_lock unlock];
|
|
996
1031
|
}
|
|
1032
|
+
|
|
997
1033
|
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
|
998
1034
|
if (metal_library == nil) {
|
|
999
1035
|
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
|
@@ -1142,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1142
1178
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
1143
1179
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
1144
1180
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
1181
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
|
1182
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
|
1183
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
|
1184
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
|
1185
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
|
1186
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1145
1190
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1146
1191
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1147
1192
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
@@ -1151,11 +1196,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1151
1196
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1152
1197
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1153
1198
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
1199
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
|
|
1154
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
1201
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
|
|
1155
1202
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
1156
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
|
1157
1204
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
|
1158
1205
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
1206
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
|
|
1159
1207
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
|
1160
1208
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
|
1161
1209
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
@@ -1605,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1605
1653
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
|
1606
1654
|
|
|
1607
1655
|
if (!use_bfloat) {
|
|
1656
|
+
if (op->type == GGML_TYPE_BF16) {
|
|
1657
|
+
return false;
|
|
1658
|
+
}
|
|
1659
|
+
|
|
1608
1660
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
1609
1661
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
|
1610
1662
|
return false;
|
|
@@ -1774,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1774
1826
|
{
|
|
1775
1827
|
return op->ne[3] == 1;
|
|
1776
1828
|
}
|
|
1829
|
+
case GGML_OP_SET_ROWS:
|
|
1830
|
+
{
|
|
1831
|
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
1832
|
+
return false;
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
switch (op->type) {
|
|
1836
|
+
case GGML_TYPE_F32:
|
|
1837
|
+
case GGML_TYPE_F16:
|
|
1838
|
+
case GGML_TYPE_BF16:
|
|
1839
|
+
case GGML_TYPE_Q8_0:
|
|
1840
|
+
case GGML_TYPE_Q4_0:
|
|
1841
|
+
case GGML_TYPE_Q4_1:
|
|
1842
|
+
case GGML_TYPE_Q5_0:
|
|
1843
|
+
case GGML_TYPE_Q5_1:
|
|
1844
|
+
case GGML_TYPE_IQ4_NL:
|
|
1845
|
+
return true;
|
|
1846
|
+
default:
|
|
1847
|
+
return false;
|
|
1848
|
+
};
|
|
1849
|
+
}
|
|
1777
1850
|
default:
|
|
1778
1851
|
return false;
|
|
1779
1852
|
}
|
|
@@ -2426,6 +2499,7 @@ static bool ggml_metal_encode_node(
|
|
|
2426
2499
|
nth *= 2;
|
|
2427
2500
|
}
|
|
2428
2501
|
|
|
2502
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
2429
2503
|
nth = MIN(nth, ne00);
|
|
2430
2504
|
|
|
2431
2505
|
ggml_metal_kargs_sum_rows args = {
|
|
@@ -3086,14 +3160,23 @@ static bool ggml_metal_encode_node(
|
|
|
3086
3160
|
nsg = 1;
|
|
3087
3161
|
nr0 = 1;
|
|
3088
3162
|
nr1 = 4;
|
|
3089
|
-
|
|
3163
|
+
if (ne00 == 4) {
|
|
3164
|
+
nr0 = 32;
|
|
3165
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
|
|
3166
|
+
} else {
|
|
3167
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
3168
|
+
}
|
|
3090
3169
|
} break;
|
|
3091
3170
|
case GGML_TYPE_F16:
|
|
3092
3171
|
{
|
|
3093
3172
|
nsg = 1;
|
|
3094
3173
|
nr0 = 1;
|
|
3095
3174
|
if (src1t == GGML_TYPE_F32) {
|
|
3096
|
-
if (
|
|
3175
|
+
if (ne00 == 4) {
|
|
3176
|
+
nr0 = 32;
|
|
3177
|
+
nr1 = 4;
|
|
3178
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
|
|
3179
|
+
} else if (ne11 * ne12 < 4) {
|
|
3097
3180
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
3098
3181
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3099
3182
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
@@ -3112,7 +3195,11 @@ static bool ggml_metal_encode_node(
|
|
|
3112
3195
|
nsg = 1;
|
|
3113
3196
|
nr0 = 1;
|
|
3114
3197
|
if (src1t == GGML_TYPE_F32) {
|
|
3115
|
-
if (
|
|
3198
|
+
if (ne00 == 4) {
|
|
3199
|
+
nr0 = 32;
|
|
3200
|
+
nr1 = 4;
|
|
3201
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
|
|
3202
|
+
} else if (ne11 * ne12 < 4) {
|
|
3116
3203
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
3117
3204
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3118
3205
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
@@ -3733,13 +3820,74 @@ static bool ggml_metal_encode_node(
|
|
|
3733
3820
|
};
|
|
3734
3821
|
|
|
3735
3822
|
[encoder setComputePipelineState:pipeline];
|
|
3736
|
-
[encoder
|
|
3737
|
-
[encoder setBuffer:
|
|
3738
|
-
[encoder setBuffer:
|
|
3739
|
-
[encoder
|
|
3823
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3824
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3825
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3826
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3740
3827
|
|
|
3741
3828
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
3742
3829
|
} break;
|
|
3830
|
+
case GGML_OP_SET_ROWS:
|
|
3831
|
+
{
|
|
3832
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
3833
|
+
|
|
3834
|
+
switch (dst->type) {
|
|
3835
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
|
3836
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
|
3837
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
|
3838
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
|
3839
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
|
3840
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
|
3841
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
|
3842
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
|
3843
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
|
3844
|
+
default: GGML_ABORT("not implemented");
|
|
3845
|
+
}
|
|
3846
|
+
|
|
3847
|
+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
|
|
3848
|
+
|
|
3849
|
+
int nth = 32; // SIMD width
|
|
3850
|
+
|
|
3851
|
+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3852
|
+
nth *= 2;
|
|
3853
|
+
}
|
|
3854
|
+
|
|
3855
|
+
int nrptg = 1;
|
|
3856
|
+
if (nth > nk0) {
|
|
3857
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
3858
|
+
nth = nk0;
|
|
3859
|
+
|
|
3860
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3861
|
+
nrptg--;
|
|
3862
|
+
}
|
|
3863
|
+
}
|
|
3864
|
+
|
|
3865
|
+
nth = MIN(nth, nk0);
|
|
3866
|
+
|
|
3867
|
+
ggml_metal_kargs_set_rows args = {
|
|
3868
|
+
/*.nk0 =*/ nk0,
|
|
3869
|
+
/*.ne01 =*/ ne01,
|
|
3870
|
+
/*.nb01 =*/ nb01,
|
|
3871
|
+
/*.nb02 =*/ nb02,
|
|
3872
|
+
/*.nb03 =*/ nb03,
|
|
3873
|
+
/*.ne11 =*/ ne11,
|
|
3874
|
+
/*.ne12 =*/ ne12,
|
|
3875
|
+
/*.nb10 =*/ nb10,
|
|
3876
|
+
/*.nb11 =*/ nb11,
|
|
3877
|
+
/*.nb12 =*/ nb12,
|
|
3878
|
+
/*.nb1 =*/ nb1,
|
|
3879
|
+
/*.nb2 =*/ nb2,
|
|
3880
|
+
/*.nb3 =*/ nb3,
|
|
3881
|
+
};
|
|
3882
|
+
|
|
3883
|
+
[encoder setComputePipelineState:pipeline];
|
|
3884
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3885
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3886
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3887
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3888
|
+
|
|
3889
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
3890
|
+
} break;
|
|
3743
3891
|
case GGML_OP_RMS_NORM:
|
|
3744
3892
|
{
|
|
3745
3893
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -3756,6 +3904,7 @@ static bool ggml_metal_encode_node(
|
|
|
3756
3904
|
nth *= 2;
|
|
3757
3905
|
}
|
|
3758
3906
|
|
|
3907
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3759
3908
|
nth = MIN(nth, ne00/4);
|
|
3760
3909
|
|
|
3761
3910
|
ggml_metal_kargs_rms_norm args = {
|
|
@@ -3792,6 +3941,7 @@ static bool ggml_metal_encode_node(
|
|
|
3792
3941
|
nth *= 2;
|
|
3793
3942
|
}
|
|
3794
3943
|
|
|
3944
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3795
3945
|
nth = MIN(nth, ne00/4);
|
|
3796
3946
|
|
|
3797
3947
|
ggml_metal_kargs_l2_norm args = {
|
|
@@ -3864,6 +4014,7 @@ static bool ggml_metal_encode_node(
|
|
|
3864
4014
|
nth *= 2;
|
|
3865
4015
|
}
|
|
3866
4016
|
|
|
4017
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3867
4018
|
nth = MIN(nth, ne00/4);
|
|
3868
4019
|
|
|
3869
4020
|
ggml_metal_kargs_norm args = {
|
|
@@ -4950,8 +5101,39 @@ static bool ggml_metal_encode_node(
|
|
|
4950
5101
|
default: GGML_ABORT("not implemented");
|
|
4951
5102
|
}
|
|
4952
5103
|
|
|
5104
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
5105
|
+
|
|
5106
|
+
// TODO: support
|
|
5107
|
+
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
|
|
5108
|
+
const int32_t nk00 = ne00;
|
|
5109
|
+
|
|
5110
|
+
int nth = 32; // SIMD width
|
|
5111
|
+
|
|
5112
|
+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5113
|
+
nth *= 2;
|
|
5114
|
+
}
|
|
5115
|
+
|
|
5116
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
5117
|
+
|
|
5118
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
5119
|
+
int nrptg = 1;
|
|
5120
|
+
|
|
5121
|
+
// TODO: relax this constraint in the future
|
|
5122
|
+
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
|
|
5123
|
+
if (nth > nk00) {
|
|
5124
|
+
nrptg = (nth + nk00 - 1)/nk00;
|
|
5125
|
+
nth = nk00;
|
|
5126
|
+
|
|
5127
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5128
|
+
nrptg--;
|
|
5129
|
+
}
|
|
5130
|
+
}
|
|
5131
|
+
}
|
|
5132
|
+
|
|
5133
|
+
nth = MIN(nth, nk00);
|
|
5134
|
+
|
|
4953
5135
|
ggml_metal_kargs_cpy args = {
|
|
4954
|
-
/*.ne00 =*/
|
|
5136
|
+
/*.ne00 =*/ nk00,
|
|
4955
5137
|
/*.ne01 =*/ ne01,
|
|
4956
5138
|
/*.ne02 =*/ ne02,
|
|
4957
5139
|
/*.ne03 =*/ ne03,
|
|
@@ -4974,11 +5156,7 @@ static bool ggml_metal_encode_node(
|
|
|
4974
5156
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4975
5157
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
4976
5158
|
|
|
4977
|
-
|
|
4978
|
-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
|
4979
|
-
|
|
4980
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4981
|
-
|
|
5159
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
4982
5160
|
} break;
|
|
4983
5161
|
case GGML_OP_SET:
|
|
4984
5162
|
{
|
|
@@ -5284,7 +5462,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
|
5284
5462
|
}
|
|
5285
5463
|
|
|
5286
5464
|
ggml_backend_metal_buffer_rset_free(ctx);
|
|
5287
|
-
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
5288
5465
|
|
|
5289
5466
|
if (ctx->owned) {
|
|
5290
5467
|
#if TARGET_OS_OSX
|
|
@@ -5393,7 +5570,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5393
5570
|
}
|
|
5394
5571
|
|
|
5395
5572
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
|
5396
|
-
|
|
5573
|
+
|
|
5574
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5575
|
+
|
|
5576
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5397
5577
|
|
|
5398
5578
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
|
5399
5579
|
ctx->all_size = size_aligned;
|
|
@@ -5416,14 +5596,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5416
5596
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
|
5417
5597
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
5418
5598
|
free(ctx);
|
|
5419
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5420
5599
|
return NULL;
|
|
5421
5600
|
}
|
|
5422
5601
|
|
|
5423
5602
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5424
5603
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5425
5604
|
free(ctx);
|
|
5426
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5427
5605
|
return NULL;
|
|
5428
5606
|
}
|
|
5429
5607
|
|
|
@@ -5434,17 +5612,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5434
5612
|
|
|
5435
5613
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
5436
5614
|
return 32;
|
|
5615
|
+
|
|
5437
5616
|
GGML_UNUSED(buft);
|
|
5438
5617
|
}
|
|
5439
5618
|
|
|
5440
5619
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
5441
|
-
|
|
5442
|
-
const size_t max_size = device.maxBufferLength;
|
|
5443
|
-
ggml_backend_metal_device_rel(buft->device->context);
|
|
5620
|
+
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
|
5444
5621
|
|
|
5445
5622
|
return max_size;
|
|
5446
|
-
|
|
5447
|
-
GGML_UNUSED(buft);
|
|
5448
5623
|
}
|
|
5449
5624
|
|
|
5450
5625
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
@@ -5517,7 +5692,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
5517
5692
|
}
|
|
5518
5693
|
|
|
5519
5694
|
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
|
5520
|
-
|
|
5695
|
+
|
|
5696
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5697
|
+
|
|
5698
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5521
5699
|
|
|
5522
5700
|
// the buffer fits into the max buffer size allowed by the device
|
|
5523
5701
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5573,7 +5751,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
5573
5751
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5574
5752
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5575
5753
|
free(ctx);
|
|
5576
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5577
5754
|
return NULL;
|
|
5578
5755
|
}
|
|
5579
5756
|
|
|
@@ -5589,10 +5766,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
|
5589
5766
|
}
|
|
5590
5767
|
|
|
5591
5768
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
5592
|
-
struct ggml_backend_metal_context
|
|
5593
|
-
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5769
|
+
struct ggml_backend_metal_context * ctx = backend->context;
|
|
5594
5770
|
|
|
5595
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5596
5771
|
ggml_metal_free(ctx);
|
|
5597
5772
|
|
|
5598
5773
|
free(backend);
|
|
@@ -5732,6 +5907,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
|
5732
5907
|
|
|
5733
5908
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5734
5909
|
|
|
5910
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5911
|
+
|
|
5735
5912
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
5736
5913
|
}
|
|
5737
5914
|
|
|
@@ -5751,10 +5928,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
|
|
5751
5928
|
}
|
|
5752
5929
|
|
|
5753
5930
|
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
|
5754
|
-
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
5755
5931
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5756
|
-
ggml_backend_metal_device_acq(ctx_dev);
|
|
5757
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5758
5932
|
|
|
5759
5933
|
return ctx_dev->name;
|
|
5760
5934
|
}
|
|
@@ -5762,12 +5936,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
|
|
|
5762
5936
|
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
5763
5937
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
5764
5938
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5765
|
-
id<MTLDevice> device =
|
|
5939
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5766
5940
|
|
|
5767
5941
|
*total = device.recommendedMaxWorkingSetSize;
|
|
5768
5942
|
*free = *total - device.currentAllocatedSize;
|
|
5769
|
-
|
|
5770
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5771
5943
|
} else {
|
|
5772
5944
|
*free = 1;
|
|
5773
5945
|
*total = 1;
|
|
@@ -5845,7 +6017,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
|
5845
6017
|
}
|
|
5846
6018
|
|
|
5847
6019
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5848
|
-
|
|
6020
|
+
|
|
6021
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
6022
|
+
|
|
6023
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5849
6024
|
|
|
5850
6025
|
// the buffer fits into the max buffer size allowed by the device
|
|
5851
6026
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5901,7 +6076,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
|
5901
6076
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5902
6077
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5903
6078
|
free(ctx);
|
|
5904
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5905
6079
|
return NULL;
|
|
5906
6080
|
}
|
|
5907
6081
|
|
|
@@ -5915,8 +6089,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
5915
6089
|
}
|
|
5916
6090
|
|
|
5917
6091
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
5918
|
-
return
|
|
5919
|
-
|
|
6092
|
+
return
|
|
6093
|
+
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
|
6094
|
+
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
|
5920
6095
|
|
|
5921
6096
|
GGML_UNUSED(dev);
|
|
5922
6097
|
}
|
|
@@ -6001,8 +6176,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
|
|
6001
6176
|
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
|
6002
6177
|
};
|
|
6003
6178
|
|
|
6179
|
+
// called upon program exit
|
|
6180
|
+
static void ggml_metal_cleanup(void) {
|
|
6181
|
+
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
|
|
6182
|
+
}
|
|
6183
|
+
|
|
6184
|
+
// TODO: make thread-safe
|
|
6004
6185
|
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
|
6005
|
-
|
|
6186
|
+
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
|
6187
|
+
|
|
6188
|
+
// register cleanup callback
|
|
6189
|
+
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
|
6190
|
+
atexit(ggml_metal_cleanup);
|
|
6191
|
+
|
|
6006
6192
|
{
|
|
6007
6193
|
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
|
6008
6194
|
/* .api_version = */ GGML_BACKEND_API_VERSION,
|