cactus-react-native 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. package/LICENSE.txt +20 -0
  2. package/README.md +3 -1
  3. package/android/src/main/CMakeLists.txt +60 -21
  4. package/android/src/main/java/com/cactus/Cactus.java +465 -0
  5. package/android/src/main/java/com/cactus/LlamaContext.java +199 -0
  6. package/android/src/main/jni.cpp +325 -10
  7. package/android/src/main/jniLibs/arm64-v8a/libcactus.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libcactus.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libcactus_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/cactus/CactusModule.java +79 -7
  16. package/android/src/oldarch/java/com/cactus/CactusModule.java +70 -0
  17. package/cactus-react-native.podspec +0 -3
  18. package/ios/CMakeLists.txt +56 -36
  19. package/ios/Cactus.mm +243 -2
  20. package/ios/CactusContext.h +22 -0
  21. package/ios/CactusContext.mm +176 -1
  22. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +92 -5
  23. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +229 -0
  24. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/chat.h +2 -0
  25. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/common.h +42 -51
  26. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-backend.h +4 -4
  27. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-common.h +12 -6
  28. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpp.h +1 -1
  29. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu.h +5 -0
  30. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-impl.h +52 -18
  31. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  32. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-opt.h +49 -28
  33. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml.h +87 -106
  34. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-arch.h +16 -0
  35. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-batch.h +2 -1
  36. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-chat.h +7 -2
  37. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-context.h +44 -33
  38. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-cparams.h +1 -0
  39. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-graph.h +83 -17
  40. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-hparams.h +44 -2
  41. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-kv-cache.h +407 -179
  42. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-memory.h +13 -2
  43. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model-loader.h +5 -3
  44. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model-saver.h +37 -0
  45. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model.h +24 -2
  46. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-vocab.h +6 -0
  47. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama.h +102 -142
  48. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  49. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/minja/minja.hpp +186 -127
  50. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
  51. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  52. package/ios/cactus.xcframework/ios-arm64/cactus.framework/ggml-llama.metallib +0 -0
  53. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus.h +92 -5
  54. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +229 -0
  55. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/chat.h +2 -0
  56. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/common.h +42 -51
  57. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-backend.h +4 -4
  58. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-common.h +12 -6
  59. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpp.h +1 -1
  60. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu.h +5 -0
  61. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-impl.h +52 -18
  62. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  63. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-opt.h +49 -28
  64. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml.h +87 -106
  65. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-arch.h +16 -0
  66. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-batch.h +2 -1
  67. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-chat.h +7 -2
  68. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h +44 -33
  69. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-cparams.h +1 -0
  70. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h +83 -17
  71. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h +44 -2
  72. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-kv-cache.h +407 -179
  73. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-memory.h +13 -2
  74. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-loader.h +5 -3
  75. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-saver.h +37 -0
  76. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model.h +24 -2
  77. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-vocab.h +6 -0
  78. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama.h +102 -142
  79. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  80. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/minja/minja.hpp +186 -127
  81. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Info.plist +0 -0
  82. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  83. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  84. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/ggml-llama-sim.metallib +0 -0
  85. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus.h +92 -5
  86. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus_ffi.h +229 -0
  87. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/chat.h +2 -0
  88. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/common.h +42 -51
  89. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-backend.h +4 -4
  90. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-common.h +12 -6
  91. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpp.h +1 -1
  92. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu.h +5 -0
  93. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-impl.h +52 -18
  94. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  95. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-opt.h +49 -28
  96. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml.h +87 -106
  97. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-arch.h +16 -0
  98. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-batch.h +2 -1
  99. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-chat.h +7 -2
  100. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-context.h +44 -33
  101. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-cparams.h +1 -0
  102. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-graph.h +83 -17
  103. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-hparams.h +44 -2
  104. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-kv-cache.h +407 -179
  105. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-memory.h +13 -2
  106. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model-loader.h +5 -3
  107. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model-saver.h +37 -0
  108. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model.h +24 -2
  109. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-vocab.h +6 -0
  110. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama.h +102 -142
  111. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  112. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/minja/minja.hpp +186 -127
  113. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Info.plist +0 -0
  114. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/cactus +0 -0
  115. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/ggml-llama.metallib +0 -0
  116. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus.h +92 -5
  117. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +229 -0
  118. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/chat.h +2 -0
  119. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/common.h +42 -51
  120. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-backend.h +4 -4
  121. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-common.h +12 -6
  122. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-impl.h +52 -18
  125. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-metal-impl.h +106 -14
  126. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-opt.h +49 -28
  127. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml.h +87 -106
  128. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-arch.h +16 -0
  129. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-batch.h +2 -1
  130. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-chat.h +7 -2
  131. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h +44 -33
  132. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-cparams.h +1 -0
  133. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h +83 -17
  134. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h +44 -2
  135. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-kv-cache.h +407 -179
  136. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-memory.h +13 -2
  137. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-loader.h +5 -3
  138. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-saver.h +37 -0
  139. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model.h +24 -2
  140. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-vocab.h +6 -0
  141. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama.h +102 -142
  142. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/minja/chat-template.hpp +23 -11
  143. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/minja/minja.hpp +186 -127
  144. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Info.plist +0 -0
  145. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  146. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  147. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/ggml-llama-sim.metallib +0 -0
  148. package/lib/commonjs/NativeCactus.js +1 -0
  149. package/lib/commonjs/NativeCactus.js.map +1 -1
  150. package/lib/commonjs/index.js +112 -0
  151. package/lib/commonjs/index.js.map +1 -1
  152. package/lib/commonjs/tools.js +118 -0
  153. package/lib/commonjs/tools.js.map +1 -0
  154. package/lib/module/NativeCactus.js +3 -0
  155. package/lib/module/NativeCactus.js.map +1 -1
  156. package/lib/module/index.js +87 -1
  157. package/lib/module/index.js.map +1 -1
  158. package/lib/module/tools.js +110 -0
  159. package/lib/module/tools.js.map +1 -0
  160. package/lib/typescript/NativeCactus.d.ts +30 -1
  161. package/lib/typescript/NativeCactus.d.ts.map +1 -1
  162. package/lib/typescript/index.d.ts +21 -2
  163. package/lib/typescript/index.d.ts.map +1 -1
  164. package/lib/typescript/tools.d.ts +38 -0
  165. package/lib/typescript/tools.d.ts.map +1 -0
  166. package/package.json +6 -3
  167. package/src/NativeCactus.ts +62 -1
  168. package/src/index.ts +113 -2
  169. package/src/tools.ts +127 -0
  170. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  171. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  172. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  173. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  174. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/sgemm.h +0 -14
  175. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  176. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  177. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  178. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  179. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/sgemm.h +0 -14
  180. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  181. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  182. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  183. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  184. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/sgemm.h +0 -14
  185. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
  186. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
  187. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
  188. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
  189. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/sgemm.h +0 -14
@@ -393,8 +393,8 @@ extern "C" {
393
393
 
394
394
  // precision
395
395
  enum lm_ggml_prec {
396
- LM_GGML_PREC_DEFAULT,
397
- LM_GGML_PREC_F32,
396
+ LM_GGML_PREC_DEFAULT = 0, // stored as lm_ggml_tensor.op_params, 0 by default
397
+ LM_GGML_PREC_F32 = 10,
398
398
  };
399
399
 
400
400
  // model file types
@@ -481,6 +481,7 @@ extern "C" {
481
481
  LM_GGML_OP_CONV_TRANSPOSE_1D,
482
482
  LM_GGML_OP_IM2COL,
483
483
  LM_GGML_OP_IM2COL_BACK,
484
+ LM_GGML_OP_CONV_2D_DW,
484
485
  LM_GGML_OP_CONV_TRANSPOSE_2D,
485
486
  LM_GGML_OP_POOL_1D,
486
487
  LM_GGML_OP_POOL_2D,
@@ -507,17 +508,12 @@ extern "C" {
507
508
 
508
509
  LM_GGML_OP_UNARY,
509
510
 
510
- LM_GGML_OP_MAP_UNARY,
511
- LM_GGML_OP_MAP_BINARY,
512
-
513
- LM_GGML_OP_MAP_CUSTOM1_F32,
514
- LM_GGML_OP_MAP_CUSTOM2_F32,
515
- LM_GGML_OP_MAP_CUSTOM3_F32,
516
-
517
511
  LM_GGML_OP_MAP_CUSTOM1,
518
512
  LM_GGML_OP_MAP_CUSTOM2,
519
513
  LM_GGML_OP_MAP_CUSTOM3,
520
514
 
515
+ LM_GGML_OP_CUSTOM,
516
+
521
517
  LM_GGML_OP_CROSS_ENTROPY_LOSS,
522
518
  LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
523
519
  LM_GGML_OP_OPT_STEP_ADAMW,
@@ -540,6 +536,7 @@ extern "C" {
540
536
  LM_GGML_UNARY_OP_HARDSWISH,
541
537
  LM_GGML_UNARY_OP_HARDSIGMOID,
542
538
  LM_GGML_UNARY_OP_EXP,
539
+ LM_GGML_UNARY_OP_GELU_ERF,
543
540
 
544
541
  LM_GGML_UNARY_OP_COUNT,
545
542
  };
@@ -677,11 +674,18 @@ extern "C" {
677
674
  LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor);
678
675
  LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars
679
676
 
677
+ // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
680
678
  LM_GGML_API bool lm_ggml_is_contiguous (const struct lm_ggml_tensor * tensor);
681
679
  LM_GGML_API bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor); // same as lm_ggml_is_contiguous()
682
680
  LM_GGML_API bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1
683
681
  LM_GGML_API bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2
684
682
 
683
+ // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
684
+ LM_GGML_API bool lm_ggml_is_contiguously_allocated(const struct lm_ggml_tensor * tensor);
685
+
686
+ // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
687
+ LM_GGML_API bool lm_ggml_is_contiguous_channels(const struct lm_ggml_tensor * tensor);
688
+
685
689
  LM_GGML_API bool lm_ggml_are_same_shape (const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1);
686
690
  LM_GGML_API bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1);
687
691
 
@@ -765,7 +769,7 @@ extern "C" {
765
769
  // Tensor flags
766
770
  LM_GGML_API void lm_ggml_set_input(struct lm_ggml_tensor * tensor);
767
771
  LM_GGML_API void lm_ggml_set_output(struct lm_ggml_tensor * tensor);
768
- LM_GGML_API void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor);
772
+ LM_GGML_API void lm_ggml_set_param(struct lm_ggml_tensor * tensor);
769
773
  LM_GGML_API void lm_ggml_set_loss(struct lm_ggml_tensor * tensor);
770
774
 
771
775
  //
@@ -935,7 +939,7 @@ extern "C" {
935
939
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_repeat_back(
936
940
  struct lm_ggml_context * ctx,
937
941
  struct lm_ggml_tensor * a,
938
- struct lm_ggml_tensor * b);
942
+ struct lm_ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
939
943
 
940
944
  // concat a and b along dim
941
945
  // used in stable-diffusion
@@ -1021,6 +1025,16 @@ extern "C" {
1021
1025
  struct lm_ggml_context * ctx,
1022
1026
  struct lm_ggml_tensor * a);
1023
1027
 
1028
+ // GELU using erf (error function) when possible
1029
+ // some backends may fallback to approximation based on Abramowitz and Stegun formula
1030
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_erf(
1031
+ struct lm_ggml_context * ctx,
1032
+ struct lm_ggml_tensor * a);
1033
+
1034
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_erf_inplace(
1035
+ struct lm_ggml_context * ctx,
1036
+ struct lm_ggml_tensor * a);
1037
+
1024
1038
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_quick(
1025
1039
  struct lm_ggml_context * ctx,
1026
1040
  struct lm_ggml_tensor * a);
@@ -1665,7 +1679,7 @@ extern "C" {
1665
1679
  struct lm_ggml_tensor * a,
1666
1680
  struct lm_ggml_tensor * b);
1667
1681
 
1668
- // depthwise
1682
+ // depthwise (via im2col and mul_mat)
1669
1683
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_2d_dw(
1670
1684
  struct lm_ggml_context * ctx,
1671
1685
  struct lm_ggml_tensor * a, // convolution kernel
@@ -1677,6 +1691,22 @@ extern "C" {
1677
1691
  int d0, // dilation dimension 0
1678
1692
  int d1); // dilation dimension 1
1679
1693
 
1694
+ // Depthwise 2D convolution
1695
+ // may be faster than lm_ggml_conv_2d_dw, but not available in all backends
1696
+ // a: KW KH 1 C convolution kernel
1697
+ // b: W H C N input data
1698
+ // res: W_out H_out C N
1699
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_2d_dw_direct(
1700
+ struct lm_ggml_context * ctx,
1701
+ struct lm_ggml_tensor * a,
1702
+ struct lm_ggml_tensor * b,
1703
+ int stride0,
1704
+ int stride1,
1705
+ int pad0,
1706
+ int pad1,
1707
+ int dilation0,
1708
+ int dilation1);
1709
+
1680
1710
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0(
1681
1711
  struct lm_ggml_context * ctx,
1682
1712
  struct lm_ggml_tensor * a,
@@ -1722,24 +1752,29 @@ extern "C" {
1722
1752
  float p0,
1723
1753
  float p1);
1724
1754
 
1725
- // nearest interpolate
1755
+ enum lm_ggml_scale_mode {
1756
+ LM_GGML_SCALE_MODE_NEAREST = 0,
1757
+ LM_GGML_SCALE_MODE_BILINEAR = 1,
1758
+ };
1759
+
1760
+ // interpolate
1726
1761
  // multiplies ne0 and ne1 by scale factor
1727
- // used in stable-diffusion
1728
1762
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_upscale(
1729
1763
  struct lm_ggml_context * ctx,
1730
1764
  struct lm_ggml_tensor * a,
1731
- int scale_factor);
1765
+ int scale_factor,
1766
+ enum lm_ggml_scale_mode mode);
1732
1767
 
1733
- // nearest interpolate
1734
- // nearest interpolate to specified dimensions
1735
- // used in tortoise.cpp
1768
+ // interpolate
1769
+ // interpolate scale to specified dimensions
1736
1770
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_upscale_ext(
1737
1771
  struct lm_ggml_context * ctx,
1738
1772
  struct lm_ggml_tensor * a,
1739
1773
  int ne0,
1740
1774
  int ne1,
1741
1775
  int ne2,
1742
- int ne3);
1776
+ int ne3,
1777
+ enum lm_ggml_scale_mode mode);
1743
1778
 
1744
1779
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1745
1780
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_pad(
@@ -1791,11 +1826,11 @@ extern "C" {
1791
1826
 
1792
1827
  #define LM_GGML_KQ_MASK_PAD 64
1793
1828
 
1794
- // q: [n_embd, n_batch, n_head, 1]
1795
- // k: [n_embd, n_kv, n_head_kv, 1]
1796
- // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1797
- // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = LM_GGML_PAD(n_batch, LM_GGML_KQ_MASK_PAD) !!
1798
- // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1829
+ // q: [n_embd_k, n_batch, n_head, 1]
1830
+ // k: [n_embd_k, n_kv, n_head_kv, 1]
1831
+ // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1832
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = LM_GGML_PAD(n_batch, LM_GGML_KQ_MASK_PAD) !!
1833
+ // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
1799
1834
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_flash_attn_ext(
1800
1835
  struct lm_ggml_context * ctx,
1801
1836
  struct lm_ggml_tensor * q,
@@ -1916,83 +1951,6 @@ extern "C" {
1916
1951
 
1917
1952
  // custom operators
1918
1953
 
1919
- typedef void (*lm_ggml_unary_op_f32_t) (const int, float *, const float *);
1920
- typedef void (*lm_ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
1921
-
1922
- typedef void (*lm_ggml_custom1_op_f32_t)(struct lm_ggml_tensor *, const struct lm_ggml_tensor *);
1923
- typedef void (*lm_ggml_custom2_op_f32_t)(struct lm_ggml_tensor *, const struct lm_ggml_tensor *, const struct lm_ggml_tensor *);
1924
- typedef void (*lm_ggml_custom3_op_f32_t)(struct lm_ggml_tensor *, const struct lm_ggml_tensor *, const struct lm_ggml_tensor *, const struct lm_ggml_tensor *);
1925
-
1926
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_unary_f32(
1927
- struct lm_ggml_context * ctx,
1928
- struct lm_ggml_tensor * a,
1929
- lm_ggml_unary_op_f32_t fun),
1930
- "use lm_ggml_map_custom1 instead");
1931
-
1932
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32(
1933
- struct lm_ggml_context * ctx,
1934
- struct lm_ggml_tensor * a,
1935
- lm_ggml_unary_op_f32_t fun),
1936
- "use lm_ggml_map_custom1_inplace instead");
1937
-
1938
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_binary_f32(
1939
- struct lm_ggml_context * ctx,
1940
- struct lm_ggml_tensor * a,
1941
- struct lm_ggml_tensor * b,
1942
- lm_ggml_binary_op_f32_t fun),
1943
- "use lm_ggml_map_custom2 instead");
1944
-
1945
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32(
1946
- struct lm_ggml_context * ctx,
1947
- struct lm_ggml_tensor * a,
1948
- struct lm_ggml_tensor * b,
1949
- lm_ggml_binary_op_f32_t fun),
1950
- "use lm_ggml_map_custom2_inplace instead");
1951
-
1952
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom1_f32(
1953
- struct lm_ggml_context * ctx,
1954
- struct lm_ggml_tensor * a,
1955
- lm_ggml_custom1_op_f32_t fun),
1956
- "use lm_ggml_map_custom1 instead");
1957
-
1958
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom1_inplace_f32(
1959
- struct lm_ggml_context * ctx,
1960
- struct lm_ggml_tensor * a,
1961
- lm_ggml_custom1_op_f32_t fun),
1962
- "use lm_ggml_map_custom1_inplace instead");
1963
-
1964
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom2_f32(
1965
- struct lm_ggml_context * ctx,
1966
- struct lm_ggml_tensor * a,
1967
- struct lm_ggml_tensor * b,
1968
- lm_ggml_custom2_op_f32_t fun),
1969
- "use lm_ggml_map_custom2 instead");
1970
-
1971
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom2_inplace_f32(
1972
- struct lm_ggml_context * ctx,
1973
- struct lm_ggml_tensor * a,
1974
- struct lm_ggml_tensor * b,
1975
- lm_ggml_custom2_op_f32_t fun),
1976
- "use lm_ggml_map_custom2_inplace instead");
1977
-
1978
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom3_f32(
1979
- struct lm_ggml_context * ctx,
1980
- struct lm_ggml_tensor * a,
1981
- struct lm_ggml_tensor * b,
1982
- struct lm_ggml_tensor * c,
1983
- lm_ggml_custom3_op_f32_t fun),
1984
- "use lm_ggml_map_custom3 instead");
1985
-
1986
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32(
1987
- struct lm_ggml_context * ctx,
1988
- struct lm_ggml_tensor * a,
1989
- struct lm_ggml_tensor * b,
1990
- struct lm_ggml_tensor * c,
1991
- lm_ggml_custom3_op_f32_t fun),
1992
- "use lm_ggml_map_custom3_inplace instead");
1993
-
1994
- // custom operators v2
1995
-
1996
1954
  typedef void (*lm_ggml_custom1_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, int ith, int nth, void * userdata);
1997
1955
  typedef void (*lm_ggml_custom2_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, const struct lm_ggml_tensor * b, int ith, int nth, void * userdata);
1998
1956
  typedef void (*lm_ggml_custom3_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, const struct lm_ggml_tensor * b, const struct lm_ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2048,6 +2006,30 @@ extern "C" {
2048
2006
  int n_tasks,
2049
2007
  void * userdata);
2050
2008
 
2009
+ typedef void (*lm_ggml_custom_op_t)(struct lm_ggml_tensor * dst , int ith, int nth, void * userdata);
2010
+
2011
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_custom_4d(
2012
+ struct lm_ggml_context * ctx,
2013
+ enum lm_ggml_type type,
2014
+ int64_t ne0,
2015
+ int64_t ne1,
2016
+ int64_t ne2,
2017
+ int64_t ne3,
2018
+ struct lm_ggml_tensor ** args,
2019
+ int n_args,
2020
+ lm_ggml_custom_op_t fun,
2021
+ int n_tasks,
2022
+ void * userdata);
2023
+
2024
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_custom_inplace(
2025
+ struct lm_ggml_context * ctx,
2026
+ struct lm_ggml_tensor * a,
2027
+ struct lm_ggml_tensor ** args,
2028
+ int n_args,
2029
+ lm_ggml_custom_op_t fun,
2030
+ int n_tasks,
2031
+ void * userdata);
2032
+
2051
2033
  // loss function
2052
2034
 
2053
2035
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_cross_entropy_loss(
@@ -2078,15 +2060,14 @@ extern "C" {
2078
2060
 
2079
2061
  LM_GGML_API void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor);
2080
2062
  LM_GGML_API void lm_ggml_build_backward_expand(
2081
- struct lm_ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2082
- struct lm_ggml_context * ctx_compute, // context for gradient computation
2083
- struct lm_ggml_cgraph * cgraph,
2084
- bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
2063
+ struct lm_ggml_context * ctx, // context for gradient computation
2064
+ struct lm_ggml_cgraph * cgraph,
2065
+ struct lm_ggml_tensor ** grad_accs);
2085
2066
 
2086
2067
  // graph allocation in a context
2087
2068
  LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph (struct lm_ggml_context * ctx); // size = LM_GGML_DEFAULT_GRAPH_SIZE, grads = false
2088
2069
  LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, size_t size, bool grads);
2089
- LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph);
2070
+ LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, bool force_grads);
2090
2071
  LM_GGML_API void lm_ggml_graph_cpy (struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst);
2091
2072
  LM_GGML_API void lm_ggml_graph_reset (struct lm_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2092
2073
  LM_GGML_API void lm_ggml_graph_clear (struct lm_ggml_cgraph * cgraph);
@@ -10,6 +10,7 @@
10
10
 
11
11
  enum llm_arch {
12
12
  LLM_ARCH_LLAMA,
13
+ LLM_ARCH_LLAMA4,
13
14
  LLM_ARCH_DECI,
14
15
  LLM_ARCH_FALCON,
15
16
  LLM_ARCH_BAICHUAN,
@@ -22,6 +23,7 @@ enum llm_arch {
22
23
  LLM_ARCH_REFACT,
23
24
  LLM_ARCH_BERT,
24
25
  LLM_ARCH_NOMIC_BERT,
26
+ LLM_ARCH_NOMIC_BERT_MOE,
25
27
  LLM_ARCH_JINA_BERT_V2,
26
28
  LLM_ARCH_BLOOM,
27
29
  LLM_ARCH_STABLELM,
@@ -29,6 +31,8 @@ enum llm_arch {
29
31
  LLM_ARCH_QWEN2,
30
32
  LLM_ARCH_QWEN2MOE,
31
33
  LLM_ARCH_QWEN2VL,
34
+ LLM_ARCH_QWEN3,
35
+ LLM_ARCH_QWEN3MOE,
32
36
  LLM_ARCH_PHI2,
33
37
  LLM_ARCH_PHI3,
34
38
  LLM_ARCH_PHIMOE,
@@ -55,6 +59,7 @@ enum llm_arch {
55
59
  LLM_ARCH_DEEPSEEK,
56
60
  LLM_ARCH_DEEPSEEK2,
57
61
  LLM_ARCH_CHATGLM,
62
+ LLM_ARCH_GLM4,
58
63
  LLM_ARCH_BITNET,
59
64
  LLM_ARCH_T5,
60
65
  LLM_ARCH_T5ENCODER,
@@ -69,6 +74,8 @@ enum llm_arch {
69
74
  LLM_ARCH_GRANITE_MOE,
70
75
  LLM_ARCH_CHAMELEON,
71
76
  LLM_ARCH_WAVTOKENIZER_DEC,
77
+ LLM_ARCH_PLM,
78
+ LLM_ARCH_BAILINGMOE,
72
79
  LLM_ARCH_UNKNOWN,
73
80
  };
74
81
 
@@ -77,6 +84,7 @@ enum llm_kv {
77
84
  LLM_KV_GENERAL_ARCHITECTURE,
78
85
  LLM_KV_GENERAL_QUANTIZATION_VERSION,
79
86
  LLM_KV_GENERAL_ALIGNMENT,
87
+ LLM_KV_GENERAL_FILE_TYPE,
80
88
  LLM_KV_GENERAL_NAME,
81
89
  LLM_KV_GENERAL_AUTHOR,
82
90
  LLM_KV_GENERAL_VERSION,
@@ -103,6 +111,7 @@ enum llm_kv {
103
111
  LLM_KV_EXPERT_WEIGHTS_SCALE,
104
112
  LLM_KV_EXPERT_WEIGHTS_NORM,
105
113
  LLM_KV_EXPERT_GATING_FUNC,
114
+ LLM_KV_MOE_EVERY_N_LAYERS,
106
115
  LLM_KV_POOLING_TYPE,
107
116
  LLM_KV_LOGIT_SCALE,
108
117
  LLM_KV_DECODER_START_TOKEN_ID,
@@ -115,6 +124,7 @@ enum llm_kv {
115
124
  LLM_KV_RESIDUAL_SCALE,
116
125
  LLM_KV_EMBEDDING_SCALE,
117
126
  LLM_KV_TOKEN_SHIFT_COUNT,
127
+ LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
118
128
 
119
129
  LLM_KV_ATTENTION_HEAD_COUNT,
120
130
  LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -136,6 +146,8 @@ enum llm_kv {
136
146
  LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
137
147
  LLM_KV_ATTENTION_SLIDING_WINDOW,
138
148
  LLM_KV_ATTENTION_SCALE,
149
+ LLM_KV_ATTENTION_KEY_LENGTH_MLA,
150
+ LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
139
151
 
140
152
  LLM_KV_ROPE_DIMENSION_COUNT,
141
153
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -249,6 +261,8 @@ enum llm_tensor {
249
261
  LLM_TENSOR_ATTN_Q_NORM,
250
262
  LLM_TENSOR_ATTN_K_NORM,
251
263
  LLM_TENSOR_LAYER_OUT_NORM,
264
+ LLM_TENSOR_POST_ATTN_NORM,
265
+ LLM_TENSOR_POST_MLP_NORM,
252
266
  LLM_TENSOR_SSM_IN,
253
267
  LLM_TENSOR_SSM_CONV1D,
254
268
  LLM_TENSOR_SSM_X,
@@ -296,6 +310,8 @@ enum llm_tensor {
296
310
  LLM_TENSOR_ATTN_Q_B,
297
311
  LLM_TENSOR_ATTN_KV_A_MQA,
298
312
  LLM_TENSOR_ATTN_KV_B,
313
+ LLM_TENSOR_ATTN_K_B,
314
+ LLM_TENSOR_ATTN_V_B,
299
315
  LLM_TENSOR_ATTN_Q_A_NORM,
300
316
  LLM_TENSOR_ATTN_KV_A_NORM,
301
317
  LLM_TENSOR_ATTN_SUB_NORM,
@@ -70,7 +70,8 @@ struct llama_sbatch {
70
70
  // sequence-wise split
71
71
  llama_ubatch split_seq(size_t n_ubatch);
72
72
 
73
- void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
73
+ llama_sbatch() = default;
74
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
74
75
  };
75
76
 
76
77
  // temporary allocate memory for the input batch if needed
@@ -14,6 +14,7 @@ enum llm_chat_template {
14
14
  LLM_CHAT_TEMPLATE_MISTRAL_V3,
15
15
  LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
16
16
  LLM_CHAT_TEMPLATE_MISTRAL_V7,
17
+ LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
17
18
  LLM_CHAT_TEMPLATE_PHI_3,
18
19
  LLM_CHAT_TEMPLATE_PHI_4,
19
20
  LLM_CHAT_TEMPLATE_FALCON_3,
@@ -29,8 +30,8 @@ enum llm_chat_template {
29
30
  LLM_CHAT_TEMPLATE_DEEPSEEK_3,
30
31
  LLM_CHAT_TEMPLATE_COMMAND_R,
31
32
  LLM_CHAT_TEMPLATE_LLAMA_3,
32
- LLM_CHAT_TEMPLATE_CHATGML_3,
33
- LLM_CHAT_TEMPLATE_CHATGML_4,
33
+ LLM_CHAT_TEMPLATE_CHATGLM_3,
34
+ LLM_CHAT_TEMPLATE_CHATGLM_4,
34
35
  LLM_CHAT_TEMPLATE_GLMEDGE,
35
36
  LLM_CHAT_TEMPLATE_MINICPM,
36
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
@@ -38,6 +39,10 @@ enum llm_chat_template {
38
39
  LLM_CHAT_TEMPLATE_GRANITE,
39
40
  LLM_CHAT_TEMPLATE_GIGACHAT,
40
41
  LLM_CHAT_TEMPLATE_MEGREZ,
42
+ LLM_CHAT_TEMPLATE_YANDEX,
43
+ LLM_CHAT_TEMPLATE_BAILING,
44
+ LLM_CHAT_TEMPLATE_LLAMA4,
45
+ LLM_CHAT_TEMPLATE_SMOLVLM,
41
46
  LLM_CHAT_TEMPLATE_UNKNOWN,
42
47
  };
43
48
 
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -27,7 +28,12 @@ struct llama_context {
27
28
 
28
29
  void synchronize();
29
30
 
30
- const llama_model & get_model() const;
31
+ const llama_model & get_model() const;
32
+ const llama_cparams & get_cparams() const;
33
+
34
+ lm_ggml_backend_sched_t get_sched() const;
35
+
36
+ lm_ggml_context * get_ctx_compute() const;
31
37
 
32
38
  uint32_t n_ctx() const;
33
39
  uint32_t n_ctx_per_seq() const;
@@ -128,6 +134,32 @@ struct llama_context {
128
134
  llama_perf_context_data perf_get_data() const;
129
135
  void perf_reset();
130
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ lm_ggml_opt_dataset_t dataset,
145
+ lm_ggml_opt_result_t result_train,
146
+ lm_ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ lm_ggml_opt_epoch_callback callback_train,
149
+ lm_ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ lm_ggml_opt_dataset_t dataset,
153
+ lm_ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ lm_ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
131
163
  private:
132
164
  //
133
165
  // output
@@ -137,50 +169,30 @@ private:
137
169
  // Returns max number of outputs for which space was reserved.
138
170
  int32_t output_reserve(int32_t n_outputs);
139
171
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
172
  //
145
173
  // graph
146
174
  //
147
175
 
176
+ public:
148
177
  int32_t graph_max_nodes() const;
149
178
 
150
179
  // zero-out inputs and create the ctx_compute for the compute graph
151
180
  lm_ggml_cgraph * graph_init();
152
181
 
182
+ // returns the result of lm_ggml_backend_sched_graph_compute_async execution
183
+ lm_ggml_status graph_compute(
184
+ lm_ggml_cgraph * gf,
185
+ bool batched);
186
+
187
+ private:
153
188
  llm_graph_result_ptr graph_build(
154
189
  lm_ggml_context * ctx,
155
190
  lm_ggml_cgraph * gf,
156
191
  const llama_ubatch & ubatch,
157
192
  llm_graph_type gtype);
158
193
 
159
- // returns the result of lm_ggml_backend_sched_graph_compute_async execution
160
- lm_ggml_status graph_compute(
161
- lm_ggml_cgraph * gf,
162
- bool batched);
163
-
164
194
  llm_graph_cb graph_get_cb() const;
165
195
 
166
- // used by kv_self_update()
167
- lm_ggml_tensor * build_rope_shift(
168
- lm_ggml_context * ctx0,
169
- lm_ggml_tensor * cur,
170
- lm_ggml_tensor * shift,
171
- lm_ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale,
174
- lm_ggml_backend_buffer * bbuf) const;
175
-
176
- llm_graph_result_ptr build_kv_self_shift(
177
- lm_ggml_context * ctx0,
178
- lm_ggml_cgraph * gf) const;
179
-
180
- llm_graph_result_ptr build_kv_self_defrag(
181
- lm_ggml_context * ctx0,
182
- lm_ggml_cgraph * gf) const;
183
-
184
196
  // TODO: read/write lora adapters and cvec
185
197
  size_t state_write_data(llama_io_write_i & io);
186
198
  size_t state_read_data (llama_io_read_i & io);
@@ -197,14 +209,10 @@ private:
197
209
  llama_cparams cparams;
198
210
  llama_adapter_cvec cvec;
199
211
  llama_adapter_loras loras;
200
- llama_sbatch sbatch;
201
212
 
202
213
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203
214
 
204
- std::unique_ptr<llama_kv_cache_unified> kv_self;
205
-
206
- // TODO: remove
207
- bool logits_all = false;
215
+ std::unique_ptr<llama_memory_i> memory;
208
216
 
209
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
210
218
  size_t logits_size = 0; // capacity (of floats) for logits
@@ -231,6 +239,9 @@ private:
231
239
 
232
240
  lm_ggml_context_ptr ctx_compute;
233
241
 
242
+ // training
243
+ lm_ggml_opt_context_t opt_ctx = nullptr;
244
+
234
245
  lm_ggml_threadpool_t threadpool = nullptr;
235
246
  lm_ggml_threadpool_t threadpool_batch = nullptr;
236
247
 
@@ -30,6 +30,7 @@ struct llama_cparams {
30
30
  bool flash_attn;
31
31
  bool no_perf;
32
32
  bool warmup;
33
+ bool op_offload;
33
34
 
34
35
  enum llama_pooling_type pooling_type;
35
36