whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
48
48
  int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
+ NSLock * mtl_lock;
52
+
51
53
  bool has_simdgroup_reduction;
52
54
  bool has_simdgroup_mm;
53
55
  bool has_residency_sets;
54
56
  bool has_bfloat;
55
57
  bool use_bfloat;
56
58
 
59
+ size_t max_size;
60
+
57
61
  char name[128];
58
62
  } g_ggml_ctx_dev_main = {
59
63
  /*.mtl_device =*/ nil,
60
64
  /*.mtl_device_ref_count =*/ 0,
61
65
  /*.mtl_library =*/ nil,
66
+ /*.mtl_lock =*/ nil,
62
67
  /*.has_simdgroup_reduction =*/ false,
63
68
  /*.has_simdgroup_mm =*/ false,
64
69
  /*.has_residency_sets =*/ false,
65
70
  /*.has_bfloat =*/ false,
66
71
  /*.use_bfloat =*/ false,
72
+ /*.max_size =*/ 0,
67
73
  /*.name =*/ "",
68
74
  };
69
75
 
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
71
77
  static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
72
78
  assert(ctx != NULL);
73
79
 
80
+ if (ctx->mtl_lock == nil) {
81
+ ctx->mtl_lock = [[NSLock alloc] init];
82
+ }
83
+
74
84
  if (ctx->mtl_device == nil) {
75
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
86
  }
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
94
104
  ctx->use_bfloat = false;
95
105
  #endif
96
106
 
107
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
108
+
97
109
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98
110
  }
99
111
 
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
110
122
  ctx->mtl_device_ref_count--;
111
123
 
112
124
  if (ctx->mtl_device_ref_count == 0) {
125
+ if (ctx->mtl_lock) {
126
+ [ctx->mtl_lock release];
127
+ ctx->mtl_lock = nil;
128
+ }
129
+
113
130
  if (ctx->mtl_library) {
114
131
  [ctx->mtl_library release];
115
132
  ctx->mtl_library = nil;
@@ -185,6 +202,15 @@ enum ggml_metal_kernel_type {
185
202
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
186
203
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
187
204
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
188
214
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
189
215
  GGML_METAL_KERNEL_TYPE_L2_NORM,
190
216
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -194,11 +220,14 @@ enum ggml_metal_kernel_type {
194
220
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
196
222
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
223
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
197
224
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
225
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
198
226
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
199
227
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
200
228
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
201
229
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
230
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
202
231
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
203
232
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
204
233
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -497,7 +526,11 @@ enum ggml_metal_kernel_type {
497
526
  GGML_METAL_KERNEL_TYPE_SIN,
498
527
  GGML_METAL_KERNEL_TYPE_COS,
499
528
  GGML_METAL_KERNEL_TYPE_NEG,
529
+ GGML_METAL_KERNEL_TYPE_REGLU,
530
+ GGML_METAL_KERNEL_TYPE_GEGLU,
531
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
500
532
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
533
+ GGML_METAL_KERNEL_TYPE_MEAN,
501
534
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
535
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
536
  GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -976,7 +1009,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
976
1009
  struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
977
1010
  struct ggml_backend_metal_device_context * ctx_dev = dev->context;
978
1011
 
979
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
1012
+ id<MTLDevice> device = ctx_dev->mtl_device;
980
1013
 
981
1014
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
982
1015
 
@@ -990,9 +1023,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
990
1023
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
991
1024
 
992
1025
  // load library
993
- if (ctx_dev->mtl_library == nil) {
994
- ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1026
+ {
1027
+ [ctx_dev->mtl_lock lock];
1028
+
1029
+ if (ctx_dev->mtl_library == nil) {
1030
+ ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1031
+ }
1032
+
1033
+ [ctx_dev->mtl_lock unlock];
995
1034
  }
1035
+
996
1036
  id<MTLLibrary> metal_library = ctx_dev->mtl_library;
997
1037
  if (metal_library == nil) {
998
1038
  GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1141,6 +1181,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1141
1181
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1142
1182
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1143
1183
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1190
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1191
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1192
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1144
1193
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1145
1194
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1146
1195
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1150,11 +1199,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1150
1199
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1151
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1152
1201
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1202
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1153
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1204
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1154
1205
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1155
1206
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1156
1207
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1157
1208
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1209
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1158
1210
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1159
1211
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1160
1212
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -1453,7 +1505,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1453
1505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1454
1506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
1507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1509
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1456
1511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1457
1513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
1514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
1515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1603,6 +1659,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1603
1659
  const bool use_bfloat = ctx_dev->use_bfloat;
1604
1660
 
1605
1661
  if (!use_bfloat) {
1662
+ if (op->type == GGML_TYPE_BF16) {
1663
+ return false;
1664
+ }
1665
+
1606
1666
  for (size_t i = 0, n = 3; i < n; ++i) {
1607
1667
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1608
1668
  return false;
@@ -1626,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1626
1686
  default:
1627
1687
  return false;
1628
1688
  }
1689
+ case GGML_OP_GLU:
1690
+ switch (ggml_get_glu_op(op)) {
1691
+ case GGML_GLU_OP_REGLU:
1692
+ case GGML_GLU_OP_GEGLU:
1693
+ case GGML_GLU_OP_SWIGLU:
1694
+ return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1695
+ default:
1696
+ return false;
1697
+ }
1629
1698
  case GGML_OP_NONE:
1630
1699
  case GGML_OP_RESHAPE:
1631
1700
  case GGML_OP_VIEW:
@@ -1653,6 +1722,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1653
1722
  case GGML_OP_LOG:
1654
1723
  return false; // TODO: implement
1655
1724
  case GGML_OP_SUM_ROWS:
1725
+ case GGML_OP_MEAN:
1656
1726
  case GGML_OP_SOFT_MAX:
1657
1727
  case GGML_OP_GROUP_NORM:
1658
1728
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -1771,6 +1841,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1771
1841
  {
1772
1842
  return op->ne[3] == 1;
1773
1843
  }
1844
+ case GGML_OP_SET_ROWS:
1845
+ {
1846
+ if (op->src[0]->type != GGML_TYPE_F32) {
1847
+ return false;
1848
+ }
1849
+
1850
+ switch (op->type) {
1851
+ case GGML_TYPE_F32:
1852
+ case GGML_TYPE_F16:
1853
+ case GGML_TYPE_BF16:
1854
+ case GGML_TYPE_Q8_0:
1855
+ case GGML_TYPE_Q4_0:
1856
+ case GGML_TYPE_Q4_1:
1857
+ case GGML_TYPE_Q5_0:
1858
+ case GGML_TYPE_Q5_1:
1859
+ case GGML_TYPE_IQ4_NL:
1860
+ return true;
1861
+ default:
1862
+ return false;
1863
+ };
1864
+ }
1774
1865
  default:
1775
1866
  return false;
1776
1867
  }
@@ -2343,6 +2434,62 @@ static bool ggml_metal_encode_node(
2343
2434
  GGML_ABORT("fatal error");
2344
2435
  }
2345
2436
  } break;
2437
+ case GGML_OP_GLU:
2438
+ {
2439
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2440
+
2441
+ if (src1) {
2442
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
2443
+ }
2444
+
2445
+ id<MTLComputePipelineState> pipeline = nil;
2446
+
2447
+ switch (ggml_get_glu_op(node)) {
2448
+ case GGML_GLU_OP_REGLU:
2449
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2450
+ break;
2451
+ case GGML_GLU_OP_GEGLU:
2452
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2453
+ break;
2454
+ case GGML_GLU_OP_SWIGLU:
2455
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2456
+ break;
2457
+ default:
2458
+ GGML_ABORT("fatal error");
2459
+ }
2460
+
2461
+ const int32_t swp = ((const int32_t *) dst->op_params)[1];
2462
+
2463
+ const int32_t i00 = swp ? ne0 : 0;
2464
+ const int32_t i10 = swp ? 0 : ne0;
2465
+
2466
+ ggml_metal_kargs_glu args = {
2467
+ /*.ne00 =*/ ne00,
2468
+ /*.nb01 =*/ nb01,
2469
+ /*.ne10 =*/ src1 ? ne10 : ne00,
2470
+ /*.nb11 =*/ src1 ? nb11 : nb01,
2471
+ /*.ne0 =*/ ne0,
2472
+ /*.nb1 =*/ nb1,
2473
+ /*.i00 =*/ src1 ? 0 : i00,
2474
+ /*.i10 =*/ src1 ? 0 : i10,
2475
+ };
2476
+
2477
+ [encoder setComputePipelineState:pipeline];
2478
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2479
+ if (src1) {
2480
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2481
+ } else {
2482
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2483
+ }
2484
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2485
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2486
+
2487
+ const int64_t nrows = ggml_nrows(src0);
2488
+
2489
+ const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2490
+
2491
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2492
+ } break;
2346
2493
  case GGML_OP_SQR:
2347
2494
  {
2348
2495
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -2400,11 +2547,31 @@ static bool ggml_metal_encode_node(
2400
2547
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2401
2548
  } break;
2402
2549
  case GGML_OP_SUM_ROWS:
2550
+ case GGML_OP_MEAN:
2403
2551
  {
2404
2552
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2405
2553
 
2406
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2554
+ id<MTLComputePipelineState> pipeline = nil;
2555
+
2556
+ switch (dst->op) {
2557
+ case GGML_OP_SUM_ROWS:
2558
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2559
+ break;
2560
+ case GGML_OP_MEAN:
2561
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2562
+ break;
2563
+ default:
2564
+ GGML_ABORT("fatal error");
2565
+ }
2407
2566
 
2567
+ int nth = 32; // SIMD width
2568
+
2569
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2570
+ nth *= 2;
2571
+ }
2572
+
2573
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
2574
+ nth = MIN(nth, ne00);
2408
2575
 
2409
2576
  ggml_metal_kargs_sum_rows args = {
2410
2577
  /*.ne00 =*/ ne00,
@@ -2434,11 +2601,12 @@ static bool ggml_metal_encode_node(
2434
2601
  };
2435
2602
 
2436
2603
  [encoder setComputePipelineState:pipeline];
2437
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2438
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2439
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
2604
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2605
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2606
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2607
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2440
2608
 
2441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2609
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2442
2610
  } break;
2443
2611
  case GGML_OP_SOFT_MAX:
2444
2612
  {
@@ -3063,14 +3231,23 @@ static bool ggml_metal_encode_node(
3063
3231
  nsg = 1;
3064
3232
  nr0 = 1;
3065
3233
  nr1 = 4;
3066
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3234
+ if (ne00 == 4) {
3235
+ nr0 = 32;
3236
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3237
+ } else {
3238
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3239
+ }
3067
3240
  } break;
3068
3241
  case GGML_TYPE_F16:
3069
3242
  {
3070
3243
  nsg = 1;
3071
3244
  nr0 = 1;
3072
3245
  if (src1t == GGML_TYPE_F32) {
3073
- if (ne11 * ne12 < 4) {
3246
+ if (ne00 == 4) {
3247
+ nr0 = 32;
3248
+ nr1 = 4;
3249
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3250
+ } else if (ne11 * ne12 < 4) {
3074
3251
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3075
3252
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3076
3253
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3089,7 +3266,11 @@ static bool ggml_metal_encode_node(
3089
3266
  nsg = 1;
3090
3267
  nr0 = 1;
3091
3268
  if (src1t == GGML_TYPE_F32) {
3092
- if (ne11 * ne12 < 4) {
3269
+ if (ne00 == 4) {
3270
+ nr0 = 32;
3271
+ nr1 = 4;
3272
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3273
+ } else if (ne11 * ne12 < 4) {
3093
3274
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3094
3275
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3095
3276
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
@@ -3710,13 +3891,74 @@ static bool ggml_metal_encode_node(
3710
3891
  };
3711
3892
 
3712
3893
  [encoder setComputePipelineState:pipeline];
3713
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3714
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3715
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3716
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3894
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3895
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3896
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3897
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3717
3898
 
3718
3899
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3719
3900
  } break;
3901
+ case GGML_OP_SET_ROWS:
3902
+ {
3903
+ id<MTLComputePipelineState> pipeline = nil;
3904
+
3905
+ switch (dst->type) {
3906
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3907
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3908
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3909
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3910
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3911
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3912
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3913
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3914
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3915
+ default: GGML_ABORT("not implemented");
3916
+ }
3917
+
3918
+ const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3919
+
3920
+ int nth = 32; // SIMD width
3921
+
3922
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3923
+ nth *= 2;
3924
+ }
3925
+
3926
+ int nrptg = 1;
3927
+ if (nth > nk0) {
3928
+ nrptg = (nth + nk0 - 1)/nk0;
3929
+ nth = nk0;
3930
+
3931
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3932
+ nrptg--;
3933
+ }
3934
+ }
3935
+
3936
+ nth = MIN(nth, nk0);
3937
+
3938
+ ggml_metal_kargs_set_rows args = {
3939
+ /*.nk0 =*/ nk0,
3940
+ /*.ne01 =*/ ne01,
3941
+ /*.nb01 =*/ nb01,
3942
+ /*.nb02 =*/ nb02,
3943
+ /*.nb03 =*/ nb03,
3944
+ /*.ne11 =*/ ne11,
3945
+ /*.ne12 =*/ ne12,
3946
+ /*.nb10 =*/ nb10,
3947
+ /*.nb11 =*/ nb11,
3948
+ /*.nb12 =*/ nb12,
3949
+ /*.nb1 =*/ nb1,
3950
+ /*.nb2 =*/ nb2,
3951
+ /*.nb3 =*/ nb3,
3952
+ };
3953
+
3954
+ [encoder setComputePipelineState:pipeline];
3955
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3956
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3957
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3958
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3959
+
3960
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3961
+ } break;
3720
3962
  case GGML_OP_RMS_NORM:
3721
3963
  {
3722
3964
  GGML_ASSERT(ne00 % 4 == 0);
@@ -3733,6 +3975,7 @@ static bool ggml_metal_encode_node(
3733
3975
  nth *= 2;
3734
3976
  }
3735
3977
 
3978
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3736
3979
  nth = MIN(nth, ne00/4);
3737
3980
 
3738
3981
  ggml_metal_kargs_rms_norm args = {
@@ -3769,6 +4012,7 @@ static bool ggml_metal_encode_node(
3769
4012
  nth *= 2;
3770
4013
  }
3771
4014
 
4015
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3772
4016
  nth = MIN(nth, ne00/4);
3773
4017
 
3774
4018
  ggml_metal_kargs_l2_norm args = {
@@ -3841,6 +4085,7 @@ static bool ggml_metal_encode_node(
3841
4085
  nth *= 2;
3842
4086
  }
3843
4087
 
4088
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3844
4089
  nth = MIN(nth, ne00/4);
3845
4090
 
3846
4091
  ggml_metal_kargs_norm args = {
@@ -4766,6 +5011,8 @@ static bool ggml_metal_encode_node(
4766
5011
  GGML_ASSERT(nqptg % 8 == 0);
4767
5012
  GGML_ASSERT(ncpsg % 32 == 0);
4768
5013
 
5014
+ const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
5015
+
4769
5016
  // 2*(2*ncpsg + nqptg)*(nsg)
4770
5017
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4771
5018
  //
@@ -4773,7 +5020,7 @@ static bool ggml_metal_encode_node(
4773
5020
  // the shared memory needed for the simdgroups to load the KV cache
4774
5021
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4775
5022
  //
4776
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
5023
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
4777
5024
 
4778
5025
  int64_t nsgmax = 2;
4779
5026
 
@@ -4810,9 +5057,9 @@ static bool ggml_metal_encode_node(
4810
5057
  // and store the soft_max values and the mask
4811
5058
  //
4812
5059
  // ne00*(nsg)
4813
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
5060
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
4814
5061
  //
4815
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
5062
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
4816
5063
 
4817
5064
  int64_t nsgmax = 2;
4818
5065
  while (true) {
@@ -4925,8 +5172,39 @@ static bool ggml_metal_encode_node(
4925
5172
  default: GGML_ABORT("not implemented");
4926
5173
  }
4927
5174
 
5175
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
5176
+
5177
+ // TODO: support
5178
+ //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
5179
+ const int32_t nk00 = ne00;
5180
+
5181
+ int nth = 32; // SIMD width
5182
+
5183
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
5184
+ nth *= 2;
5185
+ }
5186
+
5187
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
5188
+
5189
+ // when rows are small, we can batch them together in a single threadgroup
5190
+ int nrptg = 1;
5191
+
5192
+ // TODO: relax this constraint in the future
5193
+ if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
5194
+ if (nth > nk00) {
5195
+ nrptg = (nth + nk00 - 1)/nk00;
5196
+ nth = nk00;
5197
+
5198
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5199
+ nrptg--;
5200
+ }
5201
+ }
5202
+ }
5203
+
5204
+ nth = MIN(nth, nk00);
5205
+
4928
5206
  ggml_metal_kargs_cpy args = {
4929
- /*.ne00 =*/ ne00,
5207
+ /*.ne00 =*/ nk00,
4930
5208
  /*.ne01 =*/ ne01,
4931
5209
  /*.ne02 =*/ ne02,
4932
5210
  /*.ne03 =*/ ne03,
@@ -4949,11 +5227,7 @@ static bool ggml_metal_encode_node(
4949
5227
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4950
5228
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4951
5229
 
4952
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4953
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4954
-
4955
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4956
-
5230
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4957
5231
  } break;
4958
5232
  case GGML_OP_SET:
4959
5233
  {
@@ -5259,7 +5533,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
5259
5533
  }
5260
5534
 
5261
5535
  ggml_backend_metal_buffer_rset_free(ctx);
5262
- ggml_backend_metal_device_rel(buffer->buft->device->context);
5263
5536
 
5264
5537
  if (ctx->owned) {
5265
5538
  #if TARGET_OS_OSX
@@ -5368,7 +5641,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5368
5641
  }
5369
5642
 
5370
5643
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
5371
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5644
+
5645
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5646
+
5647
+ id<MTLDevice> device = ctx_dev->mtl_device;
5372
5648
 
5373
5649
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
5374
5650
  ctx->all_size = size_aligned;
@@ -5391,14 +5667,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5391
5667
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
5392
5668
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
5393
5669
  free(ctx);
5394
- ggml_backend_metal_device_rel(ctx_dev);
5395
5670
  return NULL;
5396
5671
  }
5397
5672
 
5398
5673
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5399
5674
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5400
5675
  free(ctx);
5401
- ggml_backend_metal_device_rel(ctx_dev);
5402
5676
  return NULL;
5403
5677
  }
5404
5678
 
@@ -5409,17 +5683,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5409
5683
 
5410
5684
  static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
5411
5685
  return 32;
5686
+
5412
5687
  GGML_UNUSED(buft);
5413
5688
  }
5414
5689
 
5415
5690
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
5416
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
5417
- const size_t max_size = device.maxBufferLength;
5418
- ggml_backend_metal_device_rel(buft->device->context);
5691
+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
5419
5692
 
5420
5693
  return max_size;
5421
-
5422
- GGML_UNUSED(buft);
5423
5694
  }
5424
5695
 
5425
5696
  static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5492,7 +5763,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5492
5763
  }
5493
5764
 
5494
5765
  struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5495
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5766
+
5767
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5768
+
5769
+ id<MTLDevice> device = ctx_dev->mtl_device;
5496
5770
 
5497
5771
  // the buffer fits into the max buffer size allowed by the device
5498
5772
  if (size_aligned <= device.maxBufferLength) {
@@ -5548,7 +5822,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5548
5822
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5549
5823
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5550
5824
  free(ctx);
5551
- ggml_backend_metal_device_rel(ctx_dev);
5552
5825
  return NULL;
5553
5826
  }
5554
5827
 
@@ -5564,10 +5837,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
5564
5837
  }
5565
5838
 
5566
5839
  static void ggml_backend_metal_free(ggml_backend_t backend) {
5567
- struct ggml_backend_metal_context * ctx = backend->context;
5568
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5840
+ struct ggml_backend_metal_context * ctx = backend->context;
5569
5841
 
5570
- ggml_backend_metal_device_rel(ctx_dev);
5571
5842
  ggml_metal_free(ctx);
5572
5843
 
5573
5844
  free(backend);
@@ -5707,6 +5978,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
5707
5978
 
5708
5979
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5709
5980
 
5981
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5982
+
5710
5983
  return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
5711
5984
  }
5712
5985
 
@@ -5726,10 +5999,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
5726
5999
  }
5727
6000
 
5728
6001
  static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
5729
- // acq/rel just to populate ctx->name in case it hasn't been done yet
5730
6002
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5731
- ggml_backend_metal_device_acq(ctx_dev);
5732
- ggml_backend_metal_device_rel(ctx_dev);
5733
6003
 
5734
6004
  return ctx_dev->name;
5735
6005
  }
@@ -5737,12 +6007,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
5737
6007
  static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
5738
6008
  if (@available(macOS 10.12, iOS 16.0, *)) {
5739
6009
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5740
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
6010
+ id<MTLDevice> device = ctx_dev->mtl_device;
5741
6011
 
5742
6012
  *total = device.recommendedMaxWorkingSetSize;
5743
6013
  *free = *total - device.currentAllocatedSize;
5744
-
5745
- ggml_backend_metal_device_rel(ctx_dev);
5746
6014
  } else {
5747
6015
  *free = 1;
5748
6016
  *total = 1;
@@ -5820,7 +6088,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5820
6088
  }
5821
6089
 
5822
6090
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5823
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
6091
+
6092
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
6093
+
6094
+ id<MTLDevice> device = ctx_dev->mtl_device;
5824
6095
 
5825
6096
  // the buffer fits into the max buffer size allowed by the device
5826
6097
  if (size_aligned <= device.maxBufferLength) {
@@ -5876,7 +6147,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5876
6147
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5877
6148
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5878
6149
  free(ctx);
5879
- ggml_backend_metal_device_rel(ctx_dev);
5880
6150
  return NULL;
5881
6151
  }
5882
6152
 
@@ -5890,8 +6160,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
5890
6160
  }
5891
6161
 
5892
6162
  static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
5893
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
5894
- buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
6163
+ return
6164
+ buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
6165
+ buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
5895
6166
 
5896
6167
  GGML_UNUSED(dev);
5897
6168
  }
@@ -5976,8 +6247,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
5976
6247
  /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
5977
6248
  };
5978
6249
 
6250
+ // called upon program exit
6251
+ static void ggml_metal_cleanup(void) {
6252
+ ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
6253
+ }
6254
+
6255
+ // TODO: make thread-safe
5979
6256
  ggml_backend_reg_t ggml_backend_metal_reg(void) {
5980
- // TODO: make this thread-safe somehow?
6257
+ ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
6258
+
6259
+ // register cleanup callback
6260
+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6261
+ atexit(ggml_metal_cleanup);
6262
+
5981
6263
  {
5982
6264
  g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
5983
6265
  /* .api_version = */ GGML_BACKEND_API_VERSION,