cui-llama.rn 1.6.1 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
  3. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +41 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
@@ -41,6 +41,16 @@ struct completion_token_output
41
41
  llama_token tok;
42
42
  };
43
43
 
44
+ struct llama_rn_context_mtmd;
45
+
46
+ struct llama_rn_tokenize_result {
47
+ std::vector<llama_token> tokens;
48
+ bool has_media = false;
49
+ std::vector<std::string> bitmap_hashes;
50
+ std::vector<size_t> chunk_pos; // both text and media
51
+ std::vector<size_t> chunk_pos_media; // media only
52
+ };
53
+
44
54
  // Main context class
45
55
  struct llama_rn_context {
46
56
  bool is_predicting = false;
@@ -51,8 +61,9 @@ struct llama_rn_context {
51
61
 
52
62
  size_t num_prompt_tokens = 0;
53
63
  size_t num_tokens_predicted = 0;
54
- size_t n_past = 0;
64
+ llama_pos n_past = 0;
55
65
  size_t n_remain = 0;
66
+ std::vector<std::string> mtmd_bitmap_past_hashes;
56
67
 
57
68
  std::vector<llama_token> embd;
58
69
  common_params params;
@@ -78,6 +89,9 @@ struct llama_rn_context {
78
89
 
79
90
  std::vector<common_adapter_lora_info> lora;
80
91
 
92
+ llama_rn_context_mtmd *mtmd_wrapper = nullptr;
93
+ bool has_multimodal = false;
94
+
81
95
  ~llama_rn_context();
82
96
 
83
97
  void rewind();
@@ -97,8 +111,9 @@ struct llama_rn_context {
97
111
  const std::string &chat_template
98
112
  ) const;
99
113
  void truncatePrompt(std::vector<llama_token> &prompt_tokens);
100
- void loadPrompt();
114
+ void loadPrompt(const std::vector<std::string> &media_paths);
101
115
  void beginCompletion();
116
+ void endCompletion();
102
117
  completion_token_output nextToken();
103
118
  size_t findStoppingStrings(const std::string &text, const size_t last_token_size, const stop_type type);
104
119
  completion_token_output doCompletion();
@@ -107,7 +122,22 @@ struct llama_rn_context {
107
122
  int applyLoraAdapters(std::vector<common_adapter_lora_info> lora);
108
123
  void removeLoraAdapters();
109
124
  std::vector<common_adapter_lora_info> getLoadedLoraAdapters();
110
- };\
125
+
126
+ // Multimodal methods
127
+ bool initMultimodal(const std::string &mmproj_path, bool use_gpu);
128
+ bool isMultimodalEnabled() const;
129
+ bool isMultimodalSupportVision() const;
130
+ bool isMultimodalSupportAudio() const;
131
+ void releaseMultimodal();
132
+
133
+ // Process multiple media and add them to the context
134
+ void processMedia(
135
+ const std::string &prompt,
136
+ const std::vector<std::string> &media_paths
137
+ );
138
+
139
+ llama_rn_tokenize_result tokenize(const std::string &text, const std::vector<std::string> &media_paths);
140
+ };
111
141
 
112
142
  // Logging macros
113
143
  extern bool rnllama_verbose;
@@ -6,7 +6,7 @@
6
6
  <dict>
7
7
  <key>Info.plist</key>
8
8
  <data>
9
- vcGrMKyzdqHg1AItKt7zUs4VeGs=
9
+ 8sGY2PdZU05Cx7NGpUmOkf+S4H4=
10
10
  </data>
11
11
  </dict>
12
12
  <key>files2</key>
@@ -3,6 +3,7 @@
3
3
  #pragma once
4
4
 
5
5
  #include "common.h"
6
+ #include <chrono>
6
7
  #include <string>
7
8
  #include <vector>
8
9
  #include "minja/chat-template.hpp"
@@ -79,6 +80,7 @@ struct common_chat_templates_inputs {
79
80
  common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
80
81
  bool parallel_tool_calls = false;
81
82
  bool extract_reasoning = true;
83
+ std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
82
84
  };
83
85
 
84
86
  struct common_chat_params {
@@ -6,6 +6,7 @@
6
6
 
7
7
  #include <set>
8
8
  #include <string>
9
+ #include <string_view>
9
10
  #include <vector>
10
11
  #include <sstream>
11
12
 
@@ -77,7 +78,6 @@ enum llama_example {
77
78
  LLAMA_EXAMPLE_COMMON,
78
79
  LLAMA_EXAMPLE_SPECULATIVE,
79
80
  LLAMA_EXAMPLE_MAIN,
80
- LLAMA_EXAMPLE_INFILL,
81
81
  LLAMA_EXAMPLE_EMBEDDING,
82
82
  LLAMA_EXAMPLE_PERPLEXITY,
83
83
  LLAMA_EXAMPLE_RETRIEVAL,
@@ -87,7 +87,7 @@ enum llama_example {
87
87
  LLAMA_EXAMPLE_SERVER,
88
88
  LLAMA_EXAMPLE_CVECTOR_GENERATOR,
89
89
  LLAMA_EXAMPLE_EXPORT_LORA,
90
- LLAMA_EXAMPLE_LLAVA,
90
+ LLAMA_EXAMPLE_MTMD,
91
91
  LLAMA_EXAMPLE_LOOKUP,
92
92
  LLAMA_EXAMPLE_PARALLEL,
93
93
  LLAMA_EXAMPLE_TTS,
@@ -107,6 +107,7 @@ enum common_sampler_type {
107
107
  COMMON_SAMPLER_TYPE_XTC = 8,
108
108
  COMMON_SAMPLER_TYPE_INFILL = 9,
109
109
  COMMON_SAMPLER_TYPE_PENALTIES = 10,
110
+ COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
110
111
  };
111
112
 
112
113
  // dimensionality reduction methods, used by cvector-generator
@@ -172,6 +173,7 @@ struct common_params_sampling {
172
173
  std::vector<enum common_sampler_type> samplers = {
173
174
  COMMON_SAMPLER_TYPE_PENALTIES,
174
175
  COMMON_SAMPLER_TYPE_DRY,
176
+ COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
175
177
  COMMON_SAMPLER_TYPE_TOP_K,
176
178
  COMMON_SAMPLER_TYPE_TYPICAL_P,
177
179
  COMMON_SAMPLER_TYPE_TOP_P,
@@ -336,17 +338,17 @@ struct common_params {
336
338
  bool flash_attn = false; // flash attention
337
339
  bool no_perf = false; // disable performance metrics
338
340
  bool ctx_shift = true; // context shift on inifinite text generation
341
+ bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
339
342
 
340
343
  bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
341
- bool logits_all = false; // return logits for all tokens in the batch
342
344
  bool use_mmap = true; // use mmap for faster loads
343
345
  bool use_mlock = false; // use mlock to keep model in memory
344
346
  bool verbose_prompt = false; // print prompt tokens before generation
345
347
  bool display_prompt = true; // print prompt before generation
346
- bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
347
348
  bool no_kv_offload = false; // disable KV offloading
348
349
  bool warmup = true; // warmup run
349
350
  bool check_tensors = false; // validate tensor data
351
+ bool no_op_offload = false; // globally disable offload host tensor operations to device
350
352
 
351
353
  bool single_turn = false; // single turn chat conversation
352
354
 
@@ -355,7 +357,7 @@ struct common_params {
355
357
 
356
358
  common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
357
359
 
358
- // multimodal models (see tools/llava)
360
+ // multimodal models (see tools/mtmd)
359
361
  struct common_params_model mmproj;
360
362
  bool mmproj_use_gpu = true; // use GPU for multimodal model
361
363
  bool no_mmproj = false; // explicitly disable multimodal model
@@ -381,6 +383,7 @@ struct common_params {
381
383
  bool use_jinja = false; // NOLINT
382
384
  bool enable_chat_template = true;
383
385
  common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
386
+ bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
384
387
 
385
388
  std::vector<std::string> api_keys;
386
389
 
@@ -424,6 +427,7 @@ struct common_params {
424
427
 
425
428
  bool process_output = false; // collect data for the output tensor
426
429
  bool compute_ppl = true; // whether to compute perplexity
430
+ bool parse_special = false; // whether to parse special tokens during imatrix tokenization
427
431
 
428
432
  // cvector-generator params
429
433
  int n_pca_batch = 100;
@@ -439,6 +443,11 @@ struct common_params {
439
443
 
440
444
  // common params
441
445
  std::string out_file; // output filename for all example programs
446
+ // optional callback for model loading progress and cancellation:
447
+ // called with a progress value between 0.0 and 1.0.
448
+ // return false from callback to abort model loading or true to continue
449
+ llama_progress_callback load_progress_callback = NULL;
450
+ void * load_progress_callback_user_data = NULL;
442
451
  };
443
452
 
444
453
  // call once at the start of a program if it uses libcommon
@@ -516,10 +525,9 @@ static bool string_starts_with(const std::string & str,
516
525
  return str.rfind(prefix, 0) == 0;
517
526
  }
518
527
 
519
- static bool string_ends_with(const std::string & str,
520
- const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
521
- return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
522
- }
528
+ // While we wait for C++20's std::string::ends_with...
529
+ bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
530
+ size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
523
531
 
524
532
  bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
525
533
  void string_process_escapes(std::string & input);
@@ -628,16 +636,6 @@ std::string common_detokenize(
628
636
  const std::vector<llama_token> & tokens,
629
637
  bool special = true);
630
638
 
631
- //
632
- // KV cache utils
633
- //
634
-
635
- // Dump the KV cache view with the number of sequences per cell.
636
- void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
637
-
638
- // Dump the KV cache view showing individual sequences in each cell (long output).
639
- void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
640
-
641
639
  //
642
640
  // Embedding utils
643
641
  //
@@ -679,3 +677,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
679
677
  const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
680
678
 
681
679
  }
680
+
681
+ //
682
+ // training utils
683
+ //
684
+
685
+ lm_ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
@@ -38,7 +38,7 @@ extern "C" {
38
38
  LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_buft_alloc_buffer (lm_ggml_backend_buffer_type_t buft, size_t size);
39
39
  LM_GGML_API size_t lm_ggml_backend_buft_get_alignment (lm_ggml_backend_buffer_type_t buft);
40
40
  LM_GGML_API size_t lm_ggml_backend_buft_get_max_size (lm_ggml_backend_buffer_type_t buft);
41
- LM_GGML_API size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor);
41
+ LM_GGML_API size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor);
42
42
  LM_GGML_API bool lm_ggml_backend_buft_is_host (lm_ggml_backend_buffer_type_t buft);
43
43
  LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_buft_get_device (lm_ggml_backend_buffer_type_t buft);
44
44
 
@@ -59,7 +59,7 @@ extern "C" {
59
59
  LM_GGML_API enum lm_ggml_status lm_ggml_backend_buffer_init_tensor (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor);
60
60
  LM_GGML_API size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer);
61
61
  LM_GGML_API size_t lm_ggml_backend_buffer_get_max_size (lm_ggml_backend_buffer_t buffer);
62
- LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor);
62
+ LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor);
63
63
  LM_GGML_API void lm_ggml_backend_buffer_clear (lm_ggml_backend_buffer_t buffer, uint8_t value);
64
64
  LM_GGML_API bool lm_ggml_backend_buffer_is_host (lm_ggml_backend_buffer_t buffer);
65
65
  LM_GGML_API void lm_ggml_backend_buffer_set_usage (lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage);
@@ -248,7 +248,7 @@ extern "C" {
248
248
  // preferrably to run on the same backend as the buffer
249
249
  lm_ggml_backend_buffer_set_usage(buf_weights, LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
250
250
 
251
- sched = lm_ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, LM_GGML_DEFAULT_GRAPH_SIZE, false);
251
+ sched = lm_ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, LM_GGML_DEFAULT_GRAPH_SIZE, false, true);
252
252
 
253
253
  // initialize buffers from a max size graph (optional)
254
254
  reserve_graph = build_graph(sched, max_batch_size);
@@ -289,7 +289,7 @@ extern "C" {
289
289
  typedef bool (*lm_ggml_backend_sched_eval_callback)(struct lm_ggml_tensor * t, bool ask, void * user_data);
290
290
 
291
291
  // Initialize a backend scheduler, backends with low index are given priority over backends with high index
292
- LM_GGML_API lm_ggml_backend_sched_t lm_ggml_backend_sched_new(lm_ggml_backend_t * backends, lm_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
292
+ LM_GGML_API lm_ggml_backend_sched_t lm_ggml_backend_sched_new(lm_ggml_backend_t * backends, lm_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
293
293
  LM_GGML_API void lm_ggml_backend_sched_free(lm_ggml_backend_sched_t sched);
294
294
 
295
295
  // Initialize backend buffers from a measure graph
@@ -207,6 +207,10 @@ typedef struct {
207
207
  float attn_factor;
208
208
  float beta_fast;
209
209
  float beta_slow;
210
+ int32_t sect_0;
211
+ int32_t sect_1;
212
+ int32_t sect_2;
213
+ int32_t sect_3;
210
214
  } lm_ggml_metal_kargs_rope;
211
215
 
212
216
  typedef struct {
@@ -299,21 +303,42 @@ typedef struct {
299
303
  } lm_ggml_metal_kargs_mul_mv_ext;
300
304
 
301
305
  typedef struct {
302
- int32_t nei0;
303
- int32_t nei1;
304
- uint64_t nbi1;
306
+ int32_t ne10;
307
+ int32_t ne11; // n_expert_used (bcast)
308
+ uint64_t nb11;
309
+ uint64_t nb12;
310
+ int32_t neh11; // n_tokens
311
+ uint64_t nbh11;
312
+ int32_t ne20; // n_expert_used
313
+ uint64_t nb21;
314
+ } lm_ggml_metal_kargs_mul_mm_id_map0;
315
+
316
+ typedef struct {
317
+ int32_t ne20; // n_expert_used
318
+ int32_t neh0;
319
+ int32_t neh1;
320
+ uint64_t nbh1;
321
+ uint64_t nbh2;
322
+ int32_t ne0;
323
+ uint64_t nb1;
324
+ uint64_t nb2;
325
+ } lm_ggml_metal_kargs_mul_mm_id_map1;
326
+
327
+ typedef struct {
305
328
  int32_t ne00;
306
329
  int32_t ne02;
307
330
  uint64_t nb01;
308
331
  uint64_t nb02;
309
- int32_t ne11;
310
- int32_t ne12;
311
- int32_t ne13;
312
- uint64_t nb10;
313
- uint64_t nb11;
314
- uint64_t nb12;
315
- int32_t ne0;
316
- int32_t ne1;
332
+ uint64_t nb03;
333
+ int32_t neh12;
334
+ uint64_t nbh10;
335
+ uint64_t nbh11;
336
+ uint64_t nbh12;
337
+ uint64_t nbh13;
338
+ int32_t neh0;
339
+ int32_t neh1;
340
+ int16_t r2;
341
+ int16_t r3;
317
342
  } lm_ggml_metal_kargs_mul_mm_id;
318
343
 
319
344
  typedef struct {
@@ -37,13 +37,16 @@ extern "C" {
37
37
  // ====== Dataset ======
38
38
 
39
39
  LM_GGML_API lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(
40
- int64_t ne_datapoint, // number of elements per datapoint
41
- int64_t ne_label, // number of elements per label
42
- int64_t ndata, // total number of datapoints/labels
43
- int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40
+ enum lm_ggml_type type_data, // the type for the internal data tensor
41
+ enum lm_ggml_type type_label, // the type for the internal labels tensor
42
+ int64_t ne_datapoint, // number of elements per datapoint
43
+ int64_t ne_label, // number of elements per label
44
+ int64_t ndata, // total number of datapoints/labels
45
+ int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
44
46
  LM_GGML_API void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset);
45
47
 
46
48
  // get underlying tensors that store the data
49
+ LM_GGML_API int64_t lm_ggml_opt_dataset_ndata (lm_ggml_opt_dataset_t dataset);
47
50
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_dataset_data (lm_ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
48
51
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_dataset_labels(lm_ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
49
52
 
@@ -56,13 +59,19 @@ extern "C" {
56
59
  struct lm_ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
57
60
  struct lm_ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
58
61
  int64_t ibatch);
62
+ LM_GGML_API void lm_ggml_opt_dataset_get_batch_host(
63
+ lm_ggml_opt_dataset_t dataset,
64
+ void * data_batch,
65
+ size_t nb_data_batch,
66
+ void * labels_batch,
67
+ int64_t ibatch);
59
68
 
60
69
  // ====== Model / Context ======
61
70
 
62
71
  enum lm_ggml_opt_build_type {
63
- LM_GGML_OPT_BUILD_TYPE_FORWARD,
64
- LM_GGML_OPT_BUILD_TYPE_GRAD,
65
- LM_GGML_OPT_BUILD_TYPE_OPT,
72
+ LM_GGML_OPT_BUILD_TYPE_FORWARD = 10,
73
+ LM_GGML_OPT_BUILD_TYPE_GRAD = 20,
74
+ LM_GGML_OPT_BUILD_TYPE_OPT = 30,
66
75
  };
67
76
 
68
77
  // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
81
90
  // userdata can be used to pass arbitrary data
82
91
  typedef struct lm_ggml_opt_optimizer_params (*lm_ggml_opt_get_optimizer_params)(void * userdata);
83
92
 
84
- // returns the default optimizer params (constant)
93
+ // returns the default optimizer params (constant, hard-coded values)
85
94
  // userdata is not used
86
95
  LM_GGML_API struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata);
87
96
 
97
+ // casts userdata to lm_ggml_opt_optimizer_params and returns it
98
+ LM_GGML_API struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_constant_optimizer_params(void * userdata);
99
+
88
100
  // parameters for initializing a new optimization context
89
101
  struct lm_ggml_opt_params {
90
102
  lm_ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
91
103
 
92
- struct lm_ggml_context * ctx_compute; // created in user code, holds non-static tensors
93
-
94
- // the forward graph is defined by inputs and outputs
95
- // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96
- struct lm_ggml_tensor * inputs;
97
- struct lm_ggml_tensor * outputs;
104
+ // by default the forward graph needs to be reconstructed for each eval
105
+ // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106
+ struct lm_ggml_context * ctx_compute;
107
+ struct lm_ggml_tensor * inputs;
108
+ struct lm_ggml_tensor * outputs;
98
109
 
99
110
  enum lm_ggml_opt_loss_type loss_type;
100
111
  enum lm_ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
107
118
 
108
119
  // get parameters for an optimization context with defaults set where possible
109
120
  // parameters for which no sensible defaults exist are supplied as arguments to this function
110
- LM_GGML_API lm_ggml_opt_params lm_ggml_opt_default_params(
111
- lm_ggml_backend_sched_t backend_sched,
112
- struct lm_ggml_context * ctx_compute,
113
- struct lm_ggml_tensor * inputs,
114
- struct lm_ggml_tensor * outputs,
115
- enum lm_ggml_opt_loss_type loss_type);
121
+ LM_GGML_API struct lm_ggml_opt_params lm_ggml_opt_default_params(
122
+ lm_ggml_backend_sched_t backend_sched,
123
+ enum lm_ggml_opt_loss_type loss_type);
116
124
 
117
125
  LM_GGML_API lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params);
118
126
  LM_GGML_API void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx);
@@ -120,7 +128,10 @@ extern "C" {
120
128
  // set gradients to zero, initilize loss, and optionally reset the optimizer
121
129
  LM_GGML_API void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer);
122
130
 
131
+ LM_GGML_API bool lm_ggml_opt_static_graphs(lm_ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
132
+
123
133
  // get underlying tensors that store data
134
+ // if not using static graphs these pointers become invalid with the next call to lm_ggml_opt_alloc
124
135
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_inputs( lm_ggml_opt_context_t opt_ctx); // forward graph input tensor
125
136
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_outputs( lm_ggml_opt_context_t opt_ctx); // forward graph output tensor
126
137
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_labels( lm_ggml_opt_context_t opt_ctx); // labels to compare outputs against
@@ -128,11 +139,12 @@ extern "C" {
128
139
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_pred( lm_ggml_opt_context_t opt_ctx); // predictions made by outputs
129
140
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_ncorrect(lm_ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
130
141
 
142
+ // get the gradient accumulator for a node from the forward graph
131
143
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_grad_acc(lm_ggml_opt_context_t opt_ctx, struct lm_ggml_tensor * node);
132
144
 
133
145
  // ====== Optimization Result ======
134
146
 
135
- LM_GGML_API lm_ggml_opt_result_t lm_ggml_opt_result_init();
147
+ LM_GGML_API lm_ggml_opt_result_t lm_ggml_opt_result_init(void);
136
148
  LM_GGML_API void lm_ggml_opt_result_free(lm_ggml_opt_result_t result);
137
149
  LM_GGML_API void lm_ggml_opt_result_reset(lm_ggml_opt_result_t result);
138
150
 
@@ -144,11 +156,20 @@ extern "C" {
144
156
 
145
157
  // ====== Computation ======
146
158
 
147
- // do forward pass, increment result if not NULL
148
- LM_GGML_API void lm_ggml_opt_forward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result);
159
+ // if not using static graphs, this function must be called prior to lm_ggml_opt_alloc
160
+ LM_GGML_API void lm_ggml_opt_prepare_alloc(
161
+ lm_ggml_opt_context_t opt_ctx,
162
+ struct lm_ggml_context * ctx_compute,
163
+ struct lm_ggml_cgraph * gf,
164
+ struct lm_ggml_tensor * inputs,
165
+ struct lm_ggml_tensor * outputs);
166
+
167
+ // allocate the next graph for evaluation, either forward or forward + backward
168
+ // must be called exactly once prior to calling lm_ggml_opt_eval
169
+ LM_GGML_API void lm_ggml_opt_alloc(lm_ggml_opt_context_t opt_ctx, bool backward);
149
170
 
150
- // do forward pass, increment result if not NULL, do backward pass
151
- LM_GGML_API void lm_ggml_opt_forward_backward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result);
171
+ // do forward pass, increment result if not NULL, do backward pass if allocated
172
+ LM_GGML_API void lm_ggml_opt_eval(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result);
152
173
 
153
174
  // ############################################################################
154
175
  // ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +221,9 @@ extern "C" {
200
221
  // fit model defined by inputs and outputs to dataset
201
222
  LM_GGML_API void lm_ggml_opt_fit(
202
223
  lm_ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203
- lm_ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204
- lm_ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205
- lm_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
224
+ struct lm_ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
225
+ struct lm_ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
226
+ struct lm_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206
227
  lm_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207
228
  enum lm_ggml_opt_loss_type loss_type, // loss to minimize
208
229
  lm_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
@@ -537,6 +537,7 @@ extern "C" {
537
537
  LM_GGML_UNARY_OP_HARDSWISH,
538
538
  LM_GGML_UNARY_OP_HARDSIGMOID,
539
539
  LM_GGML_UNARY_OP_EXP,
540
+ LM_GGML_UNARY_OP_GELU_ERF,
540
541
 
541
542
  LM_GGML_UNARY_OP_COUNT,
542
543
  };
@@ -674,11 +675,15 @@ extern "C" {
674
675
  LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor);
675
676
  LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars
676
677
 
678
+ // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
677
679
  LM_GGML_API bool lm_ggml_is_contiguous (const struct lm_ggml_tensor * tensor);
678
680
  LM_GGML_API bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor); // same as lm_ggml_is_contiguous()
679
681
  LM_GGML_API bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1
680
682
  LM_GGML_API bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2
681
683
 
684
+ // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
685
+ LM_GGML_API bool lm_ggml_is_contiguously_allocated(const struct lm_ggml_tensor * tensor);
686
+
682
687
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
683
688
  LM_GGML_API bool lm_ggml_is_contiguous_channels(const struct lm_ggml_tensor * tensor);
684
689
 
@@ -765,7 +770,7 @@ extern "C" {
765
770
  // Tensor flags
766
771
  LM_GGML_API void lm_ggml_set_input(struct lm_ggml_tensor * tensor);
767
772
  LM_GGML_API void lm_ggml_set_output(struct lm_ggml_tensor * tensor);
768
- LM_GGML_API void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor);
773
+ LM_GGML_API void lm_ggml_set_param(struct lm_ggml_tensor * tensor);
769
774
  LM_GGML_API void lm_ggml_set_loss(struct lm_ggml_tensor * tensor);
770
775
 
771
776
  //
@@ -935,7 +940,7 @@ extern "C" {
935
940
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_repeat_back(
936
941
  struct lm_ggml_context * ctx,
937
942
  struct lm_ggml_tensor * a,
938
- struct lm_ggml_tensor * b);
943
+ struct lm_ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
939
944
 
940
945
  // concat a and b along dim
941
946
  // used in stable-diffusion
@@ -1021,6 +1026,16 @@ extern "C" {
1021
1026
  struct lm_ggml_context * ctx,
1022
1027
  struct lm_ggml_tensor * a);
1023
1028
 
1029
+ // GELU using erf (error function) when possible
1030
+ // some backends may fallback to approximation based on Abramowitz and Stegun formula
1031
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_erf(
1032
+ struct lm_ggml_context * ctx,
1033
+ struct lm_ggml_tensor * a);
1034
+
1035
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_erf_inplace(
1036
+ struct lm_ggml_context * ctx,
1037
+ struct lm_ggml_tensor * a);
1038
+
1024
1039
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_quick(
1025
1040
  struct lm_ggml_context * ctx,
1026
1041
  struct lm_ggml_tensor * a);
@@ -2046,15 +2061,14 @@ extern "C" {
2046
2061
 
2047
2062
  LM_GGML_API void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor);
2048
2063
  LM_GGML_API void lm_ggml_build_backward_expand(
2049
- struct lm_ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2050
- struct lm_ggml_context * ctx_compute, // context for gradient computation
2051
- struct lm_ggml_cgraph * cgraph,
2052
- bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
2064
+ struct lm_ggml_context * ctx, // context for gradient computation
2065
+ struct lm_ggml_cgraph * cgraph,
2066
+ struct lm_ggml_tensor ** grad_accs);
2053
2067
 
2054
2068
  // graph allocation in a context
2055
2069
  LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph (struct lm_ggml_context * ctx); // size = LM_GGML_DEFAULT_GRAPH_SIZE, grads = false
2056
2070
  LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, size_t size, bool grads);
2057
- LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph);
2071
+ LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, bool force_grads);
2058
2072
  LM_GGML_API void lm_ggml_graph_cpy (struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst);
2059
2073
  LM_GGML_API void lm_ggml_graph_reset (struct lm_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2060
2074
  LM_GGML_API void lm_ggml_graph_clear (struct lm_ggml_cgraph * cgraph);
@@ -14,6 +14,7 @@ enum llm_chat_template {
14
14
  LLM_CHAT_TEMPLATE_MISTRAL_V3,
15
15
  LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
16
16
  LLM_CHAT_TEMPLATE_MISTRAL_V7,
17
+ LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
17
18
  LLM_CHAT_TEMPLATE_PHI_3,
18
19
  LLM_CHAT_TEMPLATE_PHI_4,
19
20
  LLM_CHAT_TEMPLATE_FALCON_3,
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -133,6 +134,32 @@ struct llama_context {
133
134
  llama_perf_context_data perf_get_data() const;
134
135
  void perf_reset();
135
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ lm_ggml_opt_dataset_t dataset,
145
+ lm_ggml_opt_result_t result_train,
146
+ lm_ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ lm_ggml_opt_epoch_callback callback_train,
149
+ lm_ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ lm_ggml_opt_dataset_t dataset,
153
+ lm_ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ lm_ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
136
163
  private:
137
164
  //
138
165
  // output
@@ -187,9 +214,6 @@ private:
187
214
 
188
215
  std::unique_ptr<llama_memory_i> memory;
189
216
 
190
- // TODO: remove
191
- bool logits_all = false;
192
-
193
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
194
218
  size_t logits_size = 0; // capacity (of floats) for logits
195
219
  float * logits = nullptr;
@@ -215,6 +239,9 @@ private:
215
239
 
216
240
  lm_ggml_context_ptr ctx_compute;
217
241
 
242
+ // training
243
+ lm_ggml_opt_context_t opt_ctx = nullptr;
244
+
218
245
  lm_ggml_threadpool_t threadpool = nullptr;
219
246
  lm_ggml_threadpool_t threadpool_batch = nullptr;
220
247
 
@@ -30,6 +30,7 @@ struct llama_cparams {
30
30
  bool flash_attn;
31
31
  bool no_perf;
32
32
  bool warmup;
33
+ bool op_offload;
33
34
 
34
35
  enum llama_pooling_type pooling_type;
35
36