whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -196,6 +196,95 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
196
196
  ggml_cuda_op_unary<op_log>(ctx, dst);
197
197
  }
198
198
 
199
+ /* gated ops */
200
+
201
+ template <float (*op)(float), typename T>
202
+ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
203
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
204
+
205
+ if (i >= k) {
206
+ return;
207
+ }
208
+
209
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
210
+ const int64_t j0 = (i / n) * o0 + (i % n);
211
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
212
+
213
+ dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
214
+ }
215
+
216
+ template <float (*op)(float), typename T>
217
+ static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
218
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
219
+ unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
220
+ }
221
+
222
+ template <float (*op)(float)>
223
+ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
224
+ const ggml_tensor * src0 = dst->src[0];
225
+ const ggml_tensor * src1 = dst->src[1];
226
+ void * src0_d = src0->data;
227
+ void * src1_d = src1 ? src1->data : src0->data;
228
+ const int64_t src0_o = src0->nb[1];
229
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
230
+ void * dst_d = dst->data;
231
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
232
+ cudaStream_t stream = ctx.stream();
233
+
234
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
235
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
236
+ GGML_ASSERT(ggml_is_contiguous(dst));
237
+
238
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
239
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
240
+ GGML_ASSERT(src0->type == dst->type);
241
+ GGML_ASSERT(dst->ne[0] == nc);
242
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
243
+
244
+ if (src1) {
245
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
246
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
247
+ GGML_ASSERT(src1->ne[0] == nc);
248
+ GGML_ASSERT(src0->type == src1->type);
249
+ }
250
+
251
+ const int32_t swapped = ((const int32_t *) dst->op_params)[1];
252
+
253
+ if (src0->type == GGML_TYPE_F16) {
254
+ half * src0_p = (half *) src0_d;
255
+ half * src1_p = (half *) src1_d;
256
+
257
+ if (!src1) {
258
+ src0_p += swapped ? nc : 0;
259
+ src1_p += swapped ? 0 : nc;
260
+ }
261
+
262
+ unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
263
+ } else {
264
+ float * src0_p = (float *) src0_d;
265
+ float * src1_p = (float *) src1_d;
266
+
267
+ if (!src1) {
268
+ src0_p += swapped ? nc : 0;
269
+ src1_p += swapped ? 0 : nc;
270
+ }
271
+
272
+ unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
273
+ }
274
+ }
275
+
276
+ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
277
+ ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
278
+ }
279
+
280
+ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
281
+ ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
282
+ }
283
+
284
+ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
285
+ ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
286
+ }
287
+
199
288
  /* silu_back */
200
289
 
201
290
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
@@ -15,6 +15,7 @@
15
15
  #define CUDA_SQRT_BLOCK_SIZE 256
16
16
  #define CUDA_SIN_BLOCK_SIZE 256
17
17
  #define CUDA_COS_BLOCK_SIZE 256
18
+ #define CUDA_GLU_BLOCK_SIZE 256
18
19
 
19
20
  void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
20
21
 
@@ -57,3 +58,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
57
58
  void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
58
59
 
59
60
  void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
61
+
62
+ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
63
+
64
+ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
65
+
66
+ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
113
113
  add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
114
114
  endif()
115
115
 
116
+ if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
117
+ add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
118
+ endif()
119
+
116
120
  if (NOT GGML_CUDA_FA)
117
121
  add_compile_definitions(GGML_CUDA_NO_FA)
118
122
  endif()
@@ -32,6 +32,8 @@
32
32
  extern "C" {
33
33
  #endif
34
34
 
35
+ void ggml_print_backtrace(void);
36
+
35
37
  #ifndef MIN
36
38
  # define MIN(a, b) ((a) < (b) ? (a) : (b))
37
39
  #endif
@@ -299,6 +301,7 @@ struct ggml_cgraph {
299
301
  struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
300
302
  struct ggml_tensor ** grad_accs; // accumulators for node gradients
301
303
  struct ggml_tensor ** leafs; // tensors with constant data
304
+ int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
302
305
 
303
306
  struct ggml_hash_set visited_hash_set;
304
307
 
@@ -315,203 +318,81 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
315
318
  GGML_API void * ggml_aligned_malloc(size_t size);
316
319
  GGML_API void ggml_aligned_free(void * ptr, size_t size);
317
320
 
318
- // FP16 to FP32 conversion
319
-
320
- // 16-bit float
321
- // on Arm, we use __fp16
322
- // on x86, we use uint16_t
323
- //
324
- // for old CUDA compilers (<= 11), we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/10616
325
- // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
326
- //
327
- #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
328
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
329
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
330
-
331
- #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
332
-
333
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
334
- __fp16 tmp;
335
- memcpy(&tmp, &h, sizeof(ggml_fp16_t));
336
- return (float)tmp;
337
- }
338
-
339
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
340
- ggml_fp16_t res;
341
- __fp16 tmp = f;
342
- memcpy(&res, &tmp, sizeof(ggml_fp16_t));
343
- return res;
344
- }
345
-
346
- #elif defined(__F16C__)
347
-
348
- #ifdef _MSC_VER
349
- #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
350
- #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
351
- #else
352
- #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
353
- #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
354
- #endif
355
-
356
- #elif defined(__POWER9_VECTOR__)
357
-
358
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
359
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
360
- /* the inline asm below is about 12% faster than the lookup method */
361
- #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
362
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
363
-
364
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
365
- float f;
366
- double d;
367
- __asm__(
368
- "mtfprd %0,%2\n"
369
- "xscvhpdp %0,%0\n"
370
- "frsp %1,%0\n" :
371
- /* temp */ "=d"(d),
372
- /* out */ "=f"(f):
373
- /* in */ "r"(h));
374
- return f;
375
- }
376
-
377
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
378
- double d;
379
- ggml_fp16_t r;
380
- __asm__( /* xscvdphp can work on double or single precision */
381
- "xscvdphp %0,%2\n"
382
- "mffprd %1,%0\n" :
383
- /* temp */ "=d"(d),
384
- /* out */ "=r"(r):
385
- /* in */ "f"(f));
386
- return r;
387
- }
388
-
389
- #elif defined(__riscv) && defined(__riscv_zfhmin)
321
+ // FP16 <-> FP32
322
+ // ref: https://github.com/Maratyszcza/FP16
390
323
 
391
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
392
- float f;
393
- __asm__(
394
- "fmv.h.x %[f], %[h]\n\t"
395
- "fcvt.s.h %[f], %[f]"
396
- : [f] "=&f" (f)
397
- : [h] "r" (h)
398
- );
399
- return f;
400
- }
324
+ static inline float fp32_from_bits(uint32_t w) {
325
+ union {
326
+ uint32_t as_bits;
327
+ float as_value;
328
+ } fp32;
329
+ fp32.as_bits = w;
330
+ return fp32.as_value;
331
+ }
401
332
 
402
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
403
- ggml_fp16_t res;
404
- __asm__(
405
- "fcvt.h.s %[f], %[f]\n\t"
406
- "fmv.x.h %[h], %[f]"
407
- : [h] "=&r" (res)
408
- : [f] "f" (f)
409
- );
410
- return res;
411
- }
333
+ static inline uint32_t fp32_to_bits(float f) {
334
+ union {
335
+ float as_value;
336
+ uint32_t as_bits;
337
+ } fp32;
338
+ fp32.as_value = f;
339
+ return fp32.as_bits;
340
+ }
412
341
 
413
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
414
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
415
- #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
416
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
342
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
343
+ const uint32_t w = (uint32_t) h << 16;
344
+ const uint32_t sign = w & UINT32_C(0x80000000);
345
+ const uint32_t two_w = w + w;
417
346
 
347
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
348
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
349
+ const float exp_scale = 0x1.0p-112f;
418
350
  #else
351
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
352
+ #endif
353
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
419
354
 
420
- // FP16 <-> FP32
421
- // ref: https://github.com/Maratyszcza/FP16
422
-
423
- static inline float fp32_from_bits(uint32_t w) {
424
- union {
425
- uint32_t as_bits;
426
- float as_value;
427
- } fp32;
428
- fp32.as_bits = w;
429
- return fp32.as_value;
430
- }
431
-
432
- static inline uint32_t fp32_to_bits(float f) {
433
- union {
434
- float as_value;
435
- uint32_t as_bits;
436
- } fp32;
437
- fp32.as_value = f;
438
- return fp32.as_bits;
439
- }
440
-
441
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
442
- const uint32_t w = (uint32_t) h << 16;
443
- const uint32_t sign = w & UINT32_C(0x80000000);
444
- const uint32_t two_w = w + w;
445
-
446
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
447
- #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
448
- const float exp_scale = 0x1.0p-112f;
449
- #else
450
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
451
- #endif
452
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
453
-
454
- const uint32_t magic_mask = UINT32_C(126) << 23;
455
- const float magic_bias = 0.5f;
456
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
355
+ const uint32_t magic_mask = UINT32_C(126) << 23;
356
+ const float magic_bias = 0.5f;
357
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
457
358
 
458
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
459
- const uint32_t result = sign |
460
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
461
- return fp32_from_bits(result);
462
- }
463
-
464
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
465
- #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
466
- const float scale_to_inf = 0x1.0p+112f;
467
- const float scale_to_zero = 0x1.0p-110f;
468
- #else
469
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
470
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
471
- #endif
472
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
473
-
474
- const uint32_t w = fp32_to_bits(f);
475
- const uint32_t shl1_w = w + w;
476
- const uint32_t sign = w & UINT32_C(0x80000000);
477
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
478
- if (bias < UINT32_C(0x71000000)) {
479
- bias = UINT32_C(0x71000000);
480
- }
359
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
360
+ const uint32_t result = sign |
361
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
362
+ return fp32_from_bits(result);
363
+ }
481
364
 
482
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
483
- const uint32_t bits = fp32_to_bits(base);
484
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
485
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
486
- const uint32_t nonsign = exp_bits + mantissa_bits;
487
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
365
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
366
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
367
+ const float scale_to_inf = 0x1.0p+112f;
368
+ const float scale_to_zero = 0x1.0p-110f;
369
+ #else
370
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
371
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
372
+ #endif
373
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
374
+
375
+ const uint32_t w = fp32_to_bits(f);
376
+ const uint32_t shl1_w = w + w;
377
+ const uint32_t sign = w & UINT32_C(0x80000000);
378
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
379
+ if (bias < UINT32_C(0x71000000)) {
380
+ bias = UINT32_C(0x71000000);
488
381
  }
489
382
 
490
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
491
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
492
-
493
- #endif // defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
494
-
495
- // precomputed f32 table for f16 (256 KB)
496
- // defined in ggml.c, initialized in ggml_init()
497
- GGML_API float ggml_table_f32_f16[1 << 16];
498
-
499
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
500
- // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
501
- // This is also true for POWER9.
502
- #if !defined(GGML_FP16_TO_FP32)
503
- inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
504
- uint16_t s;
505
- memcpy(&s, &f, sizeof(uint16_t));
506
- return ggml_table_f32_f16[s];
383
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
384
+ const uint32_t bits = fp32_to_bits(base);
385
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
386
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
387
+ const uint32_t nonsign = exp_bits + mantissa_bits;
388
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
507
389
  }
508
390
 
509
- #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
510
- #endif
391
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
392
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
511
393
 
512
- #if !defined(GGML_FP32_TO_FP16)
394
+ #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
513
395
  #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
514
- #endif
515
396
 
516
397
  /**
517
398
  * Converts brain16 to float32.
@@ -587,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
587
468
  #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
588
469
  #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
589
470
 
471
+ // return true if the node's results are only used by N other nodes
472
+ // and can be fused into their calculations.
473
+ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
474
+ const struct ggml_tensor * node = cgraph->nodes[node_idx];
475
+
476
+ // check the use count against how many we're replacing
477
+ size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
478
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
479
+ return false;
480
+ }
481
+
482
+ // if node is a view, some other node might be using the intermediate result
483
+ // via the view source.
484
+ if (node->view_src) {
485
+ return false;
486
+ }
487
+
488
+ // If the user requested output for the node, can't fuse
489
+ if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
490
+ return false;
491
+ }
492
+
493
+ return true;
494
+ }
495
+
496
+ // Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
497
+ // and are fusable. Nodes are considered fusable according to this function if:
498
+ // - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
499
+ // - all nodes except the last are a src of the following node.
500
+ // - all nodes are the same shape.
501
+ // TODO: Consider allowing GGML_OP_NONE nodes in between
502
+ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
503
+ if (node_idx + num_ops > cgraph->n_nodes) {
504
+ return false;
505
+ }
506
+
507
+ for (int i = 0; i < num_ops; ++i) {
508
+ struct ggml_tensor * node = cgraph->nodes[node_idx + i];
509
+ if (node->op != ops[i]) {
510
+ return false;
511
+ }
512
+ if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
513
+ return false;
514
+ }
515
+ if (i > 0) {
516
+ struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
517
+ if (node->src[0] != prev && node->src[1] != prev) {
518
+ return false;
519
+ }
520
+ if (!ggml_are_same_shape(node, prev)) {
521
+ return false;
522
+ }
523
+ }
524
+ }
525
+ return true;
526
+ }
527
+
590
528
  #ifdef __cplusplus
591
529
  }
592
530
  #endif
593
531
 
594
532
  #ifdef __cplusplus
533
+ #include <initializer_list>
595
534
  #include <vector>
596
535
 
536
+ // nicer C++ syntax for ggml_can_fuse
537
+ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
538
+ return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
539
+ }
540
+
597
541
  // expose GGUF internals for test code
598
542
  GGML_API size_t gguf_type_size(enum gguf_type type);
599
543
  GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
@@ -44,21 +44,22 @@ if (GGML_METAL_EMBED_LIBRARY)
44
44
  set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
45
45
 
46
46
  add_custom_command(
47
- OUTPUT ${METALLIB_EMBED_ASM}
47
+ OUTPUT "${METALLIB_EMBED_ASM}"
48
48
  COMMAND echo "Embedding Metal library"
49
- COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP}
50
- COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED}
51
- COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
52
- COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
53
- COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
54
- COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
55
- COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
56
- COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
49
+ COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
50
+ COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
51
+ COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
52
+ COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
53
+ COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
54
+ COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
55
+ COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
56
+ COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
57
57
  DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
58
58
  COMMENT "Generate assembly for embedded Metal library"
59
+ VERBATIM
59
60
  )
60
61
 
61
- target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM})
62
+ target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
62
63
  else()
63
64
  if (GGML_METAL_SHADER_DEBUG)
64
65
  # custom command to do the following:
@@ -422,6 +422,17 @@ typedef struct {
422
422
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
423
  } ggml_metal_kargs_im2col;
424
424
 
425
+ typedef struct{
426
+ int32_t ne00;
427
+ uint64_t nb01;
428
+ int32_t ne10;
429
+ uint64_t nb11;
430
+ int32_t ne0;
431
+ uint64_t nb1;
432
+ int32_t i00;
433
+ int32_t i10;
434
+ } ggml_metal_kargs_glu;
435
+
425
436
  typedef struct {
426
437
  int64_t ne00;
427
438
  int64_t ne01;
@@ -521,6 +532,22 @@ typedef struct {
521
532
  uint64_t nb2;
522
533
  } ggml_metal_kargs_get_rows;
523
534
 
535
+ typedef struct {
536
+ int32_t nk0;
537
+ int32_t ne01;
538
+ uint64_t nb01;
539
+ uint64_t nb02;
540
+ uint64_t nb03;
541
+ int32_t ne11;
542
+ int32_t ne12;
543
+ uint64_t nb10;
544
+ uint64_t nb11;
545
+ uint64_t nb12;
546
+ uint64_t nb1;
547
+ uint64_t nb2;
548
+ uint64_t nb3;
549
+ } ggml_metal_kargs_set_rows;
550
+
524
551
  typedef struct {
525
552
  int64_t ne00;
526
553
  int64_t ne01;