@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -4,6 +4,7 @@
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-threading.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml.h"
8
9
 
9
10
  // FIXME: required here for quantization functions
@@ -382,58 +383,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
382
383
  }
383
384
  }
384
385
 
385
- // FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
386
- // currently, the ggml_cpu_has_* functions are entirely compile-time
387
386
  void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
388
- int64_t i = 0;
389
- #if defined(__F16C__)
390
- //if (ggml_cpu_has_f16c()) {
391
- for (; i + 7 < n; i += 8) {
392
- __m256 x_vec = _mm256_loadu_ps(x + i);
393
- __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
394
- _mm_storeu_si128((__m128i *)(y + i), y_vec);
395
- }
396
- for(; i + 3 < n; i += 4) {
397
- __m128 x_vec = _mm_loadu_ps(x + i);
398
- __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
399
- _mm_storel_epi64((__m128i *)(y + i), y_vec);
400
- }
401
- //}
402
- #endif
403
- for (; i < n; i++) {
387
+ int i = 0;
388
+ for (; i < n; ++i) {
404
389
  y[i] = GGML_FP32_TO_FP16(x[i]);
405
390
  }
406
391
  }
407
392
 
408
393
  void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
409
- int64_t i = 0;
410
- #if defined(__AVX512F__)
411
- //if (ggml_cpu_has_avx512()) {
412
- for (; i + 16 <= n; i += 16) {
413
- _mm512_storeu_ps(y + i,
414
- _mm512_castsi512_ps(
415
- _mm512_slli_epi32(
416
- _mm512_cvtepu16_epi32(
417
- _mm256_loadu_si256(
418
- (const __m256i *)(x + i))),
419
- 16)));
420
- }
421
- //}
422
- #endif
423
- #if defined(__AVX2__)
424
- //if (ggml_cpu_has_avx2()) {
425
- for (; i + 8 <= n; i += 8) {
426
- _mm256_storeu_ps(y + i,
427
- _mm256_castsi256_ps(
428
- _mm256_slli_epi32(
429
- _mm256_cvtepu16_epi32(
430
- _mm_loadu_si128(
431
- (const __m128i *)(x + i))),
432
- 16)));
433
- }
434
- //}
435
- #endif
436
- for (; i < n; i++) {
394
+ int i = 0;
395
+ for (; i < n; ++i) {
437
396
  y[i] = GGML_BF16_TO_FP32(x[i]);
438
397
  }
439
398
  }
@@ -956,6 +915,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
956
915
  "CONV_TRANSPOSE_1D",
957
916
  "IM2COL",
958
917
  "IM2COL_BACK",
918
+ "CONV_2D_DW",
959
919
  "CONV_TRANSPOSE_2D",
960
920
  "POOL_1D",
961
921
  "POOL_2D",
@@ -982,23 +942,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
982
942
 
983
943
  "UNARY",
984
944
 
985
- "MAP_UNARY",
986
- "MAP_BINARY",
987
-
988
- "MAP_CUSTOM1_F32",
989
- "MAP_CUSTOM2_F32",
990
- "MAP_CUSTOM3_F32",
991
-
992
945
  "MAP_CUSTOM1",
993
946
  "MAP_CUSTOM2",
994
947
  "MAP_CUSTOM3",
995
948
 
949
+ "CUSTOM",
950
+
996
951
  "CROSS_ENTROPY_LOSS",
997
952
  "CROSS_ENTROPY_LOSS_BACK",
998
953
  "OPT_STEP_ADAMW",
999
954
  };
1000
955
 
1001
- static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
956
+ static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1002
957
 
1003
958
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1004
959
  "none",
@@ -1055,6 +1010,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1055
1010
  "conv_transpose_1d(x)",
1056
1011
  "im2col(x)",
1057
1012
  "im2col_back(x)",
1013
+ "conv_2d_dw(x)",
1058
1014
  "conv_transpose_2d(x)",
1059
1015
  "pool_1d(x)",
1060
1016
  "pool_2d(x)",
@@ -1081,23 +1037,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1081
1037
 
1082
1038
  "unary(x)",
1083
1039
 
1084
- "f(x)",
1085
- "f(x,y)",
1086
-
1087
- "custom_f32(x)",
1088
- "custom_f32(x,y)",
1089
- "custom_f32(x,y,z)",
1040
+ "map_custom(x)",
1041
+ "map_custom(x,y)",
1042
+ "map_custom(x,y,z)",
1090
1043
 
1091
1044
  "custom(x)",
1092
- "custom(x,y)",
1093
- "custom(x,y,z)",
1094
1045
 
1095
1046
  "cross_entropy_loss(x,y)",
1096
1047
  "cross_entropy_loss_back(x,y)",
1097
1048
  "adamw(x)",
1098
1049
  };
1099
1050
 
1100
- static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1051
+ static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1101
1052
 
1102
1053
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1103
1054
 
@@ -1159,6 +1110,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
1159
1110
  }
1160
1111
 
1161
1112
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
1113
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1114
+ if (tensor->ne[i] <= 0) {
1115
+ return 0;
1116
+ }
1117
+ }
1118
+
1162
1119
  size_t nbytes;
1163
1120
  const size_t blck_size = ggml_blck_size(tensor->type);
1164
1121
  if (blck_size == 1) {
@@ -1342,12 +1299,23 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
1342
1299
  return ggml_is_contiguous_n(tensor, 2);
1343
1300
  }
1344
1301
 
1302
+ bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
1303
+ return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
1304
+ }
1305
+
1345
1306
  bool ggml_is_permuted(const struct ggml_tensor * tensor) {
1346
1307
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1347
1308
 
1348
1309
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1349
1310
  }
1350
1311
 
1312
+ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1313
+ return
1314
+ tensor->nb[0] > tensor->nb[2] &&
1315
+ tensor->nb[1] > tensor->nb[0] &&
1316
+ tensor->nb[2] == ggml_type_size(tensor->type);
1317
+ }
1318
+
1351
1319
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1352
1320
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1353
1321
 
@@ -2764,11 +2732,11 @@ void ggml_mul_mat_set_prec(
2764
2732
  c = ggml_mul_mat_id(ctx, as, b, ids);
2765
2733
 
2766
2734
  as -> [cols, rows, n_expert]
2767
- ids -> [n_experts_used, n_tokens] (i32)
2768
2735
  b -> [cols, n_expert_used, n_tokens]
2736
+ ids -> [n_expert_used, n_tokens] (i32)
2769
2737
  c -> [rows, n_expert_used, n_tokens]
2770
2738
 
2771
- in b, n_experts_used can be broadcasted to match the n_expert_used of ids
2739
+ in b, n_expert_used can be broadcasted to match the n_expert_used of ids
2772
2740
 
2773
2741
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2774
2742
  */
@@ -4054,6 +4022,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
4054
4022
  return result;
4055
4023
  }
4056
4024
 
4025
+ // ggml_conv_2d_dw_direct
4026
+
4027
+ struct ggml_tensor * ggml_conv_2d_dw_direct(
4028
+ struct ggml_context * ctx,
4029
+ struct ggml_tensor * a,
4030
+ struct ggml_tensor * b,
4031
+ int stride0,
4032
+ int stride1,
4033
+ int pad0,
4034
+ int pad1,
4035
+ int dilation0,
4036
+ int dilation1) {
4037
+ GGML_ASSERT(a->ne[2] == 1);
4038
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
4039
+ int64_t ne[4];
4040
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4041
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4042
+ ne[2] = b->ne[2];
4043
+ ne[3] = b->ne[3];
4044
+
4045
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4046
+
4047
+ if (ggml_is_contiguous_channels(b)) {
4048
+ // Result will be permuted the same way as input (CWHN order)
4049
+ const int64_t type_size = ggml_type_size(result->type);
4050
+ GGML_ASSERT(ggml_blck_size(result->type) == 1);
4051
+ result->nb[0] = result->ne[2] * type_size;
4052
+ result->nb[1] = result->ne[0] * result->nb[0];
4053
+ result->nb[2] = type_size;
4054
+ }
4055
+
4056
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4057
+ ggml_set_op_params(result, params, sizeof(params));
4058
+
4059
+ result->op = GGML_OP_CONV_2D_DW;
4060
+ result->src[0] = a;
4061
+ result->src[1] = b;
4062
+ return result;
4063
+ }
4064
+
4057
4065
  // ggml_conv_transpose_2d_p0
4058
4066
 
4059
4067
  static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4178,7 +4186,8 @@ static struct ggml_tensor * ggml_upscale_impl(
4178
4186
  int ne0,
4179
4187
  int ne1,
4180
4188
  int ne2,
4181
- int ne3) {
4189
+ int ne3,
4190
+ enum ggml_scale_mode mode) {
4182
4191
  GGML_ASSERT(a->ne[0] <= ne0);
4183
4192
  GGML_ASSERT(a->ne[1] <= ne1);
4184
4193
  GGML_ASSERT(a->ne[2] <= ne2);
@@ -4186,6 +4195,8 @@ static struct ggml_tensor * ggml_upscale_impl(
4186
4195
 
4187
4196
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4188
4197
 
4198
+ ggml_set_op_params_i32(result, 0, mode);
4199
+
4189
4200
  result->op = GGML_OP_UPSCALE;
4190
4201
  result->src[0] = a;
4191
4202
 
@@ -4195,8 +4206,9 @@ static struct ggml_tensor * ggml_upscale_impl(
4195
4206
  struct ggml_tensor * ggml_upscale(
4196
4207
  struct ggml_context * ctx,
4197
4208
  struct ggml_tensor * a,
4198
- int scale_factor) {
4199
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4209
+ int scale_factor,
4210
+ enum ggml_scale_mode mode) {
4211
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4200
4212
  }
4201
4213
 
4202
4214
  struct ggml_tensor * ggml_upscale_ext(
@@ -4205,8 +4217,9 @@ struct ggml_tensor * ggml_upscale_ext(
4205
4217
  int ne0,
4206
4218
  int ne1,
4207
4219
  int ne2,
4208
- int ne3) {
4209
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4220
+ int ne3,
4221
+ enum ggml_scale_mode mode) {
4222
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4210
4223
  }
4211
4224
 
4212
4225
  // ggml_pad
@@ -4369,7 +4382,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
4369
4382
  }
4370
4383
 
4371
4384
  // permute(0, 2, 1, 3)
4372
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4385
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4373
4386
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4374
4387
 
4375
4388
  float params[] = { scale, max_bias, logit_softcap };
@@ -4836,179 +4849,6 @@ struct ggml_tensor * ggml_unary_inplace(
4836
4849
  return ggml_unary_impl(ctx, a, op, true);
4837
4850
  }
4838
4851
 
4839
- // ggml_map_unary
4840
-
4841
- static struct ggml_tensor * ggml_map_unary_impl_f32(
4842
- struct ggml_context * ctx,
4843
- struct ggml_tensor * a,
4844
- const ggml_unary_op_f32_t fun,
4845
- bool inplace) {
4846
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4847
-
4848
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4849
-
4850
- result->op = GGML_OP_MAP_UNARY;
4851
- result->src[0] = a;
4852
-
4853
- return result;
4854
- }
4855
-
4856
- struct ggml_tensor * ggml_map_unary_f32(
4857
- struct ggml_context * ctx,
4858
- struct ggml_tensor * a,
4859
- const ggml_unary_op_f32_t fun) {
4860
- return ggml_map_unary_impl_f32(ctx, a, fun, false);
4861
- }
4862
-
4863
- struct ggml_tensor * ggml_map_unary_inplace_f32(
4864
- struct ggml_context * ctx,
4865
- struct ggml_tensor * a,
4866
- const ggml_unary_op_f32_t fun) {
4867
- return ggml_map_unary_impl_f32(ctx, a, fun, true);
4868
- }
4869
-
4870
- // ggml_map_binary
4871
-
4872
- static struct ggml_tensor * ggml_map_binary_impl_f32(
4873
- struct ggml_context * ctx,
4874
- struct ggml_tensor * a,
4875
- struct ggml_tensor * b,
4876
- const ggml_binary_op_f32_t fun,
4877
- bool inplace) {
4878
- GGML_ASSERT(ggml_are_same_shape(a, b));
4879
-
4880
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4881
-
4882
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4883
-
4884
- result->op = GGML_OP_MAP_BINARY;
4885
- result->src[0] = a;
4886
- result->src[1] = b;
4887
-
4888
- return result;
4889
- }
4890
-
4891
- struct ggml_tensor * ggml_map_binary_f32(
4892
- struct ggml_context * ctx,
4893
- struct ggml_tensor * a,
4894
- struct ggml_tensor * b,
4895
- const ggml_binary_op_f32_t fun) {
4896
- return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4897
- }
4898
-
4899
- struct ggml_tensor * ggml_map_binary_inplace_f32(
4900
- struct ggml_context * ctx,
4901
- struct ggml_tensor * a,
4902
- struct ggml_tensor * b,
4903
- const ggml_binary_op_f32_t fun) {
4904
- return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4905
- }
4906
-
4907
- // ggml_map_custom1_f32
4908
-
4909
- static struct ggml_tensor * ggml_map_custom1_impl_f32(
4910
- struct ggml_context * ctx,
4911
- struct ggml_tensor * a,
4912
- const ggml_custom1_op_f32_t fun,
4913
- bool inplace) {
4914
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4915
-
4916
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4917
-
4918
- result->op = GGML_OP_MAP_CUSTOM1_F32;
4919
- result->src[0] = a;
4920
-
4921
- return result;
4922
- }
4923
-
4924
- struct ggml_tensor * ggml_map_custom1_f32(
4925
- struct ggml_context * ctx,
4926
- struct ggml_tensor * a,
4927
- const ggml_custom1_op_f32_t fun) {
4928
- return ggml_map_custom1_impl_f32(ctx, a, fun, false);
4929
- }
4930
-
4931
- struct ggml_tensor * ggml_map_custom1_inplace_f32(
4932
- struct ggml_context * ctx,
4933
- struct ggml_tensor * a,
4934
- const ggml_custom1_op_f32_t fun) {
4935
- return ggml_map_custom1_impl_f32(ctx, a, fun, true);
4936
- }
4937
-
4938
- // ggml_map_custom2_f32
4939
-
4940
- static struct ggml_tensor * ggml_map_custom2_impl_f32(
4941
- struct ggml_context * ctx,
4942
- struct ggml_tensor * a,
4943
- struct ggml_tensor * b,
4944
- const ggml_custom2_op_f32_t fun,
4945
- bool inplace) {
4946
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4947
-
4948
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4949
-
4950
- result->op = GGML_OP_MAP_CUSTOM2_F32;
4951
- result->src[0] = a;
4952
- result->src[1] = b;
4953
-
4954
- return result;
4955
- }
4956
-
4957
- struct ggml_tensor * ggml_map_custom2_f32(
4958
- struct ggml_context * ctx,
4959
- struct ggml_tensor * a,
4960
- struct ggml_tensor * b,
4961
- const ggml_custom2_op_f32_t fun) {
4962
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
4963
- }
4964
-
4965
- struct ggml_tensor * ggml_map_custom2_inplace_f32(
4966
- struct ggml_context * ctx,
4967
- struct ggml_tensor * a,
4968
- struct ggml_tensor * b,
4969
- const ggml_custom2_op_f32_t fun) {
4970
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
4971
- }
4972
-
4973
- // ggml_map_custom3_f32
4974
-
4975
- static struct ggml_tensor * ggml_map_custom3_impl_f32(
4976
- struct ggml_context * ctx,
4977
- struct ggml_tensor * a,
4978
- struct ggml_tensor * b,
4979
- struct ggml_tensor * c,
4980
- const ggml_custom3_op_f32_t fun,
4981
- bool inplace) {
4982
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4983
-
4984
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4985
-
4986
- result->op = GGML_OP_MAP_CUSTOM3_F32;
4987
- result->src[0] = a;
4988
- result->src[1] = b;
4989
- result->src[2] = c;
4990
-
4991
- return result;
4992
- }
4993
-
4994
- struct ggml_tensor * ggml_map_custom3_f32(
4995
- struct ggml_context * ctx,
4996
- struct ggml_tensor * a,
4997
- struct ggml_tensor * b,
4998
- struct ggml_tensor * c,
4999
- const ggml_custom3_op_f32_t fun) {
5000
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
5001
- }
5002
-
5003
- struct ggml_tensor * ggml_map_custom3_inplace_f32(
5004
- struct ggml_context * ctx,
5005
- struct ggml_tensor * a,
5006
- struct ggml_tensor * b,
5007
- struct ggml_tensor * c,
5008
- const ggml_custom3_op_f32_t fun) {
5009
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
5010
- }
5011
-
5012
4852
  // ggml_map_custom1
5013
4853
 
5014
4854
  static struct ggml_tensor * ggml_map_custom1_impl(
@@ -5027,7 +4867,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
5027
4867
  /*.n_tasks =*/ n_tasks,
5028
4868
  /*.userdata =*/ userdata
5029
4869
  };
5030
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
4870
+ ggml_set_op_params(result, &params, sizeof(params));
5031
4871
 
5032
4872
  result->op = GGML_OP_MAP_CUSTOM1;
5033
4873
  result->src[0] = a;
@@ -5072,7 +4912,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
5072
4912
  /*.n_tasks =*/ n_tasks,
5073
4913
  /*.userdata =*/ userdata
5074
4914
  };
5075
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
4915
+ ggml_set_op_params(result, &params, sizeof(params));
5076
4916
 
5077
4917
  result->op = GGML_OP_MAP_CUSTOM2;
5078
4918
  result->src[0] = a;
@@ -5121,7 +4961,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
5121
4961
  /*.n_tasks =*/ n_tasks,
5122
4962
  /*.userdata =*/ userdata
5123
4963
  };
5124
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
4964
+ ggml_set_op_params(result, &params, sizeof(params));
5125
4965
 
5126
4966
  result->op = GGML_OP_MAP_CUSTOM3;
5127
4967
  result->src[0] = a;
@@ -5153,6 +4993,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
5153
4993
  return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
5154
4994
  }
5155
4995
 
4996
+ struct ggml_tensor * ggml_custom_4d(
4997
+ struct ggml_context * ctx,
4998
+ enum ggml_type type,
4999
+ int64_t ne0,
5000
+ int64_t ne1,
5001
+ int64_t ne2,
5002
+ int64_t ne3,
5003
+ struct ggml_tensor ** args,
5004
+ int n_args,
5005
+ ggml_custom_op_t fun,
5006
+ int n_tasks,
5007
+ void * userdata) {
5008
+
5009
+ GGML_ASSERT(n_args < GGML_MAX_SRC);
5010
+
5011
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
5012
+
5013
+ struct ggml_custom_op_params params = {
5014
+ /*.fun =*/ fun,
5015
+ /*.n_tasks =*/ n_tasks,
5016
+ /*.userdata =*/ userdata
5017
+ };
5018
+ ggml_set_op_params(result, &params, sizeof(params));
5019
+
5020
+ result->op = GGML_OP_CUSTOM;
5021
+ for (int i = 0; i < n_args; i++) {
5022
+ result->src[i] = args[i];
5023
+ }
5024
+
5025
+ return result;
5026
+ }
5027
+
5028
+ struct ggml_tensor * ggml_custom_inplace(
5029
+ struct ggml_context * ctx,
5030
+ struct ggml_tensor * a,
5031
+ struct ggml_tensor ** args,
5032
+ int n_args,
5033
+ ggml_custom_op_t fun,
5034
+ int n_tasks,
5035
+ void * userdata) {
5036
+
5037
+ GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
5038
+
5039
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5040
+
5041
+ struct ggml_custom_op_params params = {
5042
+ /*.fun =*/ fun,
5043
+ /*.n_tasks =*/ n_tasks,
5044
+ /*.userdata =*/ userdata
5045
+ };
5046
+ ggml_set_op_params(result, &params, sizeof(params));
5047
+
5048
+ result->op = GGML_OP_CUSTOM;
5049
+ result->src[0] = a;
5050
+ for (int i = 0; i < n_args; i++) {
5051
+ result->src[i + 1] = args[i];
5052
+ }
5053
+
5054
+ return result;
5055
+ }
5156
5056
  // ggml_cross_entropy_loss
5157
5057
 
5158
5058
  struct ggml_tensor * ggml_cross_entropy_loss(
@@ -5599,7 +5499,7 @@ static void ggml_compute_backward(
5599
5499
  // tensor = src0 * 1 + src1 * 0
5600
5500
  if (src0_needs_grads) {
5601
5501
  // dsrc0 = dtensor * 1
5602
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
5502
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
5603
5503
  }
5604
5504
  if (src1_needs_grads) {
5605
5505
  // dsrc1 = dtensor * 0 -> noop
@@ -5880,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
5880
5780
  }
5881
5781
 
5882
5782
  void ggml_build_backward_expand(
5883
- struct ggml_context * ctx_static,
5884
- struct ggml_context * ctx_compute,
5885
- struct ggml_cgraph * cgraph,
5886
- bool accumulate) {
5783
+ struct ggml_context * ctx,
5784
+ struct ggml_cgraph * cgraph,
5785
+ struct ggml_tensor ** grad_accs) {
5887
5786
  GGML_ASSERT(cgraph->n_nodes > 0);
5888
5787
  GGML_ASSERT(cgraph->grads);
5889
5788
  GGML_ASSERT(cgraph->grad_accs);
@@ -5956,21 +5855,24 @@ void ggml_build_backward_expand(
5956
5855
  GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
5957
5856
  node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
5958
5857
 
5959
- const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5960
- GGML_ASSERT(igrad != GGML_HASHSET_FULL);
5961
- GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5962
- if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
5963
- cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
5964
- cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5965
- ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
5858
+ const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
5859
+ GGML_ASSERT(ihash != GGML_HASHSET_FULL);
5860
+ GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
5861
+ if (grad_accs && grad_accs[i]) {
5862
+ cgraph->grad_accs[ihash] = grad_accs[i];
5863
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5864
+ } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
5865
+ // loss tensors always need a gradient accumulator
5866
+ cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
5867
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5966
5868
  }
5967
- grads_needed[igrad] = true;
5869
+ grads_needed[ihash] = true;
5968
5870
  }
5969
5871
 
5970
5872
  for (int i = n_nodes_f - 1; i >= 0; --i) {
5971
5873
  // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
5972
5874
  // use allocator to automatically make inplace operations
5973
- ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
5875
+ ggml_compute_backward(ctx, cgraph, i, grads_needed);
5974
5876
  }
5975
5877
 
5976
5878
  free(grads_needed);
@@ -6116,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6116
6018
  }
6117
6019
  }
6118
6020
 
6119
- struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
6120
- struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
6021
+ struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
6022
+ struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
6121
6023
  ggml_graph_cpy(cgraph, result);
6122
6024
  return result;
6123
6025
  }
@@ -6136,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
6136
6038
  }
6137
6039
 
6138
6040
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
6041
+ if (!cgraph) {
6042
+ return;
6043
+ }
6139
6044
  GGML_ASSERT(cgraph->grads != NULL);
6140
6045
 
6141
6046
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6445,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
6445
6350
  tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
6446
6351
  }
6447
6352
 
6448
- void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
6449
- GGML_UNUSED(ctx); // TODO: remove this parameter
6353
+ void ggml_set_param(struct ggml_tensor * tensor) {
6354
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
6450
6355
  tensor->flags |= GGML_TENSOR_FLAG_PARAM;
6451
6356
  }
6452
6357