whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -470,6 +470,7 @@ extern "C" {
470
470
  GGML_OP_TRANSPOSE,
471
471
  GGML_OP_GET_ROWS,
472
472
  GGML_OP_GET_ROWS_BACK,
473
+ GGML_OP_SET_ROWS,
473
474
  GGML_OP_DIAG,
474
475
  GGML_OP_DIAG_MASK_INF,
475
476
  GGML_OP_DIAG_MASK_ZERO,
@@ -481,6 +482,7 @@ extern "C" {
481
482
  GGML_OP_CONV_TRANSPOSE_1D,
482
483
  GGML_OP_IM2COL,
483
484
  GGML_OP_IM2COL_BACK,
485
+ GGML_OP_CONV_2D,
484
486
  GGML_OP_CONV_2D_DW,
485
487
  GGML_OP_CONV_TRANSPOSE_2D,
486
488
  GGML_OP_POOL_1D,
@@ -489,6 +491,7 @@ extern "C" {
489
491
  GGML_OP_UPSCALE, // nearest interpolate
490
492
  GGML_OP_PAD,
491
493
  GGML_OP_PAD_REFLECT_1D,
494
+ GGML_OP_ROLL,
492
495
  GGML_OP_ARANGE,
493
496
  GGML_OP_TIMESTEP_EMBEDDING,
494
497
  GGML_OP_ARGSORT,
@@ -518,6 +521,8 @@ extern "C" {
518
521
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519
522
  GGML_OP_OPT_STEP_ADAMW,
520
523
 
524
+ GGML_OP_GLU,
525
+
521
526
  GGML_OP_COUNT,
522
527
  };
523
528
 
@@ -541,6 +546,14 @@ extern "C" {
541
546
  GGML_UNARY_OP_COUNT,
542
547
  };
543
548
 
549
+ enum ggml_glu_op {
550
+ GGML_GLU_OP_REGLU,
551
+ GGML_GLU_OP_GEGLU,
552
+ GGML_GLU_OP_SWIGLU,
553
+
554
+ GGML_GLU_OP_COUNT,
555
+ };
556
+
544
557
  enum ggml_object_type {
545
558
  GGML_OBJECT_TYPE_TENSOR,
546
559
  GGML_OBJECT_TYPE_GRAPH,
@@ -656,6 +669,7 @@ extern "C" {
656
669
  GGML_API const char * ggml_op_symbol(enum ggml_op op);
657
670
 
658
671
  GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
672
+ GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
659
673
  GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
660
674
 
661
675
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@@ -686,6 +700,9 @@ extern "C" {
686
700
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
687
701
  GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
688
702
 
703
+ // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
704
+ GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
705
+
689
706
  GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
690
707
  GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
691
708
 
@@ -757,6 +774,7 @@ extern "C" {
757
774
  GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
758
775
 
759
776
  GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
777
+ GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
760
778
 
761
779
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
762
780
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -935,6 +953,15 @@ extern "C" {
935
953
  struct ggml_tensor * a,
936
954
  struct ggml_tensor * b);
937
955
 
956
+ // repeat a to the specified shape
957
+ GGML_API struct ggml_tensor * ggml_repeat_4d(
958
+ struct ggml_context * ctx,
959
+ struct ggml_tensor * a,
960
+ int64_t ne0,
961
+ int64_t ne1,
962
+ int64_t ne2,
963
+ int64_t ne3);
964
+
938
965
  // sums repetitions in a into shape of b
939
966
  GGML_API struct ggml_tensor * ggml_repeat_back(
940
967
  struct ggml_context * ctx,
@@ -1076,6 +1103,63 @@ extern "C" {
1076
1103
  struct ggml_context * ctx,
1077
1104
  struct ggml_tensor * a);
1078
1105
 
1106
+ // gated linear unit ops
1107
+ // A: n columns, r rows,
1108
+ // result is n / 2 columns, r rows,
1109
+ // expects gate in second half of row, unless swapped is true
1110
+ GGML_API struct ggml_tensor * ggml_glu(
1111
+ struct ggml_context * ctx,
1112
+ struct ggml_tensor * a,
1113
+ enum ggml_glu_op op,
1114
+ bool swapped);
1115
+
1116
+ GGML_API struct ggml_tensor * ggml_reglu(
1117
+ struct ggml_context * ctx,
1118
+ struct ggml_tensor * a);
1119
+
1120
+ GGML_API struct ggml_tensor * ggml_reglu_swapped(
1121
+ struct ggml_context * ctx,
1122
+ struct ggml_tensor * a);
1123
+
1124
+ GGML_API struct ggml_tensor * ggml_geglu(
1125
+ struct ggml_context * ctx,
1126
+ struct ggml_tensor * a);
1127
+
1128
+ GGML_API struct ggml_tensor * ggml_geglu_swapped(
1129
+ struct ggml_context * ctx,
1130
+ struct ggml_tensor * a);
1131
+
1132
+ GGML_API struct ggml_tensor * ggml_swiglu(
1133
+ struct ggml_context * ctx,
1134
+ struct ggml_tensor * a);
1135
+
1136
+ GGML_API struct ggml_tensor * ggml_swiglu_swapped(
1137
+ struct ggml_context * ctx,
1138
+ struct ggml_tensor * a);
1139
+
1140
+ // A: n columns, r rows,
1141
+ // B: n columns, r rows,
1142
+ GGML_API struct ggml_tensor * ggml_glu_split(
1143
+ struct ggml_context * ctx,
1144
+ struct ggml_tensor * a,
1145
+ struct ggml_tensor * b,
1146
+ enum ggml_glu_op op);
1147
+
1148
+ GGML_API struct ggml_tensor * ggml_reglu_split(
1149
+ struct ggml_context * ctx,
1150
+ struct ggml_tensor * a,
1151
+ struct ggml_tensor * b);
1152
+
1153
+ GGML_API struct ggml_tensor * ggml_geglu_split(
1154
+ struct ggml_context * ctx,
1155
+ struct ggml_tensor * a,
1156
+ struct ggml_tensor * b);
1157
+
1158
+ GGML_API struct ggml_tensor * ggml_swiglu_split(
1159
+ struct ggml_context * ctx,
1160
+ struct ggml_tensor * a,
1161
+ struct ggml_tensor * b);
1162
+
1079
1163
  // normalize along rows
1080
1164
  GGML_API struct ggml_tensor * ggml_norm(
1081
1165
  struct ggml_context * ctx,
@@ -1365,6 +1449,23 @@ extern "C" {
1365
1449
  struct ggml_tensor * b, // row indices
1366
1450
  struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
1367
1451
 
1452
+ // a TD [n_embd, ne1, ne2, ne3]
1453
+ // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
1454
+ // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
1455
+ //
1456
+ // undefined behavior if destination rows overlap
1457
+ //
1458
+ // broadcast:
1459
+ // ne2 % ne11 == 0
1460
+ // ne3 % ne12 == 0
1461
+ //
1462
+ // return view(a)
1463
+ GGML_API struct ggml_tensor * ggml_set_rows(
1464
+ struct ggml_context * ctx,
1465
+ struct ggml_tensor * a, // destination
1466
+ struct ggml_tensor * b, // source
1467
+ struct ggml_tensor * c); // row indices
1468
+
1368
1469
  GGML_API struct ggml_tensor * ggml_diag(
1369
1470
  struct ggml_context * ctx,
1370
1471
  struct ggml_tensor * a);
@@ -1713,6 +1814,17 @@ extern "C" {
1713
1814
  struct ggml_tensor * b,
1714
1815
  int stride);
1715
1816
 
1817
+ GGML_API struct ggml_tensor * ggml_conv_2d_direct(
1818
+ struct ggml_context * ctx,
1819
+ struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
1820
+ struct ggml_tensor * b, // input data [W, H, C, N]
1821
+ int s0, // stride dimension 0
1822
+ int s1, // stride dimension 1
1823
+ int p0, // padding dimension 0
1824
+ int p1, // padding dimension 1
1825
+ int d0, // dilation dimension 0
1826
+ int d1); // dilation dimension 1
1827
+
1716
1828
  enum ggml_op_pool {
1717
1829
  GGML_OP_POOL_MAX,
1718
1830
  GGML_OP_POOL_AVG,
@@ -1755,6 +1867,12 @@ extern "C" {
1755
1867
  enum ggml_scale_mode {
1756
1868
  GGML_SCALE_MODE_NEAREST = 0,
1757
1869
  GGML_SCALE_MODE_BILINEAR = 1,
1870
+
1871
+ GGML_SCALE_MODE_COUNT
1872
+ };
1873
+
1874
+ enum ggml_scale_flag {
1875
+ GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
1758
1876
  };
1759
1877
 
1760
1878
  // interpolate
@@ -1767,14 +1885,26 @@ extern "C" {
1767
1885
 
1768
1886
  // interpolate
1769
1887
  // interpolate scale to specified dimensions
1770
- GGML_API struct ggml_tensor * ggml_upscale_ext(
1888
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext(
1771
1889
  struct ggml_context * ctx,
1772
1890
  struct ggml_tensor * a,
1773
1891
  int ne0,
1774
1892
  int ne1,
1775
1893
  int ne2,
1776
1894
  int ne3,
1777
- enum ggml_scale_mode mode);
1895
+ enum ggml_scale_mode mode),
1896
+ "use ggml_interpolate instead");
1897
+
1898
+ // Up- or downsamples the input to the specified size.
1899
+ // 2D scale modes (eg. bilinear) are applied to the first two dimensions.
1900
+ GGML_API struct ggml_tensor * ggml_interpolate(
1901
+ struct ggml_context * ctx,
1902
+ struct ggml_tensor * a,
1903
+ int64_t ne0,
1904
+ int64_t ne1,
1905
+ int64_t ne2,
1906
+ int64_t ne3,
1907
+ uint32_t mode); // ggml_scale_mode [ | ggml_scale_flag...]
1778
1908
 
1779
1909
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1780
1910
  GGML_API struct ggml_tensor * ggml_pad(
@@ -1792,6 +1922,17 @@ extern "C" {
1792
1922
  int p0,
1793
1923
  int p1);
1794
1924
 
1925
+ // Move tensor elements by an offset given for each dimension. Elements that
1926
+ // are shifted beyond the last position are wrapped around to the beginning.
1927
+ GGML_API struct ggml_tensor * ggml_roll(
1928
+ struct ggml_context * ctx,
1929
+ struct ggml_tensor * a,
1930
+ int shift0,
1931
+ int shift1,
1932
+ int shift2,
1933
+ int shift3);
1934
+
1935
+
1795
1936
  // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1796
1937
  // timesteps: [N,]
1797
1938
  // return: [N, dim]
@@ -2086,9 +2227,6 @@ extern "C" {
2086
2227
  GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
2087
2228
  GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
2088
2229
 
2089
- GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
2090
- GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
2091
-
2092
2230
  // print info and performance information for the graph
2093
2231
  GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
2094
2232
 
@@ -2172,6 +2310,7 @@ extern "C" {
2172
2310
 
2173
2311
  // scheduling priorities
2174
2312
  enum ggml_sched_priority {
2313
+ GGML_SCHED_PRIO_LOW = -1,
2175
2314
  GGML_SCHED_PRIO_NORMAL,
2176
2315
  GGML_SCHED_PRIO_MEDIUM,
2177
2316
  GGML_SCHED_PRIO_HIGH,
@@ -109,6 +109,8 @@ if (MSVC)
109
109
  else ()
110
110
  set(CMAKE_GENERATOR_PLATFORM_LWR "")
111
111
  endif ()
112
+ ggml_get_system_arch()
113
+ message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
112
114
 
113
115
  if (NOT MSVC)
114
116
  if (GGML_STATIC)
@@ -123,7 +125,6 @@ if (NOT MSVC)
123
125
  endif()
124
126
 
125
127
  if (MINGW)
126
- # Target Windows 8 for PrefetchVirtualMemory
127
128
  add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
128
129
  endif()
129
130
 
@@ -194,6 +195,7 @@ add_library(ggml-base
194
195
  ../include/ggml-opt.h
195
196
  ../include/gguf.h
196
197
  ggml.c
198
+ ggml.cpp
197
199
  ggml-alloc.c
198
200
  ggml-backend.cpp
199
201
  ggml-opt.cpp
@@ -210,6 +212,7 @@ endif()
210
212
 
211
213
  add_library(ggml
212
214
  ggml-backend-reg.cpp)
215
+ add_library(ggml::ggml ALIAS ggml)
213
216
 
214
217
  target_link_libraries(ggml PUBLIC ggml-base)
215
218
 
@@ -224,8 +227,8 @@ function(ggml_add_backend_library backend)
224
227
  set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
225
228
  target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
226
229
  add_dependencies(ggml ${backend})
227
- install(TARGETS ${backend} LIBRARY DESTINATION bin)
228
- else()
230
+ install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
231
+ else()
229
232
  add_library(${backend} ${ARGN})
230
233
  target_link_libraries(ggml PUBLIC ${backend})
231
234
  install(TARGETS ${backend} LIBRARY)
@@ -267,17 +270,27 @@ endfunction()
267
270
  function(ggml_add_cpu_backend_variant tag_name)
268
271
  set(GGML_CPU_TAG_NAME ${tag_name})
269
272
  # other: OPENMP LLAMAFILE CPU_HBM
270
- foreach (feat NATIVE
271
- SSE42
272
- AVX AVX2 BMI2 AVX_VNNI FMA F16C
273
- AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
274
- AMX_TILE AMX_INT8 AMX_BF16)
275
- set(GGML_${feat} OFF)
276
- endforeach()
277
-
278
- foreach (feat ${ARGN})
279
- set(GGML_${feat} ON)
280
- endforeach()
273
+ if (GGML_SYSTEM_ARCH STREQUAL "x86")
274
+ foreach (feat NATIVE
275
+ SSE42
276
+ AVX AVX2 BMI2 AVX_VNNI FMA F16C
277
+ AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
278
+ AMX_TILE AMX_INT8 AMX_BF16)
279
+ set(GGML_${feat} OFF)
280
+ endforeach()
281
+
282
+ foreach (feat ${ARGN})
283
+ set(GGML_${feat} ON)
284
+ endforeach()
285
+ elseif (GGML_SYSTEM_ARCH STREQUAL "ARM")
286
+ foreach (feat ${ARGN})
287
+ set(GGML_INTERNAL_${feat} ON)
288
+ endforeach()
289
+ elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
290
+ foreach (feat ${ARGN})
291
+ set(GGML_INTERNAL_${feat} ON)
292
+ endforeach()
293
+ endif()
281
294
 
282
295
  ggml_add_cpu_backend_variant_impl(${tag_name})
283
296
  endfunction()
@@ -287,17 +300,62 @@ ggml_add_backend(CPU)
287
300
  if (GGML_CPU_ALL_VARIANTS)
288
301
  if (NOT GGML_BACKEND_DL)
289
302
  message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
303
+ elseif (GGML_CPU_ARM_ARCH)
304
+ message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
290
305
  endif()
291
- ggml_add_cpu_backend_variant(x64)
292
- ggml_add_cpu_backend_variant(sse42 SSE42)
293
- ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
294
- ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
295
- ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
296
- ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
297
- ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
298
- if (NOT MSVC)
299
- # MSVC doesn't support AMX
300
- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
306
+ if (GGML_SYSTEM_ARCH STREQUAL "x86")
307
+ ggml_add_cpu_backend_variant(x64)
308
+ ggml_add_cpu_backend_variant(sse42 SSE42)
309
+ ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
310
+ ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
311
+ ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
312
+ ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
313
+ ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
314
+ if (NOT MSVC)
315
+ # MSVC doesn't support AMX
316
+ ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
317
+ endif()
318
+ elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
319
+ if (CMAKE_SYSTEM_NAME MATCHES "Linux")
320
+ # Many of these features are optional so we build versions with popular
321
+ # combinations and name the backends based on the version they were
322
+ # first released with
323
+ ggml_add_cpu_backend_variant(armv8.0_1)
324
+ ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD)
325
+ ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
326
+ ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE)
327
+ ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8)
328
+ ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2)
329
+ ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME)
330
+ ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME)
331
+ elseif (CMAKE_SYSTEM_NAME MATCHES "Android")
332
+ # Android-specific backends with SoC-compatible feature sets
333
+ ggml_add_cpu_backend_variant(android_armv8.0_1)
334
+ ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
335
+ ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
336
+ ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
337
+ elseif (APPLE)
338
+ ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
339
+ ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
340
+ ggml_add_cpu_backend_variant(apple_m4 DOTPROD MATMUL_INT8 NOSVE SME)
341
+ else()
342
+ message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
343
+ endif()
344
+ elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
345
+ if (CMAKE_SYSTEM_NAME MATCHES "Linux")
346
+ ggml_add_cpu_backend_variant(power0)
347
+ ggml_add_cpu_backend_variant(power7_1 POWER7)
348
+ ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
349
+ ggml_add_cpu_backend_variant(power8_1 POWER8)
350
+ ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
351
+ ggml_add_cpu_backend_variant(power9 POWER9 VSX)
352
+ ggml_add_cpu_backend_variant(power10 POWER10 VSX)
353
+ ggml_add_cpu_backend_variant(power11 POWER11 VSX)
354
+ else()
355
+ message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
356
+ endif()
357
+ else()
358
+ message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
301
359
  endif()
302
360
  elseif (GGML_CPU)
303
361
  ggml_add_cpu_backend_variant_impl("")
@@ -69,6 +69,9 @@
69
69
  #if defined(__clang__)
70
70
  # pragma clang diagnostic push
71
71
  # pragma clang diagnostic ignored "-Wdeprecated-declarations"
72
+ #elif defined(__GNUC__)
73
+ # pragma GCC diagnostic push
74
+ # pragma GCC diagnostic ignored "-Wdeprecated-declarations"
72
75
  #endif
73
76
 
74
77
  namespace fs = std::filesystem;
@@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
91
94
 
92
95
  #if defined(__clang__)
93
96
  # pragma clang diagnostic pop
97
+ #elif defined(__GNUC__)
98
+ # pragma GCC diagnostic pop
94
99
  #endif
95
100
 
96
101
  #ifdef _WIN32
@@ -817,8 +817,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
817
817
  }
818
818
  if (sched->debug > 1) {
819
819
  ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820
- GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
821
- fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
820
+ GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
821
+ fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
822
+ graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
822
823
  for (int j = 0; j < GGML_MAX_SRC; j++) {
823
824
  struct ggml_tensor * src = node->src[j];
824
825
  if (src == NULL) {
@@ -1340,7 +1341,10 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
1340
1341
  // allocate graph
1341
1342
  if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
1342
1343
  // the re-allocation may cause the split inputs to be moved to a different address
1343
- ggml_backend_sched_synchronize(sched);
1344
+ // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
1345
+ for (int i = 0; i < sched->n_backends; i++) {
1346
+ ggml_backend_synchronize(sched->backends[i]);
1347
+ }
1344
1348
  #ifndef NDEBUG
1345
1349
  GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
1346
1350
  #endif
@@ -1564,7 +1568,6 @@ bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgra
1564
1568
 
1565
1569
  ggml_backend_sched_split_graph(sched, graph);
1566
1570
 
1567
-
1568
1571
  if (!ggml_backend_sched_alloc_splits(sched)) {
1569
1572
  return false;
1570
1573
  }
@@ -1598,9 +1601,12 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
1598
1601
  for (int i = 0; i < sched->n_backends; i++) {
1599
1602
  ggml_backend_synchronize(sched->backends[i]);
1600
1603
  }
1601
- // reset the current copy to 0 so that the graphs will be similar during generation
1602
- // necessary for CUDA graphs
1603
- sched->cur_copy = 0;
1604
+ if (!sched->is_alloc) {
1605
+ // if the graph is not already allocated, always use copy 0 after a synchronization
1606
+ // this ensures that during generation the same copy is used every time,
1607
+ // which avoids changes in the graph that could cause CUDA or other graphs to be disabled
1608
+ sched->cur_copy = 0;
1609
+ }
1604
1610
  }
1605
1611
 
1606
1612
  void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
@@ -1821,7 +1827,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
1821
1827
  ggml_free(copy.ctx_unallocated);
1822
1828
  }
1823
1829
 
1824
- bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
1830
+ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
1825
1831
  struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
1826
1832
  if (copy.buffer == NULL) {
1827
1833
  return false;
@@ -1832,28 +1838,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
1832
1838
 
1833
1839
  assert(g1->n_nodes == g2->n_nodes);
1834
1840
 
1835
- for (int i = 0; i < g1->n_nodes; i++) {
1836
- struct ggml_tensor * t1 = g1->nodes[i];
1837
- struct ggml_tensor * t2 = g2->nodes[i];
1841
+ if (test_node != nullptr) {
1842
+ // Compute the whole graph and only test the output for a specific tensor
1843
+ ggml_backend_graph_compute(backend1, g1);
1844
+ ggml_backend_graph_compute(backend2, g2);
1838
1845
 
1839
- assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
1846
+ int test_node_idx = -1;
1847
+ for (int i = 0; i < g1->n_nodes; i++) {
1848
+ struct ggml_tensor * t1 = g1->nodes[i];
1849
+ if (t1 == test_node) {
1850
+ test_node_idx = i;
1851
+ break;
1852
+ }
1853
+ }
1854
+ GGML_ASSERT(test_node_idx != -1);
1840
1855
 
1841
- struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1842
- struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
1856
+ callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
1857
+ } else {
1858
+ for (int i = 0; i < g1->n_nodes; i++) {
1859
+ struct ggml_tensor * t1 = g1->nodes[i];
1860
+ struct ggml_tensor * t2 = g2->nodes[i];
1843
1861
 
1844
- ggml_backend_graph_compute(backend1, &g1v);
1845
- ggml_backend_graph_compute(backend2, &g2v);
1862
+ assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
1846
1863
 
1847
- if (ggml_is_view_op(t1->op)) {
1848
- continue;
1849
- }
1864
+ struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1865
+ struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
1850
1866
 
1851
- // compare results, calculate rms etc
1852
- if (!callback(i, t1, t2, user_data)) {
1853
- break;
1867
+ ggml_backend_graph_compute(backend1, &g1v);
1868
+ ggml_backend_graph_compute(backend2, &g2v);
1869
+
1870
+ if (ggml_is_view_op(t1->op)) {
1871
+ continue;
1872
+ }
1873
+
1874
+ // compare results, calculate rms etc
1875
+ if (!callback(i, t1, t2, user_data)) {
1876
+ break;
1877
+ }
1854
1878
  }
1855
1879
  }
1856
-
1857
1880
  ggml_backend_graph_copy_free(copy);
1858
1881
 
1859
1882
  return true;
@@ -81,7 +81,7 @@ if (BLAS_FOUND)
81
81
  target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
82
82
  target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
83
83
  else()
84
- message(ERROR "BLAS not found, please refer to "
85
- "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
86
- " to set correct GGML_BLAS_VENDOR")
84
+ message(FATAL_ERROR "BLAS not found, please refer to "
85
+ "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
86
+ " to set correct GGML_BLAS_VENDOR")
87
87
  endif()
@@ -30,6 +30,7 @@ string(TOLOWER ${SOC_TYPE} SOC_VERSION) # SOC_VERSION need lower
30
30
  string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}")
31
31
  set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
32
32
  string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
33
+ message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
33
34
 
34
35
  if (CANN_INSTALL_DIR)
35
36
  # Only Support Linux.
@@ -37,6 +37,7 @@
37
37
  #include <thread>
38
38
  #include <unistd.h>
39
39
  #include <functional>
40
+ #include <optional>
40
41
 
41
42
  #include "../include/ggml-cann.h"
42
43
  #include "../include/ggml.h"
@@ -103,6 +104,9 @@ const ggml_cann_device_info& ggml_cann_info();
103
104
  void ggml_cann_set_device(int32_t device);
104
105
  int32_t ggml_cann_get_device();
105
106
 
107
+ std::optional<std::string> get_env(const std::string& name);
108
+ bool parse_bool(const std::string& value);
109
+
106
110
  /**
107
111
  * @brief Abstract base class for memory pools used by CANN.
108
112
  */
@@ -354,7 +358,8 @@ struct ggml_backend_cann_context {
354
358
  : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
355
359
  ggml_cann_set_device(device);
356
360
  description = aclrtGetSocName();
357
- async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
361
+
362
+ async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
358
363
  GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
359
364
  device, async_mode ? "ON" : "OFF");
360
365
  }