whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -61,9 +61,6 @@
61
61
  #define m512i(p) (__m512i)(p)
62
62
  #endif
63
63
 
64
- // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
65
- float ggml_table_f32_f16[1 << 16];
66
-
67
64
  #if defined(__linux__) || \
68
65
  defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
69
66
  (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
@@ -133,7 +130,7 @@ static void ggml_print_backtrace_symbols(void) {
133
130
  }
134
131
  #endif
135
132
 
136
- static void ggml_print_backtrace(void) {
133
+ void ggml_print_backtrace(void) {
137
134
  const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
138
135
  if (GGML_NO_BACKTRACE) {
139
136
  return;
@@ -160,6 +157,10 @@ static void ggml_print_backtrace(void) {
160
157
  const int parent_pid = getpid();
161
158
  const int child_pid = fork();
162
159
  if (child_pid < 0) { // error
160
+ #if defined(__linux__)
161
+ close(lock[1]);
162
+ close(lock[0]);
163
+ #endif
163
164
  return;
164
165
  } else if (child_pid == 0) { // child
165
166
  char attach[32];
@@ -167,6 +168,7 @@ static void ggml_print_backtrace(void) {
167
168
  #if defined(__linux__)
168
169
  close(lock[1]);
169
170
  (void) !read(lock[0], lock, 1);
171
+ close(lock[0]);
170
172
  #endif
171
173
  // try gdb
172
174
  execlp("gdb", "gdb", "--batch",
@@ -195,7 +197,7 @@ static void ggml_print_backtrace(void) {
195
197
  }
196
198
  }
197
199
  #else
198
- static void ggml_print_backtrace(void) {
200
+ void ggml_print_backtrace(void) {
199
201
  // platform not supported
200
202
  }
201
203
  #endif
@@ -216,6 +218,8 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
216
218
  abort();
217
219
  }
218
220
 
221
+ // ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
222
+
219
223
  //
220
224
  // logging
221
225
  //
@@ -881,12 +885,6 @@ struct ggml_context {
881
885
  struct ggml_object * objects_end;
882
886
  };
883
887
 
884
- struct ggml_context_container {
885
- bool used;
886
-
887
- struct ggml_context context;
888
- };
889
-
890
888
  //
891
889
  // data types
892
890
  //
@@ -935,6 +933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
935
933
  "TRANSPOSE",
936
934
  "GET_ROWS",
937
935
  "GET_ROWS_BACK",
936
+ "SET_ROWS",
938
937
  "DIAG",
939
938
  "DIAG_MASK_INF",
940
939
  "DIAG_MASK_ZERO",
@@ -946,6 +945,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
946
945
  "CONV_TRANSPOSE_1D",
947
946
  "IM2COL",
948
947
  "IM2COL_BACK",
948
+ "CONV_2D",
949
949
  "CONV_2D_DW",
950
950
  "CONV_TRANSPOSE_2D",
951
951
  "POOL_1D",
@@ -954,6 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
954
954
  "UPSCALE",
955
955
  "PAD",
956
956
  "PAD_REFLECT_1D",
957
+ "ROLL",
957
958
  "ARANGE",
958
959
  "TIMESTEP_EMBEDDING",
959
960
  "ARGSORT",
@@ -982,9 +983,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
982
983
  "CROSS_ENTROPY_LOSS",
983
984
  "CROSS_ENTROPY_LOSS_BACK",
984
985
  "OPT_STEP_ADAMW",
986
+
987
+ "GLU",
985
988
  };
986
989
 
987
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
990
+ static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
988
991
 
989
992
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
990
993
  "none",
@@ -1030,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1030
1033
  "transpose(x)",
1031
1034
  "get_rows(x)",
1032
1035
  "get_rows_back(x)",
1036
+ "set_rows(x)",
1033
1037
  "diag(x)",
1034
1038
  "diag_mask_inf(x)",
1035
1039
  "diag_mask_zero(x)",
@@ -1041,6 +1045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1041
1045
  "conv_transpose_1d(x)",
1042
1046
  "im2col(x)",
1043
1047
  "im2col_back(x)",
1048
+ "conv_2d(x)",
1044
1049
  "conv_2d_dw(x)",
1045
1050
  "conv_transpose_2d(x)",
1046
1051
  "pool_1d(x)",
@@ -1049,6 +1054,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1049
1054
  "upscale(x)",
1050
1055
  "pad(x)",
1051
1056
  "pad_reflect_1d(x)",
1057
+ "roll(x)",
1052
1058
  "arange(start, stop, step)",
1053
1059
  "timestep_embedding(timesteps, dim, max_period)",
1054
1060
  "argsort(x)",
@@ -1077,9 +1083,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1077
1083
  "cross_entropy_loss(x,y)",
1078
1084
  "cross_entropy_loss_back(x,y)",
1079
1085
  "adamw(x)",
1086
+
1087
+ "glu(x)",
1080
1088
  };
1081
1089
 
1082
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1090
+ static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
1083
1091
 
1084
1092
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1085
1093
 
@@ -1105,6 +1113,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1105
1113
  static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1106
1114
 
1107
1115
 
1116
+ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1117
+ "REGLU",
1118
+ "GEGLU",
1119
+ "SWIGLU",
1120
+ };
1121
+
1122
+ static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
1123
+
1124
+
1108
1125
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1109
1126
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1110
1127
 
@@ -1207,11 +1224,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
1207
1224
  return GGML_UNARY_OP_NAME[op];
1208
1225
  }
1209
1226
 
1227
+ const char * ggml_glu_op_name(enum ggml_glu_op op) {
1228
+ return GGML_GLU_OP_NAME[op];
1229
+ }
1230
+
1210
1231
  const char * ggml_op_desc(const struct ggml_tensor * t) {
1211
1232
  if (t->op == GGML_OP_UNARY) {
1212
1233
  enum ggml_unary_op uop = ggml_get_unary_op(t);
1213
1234
  return ggml_unary_op_name(uop);
1214
1235
  }
1236
+ if (t->op == GGML_OP_GLU) {
1237
+ enum ggml_glu_op gop = ggml_get_glu_op(t);
1238
+ return ggml_glu_op_name(gop);
1239
+ }
1215
1240
  return ggml_op_name(t->op);
1216
1241
  }
1217
1242
 
@@ -1348,6 +1373,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1348
1373
  tensor->nb[2] == ggml_type_size(tensor->type);
1349
1374
  }
1350
1375
 
1376
+ bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
1377
+ return
1378
+ tensor->ne[0] == ggml_blck_size(tensor->type) ||
1379
+ tensor->nb[0] == ggml_type_size(tensor->type);
1380
+ }
1381
+
1351
1382
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1352
1383
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1353
1384
 
@@ -1419,14 +1450,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
1419
1450
  // initialize time system (required on Windows)
1420
1451
  ggml_time_init();
1421
1452
 
1422
- for (int i = 0; i < (1 << 16); ++i) {
1423
- union {
1424
- uint16_t u16;
1425
- ggml_fp16_t fp16;
1426
- } u = {i};
1427
- ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
1428
- }
1429
-
1430
1453
  is_first_call = false;
1431
1454
  }
1432
1455
 
@@ -1730,6 +1753,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
1730
1753
  return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1731
1754
  }
1732
1755
 
1756
+ enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1757
+ GGML_ASSERT(tensor->op == GGML_OP_GLU);
1758
+ return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1759
+ }
1760
+
1733
1761
  const char * ggml_get_name(const struct ggml_tensor * tensor) {
1734
1762
  return tensor->name;
1735
1763
  }
@@ -2312,6 +2340,26 @@ struct ggml_tensor * ggml_repeat(
2312
2340
  return result;
2313
2341
  }
2314
2342
 
2343
+ struct ggml_tensor * ggml_repeat_4d(
2344
+ struct ggml_context * ctx,
2345
+ struct ggml_tensor * a,
2346
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
2347
+ const bool can_repeat = ggml_is_empty(a) || (
2348
+ (ne0 % a->ne[0] == 0) &&
2349
+ (ne1 % a->ne[1] == 0) &&
2350
+ (ne2 % a->ne[2] == 0) &&
2351
+ (ne3 % a->ne[3] == 0)
2352
+ );
2353
+ GGML_ASSERT(can_repeat);
2354
+
2355
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
2356
+
2357
+ result->op = GGML_OP_REPEAT;
2358
+ result->src[0] = a;
2359
+
2360
+ return result;
2361
+ }
2362
+
2315
2363
  // ggml_repeat_back
2316
2364
 
2317
2365
  struct ggml_tensor * ggml_repeat_back(
@@ -2589,6 +2637,114 @@ struct ggml_tensor * ggml_exp_inplace(
2589
2637
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2590
2638
  }
2591
2639
 
2640
+ // ggml_glu
2641
+
2642
+ static struct ggml_tensor * ggml_glu_impl(
2643
+ struct ggml_context * ctx,
2644
+ struct ggml_tensor * a,
2645
+ struct ggml_tensor * b,
2646
+ enum ggml_glu_op op,
2647
+ bool swapped) {
2648
+ GGML_ASSERT(ggml_is_contiguous_1(a));
2649
+
2650
+ if (b) {
2651
+ GGML_ASSERT(ggml_is_contiguous_1(b));
2652
+ GGML_ASSERT(ggml_are_same_shape(a, b));
2653
+ GGML_ASSERT(a->type == b->type);
2654
+ }
2655
+
2656
+ int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2657
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2658
+
2659
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
2660
+ ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2661
+
2662
+ result->op = GGML_OP_GLU;
2663
+ result->src[0] = a;
2664
+ result->src[1] = b;
2665
+
2666
+ return result;
2667
+ }
2668
+
2669
+ struct ggml_tensor * ggml_glu(
2670
+ struct ggml_context * ctx,
2671
+ struct ggml_tensor * a,
2672
+ enum ggml_glu_op op,
2673
+ bool swapped) {
2674
+ return ggml_glu_impl(ctx, a, NULL, op, swapped);
2675
+ }
2676
+
2677
+ struct ggml_tensor * ggml_glu_split(
2678
+ struct ggml_context * ctx,
2679
+ struct ggml_tensor * a,
2680
+ struct ggml_tensor * b,
2681
+ enum ggml_glu_op op) {
2682
+ return ggml_glu_impl(ctx, a, b, op, false);
2683
+ }
2684
+
2685
+ // ggml_reglu
2686
+
2687
+ struct ggml_tensor * ggml_reglu(
2688
+ struct ggml_context * ctx,
2689
+ struct ggml_tensor * a) {
2690
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
2691
+ }
2692
+
2693
+ struct ggml_tensor * ggml_reglu_swapped(
2694
+ struct ggml_context * ctx,
2695
+ struct ggml_tensor * a) {
2696
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
2697
+ }
2698
+
2699
+ struct ggml_tensor * ggml_reglu_split(
2700
+ struct ggml_context * ctx,
2701
+ struct ggml_tensor * a,
2702
+ struct ggml_tensor * b) {
2703
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
2704
+ }
2705
+
2706
+ // ggml_geglu
2707
+
2708
+ struct ggml_tensor * ggml_geglu(
2709
+ struct ggml_context * ctx,
2710
+ struct ggml_tensor * a) {
2711
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
2712
+ }
2713
+
2714
+ struct ggml_tensor * ggml_geglu_swapped(
2715
+ struct ggml_context * ctx,
2716
+ struct ggml_tensor * a) {
2717
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
2718
+ }
2719
+
2720
+ struct ggml_tensor * ggml_geglu_split(
2721
+ struct ggml_context * ctx,
2722
+ struct ggml_tensor * a,
2723
+ struct ggml_tensor * b) {
2724
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
2725
+ }
2726
+
2727
+ // ggml_swiglu
2728
+
2729
+ struct ggml_tensor * ggml_swiglu(
2730
+ struct ggml_context * ctx,
2731
+ struct ggml_tensor * a) {
2732
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
2733
+ }
2734
+
2735
+ struct ggml_tensor * ggml_swiglu_swapped(
2736
+ struct ggml_context * ctx,
2737
+ struct ggml_tensor * a) {
2738
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
2739
+ }
2740
+
2741
+ struct ggml_tensor * ggml_swiglu_split(
2742
+ struct ggml_context * ctx,
2743
+ struct ggml_tensor * a,
2744
+ struct ggml_tensor * b) {
2745
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2746
+ }
2747
+
2592
2748
  // ggml_norm
2593
2749
 
2594
2750
  static struct ggml_tensor * ggml_norm_impl(
@@ -3372,6 +3528,35 @@ struct ggml_tensor * ggml_get_rows_back(
3372
3528
  return result;
3373
3529
  }
3374
3530
 
3531
+ // ggml_set_rows
3532
+
3533
+ struct ggml_tensor * ggml_set_rows(
3534
+ struct ggml_context * ctx,
3535
+ struct ggml_tensor * a,
3536
+ struct ggml_tensor * b,
3537
+ struct ggml_tensor * c) {
3538
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
3539
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
3540
+ GGML_ASSERT(a->ne[3] == b->ne[3]);
3541
+ GGML_ASSERT(b->ne[1] == c->ne[0]);
3542
+ GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3543
+ GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3544
+ GGML_ASSERT(c->ne[3] == 1);
3545
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
3546
+ GGML_ASSERT(c->type == GGML_TYPE_I64);
3547
+
3548
+ GGML_ASSERT(ggml_is_contiguous_rows(a));
3549
+ GGML_ASSERT(ggml_is_contiguous_rows(b));
3550
+
3551
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3552
+
3553
+ result->op = GGML_OP_SET_ROWS;
3554
+ result->src[0] = b;
3555
+ result->src[1] = c;
3556
+
3557
+ return result;
3558
+ }
3559
+
3375
3560
  // ggml_diag
3376
3561
 
3377
3562
  struct ggml_tensor * ggml_diag(
@@ -4108,6 +4293,44 @@ struct ggml_tensor * ggml_conv_2d_dw_direct(
4108
4293
  return result;
4109
4294
  }
4110
4295
 
4296
+ // ggml_conv_2d_direct
4297
+
4298
+ struct ggml_tensor * ggml_conv_2d_direct(
4299
+ struct ggml_context * ctx,
4300
+ struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4301
+ struct ggml_tensor * b, // input data [W, H, C, N]
4302
+ int s0, // stride dimension 0
4303
+ int s1, // stride dimension 1
4304
+ int p0, // padding dimension 0
4305
+ int p1, // padding dimension 1
4306
+ int d0, // dilation dimension 0
4307
+ int d1) {// dilation dimension 1
4308
+
4309
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
4310
+ //GGML_ASSERT(a->type == b->type);
4311
+
4312
+ int64_t ne[4];
4313
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4314
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4315
+ ne[2] = a->ne[3];
4316
+ ne[3] = b->ne[3];
4317
+
4318
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4319
+
4320
+ ggml_set_op_params_i32(result, 0, s0);
4321
+ ggml_set_op_params_i32(result, 1, s1);
4322
+ ggml_set_op_params_i32(result, 2, p0);
4323
+ ggml_set_op_params_i32(result, 3, p1);
4324
+ ggml_set_op_params_i32(result, 4, d0);
4325
+ ggml_set_op_params_i32(result, 5, d1);
4326
+
4327
+ result->op = GGML_OP_CONV_2D;
4328
+ result->src[0] = a;
4329
+ result->src[1] = b;
4330
+
4331
+ return result;
4332
+ }
4333
+
4111
4334
  // ggml_conv_transpose_2d_p0
4112
4335
 
4113
4336
  static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4224,24 +4447,21 @@ struct ggml_tensor * ggml_pool_2d_back(
4224
4447
  return result;
4225
4448
  }
4226
4449
 
4227
- // ggml_upscale
4450
+ // ggml_upscale / ggml_interpolate
4228
4451
 
4229
- static struct ggml_tensor * ggml_upscale_impl(
4452
+ static struct ggml_tensor * ggml_interpolate_impl(
4230
4453
  struct ggml_context * ctx,
4231
4454
  struct ggml_tensor * a,
4232
- int ne0,
4233
- int ne1,
4234
- int ne2,
4235
- int ne3,
4236
- enum ggml_scale_mode mode) {
4237
- GGML_ASSERT(a->ne[0] <= ne0);
4238
- GGML_ASSERT(a->ne[1] <= ne1);
4239
- GGML_ASSERT(a->ne[2] <= ne2);
4240
- GGML_ASSERT(a->ne[3] <= ne3);
4455
+ int64_t ne0,
4456
+ int64_t ne1,
4457
+ int64_t ne2,
4458
+ int64_t ne3,
4459
+ uint32_t mode) {
4460
+ GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
4241
4461
 
4242
4462
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4243
4463
 
4244
- ggml_set_op_params_i32(result, 0, mode);
4464
+ ggml_set_op_params_i32(result, 0, (int32_t)mode);
4245
4465
 
4246
4466
  result->op = GGML_OP_UPSCALE;
4247
4467
  result->src[0] = a;
@@ -4254,7 +4474,8 @@ struct ggml_tensor * ggml_upscale(
4254
4474
  struct ggml_tensor * a,
4255
4475
  int scale_factor,
4256
4476
  enum ggml_scale_mode mode) {
4257
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4477
+ GGML_ASSERT(scale_factor > 1);
4478
+ return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4258
4479
  }
4259
4480
 
4260
4481
  struct ggml_tensor * ggml_upscale_ext(
@@ -4265,7 +4486,18 @@ struct ggml_tensor * ggml_upscale_ext(
4265
4486
  int ne2,
4266
4487
  int ne3,
4267
4488
  enum ggml_scale_mode mode) {
4268
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4489
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4490
+ }
4491
+
4492
+ struct ggml_tensor * ggml_interpolate(
4493
+ struct ggml_context * ctx,
4494
+ struct ggml_tensor * a,
4495
+ int64_t ne0,
4496
+ int64_t ne1,
4497
+ int64_t ne2,
4498
+ int64_t ne3,
4499
+ uint32_t mode) {
4500
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4269
4501
  }
4270
4502
 
4271
4503
  // ggml_pad
@@ -4320,6 +4552,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
4320
4552
  return result;
4321
4553
  }
4322
4554
 
4555
+ // ggml_roll
4556
+
4557
+ struct ggml_tensor * ggml_roll(
4558
+ struct ggml_context * ctx,
4559
+ struct ggml_tensor * a,
4560
+ int shift0,
4561
+ int shift1,
4562
+ int shift2,
4563
+ int shift3) {
4564
+ GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
4565
+ GGML_ASSERT(abs(shift0) < a->ne[0]);
4566
+ GGML_ASSERT(abs(shift1) < a->ne[1]);
4567
+ GGML_ASSERT(abs(shift2) < a->ne[2]);
4568
+ GGML_ASSERT(abs(shift3) < a->ne[3]);
4569
+
4570
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
4571
+
4572
+ ggml_set_op_params_i32(result, 0, shift0);
4573
+ ggml_set_op_params_i32(result, 1, shift1);
4574
+ ggml_set_op_params_i32(result, 2, shift2);
4575
+ ggml_set_op_params_i32(result, 3, shift3);
4576
+
4577
+ result->op = GGML_OP_ROLL;
4578
+ result->src[0] = a;
4579
+
4580
+ return result;
4581
+ }
4582
+
4323
4583
  // ggml_arange
4324
4584
 
4325
4585
  struct ggml_tensor * ggml_arange(
@@ -5764,19 +6024,32 @@ static void ggml_compute_backward(
5764
6024
  GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5765
6025
  }
5766
6026
 
5767
- static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
6027
+ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5768
6028
  // check if already visited
5769
- if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5770
- return;
6029
+ size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
6030
+ GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
6031
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
6032
+ // This is the first time we see this node in the current graph.
6033
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
6034
+ ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
6035
+ cgraph->use_counts[node_hash_pos] = 0;
6036
+ } else {
6037
+ // already visited
6038
+ return node_hash_pos;
5771
6039
  }
5772
6040
 
5773
6041
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
5774
6042
  const int k =
5775
6043
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5776
6044
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5777
- /* unknown order, just fall back to using i*/ i;
5778
- if (node->src[k]) {
5779
- ggml_visit_parents(cgraph, node->src[k]);
6045
+ /* unknown order, just fall back to using i */ i;
6046
+
6047
+ struct ggml_tensor * src = node->src[k];
6048
+ if (src) {
6049
+ size_t src_hash_pos = ggml_visit_parents(cgraph, src);
6050
+
6051
+ // Update the use count for this operand.
6052
+ cgraph->use_counts[src_hash_pos]++;
5780
6053
  }
5781
6054
  }
5782
6055
 
@@ -5800,6 +6073,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
5800
6073
  cgraph->nodes[cgraph->n_nodes] = node;
5801
6074
  cgraph->n_nodes++;
5802
6075
  }
6076
+
6077
+ return node_hash_pos;
5803
6078
  }
5804
6079
 
5805
6080
  static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -5937,6 +6212,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
5937
6212
  incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5938
6213
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
5939
6214
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6215
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
5940
6216
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
5941
6217
  if (grads) {
5942
6218
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -5966,11 +6242,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5966
6242
 
5967
6243
  void * p = cgraph + 1;
5968
6244
 
5969
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5970
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5971
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5972
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5973
- struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6245
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6246
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6247
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6248
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6249
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6250
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5974
6251
 
5975
6252
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
5976
6253
 
@@ -5985,6 +6262,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5985
6262
  /*.grads =*/ grads_ptr,
5986
6263
  /*.grad_accs =*/ grad_accs_ptr,
5987
6264
  /*.leafs =*/ leafs_ptr,
6265
+ /*.use_counts =*/ use_counts_ptr,
5988
6266
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
5989
6267
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
5990
6268
  };
@@ -6011,7 +6289,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
6011
6289
  /*.grads =*/ NULL, // gradients would need visited_hash_set
6012
6290
  /*.grad_accs =*/ NULL,
6013
6291
  /*.leafs =*/ NULL,
6014
- /*.visited_hash_set =*/ { 0, NULL, NULL },
6292
+ /*.use_counts =*/ cgraph0->use_counts,
6293
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
6015
6294
  /*.order =*/ cgraph0->order,
6016
6295
  };
6017
6296
 
@@ -6038,7 +6317,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6038
6317
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6039
6318
  // copy all hashset keys (tensors) that are in use
6040
6319
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6041
- ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6320
+ size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6321
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
6042
6322
  }
6043
6323
  }
6044
6324
 
@@ -0,0 +1,26 @@
1
+ #include "ggml-impl.h"
2
+
3
+ #include <cstdlib>
4
+ #include <exception>
5
+
6
+ static std::terminate_handler previous_terminate_handler;
7
+
8
+ GGML_NORETURN static void ggml_uncaught_exception() {
9
+ ggml_print_backtrace();
10
+ if (previous_terminate_handler) {
11
+ previous_terminate_handler();
12
+ }
13
+ abort(); // unreachable unless previous_terminate_handler was nullptr
14
+ }
15
+
16
+ static bool ggml_uncaught_exception_init = []{
17
+ const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
18
+ if (GGML_NO_BACKTRACE) {
19
+ return false;
20
+ }
21
+ const auto prev{std::get_terminate()};
22
+ GGML_ASSERT(prev != ggml_uncaught_exception);
23
+ previous_terminate_handler = prev;
24
+ std::set_terminate(ggml_uncaught_exception);
25
+ return true;
26
+ }();