@novastera-oss/llamarn 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +4 -3
- package/cpp/llama.cpp/common/arg.cpp +45 -1
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +18 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
- package/cpp/llama.cpp/include/llama.h +15 -7
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
- package/cpp/llama.cpp/src/llama-arch.h +5 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
- package/cpp/llama.cpp/src/llama-batch.h +24 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
- package/cpp/llama.cpp/src/llama-chat.h +2 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
- package/cpp/llama.cpp/src/llama-graph.h +147 -72
- package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
- package/cpp/llama.cpp/src/llama-hparams.h +10 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
- package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
- package/cpp/llama.cpp/src/llama-model.h +3 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
- package/cpp/llama.cpp/src/llama-vocab.h +2 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/common.h +18 -4
- package/ios/include/llama.h +15 -7
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
|
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
|
|
|
55
55
|
bool has_residency_sets;
|
|
56
56
|
bool has_bfloat;
|
|
57
57
|
bool use_bfloat;
|
|
58
|
+
bool use_fusion;
|
|
59
|
+
|
|
60
|
+
int debug_fusion;
|
|
61
|
+
|
|
62
|
+
// how many times a given op was fused
|
|
63
|
+
uint64_t fuse_cnt[GGML_OP_COUNT];
|
|
58
64
|
|
|
59
65
|
size_t max_size;
|
|
60
66
|
|
|
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
|
|
|
69
75
|
/*.has_residency_sets =*/ false,
|
|
70
76
|
/*.has_bfloat =*/ false,
|
|
71
77
|
/*.use_bfloat =*/ false,
|
|
78
|
+
/*.use_fusion =*/ true,
|
|
79
|
+
/*.debug_fusion =*/ 0,
|
|
80
|
+
/*.fuse_cnt =*/ { 0 },
|
|
72
81
|
/*.max_size =*/ 0,
|
|
73
82
|
/*.name =*/ "",
|
|
74
83
|
};
|
|
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
83
92
|
|
|
84
93
|
if (ctx->mtl_device == nil) {
|
|
85
94
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
86
|
-
}
|
|
87
95
|
|
|
88
|
-
if (ctx->mtl_device) {
|
|
89
96
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
90
97
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
91
98
|
|
|
92
99
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
93
100
|
|
|
94
101
|
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
|
95
|
-
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") ==
|
|
102
|
+
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
|
96
103
|
#endif
|
|
97
104
|
|
|
98
105
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
103
110
|
#else
|
|
104
111
|
ctx->use_bfloat = false;
|
|
105
112
|
#endif
|
|
113
|
+
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
|
114
|
+
|
|
115
|
+
{
|
|
116
|
+
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
|
117
|
+
ctx->debug_fusion = val ? atoi(val) : 0;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
|
106
121
|
|
|
107
122
|
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
123
|
|
|
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
|
122
137
|
ctx->mtl_device_ref_count--;
|
|
123
138
|
|
|
124
139
|
if (ctx->mtl_device_ref_count == 0) {
|
|
140
|
+
if (ctx->debug_fusion > 0) {
|
|
141
|
+
fprintf(stderr, "%s: fusion stats:\n", __func__);
|
|
142
|
+
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
|
143
|
+
if (ctx->fuse_cnt[i] == 0) {
|
|
144
|
+
continue;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// note: cannot use ggml_log here
|
|
148
|
+
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
125
152
|
if (ctx->mtl_lock) {
|
|
126
153
|
[ctx->mtl_lock release];
|
|
127
154
|
ctx->mtl_lock = nil;
|
|
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
|
|
|
147
174
|
|
|
148
175
|
enum ggml_metal_kernel_type {
|
|
149
176
|
GGML_METAL_KERNEL_TYPE_ADD,
|
|
150
|
-
|
|
177
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
|
178
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
|
179
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
|
180
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
|
181
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
|
182
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
|
183
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
|
184
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
|
185
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
|
186
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
|
187
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
|
188
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
|
189
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
|
190
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
|
191
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
|
151
192
|
GGML_METAL_KERNEL_TYPE_SUB,
|
|
152
|
-
|
|
193
|
+
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
|
153
194
|
GGML_METAL_KERNEL_TYPE_MUL,
|
|
154
|
-
|
|
195
|
+
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
|
155
196
|
GGML_METAL_KERNEL_TYPE_DIV,
|
|
156
|
-
|
|
197
|
+
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
|
157
198
|
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
|
158
199
|
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
|
159
200
|
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
@@ -173,6 +214,12 @@ enum ggml_metal_kernel_type {
|
|
|
173
214
|
GGML_METAL_KERNEL_TYPE_SILU,
|
|
174
215
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
|
175
216
|
GGML_METAL_KERNEL_TYPE_ELU,
|
|
217
|
+
GGML_METAL_KERNEL_TYPE_ABS,
|
|
218
|
+
GGML_METAL_KERNEL_TYPE_SGN,
|
|
219
|
+
GGML_METAL_KERNEL_TYPE_STEP,
|
|
220
|
+
GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
|
221
|
+
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
|
222
|
+
GGML_METAL_KERNEL_TYPE_EXP,
|
|
176
223
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
|
177
224
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
|
178
225
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
@@ -212,6 +259,8 @@ enum ggml_metal_kernel_type {
|
|
|
212
259
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
260
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
214
261
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
262
|
+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
|
263
|
+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
|
215
264
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
216
265
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
217
266
|
GGML_METAL_KERNEL_TYPE_NORM,
|
|
@@ -1129,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1129
1178
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
1130
1179
|
|
|
1131
1180
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
1132
|
-
GGML_METAL_ADD_KERNEL(
|
|
1181
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
|
1182
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
|
1183
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
|
1184
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
|
1185
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
|
1186
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
|
1190
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
|
1191
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
|
1192
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
|
1193
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
|
1194
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
|
1195
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
|
1133
1196
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
|
1134
|
-
GGML_METAL_ADD_KERNEL(
|
|
1197
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
|
1135
1198
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
1136
|
-
GGML_METAL_ADD_KERNEL(
|
|
1199
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
|
1137
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
1138
|
-
GGML_METAL_ADD_KERNEL(
|
|
1201
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
|
1139
1202
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
|
1140
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
|
1141
1204
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
@@ -1155,6 +1218,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1155
1218
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
1156
1219
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
|
1157
1220
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
|
1221
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
|
1222
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
|
1223
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
|
1224
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
|
1225
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
|
1226
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
|
1158
1227
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
|
1159
1228
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
|
1160
1229
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
@@ -1194,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1194
1263
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1195
1264
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1196
1265
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1266
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
|
1267
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
|
1197
1268
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1198
1269
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
1199
1270
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
@@ -1688,6 +1759,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1688
1759
|
case GGML_UNARY_OP_SILU:
|
|
1689
1760
|
case GGML_UNARY_OP_ELU:
|
|
1690
1761
|
case GGML_UNARY_OP_NEG:
|
|
1762
|
+
case GGML_UNARY_OP_ABS:
|
|
1763
|
+
case GGML_UNARY_OP_SGN:
|
|
1764
|
+
case GGML_UNARY_OP_STEP:
|
|
1765
|
+
case GGML_UNARY_OP_HARDSWISH:
|
|
1766
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
1767
|
+
case GGML_UNARY_OP_EXP:
|
|
1691
1768
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
1692
1769
|
default:
|
|
1693
1770
|
return false;
|
|
@@ -1875,9 +1952,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1875
1952
|
}
|
|
1876
1953
|
}
|
|
1877
1954
|
|
|
1878
|
-
static
|
|
1955
|
+
static int ggml_metal_encode_node(
|
|
1879
1956
|
ggml_backend_t backend,
|
|
1880
1957
|
int idx,
|
|
1958
|
+
int idx_end,
|
|
1881
1959
|
id<MTLComputeCommandEncoder> encoder,
|
|
1882
1960
|
struct ggml_metal_mem_pool * mem_pool) {
|
|
1883
1961
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
@@ -1885,7 +1963,10 @@ static bool ggml_metal_encode_node(
|
|
|
1885
1963
|
|
|
1886
1964
|
struct ggml_cgraph * gf = ctx->gf;
|
|
1887
1965
|
|
|
1888
|
-
|
|
1966
|
+
enum ggml_op ops[8];
|
|
1967
|
+
|
|
1968
|
+
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
|
|
1969
|
+
struct ggml_tensor * node = nodes[0];
|
|
1889
1970
|
|
|
1890
1971
|
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
|
1891
1972
|
|
|
@@ -1895,7 +1976,7 @@ static bool ggml_metal_encode_node(
|
|
|
1895
1976
|
struct ggml_tensor * dst = node;
|
|
1896
1977
|
|
|
1897
1978
|
if (ggml_is_empty(dst)) {
|
|
1898
|
-
return
|
|
1979
|
+
return 1;
|
|
1899
1980
|
}
|
|
1900
1981
|
|
|
1901
1982
|
switch (dst->op) {
|
|
@@ -1906,7 +1987,7 @@ static bool ggml_metal_encode_node(
|
|
|
1906
1987
|
case GGML_OP_PERMUTE:
|
|
1907
1988
|
{
|
|
1908
1989
|
// noop -> next node
|
|
1909
|
-
} return
|
|
1990
|
+
} return 1;
|
|
1910
1991
|
default:
|
|
1911
1992
|
{
|
|
1912
1993
|
} break;
|
|
@@ -1973,6 +2054,8 @@ static bool ggml_metal_encode_node(
|
|
|
1973
2054
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
|
1974
2055
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
1975
2056
|
|
|
2057
|
+
int n_fuse = 1;
|
|
2058
|
+
|
|
1976
2059
|
#if 0
|
|
1977
2060
|
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
1978
2061
|
if (src0) {
|
|
@@ -2044,37 +2127,15 @@ static bool ggml_metal_encode_node(
|
|
|
2044
2127
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
2045
2128
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
2046
2129
|
|
|
2130
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
|
2131
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src1));
|
|
2132
|
+
|
|
2047
2133
|
const size_t offs = 0;
|
|
2048
2134
|
|
|
2049
2135
|
bool bcast_row = false;
|
|
2050
2136
|
|
|
2051
2137
|
id<MTLComputePipelineState> pipeline = nil;
|
|
2052
2138
|
|
|
2053
|
-
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2054
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2055
|
-
|
|
2056
|
-
// src1 is a row
|
|
2057
|
-
GGML_ASSERT(ne11 == 1);
|
|
2058
|
-
|
|
2059
|
-
switch (dst->op) {
|
|
2060
|
-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
2061
|
-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
|
2062
|
-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
|
2063
|
-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
|
2064
|
-
default: GGML_ABORT("fatal error");
|
|
2065
|
-
}
|
|
2066
|
-
|
|
2067
|
-
bcast_row = true;
|
|
2068
|
-
} else {
|
|
2069
|
-
switch (dst->op) {
|
|
2070
|
-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
2071
|
-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
2072
|
-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
2073
|
-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
2074
|
-
default: GGML_ABORT("fatal error");
|
|
2075
|
-
}
|
|
2076
|
-
}
|
|
2077
|
-
|
|
2078
2139
|
ggml_metal_kargs_bin args = {
|
|
2079
2140
|
/*.ne00 =*/ ne00,
|
|
2080
2141
|
/*.ne01 =*/ ne01,
|
|
@@ -2101,12 +2162,119 @@ static bool ggml_metal_encode_node(
|
|
|
2101
2162
|
/*.nb2 =*/ nb2,
|
|
2102
2163
|
/*.nb3 =*/ nb3,
|
|
2103
2164
|
/*.offs =*/ offs,
|
|
2165
|
+
/*.o1 =*/ { offs_src1 },
|
|
2104
2166
|
};
|
|
2105
2167
|
|
|
2168
|
+
// c[0] = add(a, b[0])
|
|
2169
|
+
// c[1] = add(c[0], b[1])
|
|
2170
|
+
// c[2] = add(c[1], b[2])
|
|
2171
|
+
// ...
|
|
2172
|
+
if (ctx_dev->use_fusion) {
|
|
2173
|
+
ops[0] = GGML_OP_ADD;
|
|
2174
|
+
ops[1] = GGML_OP_ADD;
|
|
2175
|
+
ops[2] = GGML_OP_ADD;
|
|
2176
|
+
ops[3] = GGML_OP_ADD;
|
|
2177
|
+
ops[4] = GGML_OP_ADD;
|
|
2178
|
+
ops[5] = GGML_OP_ADD;
|
|
2179
|
+
ops[6] = GGML_OP_ADD;
|
|
2180
|
+
ops[7] = GGML_OP_ADD;
|
|
2181
|
+
|
|
2182
|
+
size_t offs_fuse;
|
|
2183
|
+
id<MTLBuffer> id_fuse;
|
|
2184
|
+
|
|
2185
|
+
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
|
|
2186
|
+
// across splits. idx_end indicates the last node in the current split
|
|
2187
|
+
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
2188
|
+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
|
2189
|
+
break;
|
|
2190
|
+
}
|
|
2191
|
+
|
|
2192
|
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
|
2193
|
+
break;
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
// b[0] === b[1] === ...
|
|
2197
|
+
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
|
|
2198
|
+
break;
|
|
2199
|
+
}
|
|
2200
|
+
|
|
2201
|
+
// only fuse nodes if src1 is in the same Metal buffer
|
|
2202
|
+
id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
|
|
2203
|
+
if (id_fuse != id_src1) {
|
|
2204
|
+
break;
|
|
2205
|
+
}
|
|
2206
|
+
|
|
2207
|
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
|
2208
|
+
|
|
2209
|
+
args.o1[n_fuse + 1] = offs_fuse;
|
|
2210
|
+
}
|
|
2211
|
+
|
|
2212
|
+
++n_fuse;
|
|
2213
|
+
|
|
2214
|
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
|
2215
|
+
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
|
2216
|
+
}
|
|
2217
|
+
}
|
|
2218
|
+
|
|
2219
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2220
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2221
|
+
|
|
2222
|
+
// src1 is a row
|
|
2223
|
+
GGML_ASSERT(ne11 == 1);
|
|
2224
|
+
|
|
2225
|
+
switch (dst->op) {
|
|
2226
|
+
case GGML_OP_ADD:
|
|
2227
|
+
{
|
|
2228
|
+
switch (n_fuse) {
|
|
2229
|
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
|
2230
|
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
|
2231
|
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
|
2232
|
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
|
2233
|
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
|
2234
|
+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
|
2235
|
+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
|
2236
|
+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
|
2237
|
+
default: GGML_ABORT("fatal error");
|
|
2238
|
+
}
|
|
2239
|
+
} break;
|
|
2240
|
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
|
2241
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
|
2242
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
|
2243
|
+
default: GGML_ABORT("fatal error");
|
|
2244
|
+
}
|
|
2245
|
+
|
|
2246
|
+
bcast_row = true;
|
|
2247
|
+
} else {
|
|
2248
|
+
switch (dst->op) {
|
|
2249
|
+
case GGML_OP_ADD:
|
|
2250
|
+
{
|
|
2251
|
+
switch (n_fuse) {
|
|
2252
|
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
|
2253
|
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
|
2254
|
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
|
2255
|
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
|
2256
|
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
|
2257
|
+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
|
2258
|
+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
|
2259
|
+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
|
2260
|
+
default: GGML_ABORT("fatal error");
|
|
2261
|
+
}
|
|
2262
|
+
} break;
|
|
2263
|
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
2264
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
2265
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
2266
|
+
default: GGML_ABORT("fatal error");
|
|
2267
|
+
}
|
|
2268
|
+
}
|
|
2269
|
+
|
|
2270
|
+
if (n_fuse > 1) {
|
|
2271
|
+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
2272
|
+
}
|
|
2273
|
+
|
|
2106
2274
|
[encoder setComputePipelineState:pipeline];
|
|
2107
2275
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2108
2276
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2109
|
-
[encoder setBuffer:id_src1 offset:
|
|
2277
|
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
|
2110
2278
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2111
2279
|
|
|
2112
2280
|
if (bcast_row) {
|
|
@@ -2114,7 +2282,11 @@ static bool ggml_metal_encode_node(
|
|
|
2114
2282
|
|
|
2115
2283
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2116
2284
|
} else {
|
|
2117
|
-
|
|
2285
|
+
int nth = 32;
|
|
2286
|
+
|
|
2287
|
+
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
2288
|
+
nth *= 2;
|
|
2289
|
+
}
|
|
2118
2290
|
|
|
2119
2291
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2120
2292
|
}
|
|
@@ -2239,12 +2411,13 @@ static bool ggml_metal_encode_node(
|
|
|
2239
2411
|
/*.nb2 =*/ pnb2,
|
|
2240
2412
|
/*.nb3 =*/ pnb3,
|
|
2241
2413
|
/*.offs =*/ offs,
|
|
2414
|
+
/*.o1 =*/ { offs_src1},
|
|
2242
2415
|
};
|
|
2243
2416
|
|
|
2244
2417
|
[encoder setComputePipelineState:pipeline];
|
|
2245
2418
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2246
2419
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2247
|
-
[encoder setBuffer:id_src1 offset:
|
|
2420
|
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
|
2248
2421
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2249
2422
|
|
|
2250
2423
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
@@ -2439,6 +2612,78 @@ static bool ggml_metal_encode_node(
|
|
|
2439
2612
|
|
|
2440
2613
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2441
2614
|
} break;
|
|
2615
|
+
case GGML_UNARY_OP_ABS:
|
|
2616
|
+
{
|
|
2617
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
|
2618
|
+
|
|
2619
|
+
[encoder setComputePipelineState:pipeline];
|
|
2620
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2621
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2622
|
+
|
|
2623
|
+
const int64_t n = ggml_nelements(dst);
|
|
2624
|
+
|
|
2625
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2626
|
+
} break;
|
|
2627
|
+
case GGML_UNARY_OP_SGN:
|
|
2628
|
+
{
|
|
2629
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
|
2630
|
+
|
|
2631
|
+
[encoder setComputePipelineState:pipeline];
|
|
2632
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2633
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2634
|
+
|
|
2635
|
+
const int64_t n = ggml_nelements(dst);
|
|
2636
|
+
|
|
2637
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2638
|
+
} break;
|
|
2639
|
+
case GGML_UNARY_OP_STEP:
|
|
2640
|
+
{
|
|
2641
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
|
2642
|
+
|
|
2643
|
+
[encoder setComputePipelineState:pipeline];
|
|
2644
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2645
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2646
|
+
|
|
2647
|
+
const int64_t n = ggml_nelements(dst);
|
|
2648
|
+
|
|
2649
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2650
|
+
} break;
|
|
2651
|
+
case GGML_UNARY_OP_HARDSWISH:
|
|
2652
|
+
{
|
|
2653
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
|
2654
|
+
|
|
2655
|
+
[encoder setComputePipelineState:pipeline];
|
|
2656
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2657
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2658
|
+
|
|
2659
|
+
const int64_t n = ggml_nelements(dst);
|
|
2660
|
+
|
|
2661
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2662
|
+
} break;
|
|
2663
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
2664
|
+
{
|
|
2665
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
|
2666
|
+
|
|
2667
|
+
[encoder setComputePipelineState:pipeline];
|
|
2668
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2669
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2670
|
+
|
|
2671
|
+
const int64_t n = ggml_nelements(dst);
|
|
2672
|
+
|
|
2673
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2674
|
+
} break;
|
|
2675
|
+
case GGML_UNARY_OP_EXP:
|
|
2676
|
+
{
|
|
2677
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
|
2678
|
+
|
|
2679
|
+
[encoder setComputePipelineState:pipeline];
|
|
2680
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2681
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2682
|
+
|
|
2683
|
+
const int64_t n = ggml_nelements(dst);
|
|
2684
|
+
|
|
2685
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2686
|
+
} break;
|
|
2442
2687
|
default:
|
|
2443
2688
|
{
|
|
2444
2689
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
|
@@ -2674,7 +2919,7 @@ static bool ggml_metal_encode_node(
|
|
|
2674
2919
|
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
|
2675
2920
|
if (!h_src0) {
|
|
2676
2921
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
|
2677
|
-
return
|
|
2922
|
+
return 0;
|
|
2678
2923
|
}
|
|
2679
2924
|
|
|
2680
2925
|
offs_src0 = 0;
|
|
@@ -2896,6 +3141,7 @@ static bool ggml_metal_encode_node(
|
|
|
2896
3141
|
/*.n_group =*/ n_group,
|
|
2897
3142
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
2898
3143
|
/*.n_seqs =*/ n_seqs,
|
|
3144
|
+
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
|
2899
3145
|
/*.nb01 =*/ nb01,
|
|
2900
3146
|
/*.nb02 =*/ nb02,
|
|
2901
3147
|
/*.nb03 =*/ nb03,
|
|
@@ -2924,12 +3170,22 @@ static bool ggml_metal_encode_node(
|
|
|
2924
3170
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
|
2925
3171
|
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
|
2926
3172
|
|
|
3173
|
+
// One shared memory bucket for each simd group in the threadgroup
|
|
3174
|
+
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3175
|
+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3176
|
+
if (d_state >= 32) {
|
|
3177
|
+
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
|
3178
|
+
const int64_t shmem_size = 32;
|
|
3179
|
+
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
|
3180
|
+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
|
3181
|
+
}
|
|
3182
|
+
|
|
2927
3183
|
if (ne30 == 1) {
|
|
2928
3184
|
// Mamba-2
|
|
2929
|
-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(
|
|
3185
|
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
|
2930
3186
|
} else {
|
|
2931
3187
|
GGML_ASSERT(d_inner == 1);
|
|
2932
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(
|
|
3188
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
|
2933
3189
|
}
|
|
2934
3190
|
} break;
|
|
2935
3191
|
case GGML_OP_RWKV_WKV6:
|
|
@@ -3550,7 +3806,7 @@ static bool ggml_metal_encode_node(
|
|
|
3550
3806
|
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
|
3551
3807
|
if (!h_src1) {
|
|
3552
3808
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
|
3553
|
-
return
|
|
3809
|
+
return 0;
|
|
3554
3810
|
}
|
|
3555
3811
|
|
|
3556
3812
|
const int64_t neh0 = ne0;
|
|
@@ -3566,7 +3822,7 @@ static bool ggml_metal_encode_node(
|
|
|
3566
3822
|
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
|
3567
3823
|
if (!h_dst) {
|
|
3568
3824
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
|
3569
|
-
return
|
|
3825
|
+
return 0;
|
|
3570
3826
|
}
|
|
3571
3827
|
|
|
3572
3828
|
// tokens per expert
|
|
@@ -3574,7 +3830,7 @@ static bool ggml_metal_encode_node(
|
|
|
3574
3830
|
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
|
3575
3831
|
if (!h_tpe) {
|
|
3576
3832
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
|
3577
|
-
return
|
|
3833
|
+
return 0;
|
|
3578
3834
|
}
|
|
3579
3835
|
|
|
3580
3836
|
// id map
|
|
@@ -3583,7 +3839,7 @@ static bool ggml_metal_encode_node(
|
|
|
3583
3839
|
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
|
3584
3840
|
if (!h_ids) {
|
|
3585
3841
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
|
3586
|
-
return
|
|
3842
|
+
return 0;
|
|
3587
3843
|
}
|
|
3588
3844
|
|
|
3589
3845
|
{
|
|
@@ -4015,12 +4271,95 @@ static bool ggml_metal_encode_node(
|
|
|
4015
4271
|
case GGML_OP_RMS_NORM:
|
|
4016
4272
|
{
|
|
4017
4273
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
4018
|
-
GGML_ASSERT(
|
|
4274
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
|
4019
4275
|
|
|
4020
4276
|
float eps;
|
|
4021
4277
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
4022
4278
|
|
|
4023
|
-
|
|
4279
|
+
ggml_metal_kargs_rms_norm args = {
|
|
4280
|
+
/*.ne00 =*/ ne00,
|
|
4281
|
+
/*.ne00_4 =*/ ne00/4,
|
|
4282
|
+
/*.nb1 =*/ nb1,
|
|
4283
|
+
/*.nb2 =*/ nb2,
|
|
4284
|
+
/*.nb3 =*/ nb3,
|
|
4285
|
+
/*.eps =*/ eps,
|
|
4286
|
+
/*.nef1 =*/ { ne01 },
|
|
4287
|
+
/*.nef2 =*/ { ne02 },
|
|
4288
|
+
/*.nef3 =*/ { ne03 },
|
|
4289
|
+
/*.nbf1 =*/ { nb01 },
|
|
4290
|
+
/*.nbf2 =*/ { nb02 },
|
|
4291
|
+
/*.nbf3 =*/ { nb03 },
|
|
4292
|
+
};
|
|
4293
|
+
|
|
4294
|
+
size_t offs_fuse[2] = { 0, 0 };
|
|
4295
|
+
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
|
|
4296
|
+
|
|
4297
|
+
// d[0] = rms_norm(a)
|
|
4298
|
+
// d[1] = mul(d[0], b)
|
|
4299
|
+
// d[2] = add(d[1], c)
|
|
4300
|
+
if (ctx_dev->use_fusion) {
|
|
4301
|
+
ops[0] = GGML_OP_RMS_NORM;
|
|
4302
|
+
ops[1] = GGML_OP_MUL;
|
|
4303
|
+
ops[2] = GGML_OP_ADD;
|
|
4304
|
+
|
|
4305
|
+
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
4306
|
+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
|
4307
|
+
break;
|
|
4308
|
+
}
|
|
4309
|
+
|
|
4310
|
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
|
4311
|
+
break;
|
|
4312
|
+
}
|
|
4313
|
+
|
|
4314
|
+
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
|
|
4315
|
+
break;
|
|
4316
|
+
}
|
|
4317
|
+
|
|
4318
|
+
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
|
|
4319
|
+
break;
|
|
4320
|
+
}
|
|
4321
|
+
|
|
4322
|
+
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
|
|
4323
|
+
break;
|
|
4324
|
+
}
|
|
4325
|
+
|
|
4326
|
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
|
4327
|
+
|
|
4328
|
+
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
|
|
4329
|
+
|
|
4330
|
+
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
|
|
4331
|
+
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
|
|
4332
|
+
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
|
|
4333
|
+
|
|
4334
|
+
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
|
|
4335
|
+
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
|
|
4336
|
+
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
|
|
4337
|
+
}
|
|
4338
|
+
|
|
4339
|
+
++n_fuse;
|
|
4340
|
+
|
|
4341
|
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
|
4342
|
+
if (n_fuse == 2) {
|
|
4343
|
+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
|
4344
|
+
}
|
|
4345
|
+
if (n_fuse == 3) {
|
|
4346
|
+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
|
4347
|
+
}
|
|
4348
|
+
}
|
|
4349
|
+
}
|
|
4350
|
+
|
|
4351
|
+
if (n_fuse > 1) {
|
|
4352
|
+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
4353
|
+
}
|
|
4354
|
+
|
|
4355
|
+
id<MTLComputePipelineState> pipeline;
|
|
4356
|
+
|
|
4357
|
+
switch (n_fuse) {
|
|
4358
|
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
|
4359
|
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
|
4360
|
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
|
4361
|
+
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
|
4362
|
+
}
|
|
4024
4363
|
|
|
4025
4364
|
int nth = 32; // SIMD width
|
|
4026
4365
|
|
|
@@ -4031,23 +4370,16 @@ static bool ggml_metal_encode_node(
|
|
|
4031
4370
|
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
4032
4371
|
nth = MIN(nth, ne00/4);
|
|
4033
4372
|
|
|
4034
|
-
ggml_metal_kargs_rms_norm args = {
|
|
4035
|
-
/*.ne00 =*/ ne00,
|
|
4036
|
-
/*.ne00_4 =*/ ne00/4,
|
|
4037
|
-
/*.nb01 =*/ nb01,
|
|
4038
|
-
/*.eps =*/ eps,
|
|
4039
|
-
};
|
|
4040
|
-
|
|
4041
4373
|
[encoder setComputePipelineState:pipeline];
|
|
4042
|
-
[encoder setBytes:&args length:sizeof(args)
|
|
4043
|
-
[encoder setBuffer:id_src0
|
|
4044
|
-
[encoder setBuffer:
|
|
4374
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
4375
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4376
|
+
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
|
|
4377
|
+
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
|
|
4378
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
4045
4379
|
|
|
4046
4380
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
4047
4381
|
|
|
4048
|
-
|
|
4049
|
-
|
|
4050
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4382
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4051
4383
|
} break;
|
|
4052
4384
|
case GGML_OP_L2_NORM:
|
|
4053
4385
|
{
|
|
@@ -5442,7 +5774,7 @@ static bool ggml_metal_encode_node(
|
|
|
5442
5774
|
}
|
|
5443
5775
|
}
|
|
5444
5776
|
|
|
5445
|
-
return
|
|
5777
|
+
return n_fuse;
|
|
5446
5778
|
}
|
|
5447
5779
|
|
|
5448
5780
|
static enum ggml_status ggml_metal_graph_compute(
|
|
@@ -5948,20 +6280,26 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
5948
6280
|
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
|
5949
6281
|
ggml_metal_mem_pool_reset(mem_pool);
|
|
5950
6282
|
|
|
5951
|
-
for (int idx = node_start; idx < node_end;
|
|
6283
|
+
for (int idx = node_start; idx < node_end;) {
|
|
5952
6284
|
if (should_capture) {
|
|
5953
6285
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
5954
6286
|
}
|
|
5955
6287
|
|
|
5956
|
-
const
|
|
6288
|
+
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
|
6289
|
+
if (idx + res > node_end) {
|
|
6290
|
+
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
|
6291
|
+
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
|
6292
|
+
}
|
|
5957
6293
|
|
|
5958
6294
|
if (should_capture) {
|
|
5959
6295
|
[encoder popDebugGroup];
|
|
5960
6296
|
}
|
|
5961
6297
|
|
|
5962
|
-
if (
|
|
6298
|
+
if (res == 0) {
|
|
5963
6299
|
break;
|
|
5964
6300
|
}
|
|
6301
|
+
|
|
6302
|
+
idx += res;
|
|
5965
6303
|
}
|
|
5966
6304
|
|
|
5967
6305
|
[encoder endEncoding];
|