@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
55
55
  bool has_residency_sets;
56
56
  bool has_bfloat;
57
57
  bool use_bfloat;
58
+ bool use_fusion;
59
+
60
+ int debug_fusion;
61
+
62
+ // how many times a given op was fused
63
+ uint64_t fuse_cnt[GGML_OP_COUNT];
58
64
 
59
65
  size_t max_size;
60
66
 
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
69
75
  /*.has_residency_sets =*/ false,
70
76
  /*.has_bfloat =*/ false,
71
77
  /*.use_bfloat =*/ false,
78
+ /*.use_fusion =*/ true,
79
+ /*.debug_fusion =*/ 0,
80
+ /*.fuse_cnt =*/ { 0 },
72
81
  /*.max_size =*/ 0,
73
82
  /*.name =*/ "",
74
83
  };
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
83
92
 
84
93
  if (ctx->mtl_device == nil) {
85
94
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
86
- }
87
95
 
88
- if (ctx->mtl_device) {
89
96
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
90
97
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
91
98
 
92
99
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
93
100
 
94
101
  #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95
- ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
102
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
96
103
  #endif
97
104
 
98
105
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
103
110
  #else
104
111
  ctx->use_bfloat = false;
105
112
  #endif
113
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
114
+
115
+ {
116
+ const char * val = getenv("GGML_METAL_FUSION_DEBUG");
117
+ ctx->debug_fusion = val ? atoi(val) : 0;
118
+ }
119
+
120
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
106
121
 
107
122
  ctx->max_size = ctx->mtl_device.maxBufferLength;
108
123
 
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122
137
  ctx->mtl_device_ref_count--;
123
138
 
124
139
  if (ctx->mtl_device_ref_count == 0) {
140
+ if (ctx->debug_fusion > 0) {
141
+ fprintf(stderr, "%s: fusion stats:\n", __func__);
142
+ for (int i = 0; i < GGML_OP_COUNT; i++) {
143
+ if (ctx->fuse_cnt[i] == 0) {
144
+ continue;
145
+ }
146
+
147
+ // note: cannot use ggml_log here
148
+ fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
149
+ }
150
+ }
151
+
125
152
  if (ctx->mtl_lock) {
126
153
  [ctx->mtl_lock release];
127
154
  ctx->mtl_lock = nil;
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
147
174
 
148
175
  enum ggml_metal_kernel_type {
149
176
  GGML_METAL_KERNEL_TYPE_ADD,
150
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
177
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
178
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
179
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
180
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
181
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
182
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
183
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
184
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
185
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
186
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
187
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
188
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
189
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
190
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
191
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
151
192
  GGML_METAL_KERNEL_TYPE_SUB,
152
- GGML_METAL_KERNEL_TYPE_SUB_ROW,
193
+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
153
194
  GGML_METAL_KERNEL_TYPE_MUL,
154
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
195
+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
155
196
  GGML_METAL_KERNEL_TYPE_DIV,
156
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
197
+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
157
198
  GGML_METAL_KERNEL_TYPE_REPEAT_F32,
158
199
  GGML_METAL_KERNEL_TYPE_REPEAT_F16,
159
200
  GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -173,6 +214,12 @@ enum ggml_metal_kernel_type {
173
214
  GGML_METAL_KERNEL_TYPE_SILU,
174
215
  GGML_METAL_KERNEL_TYPE_SILU_4,
175
216
  GGML_METAL_KERNEL_TYPE_ELU,
217
+ GGML_METAL_KERNEL_TYPE_ABS,
218
+ GGML_METAL_KERNEL_TYPE_SGN,
219
+ GGML_METAL_KERNEL_TYPE_STEP,
220
+ GGML_METAL_KERNEL_TYPE_HARDSWISH,
221
+ GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
222
+ GGML_METAL_KERNEL_TYPE_EXP,
176
223
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177
224
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178
225
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -212,6 +259,8 @@ enum ggml_metal_kernel_type {
212
259
  GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
260
  GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
214
261
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
262
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
263
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
215
264
  GGML_METAL_KERNEL_TYPE_L2_NORM,
216
265
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
217
266
  GGML_METAL_KERNEL_TYPE_NORM,
@@ -1129,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1129
1178
  // simd_sum and simd_max requires MTLGPUFamilyApple7
1130
1179
 
1131
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1132
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
1190
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
1191
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
1192
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
1193
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
1194
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
1195
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
1133
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1134
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
1197
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
1135
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1136
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
1199
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1137
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1138
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
1201
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
1139
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1140
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1141
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1155,6 +1218,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1155
1218
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
1156
1219
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1157
1220
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1221
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
1222
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
1223
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
1224
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
1225
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
1226
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
1158
1227
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1159
1228
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1160
1229
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1194,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1194
1263
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1195
1264
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1196
1265
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1266
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1267
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1197
1268
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1198
1269
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1199
1270
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@@ -1688,6 +1759,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1688
1759
  case GGML_UNARY_OP_SILU:
1689
1760
  case GGML_UNARY_OP_ELU:
1690
1761
  case GGML_UNARY_OP_NEG:
1762
+ case GGML_UNARY_OP_ABS:
1763
+ case GGML_UNARY_OP_SGN:
1764
+ case GGML_UNARY_OP_STEP:
1765
+ case GGML_UNARY_OP_HARDSWISH:
1766
+ case GGML_UNARY_OP_HARDSIGMOID:
1767
+ case GGML_UNARY_OP_EXP:
1691
1768
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1692
1769
  default:
1693
1770
  return false;
@@ -1875,9 +1952,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1875
1952
  }
1876
1953
  }
1877
1954
 
1878
- static bool ggml_metal_encode_node(
1955
+ static int ggml_metal_encode_node(
1879
1956
  ggml_backend_t backend,
1880
1957
  int idx,
1958
+ int idx_end,
1881
1959
  id<MTLComputeCommandEncoder> encoder,
1882
1960
  struct ggml_metal_mem_pool * mem_pool) {
1883
1961
  struct ggml_backend_metal_context * ctx = backend->context;
@@ -1885,7 +1963,10 @@ static bool ggml_metal_encode_node(
1885
1963
 
1886
1964
  struct ggml_cgraph * gf = ctx->gf;
1887
1965
 
1888
- struct ggml_tensor * node = ggml_graph_node(gf, idx);
1966
+ enum ggml_op ops[8];
1967
+
1968
+ struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
1969
+ struct ggml_tensor * node = nodes[0];
1889
1970
 
1890
1971
  //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1891
1972
 
@@ -1895,7 +1976,7 @@ static bool ggml_metal_encode_node(
1895
1976
  struct ggml_tensor * dst = node;
1896
1977
 
1897
1978
  if (ggml_is_empty(dst)) {
1898
- return true;
1979
+ return 1;
1899
1980
  }
1900
1981
 
1901
1982
  switch (dst->op) {
@@ -1906,7 +1987,7 @@ static bool ggml_metal_encode_node(
1906
1987
  case GGML_OP_PERMUTE:
1907
1988
  {
1908
1989
  // noop -> next node
1909
- } return true;
1990
+ } return 1;
1910
1991
  default:
1911
1992
  {
1912
1993
  } break;
@@ -1973,6 +2054,8 @@ static bool ggml_metal_encode_node(
1973
2054
  id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
1974
2055
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
1975
2056
 
2057
+ int n_fuse = 1;
2058
+
1976
2059
  #if 0
1977
2060
  GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1978
2061
  if (src0) {
@@ -2044,37 +2127,15 @@ static bool ggml_metal_encode_node(
2044
2127
  GGML_ASSERT(src0t == GGML_TYPE_F32);
2045
2128
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2046
2129
 
2130
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
2131
+ GGML_ASSERT(ggml_is_contiguous_rows(src1));
2132
+
2047
2133
  const size_t offs = 0;
2048
2134
 
2049
2135
  bool bcast_row = false;
2050
2136
 
2051
2137
  id<MTLComputePipelineState> pipeline = nil;
2052
2138
 
2053
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2054
- GGML_ASSERT(ggml_is_contiguous(src0));
2055
-
2056
- // src1 is a row
2057
- GGML_ASSERT(ne11 == 1);
2058
-
2059
- switch (dst->op) {
2060
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2061
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2062
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2063
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2064
- default: GGML_ABORT("fatal error");
2065
- }
2066
-
2067
- bcast_row = true;
2068
- } else {
2069
- switch (dst->op) {
2070
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2071
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2072
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2073
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2074
- default: GGML_ABORT("fatal error");
2075
- }
2076
- }
2077
-
2078
2139
  ggml_metal_kargs_bin args = {
2079
2140
  /*.ne00 =*/ ne00,
2080
2141
  /*.ne01 =*/ ne01,
@@ -2101,12 +2162,119 @@ static bool ggml_metal_encode_node(
2101
2162
  /*.nb2 =*/ nb2,
2102
2163
  /*.nb3 =*/ nb3,
2103
2164
  /*.offs =*/ offs,
2165
+ /*.o1 =*/ { offs_src1 },
2104
2166
  };
2105
2167
 
2168
+ // c[0] = add(a, b[0])
2169
+ // c[1] = add(c[0], b[1])
2170
+ // c[2] = add(c[1], b[2])
2171
+ // ...
2172
+ if (ctx_dev->use_fusion) {
2173
+ ops[0] = GGML_OP_ADD;
2174
+ ops[1] = GGML_OP_ADD;
2175
+ ops[2] = GGML_OP_ADD;
2176
+ ops[3] = GGML_OP_ADD;
2177
+ ops[4] = GGML_OP_ADD;
2178
+ ops[5] = GGML_OP_ADD;
2179
+ ops[6] = GGML_OP_ADD;
2180
+ ops[7] = GGML_OP_ADD;
2181
+
2182
+ size_t offs_fuse;
2183
+ id<MTLBuffer> id_fuse;
2184
+
2185
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2186
+ // across splits. idx_end indicates the last node in the current split
2187
+ for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
2188
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2189
+ break;
2190
+ }
2191
+
2192
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
2193
+ break;
2194
+ }
2195
+
2196
+ // b[0] === b[1] === ...
2197
+ if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
2198
+ break;
2199
+ }
2200
+
2201
+ // only fuse nodes if src1 is in the same Metal buffer
2202
+ id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
2203
+ if (id_fuse != id_src1) {
2204
+ break;
2205
+ }
2206
+
2207
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
2208
+
2209
+ args.o1[n_fuse + 1] = offs_fuse;
2210
+ }
2211
+
2212
+ ++n_fuse;
2213
+
2214
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
2215
+ GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2216
+ }
2217
+ }
2218
+
2219
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2220
+ GGML_ASSERT(ggml_is_contiguous(src0));
2221
+
2222
+ // src1 is a row
2223
+ GGML_ASSERT(ne11 == 1);
2224
+
2225
+ switch (dst->op) {
2226
+ case GGML_OP_ADD:
2227
+ {
2228
+ switch (n_fuse) {
2229
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
2230
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
2231
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
2232
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
2233
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
2234
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
2235
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
2236
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
2237
+ default: GGML_ABORT("fatal error");
2238
+ }
2239
+ } break;
2240
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
2241
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
2242
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
2243
+ default: GGML_ABORT("fatal error");
2244
+ }
2245
+
2246
+ bcast_row = true;
2247
+ } else {
2248
+ switch (dst->op) {
2249
+ case GGML_OP_ADD:
2250
+ {
2251
+ switch (n_fuse) {
2252
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
2253
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2254
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
2255
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2256
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
2257
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2258
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
2259
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2260
+ default: GGML_ABORT("fatal error");
2261
+ }
2262
+ } break;
2263
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2264
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2265
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2266
+ default: GGML_ABORT("fatal error");
2267
+ }
2268
+ }
2269
+
2270
+ if (n_fuse > 1) {
2271
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
2272
+ }
2273
+
2106
2274
  [encoder setComputePipelineState:pipeline];
2107
2275
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2108
2276
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2109
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2277
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2110
2278
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2111
2279
 
2112
2280
  if (bcast_row) {
@@ -2114,7 +2282,11 @@ static bool ggml_metal_encode_node(
2114
2282
 
2115
2283
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2116
2284
  } else {
2117
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2285
+ int nth = 32;
2286
+
2287
+ while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2288
+ nth *= 2;
2289
+ }
2118
2290
 
2119
2291
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2120
2292
  }
@@ -2239,12 +2411,13 @@ static bool ggml_metal_encode_node(
2239
2411
  /*.nb2 =*/ pnb2,
2240
2412
  /*.nb3 =*/ pnb3,
2241
2413
  /*.offs =*/ offs,
2414
+ /*.o1 =*/ { offs_src1},
2242
2415
  };
2243
2416
 
2244
2417
  [encoder setComputePipelineState:pipeline];
2245
2418
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2246
2419
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2247
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2420
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2248
2421
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2249
2422
 
2250
2423
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@@ -2439,6 +2612,78 @@ static bool ggml_metal_encode_node(
2439
2612
 
2440
2613
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2441
2614
  } break;
2615
+ case GGML_UNARY_OP_ABS:
2616
+ {
2617
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
2618
+
2619
+ [encoder setComputePipelineState:pipeline];
2620
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2621
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2622
+
2623
+ const int64_t n = ggml_nelements(dst);
2624
+
2625
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2626
+ } break;
2627
+ case GGML_UNARY_OP_SGN:
2628
+ {
2629
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
2630
+
2631
+ [encoder setComputePipelineState:pipeline];
2632
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2633
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2634
+
2635
+ const int64_t n = ggml_nelements(dst);
2636
+
2637
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2638
+ } break;
2639
+ case GGML_UNARY_OP_STEP:
2640
+ {
2641
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
2642
+
2643
+ [encoder setComputePipelineState:pipeline];
2644
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2645
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2646
+
2647
+ const int64_t n = ggml_nelements(dst);
2648
+
2649
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2650
+ } break;
2651
+ case GGML_UNARY_OP_HARDSWISH:
2652
+ {
2653
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
2654
+
2655
+ [encoder setComputePipelineState:pipeline];
2656
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2657
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2658
+
2659
+ const int64_t n = ggml_nelements(dst);
2660
+
2661
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2662
+ } break;
2663
+ case GGML_UNARY_OP_HARDSIGMOID:
2664
+ {
2665
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
2666
+
2667
+ [encoder setComputePipelineState:pipeline];
2668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2670
+
2671
+ const int64_t n = ggml_nelements(dst);
2672
+
2673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2674
+ } break;
2675
+ case GGML_UNARY_OP_EXP:
2676
+ {
2677
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
2678
+
2679
+ [encoder setComputePipelineState:pipeline];
2680
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2681
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2682
+
2683
+ const int64_t n = ggml_nelements(dst);
2684
+
2685
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2686
+ } break;
2442
2687
  default:
2443
2688
  {
2444
2689
  GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
@@ -2674,7 +2919,7 @@ static bool ggml_metal_encode_node(
2674
2919
  id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2675
2920
  if (!h_src0) {
2676
2921
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677
- return false;
2922
+ return 0;
2678
2923
  }
2679
2924
 
2680
2925
  offs_src0 = 0;
@@ -2896,6 +3141,7 @@ static bool ggml_metal_encode_node(
2896
3141
  /*.n_group =*/ n_group,
2897
3142
  /*.n_seq_tokens =*/ n_seq_tokens,
2898
3143
  /*.n_seqs =*/ n_seqs,
3144
+ /*.s_off =*/ ggml_nelements(src1) * sizeof(float),
2899
3145
  /*.nb01 =*/ nb01,
2900
3146
  /*.nb02 =*/ nb02,
2901
3147
  /*.nb03 =*/ nb03,
@@ -2924,12 +3170,22 @@ static bool ggml_metal_encode_node(
2924
3170
  [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2925
3171
  [encoder setBytes:&args length:sizeof(args) atIndex:8];
2926
3172
 
3173
+ // One shared memory bucket for each simd group in the threadgroup
3174
+ // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
3175
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3176
+ if (d_state >= 32) {
3177
+ GGML_ASSERT((int64_t)(d_state / 32) <= 32);
3178
+ const int64_t shmem_size = 32;
3179
+ GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
3180
+ [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
3181
+ }
3182
+
2927
3183
  if (ne30 == 1) {
2928
3184
  // Mamba-2
2929
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3185
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
2930
3186
  } else {
2931
3187
  GGML_ASSERT(d_inner == 1);
2932
- [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3188
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
2933
3189
  }
2934
3190
  } break;
2935
3191
  case GGML_OP_RWKV_WKV6:
@@ -3550,7 +3806,7 @@ static bool ggml_metal_encode_node(
3550
3806
  id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3551
3807
  if (!h_src1) {
3552
3808
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3553
- return false;
3809
+ return 0;
3554
3810
  }
3555
3811
 
3556
3812
  const int64_t neh0 = ne0;
@@ -3566,7 +3822,7 @@ static bool ggml_metal_encode_node(
3566
3822
  id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3567
3823
  if (!h_dst) {
3568
3824
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3569
- return false;
3825
+ return 0;
3570
3826
  }
3571
3827
 
3572
3828
  // tokens per expert
@@ -3574,7 +3830,7 @@ static bool ggml_metal_encode_node(
3574
3830
  id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3575
3831
  if (!h_tpe) {
3576
3832
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3577
- return false;
3833
+ return 0;
3578
3834
  }
3579
3835
 
3580
3836
  // id map
@@ -3583,7 +3839,7 @@ static bool ggml_metal_encode_node(
3583
3839
  id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3584
3840
  if (!h_ids) {
3585
3841
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3586
- return false;
3842
+ return 0;
3587
3843
  }
3588
3844
 
3589
3845
  {
@@ -4015,12 +4271,95 @@ static bool ggml_metal_encode_node(
4015
4271
  case GGML_OP_RMS_NORM:
4016
4272
  {
4017
4273
  GGML_ASSERT(ne00 % 4 == 0);
4018
- GGML_ASSERT(ggml_is_contiguous_1(src0));
4274
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
4019
4275
 
4020
4276
  float eps;
4021
4277
  memcpy(&eps, dst->op_params, sizeof(float));
4022
4278
 
4023
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
4279
+ ggml_metal_kargs_rms_norm args = {
4280
+ /*.ne00 =*/ ne00,
4281
+ /*.ne00_4 =*/ ne00/4,
4282
+ /*.nb1 =*/ nb1,
4283
+ /*.nb2 =*/ nb2,
4284
+ /*.nb3 =*/ nb3,
4285
+ /*.eps =*/ eps,
4286
+ /*.nef1 =*/ { ne01 },
4287
+ /*.nef2 =*/ { ne02 },
4288
+ /*.nef3 =*/ { ne03 },
4289
+ /*.nbf1 =*/ { nb01 },
4290
+ /*.nbf2 =*/ { nb02 },
4291
+ /*.nbf3 =*/ { nb03 },
4292
+ };
4293
+
4294
+ size_t offs_fuse[2] = { 0, 0 };
4295
+ id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
4296
+
4297
+ // d[0] = rms_norm(a)
4298
+ // d[1] = mul(d[0], b)
4299
+ // d[2] = add(d[1], c)
4300
+ if (ctx_dev->use_fusion) {
4301
+ ops[0] = GGML_OP_RMS_NORM;
4302
+ ops[1] = GGML_OP_MUL;
4303
+ ops[2] = GGML_OP_ADD;
4304
+
4305
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
4306
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4307
+ break;
4308
+ }
4309
+
4310
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
4311
+ break;
4312
+ }
4313
+
4314
+ if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
4315
+ break;
4316
+ }
4317
+
4318
+ if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
4319
+ break;
4320
+ }
4321
+
4322
+ if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
4323
+ break;
4324
+ }
4325
+
4326
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
4327
+
4328
+ id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
4329
+
4330
+ args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
4331
+ args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
4332
+ args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
4333
+
4334
+ args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
4335
+ args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
4336
+ args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
4337
+ }
4338
+
4339
+ ++n_fuse;
4340
+
4341
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
4342
+ if (n_fuse == 2) {
4343
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
4344
+ }
4345
+ if (n_fuse == 3) {
4346
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
4347
+ }
4348
+ }
4349
+ }
4350
+
4351
+ if (n_fuse > 1) {
4352
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4353
+ }
4354
+
4355
+ id<MTLComputePipelineState> pipeline;
4356
+
4357
+ switch (n_fuse) {
4358
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4359
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4360
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4361
+ default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4362
+ }
4024
4363
 
4025
4364
  int nth = 32; // SIMD width
4026
4365
 
@@ -4031,23 +4370,16 @@ static bool ggml_metal_encode_node(
4031
4370
  nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
4032
4371
  nth = MIN(nth, ne00/4);
4033
4372
 
4034
- ggml_metal_kargs_rms_norm args = {
4035
- /*.ne00 =*/ ne00,
4036
- /*.ne00_4 =*/ ne00/4,
4037
- /*.nb01 =*/ nb01,
4038
- /*.eps =*/ eps,
4039
- };
4040
-
4041
4373
  [encoder setComputePipelineState:pipeline];
4042
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
4043
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4044
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4374
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
4375
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4376
+ [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
4377
+ [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
4378
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
4045
4379
 
4046
4380
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
4047
4381
 
4048
- const int64_t nrows = ggml_nrows(src0);
4049
-
4050
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4382
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4051
4383
  } break;
4052
4384
  case GGML_OP_L2_NORM:
4053
4385
  {
@@ -5442,7 +5774,7 @@ static bool ggml_metal_encode_node(
5442
5774
  }
5443
5775
  }
5444
5776
 
5445
- return true;
5777
+ return n_fuse;
5446
5778
  }
5447
5779
 
5448
5780
  static enum ggml_status ggml_metal_graph_compute(
@@ -5948,20 +6280,26 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5948
6280
  struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5949
6281
  ggml_metal_mem_pool_reset(mem_pool);
5950
6282
 
5951
- for (int idx = node_start; idx < node_end; ++idx) {
6283
+ for (int idx = node_start; idx < node_end;) {
5952
6284
  if (should_capture) {
5953
6285
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5954
6286
  }
5955
6287
 
5956
- const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6288
+ const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
6289
+ if (idx + res > node_end) {
6290
+ GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
6291
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
6292
+ }
5957
6293
 
5958
6294
  if (should_capture) {
5959
6295
  [encoder popDebugGroup];
5960
6296
  }
5961
6297
 
5962
- if (!res) {
6298
+ if (res == 0) {
5963
6299
  break;
5964
6300
  }
6301
+
6302
+ idx += res;
5965
6303
  }
5966
6304
 
5967
6305
  [encoder endEncoding];