@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -25,7 +25,6 @@
25
25
  #include <cstdint>
26
26
  #include <cstring>
27
27
  #include <cinttypes>
28
- #include <functional>
29
28
  #include <memory>
30
29
  #include <random>
31
30
  #include <stdio.h>
@@ -133,7 +132,7 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
133
132
  std::vector<uint8_t> buf(ggml_nbytes(t));
134
133
  ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
135
134
 
136
- ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
135
+ const auto * tt = ggml_get_type_traits(t->type);
137
136
  size_t bs = ggml_blck_size(t->type);
138
137
  std::vector<float> vq(ggml_blck_size(t->type));
139
138
  bool quantized = ggml_is_quantized(t->type);
@@ -159,7 +158,7 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
159
158
  } else if (t->type == GGML_TYPE_I8) {
160
159
  tv.push_back((float)*(int8_t *) &buf[i]);
161
160
  } else if (quantized) {
162
- tt.to_float(&buf[i], vq.data(), bs);
161
+ tt->to_float(&buf[i], vq.data(), bs);
163
162
  tv.insert(tv.end(), vq.begin(), vq.end());
164
163
  } else {
165
164
  GGML_ABORT("fatal error");
@@ -638,19 +637,20 @@ struct test_case {
638
637
 
639
638
  // determine number of runs
640
639
  int n_runs;
640
+ bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;
641
641
  if (op_flops(out) > 0) {
642
642
  // based on flops
643
643
  const uint64_t GFLOP = 1000 * 1000 * 1000;
644
644
  const uint64_t target_flops_cpu = 8ULL * GFLOP;
645
645
  const uint64_t target_flops_gpu = 100ULL * GFLOP;
646
- uint64_t target_flops = ggml_backend_is_cpu(backend) ? target_flops_cpu : target_flops_gpu;
646
+ uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
647
647
  n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
648
648
  } else {
649
649
  // based on memory size
650
650
  const size_t GB = 1ULL << 30;
651
651
  const size_t target_size_cpu = 8 * GB;
652
652
  const size_t target_size_gpu = 32 * GB;
653
- size_t target_size = ggml_backend_is_cpu(backend) ? target_size_cpu : target_size_gpu;
653
+ size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
654
654
  n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
655
655
  }
656
656
 
@@ -680,6 +680,7 @@ struct test_case {
680
680
 
681
681
  // run
682
682
  int64_t total_time_us = 0;
683
+ int64_t total_mem = 0;
683
684
  int total_runs = 0;
684
685
  do {
685
686
  int64_t start_time = ggml_time_us();
@@ -687,6 +688,7 @@ struct test_case {
687
688
  int64_t end_time = ggml_time_us();
688
689
 
689
690
  total_time_us += end_time - start_time;
691
+ total_mem += mem;
690
692
  total_runs += n_runs;
691
693
  } while (total_time_us < 1000*1000); // run for at least 1 second
692
694
 
@@ -716,7 +718,7 @@ struct test_case {
716
718
  } else {
717
719
  printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
718
720
  op_size(out) / 1024,
719
- mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
721
+ total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
720
722
  }
721
723
  printf("\n");
722
724
 
@@ -808,15 +810,14 @@ struct test_case {
808
810
 
809
811
  ggml_build_forward_expand(gf, out);
810
812
  ggml_graph_cpy(gf, gb);
811
- ggml_build_backward_expand(ctx, gf, gb, false);
813
+ ggml_build_backward_expand(ctx, ctx, gb, false);
812
814
  if (expect.size() != 1 || expect[0] != 0.0f) {
813
815
  GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
814
816
  for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
815
- GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE);
817
+ GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
816
818
  }
817
819
  }
818
820
 
819
- // TODO: refactor so that this check is only needed once
820
821
  for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
821
822
  if (!ggml_backend_supports_op(backend, t)) {
822
823
  printf("not supported [%s] ", ggml_backend_name(backend));
@@ -859,7 +860,13 @@ struct test_case {
859
860
  const char * bn = ggml_backend_name(backend);
860
861
  const int64_t ne = ggml_nelements(t);
861
862
 
862
- std::vector<float> ga = tensor_to_float(t->grad);
863
+ std::vector<float> ga;
864
+ struct ggml_tensor * grad = ggml_graph_get_grad(gb, t);
865
+ if (grad) {
866
+ ga = tensor_to_float(grad);
867
+ } else {
868
+ ga.resize(ne); // default value is 0.0f
869
+ }
863
870
 
864
871
  for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
865
872
  // check for nans
@@ -1146,6 +1153,26 @@ struct test_argmax : public test_case {
1146
1153
  return out;
1147
1154
  }
1148
1155
 
1156
+ void initialize_tensors(ggml_context * ctx) override {
1157
+ std::random_device rd;
1158
+ std::default_random_engine rng(rd());
1159
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1160
+ if (t->type == GGML_TYPE_F32) {
1161
+ // initialize with unique values to avoid ties
1162
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
1163
+ std::vector<float> data(t->ne[0]);
1164
+ for (int i = 0; i < t->ne[0]; i++) {
1165
+ data[i] = i;
1166
+ }
1167
+ std::shuffle(data.begin(), data.end(), rng);
1168
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
1169
+ }
1170
+ } else {
1171
+ init_tensor_uniform(t);
1172
+ }
1173
+ }
1174
+ }
1175
+
1149
1176
  double max_nmse_err() override {
1150
1177
  return 0.0;
1151
1178
  }
@@ -1613,8 +1640,8 @@ struct test_ssm_scan : public test_case {
1613
1640
  }
1614
1641
  };
1615
1642
 
1616
- // GGML_OP_RWKV_WKV
1617
- struct test_rwkv_wkv : public test_case {
1643
+ // GGML_OP_RWKV_WKV6
1644
+ struct test_rwkv_wkv6 : public test_case {
1618
1645
  const ggml_type type;
1619
1646
 
1620
1647
  const int64_t head_count;
@@ -1626,7 +1653,7 @@ struct test_rwkv_wkv : public test_case {
1626
1653
  return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1627
1654
  }
1628
1655
 
1629
- test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
1656
+ test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
1630
1657
  int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1631
1658
  : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1632
1659
 
@@ -1638,7 +1665,7 @@ struct test_rwkv_wkv : public test_case {
1638
1665
  ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
1639
1666
  ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1640
1667
  ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1641
- ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
1668
+ ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
1642
1669
  return out;
1643
1670
  }
1644
1671
  };
@@ -1650,11 +1677,12 @@ struct test_mul_mat : public test_case {
1650
1677
  const int64_t m;
1651
1678
  const int64_t n;
1652
1679
  const int64_t k;
1653
- const std::array<int64_t, 2> bs; // dims 3 and 4
1654
- const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1680
+ const std::array<int64_t, 2> bs; // dims 3 and 4
1681
+ const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1682
+ const std::array<int64_t, 4> per; // permutation of dimensions
1655
1683
 
1656
1684
  std::string vars() override {
1657
- return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
1685
+ return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
1658
1686
  }
1659
1687
 
1660
1688
  double max_nmse_err() override {
@@ -1669,17 +1697,44 @@ struct test_mul_mat : public test_case {
1669
1697
  test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
1670
1698
  int64_t m = 32, int64_t n = 32, int64_t k = 32,
1671
1699
  std::array<int64_t, 2> bs = {10, 10},
1672
- std::array<int64_t, 2> nr = {2, 2})
1673
- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
1700
+ std::array<int64_t, 2> nr = {2, 2},
1701
+ std::array<int64_t, 4> per = {0, 1, 2, 3})
1702
+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
1674
1703
 
1675
1704
  ggml_tensor * build_graph(ggml_context * ctx) override {
1676
1705
  // C^T = A * B^T: (k, m) * (k, n) => (m, n)
1677
- ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]);
1678
- ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1679
- ggml_set_param(ctx, a);
1680
- ggml_set_param(ctx, b);
1681
- ggml_set_name(a, "a");
1682
- ggml_set_name(b, "b");
1706
+ ggml_tensor * a;
1707
+ ggml_tensor * b;
1708
+
1709
+ const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
1710
+ if (npermuted > 0) {
1711
+ GGML_ASSERT(npermuted == 2);
1712
+ GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
1713
+ GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
1714
+
1715
+ // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
1716
+ const int64_t ne_a[4] = {k, m, bs[0], bs[1]};
1717
+ const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
1718
+
1719
+ a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
1720
+ b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
1721
+ ggml_set_param(ctx, a);
1722
+ ggml_set_param(ctx, b);
1723
+ ggml_set_name(a, "a");
1724
+ ggml_set_name(b, "b");
1725
+
1726
+ a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
1727
+ b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
1728
+ ggml_set_name(a, "a_permuted");
1729
+ ggml_set_name(b, "b_permuted");
1730
+ } else {
1731
+ a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
1732
+ b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1733
+ ggml_set_param(ctx, a);
1734
+ ggml_set_param(ctx, b);
1735
+ ggml_set_name(a, "a");
1736
+ ggml_set_name(b, "b");
1737
+ }
1683
1738
 
1684
1739
  ggml_tensor * out = ggml_mul_mat(ctx, a, b);
1685
1740
  ggml_set_name(out, "out");
@@ -2146,7 +2201,15 @@ struct test_rope : public test_case {
2146
2201
  ggml_set_name(a, "a");
2147
2202
  }
2148
2203
 
2149
- ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
2204
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
2205
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
2206
+
2207
+ ggml_tensor * pos;
2208
+ if (is_mrope || is_vision) {
2209
+ pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
2210
+ } else {
2211
+ pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
2212
+ }
2150
2213
  ggml_set_name(pos, "pos");
2151
2214
 
2152
2215
  ggml_tensor * freq = nullptr;
@@ -2155,7 +2218,20 @@ struct test_rope : public test_case {
2155
2218
  ggml_set_name(freq, "freq");
2156
2219
  }
2157
2220
 
2158
- ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2221
+ ggml_tensor * out;
2222
+ if (is_mrope) {
2223
+ if (is_vision) {
2224
+ GGML_ASSERT(n_dims/4 > 0);
2225
+ int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
2226
+ out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2227
+ } else {
2228
+ GGML_ASSERT(n_dims/3 > 0);
2229
+ int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
2230
+ out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2231
+ }
2232
+ } else {
2233
+ out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2234
+ }
2159
2235
  ggml_set_name(out, "out");
2160
2236
 
2161
2237
  return out;
@@ -2165,11 +2241,12 @@ struct test_rope : public test_case {
2165
2241
  for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2166
2242
  if (t->type == GGML_TYPE_I32) {
2167
2243
  // pos
2168
- std::vector<int> data(ne_a[2]);
2169
- for (int i = 0; i < ne_a[2]; i++) {
2244
+ const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
2245
+ std::vector<int> data(num_pos_ids);
2246
+ for (int i = 0; i < num_pos_ids; i++) {
2170
2247
  data[i] = rand() % n_ctx;
2171
2248
  }
2172
- ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
2249
+ ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
2173
2250
  } else {
2174
2251
  if (t->ne[0] == n_dims/2) {
2175
2252
  // frequency factors in the range [0.9f, 1.1f]
@@ -2469,6 +2546,35 @@ struct test_sum_rows : public test_case {
2469
2546
  }
2470
2547
  };
2471
2548
 
2549
+ // GGML_OP_MEAN
2550
+ struct test_mean : public test_case {
2551
+ const ggml_type type;
2552
+ const std::array<int64_t, 4> ne;
2553
+
2554
+ std::string vars() override {
2555
+ return VARS_TO_STR2(type, ne);
2556
+ }
2557
+
2558
+ test_mean(ggml_type type = GGML_TYPE_F32,
2559
+ std::array<int64_t, 4> ne = {10, 5, 4, 3})
2560
+ : type(type), ne(ne) {}
2561
+
2562
+ ggml_tensor * build_graph(ggml_context * ctx) override {
2563
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2564
+ ggml_set_param(ctx, a);
2565
+ ggml_set_name(a, "a");
2566
+
2567
+ ggml_tensor * out = ggml_mean(ctx, a);
2568
+ ggml_set_name(out, "out");
2569
+
2570
+ return out;
2571
+ }
2572
+
2573
+ float grad_eps() override {
2574
+ return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
2575
+ }
2576
+ };
2577
+
2472
2578
  // GGML_OP_UPSCALE
2473
2579
  struct test_upscale : public test_case {
2474
2580
  const ggml_type type;
@@ -2613,6 +2719,33 @@ struct test_pad : public test_case {
2613
2719
  }
2614
2720
  };
2615
2721
 
2722
+ // GGML_OP_PAD_REFLECT_1D
2723
+ struct test_pad_reflect_1d : public test_case {
2724
+ const ggml_type type;
2725
+ const std::array<int64_t, 4> ne_a;
2726
+ const int pad_0;
2727
+ const int pad_1;
2728
+
2729
+ std::string vars() override {
2730
+ return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
2731
+ }
2732
+
2733
+ test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
2734
+ std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
2735
+ int pad_0 = 10, int pad_1 = 9)
2736
+ : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
2737
+
2738
+ ggml_tensor * build_graph(ggml_context * ctx) override {
2739
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
2740
+ ggml_set_name(a, "a");
2741
+
2742
+ ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
2743
+ ggml_set_name(out, "out");
2744
+
2745
+ return out;
2746
+ }
2747
+ };
2748
+
2616
2749
  // GGML_OP_ARANGE
2617
2750
  struct test_arange : public test_case {
2618
2751
  const ggml_type type;
@@ -2711,6 +2844,13 @@ struct test_flash_attn_ext : public test_case {
2711
2844
  return 5e-4;
2712
2845
  }
2713
2846
 
2847
+ uint64_t op_flops(ggml_tensor * t) override {
2848
+ GGML_UNUSED(t);
2849
+ // Just counting matmul costs:
2850
+ // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
2851
+ return 2 * 2 * nh * nb * hs * kv;
2852
+ }
2853
+
2714
2854
  test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
2715
2855
  bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
2716
2856
  : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
@@ -2796,24 +2936,14 @@ struct test_cross_entropy_loss : public test_case {
2796
2936
  struct test_opt_step_adamw : public test_case {
2797
2937
  const ggml_type type;
2798
2938
  const std::array<int64_t, 4> ne;
2799
- const float alpha;
2800
- const float beta1;
2801
- const float beta2;
2802
- const float eps;
2803
- const float wd;
2804
2939
 
2805
2940
  std::string vars() override {
2806
- return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd);
2941
+ return VARS_TO_STR2(type, ne);
2807
2942
  }
2808
2943
 
2809
2944
  test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
2810
- std::array<int64_t, 4> ne = {10, 5, 4, 3},
2811
- float alpha = 1e-3f,
2812
- float beta1 = 0.9f,
2813
- float beta2 = 0.999f,
2814
- float eps = 1e-8f,
2815
- float wd = 0.0f)
2816
- : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {}
2945
+ std::array<int64_t, 4> ne = {10, 5, 4, 3})
2946
+ : type(type), ne(ne) {}
2817
2947
 
2818
2948
  ggml_tensor * build_graph(ggml_context * ctx) override {
2819
2949
  ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
@@ -2823,7 +2953,16 @@ struct test_opt_step_adamw : public test_case {
2823
2953
  ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
2824
2954
  ggml_set_name(grad, "grad");
2825
2955
 
2826
- ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd);
2956
+ ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
2957
+ ggml_set_name(grad_m, "grad_m");
2958
+
2959
+ ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
2960
+ ggml_set_name(grad_v, "grad_v");
2961
+
2962
+ ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);
2963
+ ggml_set_name(adamw_params, "adamw_params");
2964
+
2965
+ ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params);
2827
2966
  ggml_set_name(out, "out");
2828
2967
 
2829
2968
  return out;
@@ -2831,7 +2970,7 @@ struct test_opt_step_adamw : public test_case {
2831
2970
 
2832
2971
  void initialize_tensors(ggml_context * ctx) override {
2833
2972
  for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2834
- init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values.
2973
+ init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values.
2835
2974
  }
2836
2975
  }
2837
2976
 
@@ -3244,7 +3383,9 @@ static const ggml_type all_types[] = {
3244
3383
 
3245
3384
  static const ggml_type base_types[] = {
3246
3385
  GGML_TYPE_F32, GGML_TYPE_F16,
3386
+ GGML_TYPE_Q8_0, // for I8MM tests
3247
3387
  GGML_TYPE_Q4_0,
3388
+ GGML_TYPE_Q4_1, // for I8MM tests
3248
3389
  GGML_TYPE_Q4_K,
3249
3390
  GGML_TYPE_IQ2_XXS
3250
3391
  };
@@ -3308,13 +3449,49 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3308
3449
  }
3309
3450
  }
3310
3451
 
3311
- test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
3312
- test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
3313
- test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
3314
- // test cases for 1D im2col
3452
+ // im2col 1D
3315
3453
  test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
3316
3454
  test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
3317
3455
  test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
3456
+ for (int s0 : {1, 3}) {
3457
+ for (int p0 : {0, 3}) {
3458
+ for (int d0 : {1, 3}) {
3459
+ test_cases.emplace_back(new test_im2col(
3460
+ GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
3461
+ s0, 0, p0, 0, d0, 0, false));
3462
+ }
3463
+ }
3464
+ }
3465
+
3466
+ // im2col 2D
3467
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
3468
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
3469
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
3470
+ for (int s0 : {1, 3}) {
3471
+ for (int s1 : {1, 3}) {
3472
+ for (int p0 : {0, 3}) {
3473
+ for (int p1 : {0, 3}) {
3474
+ for (int d0 : {1, 3}) {
3475
+ for (int d1 : {1, 3}) {
3476
+ test_cases.emplace_back(new test_im2col(
3477
+ GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
3478
+ s0, s1, p0, p1, d0, d1, true));
3479
+ }
3480
+ }
3481
+ }
3482
+ }
3483
+ }
3484
+ }
3485
+
3486
+ // extra tests for im2col 2D
3487
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
3488
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
3489
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
3490
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
3491
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
3492
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
3493
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
3494
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
3318
3495
 
3319
3496
  // sycl backend will limit task global_range < MAX_INT
3320
3497
  // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
@@ -3332,9 +3509,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3332
3509
  test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
3333
3510
  test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
3334
3511
 
3335
- test_cases.emplace_back(new test_argmax());
3336
3512
  test_cases.emplace_back(new test_count_equal());
3337
3513
 
3514
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3515
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
3516
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3517
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
3518
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
3519
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));
3520
+
3338
3521
  for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
3339
3522
  test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
3340
3523
  test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
@@ -3360,10 +3543,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3360
3543
  test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
3361
3544
  }
3362
3545
 
3546
+ for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
3547
+ test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
3548
+ }
3549
+
3363
3550
  for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3364
3551
  for (ggml_type type_dst : all_types) {
3365
- test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3366
- test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3552
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3553
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3367
3554
  }
3368
3555
  }
3369
3556
  for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
@@ -3434,21 +3621,35 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3434
3621
 
3435
3622
  test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
3436
3623
 
3437
- test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
3438
- test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
3439
- test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
3440
- test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
3624
+ test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
3625
+ test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
3626
+ test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
3627
+ test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
3628
+
3629
+ for (int i = 1; i < 9; ++i) {
3630
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3631
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3632
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3633
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3634
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3635
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3636
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3637
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3638
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3639
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3640
+ }
3441
3641
 
3442
3642
  #if 1
3443
3643
  for (ggml_type type_a : base_types) {
3444
3644
  for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3445
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3446
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3447
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3448
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3449
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3450
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3451
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
3645
+ // test cases without permutation
3646
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3647
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3648
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3649
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3650
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3651
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3652
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
3452
3653
 
3453
3654
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
3454
3655
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
@@ -3457,6 +3658,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3457
3658
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
3458
3659
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
3459
3660
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
3661
+
3662
+ // test cases with permutation
3663
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
3664
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
3665
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
3666
+
3667
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
3668
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
3669
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
3670
+
3671
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
3672
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
3673
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
3460
3674
  }
3461
3675
  }
3462
3676
  for (ggml_type type_a : other_types) {
@@ -3520,7 +3734,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3520
3734
  for (int n_mats : {4}) {
3521
3735
  for (int n_used : {2}) {
3522
3736
  for (bool b : {false}) {
3523
- for (int n : {1}) {
3737
+ for (int n : {1, 32}) {
3524
3738
  int m = 512;
3525
3739
  int k = 256;
3526
3740
  test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
@@ -3621,6 +3835,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3621
3835
  test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
3622
3836
  }
3623
3837
 
3838
+ if (all) {
3839
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
3840
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
3841
+ test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
3842
+ }
3843
+
3624
3844
  test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
3625
3845
  }
3626
3846
  }
@@ -3647,12 +3867,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3647
3867
 
3648
3868
  test_cases.emplace_back(new test_sum());
3649
3869
  test_cases.emplace_back(new test_sum_rows());
3870
+ test_cases.emplace_back(new test_mean());
3650
3871
  test_cases.emplace_back(new test_upscale());
3651
3872
  test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
3652
3873
  test_cases.emplace_back(new test_upscale_ext());
3653
- test_cases.emplace_back(new test_group_norm());
3874
+ test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
3875
+ test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
3654
3876
  test_cases.emplace_back(new test_acc());
3655
3877
  test_cases.emplace_back(new test_pad());
3878
+ test_cases.emplace_back(new test_pad_reflect_1d());
3656
3879
  test_cases.emplace_back(new test_arange());
3657
3880
  test_cases.emplace_back(new test_timestep_embedding());
3658
3881
  test_cases.emplace_back(new test_leaky_relu());
@@ -3666,7 +3889,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3666
3889
  for (int nh : { 32, }) {
3667
3890
  for (int kv : { 512, 1024, }) {
3668
3891
  for (int nb : { 1, 3, 32, 35, }) {
3669
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
3892
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
3670
3893
  test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
3671
3894
  }
3672
3895
  }
@@ -3678,9 +3901,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3678
3901
  }
3679
3902
 
3680
3903
  test_cases.emplace_back(new test_cross_entropy_loss());
3681
- for (float wd : {0.0f, 1e-2f}) {
3682
- test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd));
3683
- }
3904
+ test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
3684
3905
 
3685
3906
  // these tests are disabled to save execution time, but they can be handy for debugging
3686
3907
  #if 0
@@ -3700,6 +3921,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3700
3921
  test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
3701
3922
  test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
3702
3923
 
3924
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
3925
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
3926
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
3927
+
3928
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
3929
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
3930
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
3931
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
3932
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
3933
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
3934
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
3935
+
3936
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
3937
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3938
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
3939
+
3703
3940
  for (int bs : {1, 512}) {
3704
3941
  for (ggml_type type_a : all_types) {
3705
3942
  for (ggml_type type_b : {GGML_TYPE_F32}) {
@@ -3714,7 +3951,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3714
3951
  static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
3715
3952
  if (mode == MODE_TEST) {
3716
3953
  auto test_cases = make_test_cases_eval();
3717
- ggml_backend_t backend_cpu = ggml_backend_cpu_init();
3954
+ ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
3955
+ if (backend_cpu == NULL) {
3956
+ printf(" Failed to initialize CPU backend\n");
3957
+ return false;
3958
+ }
3718
3959
 
3719
3960
  size_t n_ok = 0;
3720
3961
  for (auto & test : test_cases) {
@@ -3794,7 +4035,9 @@ int main(int argc, char ** argv) {
3794
4035
  }
3795
4036
  }
3796
4037
 
3797
- // enumerate backends
4038
+ // load and enumerate backends
4039
+ ggml_backend_load_all();
4040
+
3798
4041
  printf("Testing %zu devices\n\n", ggml_backend_dev_count());
3799
4042
 
3800
4043
  size_t n_ok = 0;
@@ -3810,19 +4053,20 @@ int main(int argc, char ** argv) {
3810
4053
  continue;
3811
4054
  }
3812
4055
 
3813
- ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
3814
- GGML_ASSERT(backend != NULL);
3815
-
3816
- if (backend_filter == NULL && ggml_backend_is_cpu(backend) && mode != MODE_GRAD) {
4056
+ if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
3817
4057
  printf(" Skipping CPU backend\n");
3818
- ggml_backend_free(backend);
3819
4058
  n_ok++;
3820
4059
  continue;
3821
4060
  }
3822
4061
 
3823
- if (ggml_backend_is_cpu(backend)) {
4062
+ ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
4063
+ GGML_ASSERT(backend != NULL);
4064
+
4065
+ ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
4066
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
4067
+ if (ggml_backend_set_n_threads_fn) {
3824
4068
  // TODO: better value for n_threads
3825
- ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
4069
+ ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
3826
4070
  }
3827
4071
 
3828
4072
  printf(" Device description: %s\n", ggml_backend_dev_description(dev));
@@ -3846,6 +4090,8 @@ int main(int argc, char ** argv) {
3846
4090
  ggml_backend_free(backend);
3847
4091
  }
3848
4092
 
4093
+ ggml_quantize_free();
4094
+
3849
4095
  printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
3850
4096
 
3851
4097
  if (n_ok != ggml_backend_dev_count()) {
@@ -3853,8 +4099,6 @@ int main(int argc, char ** argv) {
3853
4099
  return 1;
3854
4100
  }
3855
4101
 
3856
- ggml_quantize_free();
3857
-
3858
4102
  printf("\033[1;32mOK\033[0m\n");
3859
4103
  return 0;
3860
4104
  }