@fugood/llama.node 0.3.17 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -12,51 +12,30 @@ llama_add_compile_flags()
12
12
 
13
13
  # examples
14
14
 
15
- include_directories(${CMAKE_CURRENT_SOURCE_DIR})
16
-
17
15
  if (EMSCRIPTEN)
18
16
  else()
19
- add_subdirectory(batched-bench)
20
17
  add_subdirectory(batched)
21
18
  add_subdirectory(embedding)
22
19
  add_subdirectory(eval-callback)
23
20
 
24
21
  add_subdirectory(gguf-hash)
25
- add_subdirectory(gguf-split)
26
22
  add_subdirectory(gguf)
27
23
  add_subdirectory(gritlm)
28
- add_subdirectory(imatrix)
29
- add_subdirectory(infill)
30
- add_subdirectory(llama-bench)
31
24
  add_subdirectory(lookahead)
32
25
  add_subdirectory(lookup)
33
- add_subdirectory(main)
34
26
  add_subdirectory(parallel)
35
27
  add_subdirectory(passkey)
36
- add_subdirectory(perplexity)
37
- add_subdirectory(quantize)
38
28
  add_subdirectory(retrieval)
39
- if (LLAMA_BUILD_SERVER)
40
- add_subdirectory(server)
41
- endif()
42
29
  add_subdirectory(save-load-state)
43
- add_subdirectory(run)
44
30
  add_subdirectory(simple)
45
31
  add_subdirectory(simple-chat)
46
32
  add_subdirectory(speculative)
47
33
  add_subdirectory(speculative-simple)
48
- add_subdirectory(tokenize)
49
- add_subdirectory(tts)
50
34
  add_subdirectory(gen-docs)
35
+ add_subdirectory(training)
51
36
  if (NOT GGML_BACKEND_DL)
52
- # these examples use the backends directly and cannot be built with dynamic loading
53
37
  add_subdirectory(convert-llama2c-to-ggml)
54
- add_subdirectory(cvector-generator)
55
- add_subdirectory(export-lora)
56
- add_subdirectory(llava)
57
- if (GGML_RPC)
58
- add_subdirectory(rpc)
59
- endif()
38
+ # these examples use the backends directly and cannot be built with dynamic loading
60
39
  if (GGML_SYCL)
61
40
  add_subdirectory(sycl)
62
41
  endif()
@@ -35,23 +35,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
35
35
 
36
36
  static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
37
37
  const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
38
- const struct llama_model * model = llama_get_model(ctx);
39
38
 
40
39
  // clear previous kv_cache values (irrelevant for embeddings)
41
40
  llama_kv_self_clear(ctx);
42
41
 
43
42
  // run model
44
43
  LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
45
- if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
46
- // encoder-only model
47
- if (llama_encode(ctx, batch) < 0) {
48
- LOG_ERR("%s : failed to encode\n", __func__);
49
- }
50
- } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
51
- // decoder-only model
52
- if (llama_decode(ctx, batch) < 0) {
53
- LOG_ERR("%s : failed to decode\n", __func__);
54
- }
44
+ if (llama_encode(ctx, batch) < 0) {
45
+ LOG_ERR("%s : failed to encode\n", __func__);
55
46
  }
56
47
 
57
48
  for (int i = 0; i < batch.n_tokens; i++) {
@@ -34,11 +34,61 @@ static std::string k_system =
34
34
  R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
35
35
  The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
36
36
 
37
- User: Recommend a nice restaurant in the area.
38
- Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
39
- User: Who is Richard Feynman?
40
- Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
41
- User:)";
37
+ User:
38
+ Recommend a nice restaurant in the area.
39
+ Assistant:
40
+ I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
41
+ User:
42
+ Who is Richard Feynman?
43
+ Assistant:
44
+ Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
45
+ )";
46
+
47
+ static std::vector<std::string> k_questions = {
48
+ "What is the tallest mountain in the world?",
49
+ "Who was the first person to win two Nobel Prizes?",
50
+ "Which country invented paper?",
51
+ "What organ is primarily responsible for pumping blood throughout the body?",
52
+ "Which planet is known for its prominent ring system?",
53
+ "Who directed the movie 'Inception'?",
54
+ "What is the freezing point of water in Fahrenheit?",
55
+ "Which animal is known to have the longest lifespan?",
56
+ "What language has the most native speakers worldwide?",
57
+ "What is the capital city of Canada?",
58
+ "Who is credited with inventing the World Wide Web?",
59
+ "Which metal is liquid at room temperature?",
60
+ "What is the term for an animal that eats both plants and meat?",
61
+ "Who painted 'The Starry Night'?",
62
+ "What gas do humans exhale that plants use for photosynthesis?",
63
+ "What year did World War II end?",
64
+ "Which continent has the most countries?",
65
+ "Who wrote the novel 'Frankenstein'?",
66
+ "What does DNA stand for?",
67
+ "What is the main ingredient in traditional Japanese miso soup?"
68
+ };
69
+
70
+ static std::vector<std::string> k_answers = {
71
+ "The tallest mountain in the world is Mount Everest.",
72
+ "Marie Curie was the first person to win two Nobel Prizes.",
73
+ "Paper was invented in China.",
74
+ "The heart is the organ responsible for pumping blood.",
75
+ "Saturn is known for its prominent ring system.",
76
+ "Christopher Nolan directed the movie 'Inception'.",
77
+ "The freezing point of water in Fahrenheit is 32°F.",
78
+ "The bowhead whale is known to have the longest lifespan among mammals.",
79
+ "Mandarin Chinese has the most native speakers in the world.",
80
+ "The capital city of Canada is Ottawa.",
81
+ "Tim Berners-Lee is credited with inventing the World Wide Web.",
82
+ "Mercury is the metal that is liquid at room temperature.",
83
+ "An animal that eats both plants and meat is called an omnivore.",
84
+ "'The Starry Night' was painted by Vincent van Gogh.",
85
+ "Humans exhale carbon dioxide, which plants use in photosynthesis.",
86
+ "World War II ended in 1945.",
87
+ "Africa is the continent with the most countries.",
88
+ "The novel 'Frankenstein' was written by Mary Shelley.",
89
+ "DNA stands for Deoxyribonucleic Acid.",
90
+ "The main ingredient in traditional Japanese miso soup is fermented soybean paste."
91
+ };
42
92
 
43
93
  static std::vector<std::string> k_prompts = {
44
94
  "What is the meaning of life?",
@@ -49,7 +99,7 @@ static std::vector<std::string> k_prompts = {
49
99
  "What is the best way to learn a new language?",
50
100
  "How to get a job at Google?",
51
101
  "If you could have any superpower, what would it be?",
52
- "I want to learn how to play the piano.",
102
+ "I want to learn how to play the piano. What would be the best way to do it?",
53
103
  };
54
104
 
55
105
  struct client {
@@ -68,6 +118,7 @@ struct client {
68
118
  int64_t t_start_prompt;
69
119
  int64_t t_start_gen;
70
120
 
121
+ int32_t n_past = 0;
71
122
  int32_t n_prompt = 0;
72
123
  int32_t n_decoded = 0;
73
124
  int32_t i_batch = -1;
@@ -107,6 +158,7 @@ int main(int argc, char ** argv) {
107
158
  common_params params;
108
159
 
109
160
  params.n_predict = 128;
161
+ params.n_junk = 0;
110
162
 
111
163
  if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
112
164
  return 1;
@@ -128,6 +180,12 @@ int main(int argc, char ** argv) {
128
180
 
129
181
  const bool dump_kv_cache = params.dump_kv_cache;
130
182
 
183
+ // is the system prompt shared in the cache
184
+ const bool is_sp_shared = params.is_pp_shared;
185
+
186
+ // extra text to insert in each client's prompt in order to make it larger
187
+ const int32_t n_junk = params.n_junk;
188
+
131
189
  // init llama.cpp
132
190
  llama_backend_init();
133
191
  llama_numa_init(params.numa);
@@ -169,6 +227,7 @@ int main(int argc, char ** argv) {
169
227
  }
170
228
 
171
229
  std::vector<llama_token> tokens_system;
230
+
172
231
  tokens_system = common_tokenize(ctx, k_system, true);
173
232
  const int32_t n_tokens_system = tokens_system.size();
174
233
 
@@ -190,7 +249,7 @@ int main(int argc, char ** argv) {
190
249
  LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
191
250
  LOG_INF("\n");
192
251
 
193
- {
252
+ if (is_sp_shared) {
194
253
  LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
195
254
 
196
255
  for (int32_t i = 0; i < n_tokens_system; ++i) {
@@ -228,7 +287,7 @@ int main(int argc, char ** argv) {
228
287
 
229
288
  client.i_batch = batch.n_tokens;
230
289
 
231
- common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
290
+ common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true);
232
291
 
233
292
  client.n_decoded += 1;
234
293
  }
@@ -254,9 +313,23 @@ int main(int argc, char ** argv) {
254
313
  client.t_start_gen = 0;
255
314
 
256
315
  client.input = k_prompts[rand() % k_prompts.size()];
257
- client.prompt = client.input + "\nAssistant:";
258
316
  client.response = "";
259
317
 
318
+ // construct the prompt:
319
+ // [system prompt] + [junk] + [user prompt]
320
+ client.n_past = 0;
321
+ client.prompt = "";
322
+ if (is_sp_shared) {
323
+ client.n_past = n_tokens_system;
324
+ } else {
325
+ client.prompt += k_system;
326
+ }
327
+ for (int i = 0; i < n_junk; ++i) {
328
+ const int r = rand() % k_questions.size();
329
+ client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n";
330
+ }
331
+ client.prompt += "User:\n" + client.input + "\nAssistant:\n";
332
+
260
333
  common_sampler_reset(client.smpl);
261
334
 
262
335
  // do not prepend BOS because we have a system prompt!
@@ -264,7 +337,7 @@ int main(int argc, char ** argv) {
264
337
  tokens_prompt = common_tokenize(ctx, client.prompt, false);
265
338
 
266
339
  for (size_t i = 0; i < tokens_prompt.size(); ++i) {
267
- common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
340
+ common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false);
268
341
  }
269
342
 
270
343
  // extract the logits only for the last token
@@ -363,10 +436,9 @@ int main(int argc, char ** argv) {
363
436
  // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
364
437
 
365
438
  if (client.n_decoded > 2 &&
366
- (llama_vocab_is_eog(vocab, id) ||
367
- (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
368
- client.response.find("User:") != std::string::npos ||
369
- client.response.find('\n') != std::string::npos)) {
439
+ (llama_vocab_is_eog(vocab, id) ||
440
+ (params.n_predict > 0 && client.n_decoded >= params.n_predict) ||
441
+ client.response.find("User:") != std::string::npos)) {
370
442
  // basic reverse prompt
371
443
  const size_t pos = client.response.find("User:");
372
444
  if (pos != std::string::npos) {
@@ -0,0 +1,5 @@
1
+ set(TARGET llama-finetune)
2
+ add_executable(${TARGET} finetune.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
@@ -0,0 +1,96 @@
1
+ #include "arg.h"
2
+ #include "common.h"
3
+ #include "log.h"
4
+ #include "llama.h"
5
+
6
+ #include <cmath>
7
+ #include <cstdio>
8
+ #include <cstring>
9
+ #include <ctime>
10
+ #include <vector>
11
+
12
+ #if defined(_MSC_VER)
13
+ #pragma warning(disable: 4244 4267) // possible loss of data
14
+ #endif
15
+
16
+ int main(int argc, char ** argv) {
17
+ common_params params;
18
+
19
+ params.escape = false;
20
+
21
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
22
+ return 1;
23
+ }
24
+
25
+ if (params.use_mmap) {
26
+ LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
27
+ params.use_mmap = false;
28
+ }
29
+ if (params.cache_type_k != GGML_TYPE_F32) {
30
+ LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
31
+ params.cache_type_k = GGML_TYPE_F32;
32
+ }
33
+ if (params.cache_type_v != GGML_TYPE_F32) {
34
+ LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
35
+ params.cache_type_v = GGML_TYPE_F32;
36
+ }
37
+
38
+ common_init();
39
+ llama_backend_init();
40
+ llama_numa_init(params.numa);
41
+
42
+ // load the model and apply lora adapter, if any
43
+ common_init_result llama_init = common_init_from_params(params);
44
+ llama_model_ptr & model = llama_init.model;
45
+ llama_context_ptr & ctx = llama_init.context;
46
+
47
+ if (model == NULL) {
48
+ LOG_ERR("%s: unable to load model\n", __func__);
49
+ return 1;
50
+ }
51
+
52
+ // print system information
53
+ {
54
+ LOG_INF("\n");
55
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
56
+ }
57
+
58
+ constexpr float val_split = 0.05f;
59
+
60
+ std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61
+ ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62
+
63
+ struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64
+ optimizer_params.adamw.alpha = 1e-7f; // learning rate
65
+
66
+ struct llama_opt_params lopt_params {
67
+ /*n_ctx_train =*/ 0,
68
+ /*param_filter =*/ llama_opt_param_filter_all,
69
+ /*param_filter_ud =*/ nullptr,
70
+ /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71
+ /*get_opt_pars_ud =*/ &optimizer_params,
72
+ };
73
+ llama_opt_init(ctx.get(), model.get(), lopt_params);
74
+
75
+ const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
76
+
77
+ ggml_opt_result_t result_train = ggml_opt_result_init();
78
+ ggml_opt_result_t result_eval = ggml_opt_result_init();
79
+
80
+ for (int epoch = 0; epoch < 2; ++epoch) {
81
+ llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82
+ ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
83
+ fprintf(stderr, "\n");
84
+
85
+ ggml_opt_result_reset(result_train);
86
+ ggml_opt_result_reset(result_eval);
87
+ }
88
+ ggml_opt_result_free(result_train);
89
+ ggml_opt_result_free(result_eval);
90
+
91
+ llama_model_save_to_file(model.get(), "finetuned-model.gguf");
92
+
93
+ llama_backend_free();
94
+
95
+ return 0;
96
+ }
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
193
193
  option(GGML_SYCL "ggml: use SYCL" OFF)
194
194
  option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
195
195
  option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
196
+ option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
196
197
  set (GGML_SYCL_TARGET "INTEL" CACHE STRING
197
198
  "ggml: sycl target device")
198
199
  set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
@@ -360,3 +361,29 @@ write_basic_package_version_file(
360
361
  install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
361
362
  ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
362
363
  DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
364
+
365
+ if (MSVC)
366
+ set(MSVC_WARNING_FLAGS
367
+ /wd4005 # Macro redefinition
368
+ /wd4244 # Conversion from one type to another type, possible loss of data
369
+ /wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
370
+ /wd4996 # Disable POSIX deprecation warnings
371
+ /wd4702 # Unreachable code warnings
372
+ )
373
+ function(disable_msvc_warnings target_name)
374
+ if(TARGET ${target_name})
375
+ target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
376
+ endif()
377
+ endfunction()
378
+
379
+ disable_msvc_warnings(ggml-base)
380
+ disable_msvc_warnings(ggml)
381
+ disable_msvc_warnings(ggml-cpu)
382
+ disable_msvc_warnings(ggml-cpu-x64)
383
+ disable_msvc_warnings(ggml-cpu-sse42)
384
+ disable_msvc_warnings(ggml-cpu-sandybridge)
385
+ disable_msvc_warnings(ggml-cpu-haswell)
386
+ disable_msvc_warnings(ggml-cpu-skylakex)
387
+ disable_msvc_warnings(ggml-cpu-icelake)
388
+ disable_msvc_warnings(ggml-cpu-alderlake)
389
+ endif()
@@ -38,7 +38,7 @@ extern "C" {
38
38
  GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
39
39
  GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
40
40
  GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
41
- GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
41
+ GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
42
42
  GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
43
43
  GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft);
44
44
 
@@ -59,7 +59,7 @@ extern "C" {
59
59
  GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
60
60
  GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
61
61
  GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
62
- GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
62
+ GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
63
63
  GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
64
64
  GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
65
65
  GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
@@ -248,7 +248,7 @@ extern "C" {
248
248
  // preferrably to run on the same backend as the buffer
249
249
  ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
250
250
 
251
- sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
251
+ sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
252
252
 
253
253
  // initialize buffers from a max size graph (optional)
254
254
  reserve_graph = build_graph(sched, max_batch_size);
@@ -289,7 +289,7 @@ extern "C" {
289
289
  typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
290
290
 
291
291
  // Initialize a backend scheduler, backends with low index are given priority over backends with high index
292
- GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
292
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
293
293
  GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
294
294
 
295
295
  // Initialize backend buffers from a measure graph
@@ -24,7 +24,7 @@ typedef std::unique_ptr<gguf_context, gguf_context_deleter> gguf_context_ptr;
24
24
 
25
25
  struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
26
26
 
27
- typedef std::unique_ptr<ggml_gallocr_t, ggml_gallocr_deleter> ggml_gallocr_ptr;
27
+ typedef std::unique_ptr<ggml_gallocr, ggml_gallocr_deleter> ggml_gallocr_ptr;
28
28
 
29
29
  // ggml-backend
30
30
 
@@ -37,13 +37,16 @@ extern "C" {
37
37
  // ====== Dataset ======
38
38
 
39
39
  GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
40
- int64_t ne_datapoint, // number of elements per datapoint
41
- int64_t ne_label, // number of elements per label
42
- int64_t ndata, // total number of datapoints/labels
43
- int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40
+ enum ggml_type type_data, // the type for the internal data tensor
41
+ enum ggml_type type_label, // the type for the internal labels tensor
42
+ int64_t ne_datapoint, // number of elements per datapoint
43
+ int64_t ne_label, // number of elements per label
44
+ int64_t ndata, // total number of datapoints/labels
45
+ int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
44
46
  GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
45
47
 
46
48
  // get underlying tensors that store the data
49
+ GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
47
50
  GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
48
51
  GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
49
52
 
@@ -56,13 +59,19 @@ extern "C" {
56
59
  struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
57
60
  struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
58
61
  int64_t ibatch);
62
+ GGML_API void ggml_opt_dataset_get_batch_host(
63
+ ggml_opt_dataset_t dataset,
64
+ void * data_batch,
65
+ size_t nb_data_batch,
66
+ void * labels_batch,
67
+ int64_t ibatch);
59
68
 
60
69
  // ====== Model / Context ======
61
70
 
62
71
  enum ggml_opt_build_type {
63
- GGML_OPT_BUILD_TYPE_FORWARD,
64
- GGML_OPT_BUILD_TYPE_GRAD,
65
- GGML_OPT_BUILD_TYPE_OPT,
72
+ GGML_OPT_BUILD_TYPE_FORWARD = 10,
73
+ GGML_OPT_BUILD_TYPE_GRAD = 20,
74
+ GGML_OPT_BUILD_TYPE_OPT = 30,
66
75
  };
67
76
 
68
77
  // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
81
90
  // userdata can be used to pass arbitrary data
82
91
  typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
83
92
 
84
- // returns the default optimizer params (constant)
93
+ // returns the default optimizer params (constant, hard-coded values)
85
94
  // userdata is not used
86
95
  GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
87
96
 
97
+ // casts userdata to ggml_opt_optimizer_params and returns it
98
+ GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
99
+
88
100
  // parameters for initializing a new optimization context
89
101
  struct ggml_opt_params {
90
102
  ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
91
103
 
92
- struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
93
-
94
- // the forward graph is defined by inputs and outputs
95
- // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96
- struct ggml_tensor * inputs;
97
- struct ggml_tensor * outputs;
104
+ // by default the forward graph needs to be reconstructed for each eval
105
+ // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106
+ struct ggml_context * ctx_compute;
107
+ struct ggml_tensor * inputs;
108
+ struct ggml_tensor * outputs;
98
109
 
99
110
  enum ggml_opt_loss_type loss_type;
100
111
  enum ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
107
118
 
108
119
  // get parameters for an optimization context with defaults set where possible
109
120
  // parameters for which no sensible defaults exist are supplied as arguments to this function
110
- GGML_API ggml_opt_params ggml_opt_default_params(
111
- ggml_backend_sched_t backend_sched,
112
- struct ggml_context * ctx_compute,
113
- struct ggml_tensor * inputs,
114
- struct ggml_tensor * outputs,
115
- enum ggml_opt_loss_type loss_type);
121
+ GGML_API struct ggml_opt_params ggml_opt_default_params(
122
+ ggml_backend_sched_t backend_sched,
123
+ enum ggml_opt_loss_type loss_type);
116
124
 
117
125
  GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
118
126
  GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -121,6 +129,7 @@ extern "C" {
121
129
  GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
122
130
 
123
131
  // get underlying tensors that store data
132
+ // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
124
133
  GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
125
134
  GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
126
135
  GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
@@ -128,11 +137,12 @@ extern "C" {
128
137
  GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
129
138
  GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
130
139
 
140
+ // get the gradient accumulator for a node from the forward graph
131
141
  GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
132
142
 
133
143
  // ====== Optimization Result ======
134
144
 
135
- GGML_API ggml_opt_result_t ggml_opt_result_init();
145
+ GGML_API ggml_opt_result_t ggml_opt_result_init(void);
136
146
  GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
137
147
  GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
138
148
 
@@ -144,11 +154,20 @@ extern "C" {
144
154
 
145
155
  // ====== Computation ======
146
156
 
147
- // do forward pass, increment result if not NULL
148
- GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
157
+ // if not using static graphs, this function must be called prior to ggml_opt_alloc
158
+ GGML_API void ggml_opt_prepare_alloc(
159
+ ggml_opt_context_t opt_ctx,
160
+ struct ggml_context * ctx_compute,
161
+ struct ggml_cgraph * gf,
162
+ struct ggml_tensor * inputs,
163
+ struct ggml_tensor * outputs);
164
+
165
+ // allocate the next graph for evaluation, either forward or forward + backward
166
+ // must be called exactly once prior to calling ggml_opt_eval
167
+ GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
149
168
 
150
- // do forward pass, increment result if not NULL, do backward pass
151
- GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
169
+ // do forward pass, increment result if not NULL, do backward pass if allocated
170
+ GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
152
171
 
153
172
  // ############################################################################
154
173
  // ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
200
219
  // fit model defined by inputs and outputs to dataset
201
220
  GGML_API void ggml_opt_fit(
202
221
  ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203
- ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204
- ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205
- ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
222
+ struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
223
+ struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
224
+ struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206
225
  ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207
226
  enum ggml_opt_loss_type loss_type, // loss to minimize
208
227
  ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
@@ -673,11 +673,15 @@ extern "C" {
673
673
  GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
674
674
  GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
675
675
 
676
+ // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
676
677
  GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor);
677
678
  GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
678
679
  GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
679
680
  GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
680
681
 
682
+ // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
683
+ GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);
684
+
681
685
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
682
686
  GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
683
687
 
@@ -764,7 +768,7 @@ extern "C" {
764
768
  // Tensor flags
765
769
  GGML_API void ggml_set_input(struct ggml_tensor * tensor);
766
770
  GGML_API void ggml_set_output(struct ggml_tensor * tensor);
767
- GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
771
+ GGML_API void ggml_set_param(struct ggml_tensor * tensor);
768
772
  GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
769
773
 
770
774
  //
@@ -934,7 +938,7 @@ extern "C" {
934
938
  GGML_API struct ggml_tensor * ggml_repeat_back(
935
939
  struct ggml_context * ctx,
936
940
  struct ggml_tensor * a,
937
- struct ggml_tensor * b);
941
+ struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
938
942
 
939
943
  // concat a and b along dim
940
944
  // used in stable-diffusion
@@ -2045,15 +2049,14 @@ extern "C" {
2045
2049
 
2046
2050
  GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2047
2051
  GGML_API void ggml_build_backward_expand(
2048
- struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2049
- struct ggml_context * ctx_compute, // context for gradient computation
2050
- struct ggml_cgraph * cgraph,
2051
- bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
2052
+ struct ggml_context * ctx, // context for gradient computation
2053
+ struct ggml_cgraph * cgraph,
2054
+ struct ggml_tensor ** grad_accs);
2052
2055
 
2053
2056
  // graph allocation in a context
2054
2057
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
2055
2058
  GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
2056
- GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2059
+ GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
2057
2060
  GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2058
2061
  GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2059
2062
  GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);