cactus-react-native 0.0.1 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. package/LICENSE.txt +20 -0
  2. package/README.md +3 -1
  3. package/android/src/main/CMakeLists.txt +58 -23
  4. package/android/src/main/java/com/cactus/Cactus.java +484 -16
  5. package/android/src/main/java/com/cactus/LlamaContext.java +199 -0
  6. package/android/src/main/jni.cpp +325 -10
  7. package/android/src/main/jniLibs/arm64-v8a/libcactus.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libcactus.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libcactus_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/cactus/CactusModule.java +79 -7
  16. package/android/src/oldarch/java/com/cactus/CactusModule.java +70 -0
  17. package/cactus-react-native.podspec +0 -3
  18. package/ios/CMakeLists.txt +58 -36
  19. package/ios/Cactus.mm +243 -2
  20. package/ios/CactusContext.h +22 -0
  21. package/ios/CactusContext.mm +176 -1
  22. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +92 -5
  23. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +268 -0
  24. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/chat.h +2 -0
  25. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/common.h +42 -51
  26. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-backend.h +4 -4
  27. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-common.h +12 -6
  28. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpp.h +1 -1
  29. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu.h +5 -0
  30. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-impl.h +52 -18
  31. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  32. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-opt.h +49 -28
  33. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml.h +87 -106
  34. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-arch.h +16 -0
  35. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-batch.h +2 -1
  36. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-chat.h +7 -2
  37. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-context.h +44 -33
  38. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-cparams.h +1 -0
  39. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-graph.h +83 -17
  40. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-hparams.h +44 -2
  41. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-kv-cache.h +407 -179
  42. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-memory.h +13 -2
  43. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model-loader.h +5 -3
  44. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model-saver.h +37 -0
  45. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model.h +24 -2
  46. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-vocab.h +6 -0
  47. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama.h +102 -142
  48. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  49. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/minja/minja.hpp +186 -127
  50. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
  51. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  52. package/ios/cactus.xcframework/ios-arm64/cactus.framework/ggml-llama.metallib +0 -0
  53. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus.h +92 -5
  54. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +268 -0
  55. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/chat.h +2 -0
  56. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/common.h +42 -51
  57. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-backend.h +4 -4
  58. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-common.h +12 -6
  59. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpp.h +1 -1
  60. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu.h +5 -0
  61. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-impl.h +52 -18
  62. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  63. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-opt.h +49 -28
  64. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml.h +87 -106
  65. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-arch.h +16 -0
  66. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-batch.h +2 -1
  67. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-chat.h +7 -2
  68. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h +44 -33
  69. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-cparams.h +1 -0
  70. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h +83 -17
  71. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h +44 -2
  72. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-kv-cache.h +407 -179
  73. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-memory.h +13 -2
  74. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-loader.h +5 -3
  75. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-saver.h +37 -0
  76. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model.h +24 -2
  77. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-vocab.h +6 -0
  78. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama.h +102 -142
  79. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  80. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/minja/minja.hpp +186 -127
  81. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Info.plist +0 -0
  82. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  83. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  84. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/ggml-llama-sim.metallib +0 -0
  85. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus.h +92 -5
  86. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus_ffi.h +268 -0
  87. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/chat.h +2 -0
  88. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/common.h +42 -51
  89. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-backend.h +4 -4
  90. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-common.h +12 -6
  91. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpp.h +1 -1
  92. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu.h +5 -0
  93. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-impl.h +52 -18
  94. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  95. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-opt.h +49 -28
  96. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml.h +87 -106
  97. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-arch.h +16 -0
  98. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-batch.h +2 -1
  99. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-chat.h +7 -2
  100. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-context.h +44 -33
  101. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-cparams.h +1 -0
  102. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-graph.h +83 -17
  103. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-hparams.h +44 -2
  104. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-kv-cache.h +407 -179
  105. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-memory.h +13 -2
  106. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model-loader.h +5 -3
  107. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model-saver.h +37 -0
  108. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model.h +24 -2
  109. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-vocab.h +6 -0
  110. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama.h +102 -142
  111. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  112. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/minja/minja.hpp +186 -127
  113. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Info.plist +0 -0
  114. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/cactus +0 -0
  115. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/ggml-llama.metallib +0 -0
  116. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus.h +92 -5
  117. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +268 -0
  118. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/chat.h +2 -0
  119. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/common.h +42 -51
  120. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-backend.h +4 -4
  121. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-common.h +12 -6
  122. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-impl.h +52 -18
  125. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  126. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-opt.h +49 -28
  127. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml.h +87 -106
  128. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-arch.h +16 -0
  129. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-batch.h +2 -1
  130. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-chat.h +7 -2
  131. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h +44 -33
  132. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-cparams.h +1 -0
  133. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h +83 -17
  134. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h +44 -2
  135. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-kv-cache.h +407 -179
  136. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-memory.h +13 -2
  137. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-loader.h +5 -3
  138. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-saver.h +37 -0
  139. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model.h +24 -2
  140. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-vocab.h +6 -0
  141. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama.h +102 -142
  142. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  143. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/minja/minja.hpp +186 -127
  144. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Info.plist +0 -0
  145. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  146. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  147. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/ggml-llama-sim.metallib +0 -0
  148. package/lib/commonjs/NativeCactus.js +1 -0
  149. package/lib/commonjs/NativeCactus.js.map +1 -1
  150. package/lib/commonjs/index.js +112 -0
  151. package/lib/commonjs/index.js.map +1 -1
  152. package/lib/commonjs/tools.js +118 -0
  153. package/lib/commonjs/tools.js.map +1 -0
  154. package/lib/module/NativeCactus.js +3 -0
  155. package/lib/module/NativeCactus.js.map +1 -1
  156. package/lib/module/index.js +87 -1
  157. package/lib/module/index.js.map +1 -1
  158. package/lib/module/tools.js +110 -0
  159. package/lib/module/tools.js.map +1 -0
  160. package/lib/typescript/NativeCactus.d.ts +30 -1
  161. package/lib/typescript/NativeCactus.d.ts.map +1 -1
  162. package/lib/typescript/index.d.ts +21 -2
  163. package/lib/typescript/index.d.ts.map +1 -1
  164. package/lib/typescript/tools.d.ts +38 -0
  165. package/lib/typescript/tools.d.ts.map +1 -0
  166. package/package.json +6 -3
  167. package/src/NativeCactus.ts +62 -1
  168. package/src/index.ts +113 -2
  169. package/src/tools.ts +127 -0
  170. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  171. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  172. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  173. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  174. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/sgemm.h +0 -14
  175. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  176. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  177. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  178. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  179. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/sgemm.h +0 -14
  180. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  181. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  182. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  183. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  184. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/sgemm.h +0 -14
  185. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  186. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  187. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  188. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  189. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/sgemm.h +0 -14
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -27,7 +28,12 @@ struct llama_context {
27
28
 
28
29
  void synchronize();
29
30
 
30
- const llama_model & get_model() const;
31
+ const llama_model & get_model() const;
32
+ const llama_cparams & get_cparams() const;
33
+
34
+ lm_ggml_backend_sched_t get_sched() const;
35
+
36
+ lm_ggml_context * get_ctx_compute() const;
31
37
 
32
38
  uint32_t n_ctx() const;
33
39
  uint32_t n_ctx_per_seq() const;
@@ -128,6 +134,32 @@ struct llama_context {
128
134
  llama_perf_context_data perf_get_data() const;
129
135
  void perf_reset();
130
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ lm_ggml_opt_dataset_t dataset,
145
+ lm_ggml_opt_result_t result_train,
146
+ lm_ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ lm_ggml_opt_epoch_callback callback_train,
149
+ lm_ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ lm_ggml_opt_dataset_t dataset,
153
+ lm_ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ lm_ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
131
163
  private:
132
164
  //
133
165
  // output
@@ -137,50 +169,30 @@ private:
137
169
  // Returns max number of outputs for which space was reserved.
138
170
  int32_t output_reserve(int32_t n_outputs);
139
171
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
172
  //
145
173
  // graph
146
174
  //
147
175
 
176
+ public:
148
177
  int32_t graph_max_nodes() const;
149
178
 
150
179
  // zero-out inputs and create the ctx_compute for the compute graph
151
180
  lm_ggml_cgraph * graph_init();
152
181
 
182
+ // returns the result of lm_ggml_backend_sched_graph_compute_async execution
183
+ lm_ggml_status graph_compute(
184
+ lm_ggml_cgraph * gf,
185
+ bool batched);
186
+
187
+ private:
153
188
  llm_graph_result_ptr graph_build(
154
189
  lm_ggml_context * ctx,
155
190
  lm_ggml_cgraph * gf,
156
191
  const llama_ubatch & ubatch,
157
192
  llm_graph_type gtype);
158
193
 
159
- // returns the result of lm_ggml_backend_sched_graph_compute_async execution
160
- lm_ggml_status graph_compute(
161
- lm_ggml_cgraph * gf,
162
- bool batched);
163
-
164
194
  llm_graph_cb graph_get_cb() const;
165
195
 
166
- // used by kv_self_update()
167
- lm_ggml_tensor * build_rope_shift(
168
- lm_ggml_context * ctx0,
169
- lm_ggml_tensor * cur,
170
- lm_ggml_tensor * shift,
171
- lm_ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale,
174
- lm_ggml_backend_buffer * bbuf) const;
175
-
176
- llm_graph_result_ptr build_kv_self_shift(
177
- lm_ggml_context * ctx0,
178
- lm_ggml_cgraph * gf) const;
179
-
180
- llm_graph_result_ptr build_kv_self_defrag(
181
- lm_ggml_context * ctx0,
182
- lm_ggml_cgraph * gf) const;
183
-
184
196
  // TODO: read/write lora adapters and cvec
185
197
  size_t state_write_data(llama_io_write_i & io);
186
198
  size_t state_read_data (llama_io_read_i & io);
@@ -197,14 +209,10 @@ private:
197
209
  llama_cparams cparams;
198
210
  llama_adapter_cvec cvec;
199
211
  llama_adapter_loras loras;
200
- llama_sbatch sbatch;
201
212
 
202
213
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203
214
 
204
- std::unique_ptr<llama_kv_cache_unified> kv_self;
205
-
206
- // TODO: remove
207
- bool logits_all = false;
215
+ std::unique_ptr<llama_memory_i> memory;
208
216
 
209
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
210
218
  size_t logits_size = 0; // capacity (of floats) for logits
@@ -231,6 +239,9 @@ private:
231
239
 
232
240
  lm_ggml_context_ptr ctx_compute;
233
241
 
242
+ // training
243
+ lm_ggml_opt_context_t opt_ctx = nullptr;
244
+
234
245
  lm_ggml_threadpool_t threadpool = nullptr;
235
246
  lm_ggml_threadpool_t threadpool_batch = nullptr;
236
247
 
@@ -30,6 +30,7 @@ struct llama_cparams {
30
30
  bool flash_attn;
31
31
  bool no_perf;
32
32
  bool warmup;
33
+ bool op_offload;
33
34
 
34
35
  enum llama_pooling_type pooling_type;
35
36
 
@@ -19,6 +19,8 @@ struct llama_cparams;
19
19
 
20
20
  class llama_memory_i;
21
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_unified_iswa;
23
+ class llama_kv_cache_recurrent;
22
24
 
23
25
  // certain models (typically multi-modal) can produce different types of graphs
24
26
  enum llm_graph_type {
@@ -90,14 +92,29 @@ public:
90
92
 
91
93
  class llm_graph_input_pos : public llm_graph_input_i {
92
94
  public:
93
- llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
95
+ llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
94
96
  virtual ~llm_graph_input_pos() = default;
95
97
 
96
98
  void set_input(const llama_ubatch * ubatch) override;
97
99
 
98
100
  lm_ggml_tensor * pos = nullptr; // I32 [n_batch]
99
101
 
100
- const int64_t n_pos_per_token = 1;
102
+ const int64_t n_pos_per_embd = 1;
103
+ };
104
+
105
+ // temperature tuning, used by llama4
106
+ class llm_graph_input_attn_temp : public llm_graph_input_i {
107
+ public:
108
+ llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
109
+ : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
110
+ virtual ~llm_graph_input_attn_temp() = default;
111
+
112
+ void set_input(const llama_ubatch * ubatch) override;
113
+
114
+ lm_ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
115
+
116
+ const uint32_t n_attn_temp_floor_scale;
117
+ const float f_attn_temp_scale;
101
118
  };
102
119
 
103
120
  class llm_graph_input_pos_bucket : public llm_graph_input_i {
@@ -171,26 +188,26 @@ public:
171
188
 
172
189
  class llm_graph_input_s_copy : public llm_graph_input_i {
173
190
  public:
174
- llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
191
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
175
192
  virtual ~llm_graph_input_s_copy() = default;
176
193
 
177
194
  void set_input(const llama_ubatch * ubatch) override;
178
195
 
179
196
  lm_ggml_tensor * s_copy; // I32 [kv_size]
180
197
 
181
- const llama_kv_cache_unified * kv_self;
198
+ const llama_kv_cache_recurrent * kv_self;
182
199
  };
183
200
 
184
201
  class llm_graph_input_s_mask : public llm_graph_input_i {
185
202
  public:
186
- llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
203
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
187
204
  virtual ~llm_graph_input_s_mask() = default;
188
205
 
189
206
  void set_input(const llama_ubatch * ubatch) override;
190
207
 
191
208
  lm_ggml_tensor * s_mask; // F32 [1, n_kv]
192
209
 
193
- const llama_kv_cache_unified * kv_self;
210
+ const llama_kv_cache_recurrent * kv_self;
194
211
  };
195
212
 
196
213
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -239,6 +256,31 @@ public:
239
256
 
240
257
  void set_input(const llama_ubatch * ubatch) override;
241
258
 
259
+ lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260
+
261
+ lm_ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262
+ lm_ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
263
+
264
+ const llama_hparams & hparams;
265
+ const llama_cparams & cparams;
266
+
267
+ const llama_kv_cache_unified * kv_self;
268
+ };
269
+
270
+ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
271
+ public:
272
+ llm_graph_input_attn_kv_unified_iswa(
273
+ const llama_hparams & hparams,
274
+ const llama_cparams & cparams,
275
+ const llama_kv_cache_unified_iswa * kv_self) :
276
+ hparams(hparams),
277
+ cparams(cparams),
278
+ kv_self(kv_self) {
279
+ }
280
+ ~llm_graph_input_attn_kv_unified_iswa() = default;
281
+
282
+ void set_input(const llama_ubatch * ubatch) override;
283
+
242
284
  lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
243
285
  lm_ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
244
286
 
@@ -250,7 +292,7 @@ public:
250
292
  const llama_hparams & hparams;
251
293
  const llama_cparams & cparams;
252
294
 
253
- const llama_kv_cache_unified * kv_self;
295
+ const llama_kv_cache_unified_iswa * kv_self;
254
296
  };
255
297
 
256
298
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -282,6 +324,7 @@ class llm_graph_result_i {
282
324
  public:
283
325
  virtual ~llm_graph_result_i() = default;
284
326
 
327
+ virtual lm_ggml_tensor * get_tokens() = 0;
285
328
  virtual lm_ggml_tensor * get_logits() = 0;
286
329
  virtual lm_ggml_tensor * get_embd() = 0;
287
330
  virtual lm_ggml_tensor * get_embd_pooled() = 0;
@@ -296,6 +339,7 @@ class llm_graph_result : public llm_graph_result_i {
296
339
  public:
297
340
  virtual ~llm_graph_result() = default;
298
341
 
342
+ lm_ggml_tensor * get_tokens() override { return t_tokens; }
299
343
  lm_ggml_tensor * get_logits() override { return t_logits; }
300
344
  lm_ggml_tensor * get_embd() override { return t_embd; }
301
345
  lm_ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -312,6 +356,7 @@ public:
312
356
  }
313
357
 
314
358
  // important graph nodes
359
+ lm_ggml_tensor * t_tokens = nullptr;
315
360
  lm_ggml_tensor * t_logits = nullptr;
316
361
  lm_ggml_tensor * t_embd = nullptr;
317
362
  lm_ggml_tensor * t_embd_pooled = nullptr;
@@ -335,8 +380,8 @@ struct llm_graph_params {
335
380
  const llama_cparams & cparams;
336
381
  const llama_ubatch & ubatch;
337
382
 
338
- lm_ggml_backend_sched * sched;
339
- lm_ggml_backend * backend_cpu;
383
+ lm_ggml_backend_sched_t sched;
384
+ lm_ggml_backend_t backend_cpu;
340
385
 
341
386
  const llama_adapter_cvec * cvec;
342
387
  const llama_adapter_loras * loras;
@@ -359,7 +404,6 @@ struct llm_graph_context {
359
404
  const int64_t n_layer;
360
405
  const int64_t n_rot;
361
406
  const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
362
- const int64_t n_ctx_per_seq;
363
407
  const int64_t n_head;
364
408
  const int64_t n_head_kv;
365
409
  const int64_t n_embd_head_k;
@@ -387,9 +431,9 @@ struct llm_graph_context {
387
431
 
388
432
  lm_ggml_context * ctx0 = nullptr;
389
433
 
390
- lm_ggml_backend_sched * sched;
434
+ lm_ggml_backend_sched_t sched;
391
435
 
392
- lm_ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
436
+ lm_ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
393
437
 
394
438
  const llama_adapter_cvec * cvec;
395
439
  const llama_adapter_loras * loras;
@@ -402,7 +446,7 @@ struct llm_graph_context {
402
446
 
403
447
  llm_graph_context(const llm_graph_params & params);
404
448
 
405
- int64_t n_pos_per_token() const;
449
+ int64_t n_pos_per_embd() const;
406
450
 
407
451
  void cb(lm_ggml_tensor * cur, const char * name, int il) const;
408
452
 
@@ -470,6 +514,7 @@ struct llm_graph_context {
470
514
 
471
515
  lm_ggml_tensor * build_inp_embd(lm_ggml_tensor * tok_embd) const;
472
516
  lm_ggml_tensor * build_inp_pos() const;
517
+ lm_ggml_tensor * build_inp_attn_scale() const;
473
518
  lm_ggml_tensor * build_inp_out_ids() const;
474
519
  lm_ggml_tensor * build_inp_mean() const;
475
520
  lm_ggml_tensor * build_inp_cls() const;
@@ -487,12 +532,12 @@ struct llm_graph_context {
487
532
 
488
533
  lm_ggml_tensor * build_attn_mha(
489
534
  lm_ggml_cgraph * gf,
490
- lm_ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
491
- lm_ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
492
- lm_ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
535
+ lm_ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
536
+ lm_ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
537
+ lm_ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
493
538
  lm_ggml_tensor * kq_b,
494
539
  lm_ggml_tensor * kq_mask,
495
- bool v_trans,
540
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
496
541
  float kq_scale) const;
497
542
 
498
543
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
@@ -506,6 +551,7 @@ struct llm_graph_context {
506
551
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
507
552
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
508
553
  lm_ggml_tensor * kq_b,
554
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
509
555
  float kq_scale,
510
556
  int il) const;
511
557
 
@@ -520,6 +566,22 @@ struct llm_graph_context {
520
566
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
521
567
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
522
568
  lm_ggml_tensor * kq_b,
569
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
570
+ float kq_scale,
571
+ int il) const;
572
+
573
+ llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
574
+
575
+ lm_ggml_tensor * build_attn(
576
+ llm_graph_input_attn_kv_unified_iswa * inp,
577
+ lm_ggml_cgraph * gf,
578
+ lm_ggml_tensor * wo,
579
+ lm_ggml_tensor * wo_b,
580
+ lm_ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
+ lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
+ lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
583
+ lm_ggml_tensor * kq_b,
584
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
523
585
  float kq_scale,
524
586
  int il) const;
525
587
 
@@ -534,6 +596,7 @@ struct llm_graph_context {
534
596
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
535
597
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
536
598
  lm_ggml_tensor * kq_b,
599
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
537
600
  float kq_scale,
538
601
  int il) const;
539
602
 
@@ -572,3 +635,6 @@ struct llm_graph_context {
572
635
  lm_ggml_tensor * cls_out,
573
636
  lm_ggml_tensor * cls_out_b) const;
574
637
  };
638
+
639
+ // TODO: better name
640
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
14
14
  LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15
15
  };
16
16
 
17
+ enum llama_swa_type {
18
+ LLAMA_SWA_TYPE_NONE = 0,
19
+ LLAMA_SWA_TYPE_STANDARD = 1,
20
+ LLAMA_SWA_TYPE_CHUNKED = 2,
21
+ };
22
+
17
23
  struct llama_hparams_posnet {
18
24
  uint32_t n_embd;
19
25
  uint32_t n_layer;
@@ -35,14 +41,16 @@ struct llama_hparams {
35
41
  uint32_t n_embd_features = 0;
36
42
  uint32_t n_layer;
37
43
  uint32_t n_rot;
38
- uint32_t n_swa = 0; // sliding window attention (SWA)
39
- uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
40
44
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
41
45
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
42
46
  uint32_t n_expert = 0;
43
47
  uint32_t n_expert_used = 0;
44
48
  uint32_t n_rel_attn_bkts = 0;
45
49
 
50
+ // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
51
+ uint32_t n_embd_head_k_mla = 0;
52
+ uint32_t n_embd_head_v_mla = 0;
53
+
46
54
  // for WavTokenizer
47
55
  struct llama_hparams_posnet posnet;
48
56
  struct llama_hparams_convnext convnext;
@@ -62,6 +70,7 @@ struct llama_hparams {
62
70
  float expert_weights_scale = 0.0;
63
71
  bool expert_weights_norm = false;
64
72
  uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
73
+ uint32_t moe_every_n_layers = 0;
65
74
 
66
75
  float f_norm_eps;
67
76
  float f_norm_rms_eps;
@@ -91,6 +100,15 @@ struct llama_hparams {
91
100
 
92
101
  std::array<int, 4> rope_sections;
93
102
 
103
+ // Sliding Window Attention (SWA)
104
+ llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105
+ // the size of the sliding window (0 - no SWA)
106
+ uint32_t n_swa = 0;
107
+ // if swa_layers[il] == true, then layer il is SWA
108
+ // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109
+ // by default, all layers are dense
110
+ std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
111
+
94
112
  // for State Space Models
95
113
  uint32_t ssm_d_conv = 0;
96
114
  uint32_t ssm_d_inner = 0;
@@ -111,6 +129,13 @@ struct llama_hparams {
111
129
  bool causal_attn = true;
112
130
  bool use_alibi = false;
113
131
  bool attn_soft_cap = false;
132
+ bool use_kq_norm = true;
133
+
134
+ // llama4
135
+ uint32_t n_moe_layer_step = 0;
136
+ uint32_t n_no_rope_layer_step = 4;
137
+ uint32_t n_attn_temp_floor_scale = 8192;
138
+ float f_attn_temp_scale = 0.1;
114
139
 
115
140
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
116
141
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
@@ -120,6 +145,23 @@ struct llama_hparams {
120
145
  enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
121
146
  enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
122
147
 
148
+ // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
149
+ // note that if n_pattern == 0, all layers are SWA
150
+ // if n_pattern == 1, all layers are dense
151
+ // example: n_pattern = 3
152
+ // il == 0: swa
153
+ // il == 1: swa
154
+ // il == 2: dense
155
+ // il == 3: swa
156
+ // il == 4: swa
157
+ // il == 5: dense
158
+ // il == 6: swa
159
+ // etc ...
160
+ void set_swa_pattern(uint32_t n_pattern);
161
+
162
+ // return true if one of the layers is SWA
163
+ bool is_swa_any() const;
164
+
123
165
  uint32_t n_head(uint32_t il = 0) const;
124
166
 
125
167
  uint32_t n_head_kv(uint32_t il = 0) const;