cui-llama.rn 1.4.4 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +54 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1085
  14. package/cpp/chat.h +143 -0
  15. package/cpp/common.cpp +1562 -1996
  16. package/cpp/common.h +677 -744
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-alloc.c +1039 -1030
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +255 -255
  21. package/cpp/ggml-backend-reg.cpp +586 -582
  22. package/cpp/ggml-backend.cpp +2004 -2002
  23. package/cpp/ggml-backend.h +354 -354
  24. package/cpp/ggml-common.h +1857 -1851
  25. package/cpp/ggml-cpp.h +39 -39
  26. package/cpp/ggml-cpu-aarch64.cpp +5725 -4247
  27. package/cpp/ggml-cpu-aarch64.h +8 -8
  28. package/cpp/ggml-cpu-impl.h +512 -380
  29. package/cpp/ggml-cpu-quants.c +13026 -11517
  30. package/cpp/ggml-cpu-traits.cpp +36 -36
  31. package/cpp/ggml-cpu-traits.h +38 -38
  32. package/cpp/ggml-cpu.c +3438 -14485
  33. package/cpp/ggml-cpu.cpp +655 -633
  34. package/cpp/ggml-cpu.h +138 -135
  35. package/cpp/ggml-impl.h +594 -567
  36. package/cpp/ggml-metal-impl.h +312 -3
  37. package/cpp/ggml-metal.h +66 -66
  38. package/cpp/ggml-metal.m +5360 -5002
  39. package/cpp/ggml-opt.cpp +854 -854
  40. package/cpp/ggml-opt.h +216 -216
  41. package/cpp/ggml-quants.c +5238 -5238
  42. package/cpp/ggml-threading.h +14 -14
  43. package/cpp/ggml.c +6618 -6524
  44. package/cpp/ggml.h +2222 -2194
  45. package/cpp/gguf.cpp +1330 -1329
  46. package/cpp/gguf.h +202 -202
  47. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  48. package/cpp/json-schema-to-grammar.h +21 -22
  49. package/cpp/json.hpp +24766 -24766
  50. package/cpp/llama-adapter.cpp +382 -347
  51. package/cpp/llama-adapter.h +76 -74
  52. package/cpp/llama-arch.cpp +1714 -1492
  53. package/cpp/llama-arch.h +428 -402
  54. package/cpp/llama-batch.cpp +368 -368
  55. package/cpp/llama-batch.h +88 -88
  56. package/cpp/llama-chat.cpp +640 -587
  57. package/cpp/llama-chat.h +56 -53
  58. package/cpp/llama-context.cpp +2831 -1775
  59. package/cpp/llama-context.h +265 -128
  60. package/cpp/llama-cparams.cpp +1 -1
  61. package/cpp/llama-cparams.h +38 -37
  62. package/cpp/llama-cpp.h +30 -30
  63. package/cpp/llama-grammar.cpp +1219 -1219
  64. package/cpp/llama-grammar.h +173 -164
  65. package/cpp/llama-graph.cpp +1695 -0
  66. package/cpp/llama-graph.h +592 -0
  67. package/cpp/llama-hparams.cpp +79 -71
  68. package/cpp/llama-hparams.h +156 -139
  69. package/cpp/llama-impl.cpp +167 -167
  70. package/cpp/llama-impl.h +61 -61
  71. package/cpp/llama-io.cpp +15 -0
  72. package/cpp/llama-io.h +35 -0
  73. package/cpp/llama-kv-cache.cpp +1380 -718
  74. package/cpp/llama-kv-cache.h +213 -218
  75. package/cpp/llama-memory.cpp +1 -0
  76. package/cpp/llama-memory.h +21 -0
  77. package/cpp/llama-mmap.cpp +600 -590
  78. package/cpp/llama-mmap.h +68 -68
  79. package/cpp/llama-model-loader.cpp +1129 -1124
  80. package/cpp/llama-model-loader.h +169 -167
  81. package/cpp/llama-model.cpp +13080 -4023
  82. package/cpp/llama-model.h +409 -370
  83. package/cpp/llama-sampling.cpp +2563 -2525
  84. package/cpp/llama-sampling.h +32 -32
  85. package/cpp/llama-vocab.cpp +3295 -3252
  86. package/cpp/llama-vocab.h +125 -125
  87. package/cpp/llama.cpp +351 -10137
  88. package/cpp/llama.h +1434 -1340
  89. package/cpp/log.cpp +427 -423
  90. package/cpp/log.h +132 -132
  91. package/cpp/{chat-template.hpp → minja/chat-template.hpp} +537 -529
  92. package/cpp/{minja.hpp → minja/minja.hpp} +2941 -2883
  93. package/cpp/ops.cpp +8723 -0
  94. package/cpp/ops.h +128 -0
  95. package/cpp/rn-llama.cpp +45 -71
  96. package/cpp/rn-llama.h +3 -3
  97. package/cpp/sampling.cpp +573 -532
  98. package/cpp/sgemm.cpp +3043 -2598
  99. package/cpp/sgemm.h +14 -14
  100. package/cpp/simd-mappings.h +888 -0
  101. package/cpp/speculative.cpp +278 -277
  102. package/cpp/speculative.h +28 -28
  103. package/cpp/unary-ops.cpp +186 -0
  104. package/cpp/unary-ops.h +28 -0
  105. package/cpp/vec.cpp +258 -0
  106. package/cpp/vec.h +802 -0
  107. package/ios/CMakeLists.txt +5 -2
  108. package/ios/RNLlama.mm +2 -2
  109. package/ios/RNLlamaContext.mm +40 -24
  110. package/package.json +1 -1
  111. package/src/NativeRNLlama.ts +6 -4
  112. package/src/index.ts +3 -1
  113. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  114. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  115. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  116. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  117. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  118. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  119. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  120. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  122. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  124. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  125. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  126. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  127. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  128. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  129. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  130. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  131. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  132. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  133. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  134. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  135. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  136. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  194. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  195. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  196. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  197. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  198. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  199. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  200. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  201. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  202. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  203. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  204. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  205. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  206. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  207. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  208. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  209. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  210. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  211. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  212. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  213. package/android/src/main/build-arm64/Makefile +0 -1862
  214. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  215. package/cpp/chat.hpp +0 -55
  216. package/cpp/rn-llama.hpp +0 -913
package/cpp/llama-model.h CHANGED
@@ -1,370 +1,409 @@
1
- #pragma once
2
-
3
- #include "llama.h"
4
- #include "llama-arch.h"
5
- #include "llama-hparams.h"
6
- #include "llama-vocab.h"
7
-
8
- #include <memory>
9
- #include <string>
10
- #include <unordered_map>
11
- #include <vector>
12
-
13
- struct llama_model_loader;
14
-
15
- // available models
16
- enum llm_type {
17
- LLM_TYPE_UNKNOWN,
18
- LLM_TYPE_14M,
19
- LLM_TYPE_17M,
20
- LLM_TYPE_22M,
21
- LLM_TYPE_33M,
22
- LLM_TYPE_60M,
23
- LLM_TYPE_70M,
24
- LLM_TYPE_80M,
25
- LLM_TYPE_109M,
26
- LLM_TYPE_137M,
27
- LLM_TYPE_160M,
28
- LLM_TYPE_220M,
29
- LLM_TYPE_250M,
30
- LLM_TYPE_270M,
31
- LLM_TYPE_335M,
32
- LLM_TYPE_410M,
33
- LLM_TYPE_450M,
34
- LLM_TYPE_770M,
35
- LLM_TYPE_780M,
36
- LLM_TYPE_0_5B,
37
- LLM_TYPE_1B,
38
- LLM_TYPE_1_3B,
39
- LLM_TYPE_1_4B,
40
- LLM_TYPE_1_5B,
41
- LLM_TYPE_1_6B,
42
- LLM_TYPE_2B,
43
- LLM_TYPE_2_8B,
44
- LLM_TYPE_3B,
45
- LLM_TYPE_4B,
46
- LLM_TYPE_6B,
47
- LLM_TYPE_6_9B,
48
- LLM_TYPE_7B,
49
- LLM_TYPE_8B,
50
- LLM_TYPE_9B,
51
- LLM_TYPE_11B,
52
- LLM_TYPE_12B,
53
- LLM_TYPE_13B,
54
- LLM_TYPE_14B,
55
- LLM_TYPE_15B,
56
- LLM_TYPE_16B,
57
- LLM_TYPE_20B,
58
- LLM_TYPE_30B,
59
- LLM_TYPE_32B,
60
- LLM_TYPE_34B,
61
- LLM_TYPE_35B,
62
- LLM_TYPE_40B,
63
- LLM_TYPE_65B,
64
- LLM_TYPE_70B,
65
- LLM_TYPE_236B,
66
- LLM_TYPE_314B,
67
- LLM_TYPE_671B,
68
- LLM_TYPE_SMALL,
69
- LLM_TYPE_MEDIUM,
70
- LLM_TYPE_LARGE,
71
- LLM_TYPE_XL,
72
- LLM_TYPE_A1_7B,
73
- LLM_TYPE_A2_7B,
74
- LLM_TYPE_8x7B,
75
- LLM_TYPE_8x22B,
76
- LLM_TYPE_16x12B,
77
- LLM_TYPE_16x3_8B,
78
- LLM_TYPE_10B_128x3_66B,
79
- LLM_TYPE_57B_A14B,
80
- LLM_TYPE_27B,
81
- };
82
-
83
- struct llama_layer_posnet {
84
- // resnet
85
- struct lm_ggml_tensor * norm1 = nullptr;
86
- struct lm_ggml_tensor * norm1_b = nullptr;
87
-
88
- struct lm_ggml_tensor * conv1 = nullptr;
89
- struct lm_ggml_tensor * conv1_b = nullptr;
90
-
91
- struct lm_ggml_tensor * norm2 = nullptr;
92
- struct lm_ggml_tensor * norm2_b = nullptr;
93
-
94
- struct lm_ggml_tensor * conv2 = nullptr;
95
- struct lm_ggml_tensor * conv2_b = nullptr;
96
-
97
- // attention
98
- struct lm_ggml_tensor * attn_norm = nullptr;
99
- struct lm_ggml_tensor * attn_norm_b = nullptr;
100
-
101
- struct lm_ggml_tensor * attn_q = nullptr;
102
- struct lm_ggml_tensor * attn_q_b = nullptr;
103
-
104
- struct lm_ggml_tensor * attn_k = nullptr;
105
- struct lm_ggml_tensor * attn_k_b = nullptr;
106
-
107
- struct lm_ggml_tensor * attn_v = nullptr;
108
- struct lm_ggml_tensor * attn_v_b = nullptr;
109
-
110
- struct lm_ggml_tensor * attn_o = nullptr;
111
- struct lm_ggml_tensor * attn_o_b = nullptr;
112
-
113
- // normalize
114
- struct lm_ggml_tensor * norm = nullptr;
115
- struct lm_ggml_tensor * norm_b = nullptr;
116
- };
117
-
118
- struct llama_layer_convnext {
119
- struct lm_ggml_tensor * dw = nullptr;
120
- struct lm_ggml_tensor * dw_b = nullptr;
121
-
122
- struct lm_ggml_tensor * norm = nullptr;
123
- struct lm_ggml_tensor * norm_b = nullptr;
124
-
125
- struct lm_ggml_tensor * pw1 = nullptr;
126
- struct lm_ggml_tensor * pw1_b = nullptr;
127
-
128
- struct lm_ggml_tensor * pw2 = nullptr;
129
- struct lm_ggml_tensor * pw2_b = nullptr;
130
-
131
- struct lm_ggml_tensor * gamma = nullptr;
132
- };
133
-
134
- struct llama_layer {
135
- // normalization
136
- struct lm_ggml_tensor * attn_norm = nullptr;
137
- struct lm_ggml_tensor * attn_norm_b = nullptr;
138
- struct lm_ggml_tensor * attn_norm_2 = nullptr;
139
- struct lm_ggml_tensor * attn_norm_2_b = nullptr;
140
- struct lm_ggml_tensor * attn_q_norm = nullptr;
141
- struct lm_ggml_tensor * attn_q_norm_b = nullptr;
142
- struct lm_ggml_tensor * attn_k_norm = nullptr;
143
- struct lm_ggml_tensor * attn_k_norm_b = nullptr;
144
- struct lm_ggml_tensor * attn_out_norm = nullptr;
145
- struct lm_ggml_tensor * attn_out_norm_b = nullptr;
146
- struct lm_ggml_tensor * attn_q_a_norm = nullptr;
147
- struct lm_ggml_tensor * attn_kv_a_norm = nullptr;
148
- struct lm_ggml_tensor * attn_sub_norm = nullptr;
149
- struct lm_ggml_tensor * attn_post_norm = nullptr;
150
- struct lm_ggml_tensor * ffn_sub_norm = nullptr;
151
- struct lm_ggml_tensor * attn_norm_cross = nullptr;
152
- struct lm_ggml_tensor * attn_norm_enc = nullptr;
153
-
154
- // attention
155
- struct lm_ggml_tensor * wq = nullptr;
156
- struct lm_ggml_tensor * wk = nullptr;
157
- struct lm_ggml_tensor * wv = nullptr;
158
- struct lm_ggml_tensor * wo = nullptr;
159
- struct lm_ggml_tensor * wqkv = nullptr;
160
- struct lm_ggml_tensor * wq_a = nullptr;
161
- struct lm_ggml_tensor * wq_b = nullptr;
162
- struct lm_ggml_tensor * wkv_a_mqa = nullptr;
163
- struct lm_ggml_tensor * wkv_b = nullptr;
164
- struct lm_ggml_tensor * wq_cross = nullptr;
165
- struct lm_ggml_tensor * wk_cross = nullptr;
166
- struct lm_ggml_tensor * wv_cross = nullptr;
167
- struct lm_ggml_tensor * wo_cross = nullptr;
168
- struct lm_ggml_tensor * wq_enc = nullptr;
169
- struct lm_ggml_tensor * wk_enc = nullptr;
170
- struct lm_ggml_tensor * wv_enc = nullptr;
171
- struct lm_ggml_tensor * wo_enc = nullptr;
172
-
173
- // attention bias
174
- struct lm_ggml_tensor * bq = nullptr;
175
- struct lm_ggml_tensor * bk = nullptr;
176
- struct lm_ggml_tensor * bv = nullptr;
177
- struct lm_ggml_tensor * bo = nullptr;
178
- struct lm_ggml_tensor * bqkv = nullptr;
179
-
180
- // relative position bias
181
- struct lm_ggml_tensor * attn_rel_b = nullptr;
182
- struct lm_ggml_tensor * attn_rel_b_enc = nullptr;
183
- struct lm_ggml_tensor * attn_rel_b_cross = nullptr;
184
-
185
- // normalization
186
- struct lm_ggml_tensor * ffn_norm = nullptr;
187
- struct lm_ggml_tensor * ffn_norm_b = nullptr;
188
- struct lm_ggml_tensor * ffn_post_norm = nullptr;
189
- struct lm_ggml_tensor * layer_out_norm = nullptr;
190
- struct lm_ggml_tensor * layer_out_norm_b = nullptr;
191
- struct lm_ggml_tensor * ffn_norm_exps = nullptr;
192
- struct lm_ggml_tensor * ffn_norm_enc = nullptr;
193
-
194
- // ff
195
- struct lm_ggml_tensor * ffn_gate = nullptr; // w1
196
- struct lm_ggml_tensor * ffn_down = nullptr; // w2
197
- struct lm_ggml_tensor * ffn_up = nullptr; // w3
198
- struct lm_ggml_tensor * ffn_gate_enc = nullptr;
199
- struct lm_ggml_tensor * ffn_down_enc = nullptr;
200
- struct lm_ggml_tensor * ffn_up_enc = nullptr;
201
-
202
- // ff MoE
203
- struct lm_ggml_tensor * ffn_gate_inp = nullptr;
204
- struct lm_ggml_tensor * ffn_gate_exps = nullptr;
205
- struct lm_ggml_tensor * ffn_down_exps = nullptr;
206
- struct lm_ggml_tensor * ffn_up_exps = nullptr;
207
-
208
- // ff shared expert (shexp)
209
- struct lm_ggml_tensor * ffn_gate_inp_shexp = nullptr;
210
- struct lm_ggml_tensor * ffn_gate_shexp = nullptr;
211
- struct lm_ggml_tensor * ffn_down_shexp = nullptr;
212
- struct lm_ggml_tensor * ffn_up_shexp = nullptr;
213
-
214
- // ff bias
215
- struct lm_ggml_tensor * ffn_gate_b = nullptr;
216
- struct lm_ggml_tensor * ffn_down_b = nullptr; // b2
217
- struct lm_ggml_tensor * ffn_up_b = nullptr; // b3
218
- struct lm_ggml_tensor * ffn_act = nullptr;
219
- struct lm_ggml_tensor * ffn_exp_probs_b = nullptr;
220
-
221
- // mamba proj
222
- struct lm_ggml_tensor * ssm_in = nullptr;
223
- struct lm_ggml_tensor * ssm_x = nullptr;
224
- struct lm_ggml_tensor * ssm_dt = nullptr;
225
- struct lm_ggml_tensor * ssm_out = nullptr;
226
-
227
- // mamba
228
- struct lm_ggml_tensor * ssm_conv1d = nullptr;
229
- struct lm_ggml_tensor * ssm_a = nullptr;
230
- struct lm_ggml_tensor * ssm_d = nullptr;
231
-
232
- // mamba bias
233
- struct lm_ggml_tensor * ssm_conv1d_b = nullptr;
234
- struct lm_ggml_tensor * ssm_dt_b = nullptr;
235
-
236
- // rwkv
237
- struct lm_ggml_tensor * time_mix_w1 = nullptr;
238
- struct lm_ggml_tensor * time_mix_w2 = nullptr;
239
- struct lm_ggml_tensor * time_mix_lerp_x = nullptr;
240
- struct lm_ggml_tensor * time_mix_lerp_w = nullptr;
241
- struct lm_ggml_tensor * time_mix_lerp_k = nullptr;
242
- struct lm_ggml_tensor * time_mix_lerp_v = nullptr;
243
- struct lm_ggml_tensor * time_mix_lerp_r = nullptr;
244
- struct lm_ggml_tensor * time_mix_lerp_g = nullptr;
245
- struct lm_ggml_tensor * time_mix_lerp_fused = nullptr;
246
-
247
- struct lm_ggml_tensor * time_mix_first = nullptr;
248
- struct lm_ggml_tensor * time_mix_decay = nullptr;
249
- struct lm_ggml_tensor * time_mix_decay_w1 = nullptr;
250
- struct lm_ggml_tensor * time_mix_decay_w2 = nullptr;
251
- struct lm_ggml_tensor * time_mix_key = nullptr;
252
- struct lm_ggml_tensor * time_mix_key_b = nullptr;
253
- struct lm_ggml_tensor * time_mix_value = nullptr;
254
- struct lm_ggml_tensor * time_mix_value_b = nullptr;
255
- struct lm_ggml_tensor * time_mix_receptance = nullptr;
256
- struct lm_ggml_tensor * time_mix_receptance_b = nullptr;
257
- struct lm_ggml_tensor * time_mix_gate = nullptr;
258
-
259
- struct lm_ggml_tensor * time_mix_ln = nullptr;
260
- struct lm_ggml_tensor * time_mix_ln_b = nullptr;
261
- struct lm_ggml_tensor * time_mix_output = nullptr;
262
-
263
- struct lm_ggml_tensor * channel_mix_lerp_k = nullptr;
264
- struct lm_ggml_tensor * channel_mix_lerp_r = nullptr;
265
-
266
- struct lm_ggml_tensor * channel_mix_key = nullptr;
267
- struct lm_ggml_tensor * channel_mix_receptance = nullptr;
268
- struct lm_ggml_tensor * channel_mix_value = nullptr;
269
-
270
- // long rope factors
271
- struct lm_ggml_tensor * rope_long = nullptr;
272
- struct lm_ggml_tensor * rope_short = nullptr;
273
- struct lm_ggml_tensor * rope_freqs = nullptr;
274
-
275
- // bitnet scale
276
- struct lm_ggml_tensor * wq_scale = nullptr;
277
- struct lm_ggml_tensor * wk_scale = nullptr;
278
- struct lm_ggml_tensor * wv_scale = nullptr;
279
- struct lm_ggml_tensor * wo_scale = nullptr;
280
- struct lm_ggml_tensor * ffn_gate_scale = nullptr;
281
- struct lm_ggml_tensor * ffn_up_scale = nullptr;
282
- struct lm_ggml_tensor * ffn_down_scale = nullptr;
283
-
284
- struct llama_layer_posnet posnet;
285
-
286
- struct llama_layer_convnext convnext;
287
- };
288
-
289
- struct llama_model {
290
- llm_type type = LLM_TYPE_UNKNOWN;
291
- llm_arch arch = LLM_ARCH_UNKNOWN;
292
-
293
- std::string name = "n/a";
294
-
295
- llama_hparams hparams = {};
296
- llama_vocab vocab;
297
-
298
- struct lm_ggml_tensor * tok_embd = nullptr;
299
- struct lm_ggml_tensor * type_embd = nullptr;
300
- struct lm_ggml_tensor * pos_embd = nullptr;
301
- struct lm_ggml_tensor * tok_norm = nullptr;
302
- struct lm_ggml_tensor * tok_norm_b = nullptr;
303
-
304
- struct lm_ggml_tensor * output_norm = nullptr;
305
- struct lm_ggml_tensor * output_norm_b = nullptr;
306
- struct lm_ggml_tensor * output = nullptr;
307
- struct lm_ggml_tensor * output_b = nullptr;
308
- struct lm_ggml_tensor * output_norm_enc = nullptr;
309
-
310
- // classifier
311
- struct lm_ggml_tensor * cls = nullptr;
312
- struct lm_ggml_tensor * cls_b = nullptr;
313
- struct lm_ggml_tensor * cls_out = nullptr;
314
- struct lm_ggml_tensor * cls_out_b = nullptr;
315
-
316
- struct lm_ggml_tensor * conv1d = nullptr;
317
- struct lm_ggml_tensor * conv1d_b = nullptr;
318
-
319
- std::vector<llama_layer> layers;
320
-
321
- llama_model_params params;
322
-
323
- // gguf metadata
324
- std::unordered_map<std::string, std::string> lm_gguf_kv;
325
-
326
- // list of devices used in this model
327
- std::vector<lm_ggml_backend_dev_t> devices;
328
-
329
- // for quantize-stats only
330
- std::vector<std::pair<std::string, struct lm_ggml_tensor *>> tensors_by_name;
331
-
332
- int64_t t_load_us = 0;
333
- int64_t t_start_us = 0;
334
-
335
- explicit llama_model(const struct llama_model_params & params);
336
- ~llama_model();
337
-
338
- void load_stats (llama_model_loader & ml);
339
- void load_arch (llama_model_loader & ml);
340
- void load_hparams(llama_model_loader & ml);
341
- void load_vocab (llama_model_loader & ml);
342
- bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
343
-
344
- std::string arch_name() const;
345
- std::string type_name() const;
346
-
347
- std::string desc() const;
348
-
349
- size_t size() const;
350
- size_t max_nodes() const;
351
- size_t n_devices() const;
352
-
353
- // total number of parameters in the model
354
- uint64_t n_elements() const;
355
-
356
- void print_info() const;
357
-
358
- lm_ggml_backend_dev_t dev_layer(int il) const;
359
- lm_ggml_backend_dev_t dev_output() const;
360
-
361
- lm_ggml_backend_buffer_type_t select_buft(int il) const;
362
-
363
- const struct lm_ggml_tensor * get_tensor(const char * name) const;
364
-
365
- private:
366
- struct impl;
367
- std::unique_ptr<impl> pimpl;
368
- };
369
-
370
- const char * llm_type_name(llm_type type);
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "llama-arch.h"
5
+ #include "llama-graph.h"
6
+ #include "llama-hparams.h"
7
+ #include "llama-memory.h"
8
+ #include "llama-vocab.h"
9
+
10
+ #include <memory>
11
+ #include <string>
12
+ #include <unordered_map>
13
+ #include <vector>
14
+
15
+ struct llama_cparams;
16
+ struct llama_ubatch;
17
+ struct llama_model_loader;
18
+
19
+ // available models
20
+ enum llm_type {
21
+ LLM_TYPE_UNKNOWN,
22
+ LLM_TYPE_14M,
23
+ LLM_TYPE_17M,
24
+ LLM_TYPE_22M,
25
+ LLM_TYPE_33M,
26
+ LLM_TYPE_60M,
27
+ LLM_TYPE_70M,
28
+ LLM_TYPE_80M,
29
+ LLM_TYPE_109M,
30
+ LLM_TYPE_137M,
31
+ LLM_TYPE_160M,
32
+ LLM_TYPE_190M,
33
+ LLM_TYPE_220M,
34
+ LLM_TYPE_250M,
35
+ LLM_TYPE_270M,
36
+ LLM_TYPE_335M,
37
+ LLM_TYPE_410M,
38
+ LLM_TYPE_450M,
39
+ LLM_TYPE_770M,
40
+ LLM_TYPE_780M,
41
+ LLM_TYPE_0_5B,
42
+ LLM_TYPE_1B,
43
+ LLM_TYPE_1_3B,
44
+ LLM_TYPE_1_4B,
45
+ LLM_TYPE_1_5B,
46
+ LLM_TYPE_1_6B,
47
+ LLM_TYPE_1_8B,
48
+ LLM_TYPE_2B,
49
+ LLM_TYPE_2_8B,
50
+ LLM_TYPE_2_9B,
51
+ LLM_TYPE_3B,
52
+ LLM_TYPE_4B,
53
+ LLM_TYPE_6B,
54
+ LLM_TYPE_6_9B,
55
+ LLM_TYPE_7B,
56
+ LLM_TYPE_8B,
57
+ LLM_TYPE_9B,
58
+ LLM_TYPE_11B,
59
+ LLM_TYPE_12B,
60
+ LLM_TYPE_13B,
61
+ LLM_TYPE_14B,
62
+ LLM_TYPE_15B,
63
+ LLM_TYPE_16B,
64
+ LLM_TYPE_20B,
65
+ LLM_TYPE_30B,
66
+ LLM_TYPE_32B,
67
+ LLM_TYPE_34B,
68
+ LLM_TYPE_35B,
69
+ LLM_TYPE_40B,
70
+ LLM_TYPE_65B,
71
+ LLM_TYPE_70B,
72
+ LLM_TYPE_236B,
73
+ LLM_TYPE_314B,
74
+ LLM_TYPE_671B,
75
+ LLM_TYPE_SMALL,
76
+ LLM_TYPE_MEDIUM,
77
+ LLM_TYPE_LARGE,
78
+ LLM_TYPE_XL,
79
+ LLM_TYPE_A1_7B,
80
+ LLM_TYPE_A2_7B,
81
+ LLM_TYPE_8x7B,
82
+ LLM_TYPE_8x22B,
83
+ LLM_TYPE_16x12B,
84
+ LLM_TYPE_16x3_8B,
85
+ LLM_TYPE_10B_128x3_66B,
86
+ LLM_TYPE_57B_A14B,
87
+ LLM_TYPE_27B,
88
+ LLM_TYPE_290B,
89
+ LLM_TYPE_17B_16E, // llama4 Scout
90
+ LLM_TYPE_17B_128E, // llama4 Maverick
91
+ };
92
+
93
+ struct llama_layer_posnet {
94
+ // resnet
95
+ struct lm_ggml_tensor * norm1 = nullptr;
96
+ struct lm_ggml_tensor * norm1_b = nullptr;
97
+
98
+ struct lm_ggml_tensor * conv1 = nullptr;
99
+ struct lm_ggml_tensor * conv1_b = nullptr;
100
+
101
+ struct lm_ggml_tensor * norm2 = nullptr;
102
+ struct lm_ggml_tensor * norm2_b = nullptr;
103
+
104
+ struct lm_ggml_tensor * conv2 = nullptr;
105
+ struct lm_ggml_tensor * conv2_b = nullptr;
106
+
107
+ // attention
108
+ struct lm_ggml_tensor * attn_norm = nullptr;
109
+ struct lm_ggml_tensor * attn_norm_b = nullptr;
110
+
111
+ struct lm_ggml_tensor * attn_q = nullptr;
112
+ struct lm_ggml_tensor * attn_q_b = nullptr;
113
+
114
+ struct lm_ggml_tensor * attn_k = nullptr;
115
+ struct lm_ggml_tensor * attn_k_b = nullptr;
116
+
117
+ struct lm_ggml_tensor * attn_v = nullptr;
118
+ struct lm_ggml_tensor * attn_v_b = nullptr;
119
+
120
+ struct lm_ggml_tensor * attn_o = nullptr;
121
+ struct lm_ggml_tensor * attn_o_b = nullptr;
122
+
123
+ // normalize
124
+ struct lm_ggml_tensor * norm = nullptr;
125
+ struct lm_ggml_tensor * norm_b = nullptr;
126
+ };
127
+
128
+ struct llama_layer_convnext {
129
+ struct lm_ggml_tensor * dw = nullptr;
130
+ struct lm_ggml_tensor * dw_b = nullptr;
131
+
132
+ struct lm_ggml_tensor * norm = nullptr;
133
+ struct lm_ggml_tensor * norm_b = nullptr;
134
+
135
+ struct lm_ggml_tensor * pw1 = nullptr;
136
+ struct lm_ggml_tensor * pw1_b = nullptr;
137
+
138
+ struct lm_ggml_tensor * pw2 = nullptr;
139
+ struct lm_ggml_tensor * pw2_b = nullptr;
140
+
141
+ struct lm_ggml_tensor * gamma = nullptr;
142
+ };
143
+
144
+ struct llama_layer {
145
+ // normalization
146
+ struct lm_ggml_tensor * attn_norm = nullptr;
147
+ struct lm_ggml_tensor * attn_norm_b = nullptr;
148
+ struct lm_ggml_tensor * attn_norm_2 = nullptr;
149
+ struct lm_ggml_tensor * attn_norm_2_b = nullptr;
150
+ struct lm_ggml_tensor * attn_q_norm = nullptr;
151
+ struct lm_ggml_tensor * attn_q_norm_b = nullptr;
152
+ struct lm_ggml_tensor * attn_k_norm = nullptr;
153
+ struct lm_ggml_tensor * attn_k_norm_b = nullptr;
154
+ struct lm_ggml_tensor * attn_out_norm = nullptr;
155
+ struct lm_ggml_tensor * attn_out_norm_b = nullptr;
156
+ struct lm_ggml_tensor * attn_q_a_norm = nullptr;
157
+ struct lm_ggml_tensor * attn_kv_a_norm = nullptr;
158
+ struct lm_ggml_tensor * attn_sub_norm = nullptr;
159
+ struct lm_ggml_tensor * attn_post_norm = nullptr;
160
+ struct lm_ggml_tensor * ffn_sub_norm = nullptr;
161
+ struct lm_ggml_tensor * attn_norm_cross = nullptr;
162
+ struct lm_ggml_tensor * attn_norm_enc = nullptr;
163
+
164
+ // attention
165
+ struct lm_ggml_tensor * wq = nullptr;
166
+ struct lm_ggml_tensor * wk = nullptr;
167
+ struct lm_ggml_tensor * wv = nullptr;
168
+ struct lm_ggml_tensor * wo = nullptr;
169
+ struct lm_ggml_tensor * wqkv = nullptr;
170
+ struct lm_ggml_tensor * wq_a = nullptr;
171
+ struct lm_ggml_tensor * wq_b = nullptr;
172
+ struct lm_ggml_tensor * wkv_a_mqa = nullptr;
173
+ struct lm_ggml_tensor * wkv_b = nullptr;
174
+ struct lm_ggml_tensor * wq_cross = nullptr;
175
+ struct lm_ggml_tensor * wk_cross = nullptr;
176
+ struct lm_ggml_tensor * wv_cross = nullptr;
177
+ struct lm_ggml_tensor * wo_cross = nullptr;
178
+ struct lm_ggml_tensor * wq_enc = nullptr;
179
+ struct lm_ggml_tensor * wk_enc = nullptr;
180
+ struct lm_ggml_tensor * wv_enc = nullptr;
181
+ struct lm_ggml_tensor * wo_enc = nullptr;
182
+
183
+ // attention bias
184
+ struct lm_ggml_tensor * bq = nullptr;
185
+ struct lm_ggml_tensor * bk = nullptr;
186
+ struct lm_ggml_tensor * bv = nullptr;
187
+ struct lm_ggml_tensor * bo = nullptr;
188
+ struct lm_ggml_tensor * bqkv = nullptr;
189
+
190
+ // relative position bias
191
+ struct lm_ggml_tensor * attn_rel_b = nullptr;
192
+ struct lm_ggml_tensor * attn_rel_b_enc = nullptr;
193
+ struct lm_ggml_tensor * attn_rel_b_cross = nullptr;
194
+
195
+ // normalization
196
+ struct lm_ggml_tensor * ffn_norm = nullptr;
197
+ struct lm_ggml_tensor * ffn_norm_b = nullptr;
198
+ struct lm_ggml_tensor * ffn_post_norm = nullptr;
199
+ struct lm_ggml_tensor * layer_out_norm = nullptr;
200
+ struct lm_ggml_tensor * layer_out_norm_b = nullptr;
201
+ struct lm_ggml_tensor * ffn_norm_exps = nullptr;
202
+ struct lm_ggml_tensor * ffn_norm_enc = nullptr;
203
+
204
+ // ff
205
+ struct lm_ggml_tensor * ffn_gate = nullptr; // w1
206
+ struct lm_ggml_tensor * ffn_down = nullptr; // w2
207
+ struct lm_ggml_tensor * ffn_up = nullptr; // w3
208
+ struct lm_ggml_tensor * ffn_gate_enc = nullptr;
209
+ struct lm_ggml_tensor * ffn_down_enc = nullptr;
210
+ struct lm_ggml_tensor * ffn_up_enc = nullptr;
211
+
212
+ // ff MoE
213
+ struct lm_ggml_tensor * ffn_gate_inp = nullptr;
214
+ struct lm_ggml_tensor * ffn_gate_exps = nullptr;
215
+ struct lm_ggml_tensor * ffn_down_exps = nullptr;
216
+ struct lm_ggml_tensor * ffn_up_exps = nullptr;
217
+
218
+ // ff shared expert (shexp)
219
+ struct lm_ggml_tensor * ffn_gate_inp_shexp = nullptr;
220
+ struct lm_ggml_tensor * ffn_gate_shexp = nullptr;
221
+ struct lm_ggml_tensor * ffn_down_shexp = nullptr;
222
+ struct lm_ggml_tensor * ffn_up_shexp = nullptr;
223
+
224
+ // ff bias
225
+ struct lm_ggml_tensor * ffn_gate_b = nullptr;
226
+ struct lm_ggml_tensor * ffn_down_b = nullptr; // b2
227
+ struct lm_ggml_tensor * ffn_up_b = nullptr; // b3
228
+ struct lm_ggml_tensor * ffn_act = nullptr;
229
+ struct lm_ggml_tensor * ffn_exp_probs_b = nullptr;
230
+
231
+ // mamba proj
232
+ struct lm_ggml_tensor * ssm_in = nullptr;
233
+ struct lm_ggml_tensor * ssm_x = nullptr;
234
+ struct lm_ggml_tensor * ssm_dt = nullptr;
235
+ struct lm_ggml_tensor * ssm_out = nullptr;
236
+
237
+ // mamba
238
+ struct lm_ggml_tensor * ssm_conv1d = nullptr;
239
+ struct lm_ggml_tensor * ssm_a = nullptr;
240
+ struct lm_ggml_tensor * ssm_d = nullptr;
241
+
242
+ // mamba bias
243
+ struct lm_ggml_tensor * ssm_conv1d_b = nullptr;
244
+ struct lm_ggml_tensor * ssm_dt_b = nullptr;
245
+
246
+ // rwkv
247
+ struct lm_ggml_tensor * time_mix_w1 = nullptr;
248
+ struct lm_ggml_tensor * time_mix_w2 = nullptr;
249
+ struct lm_ggml_tensor * time_mix_lerp_x = nullptr;
250
+ struct lm_ggml_tensor * time_mix_lerp_w = nullptr;
251
+ struct lm_ggml_tensor * time_mix_lerp_k = nullptr;
252
+ struct lm_ggml_tensor * time_mix_lerp_v = nullptr;
253
+ struct lm_ggml_tensor * time_mix_lerp_r = nullptr;
254
+ struct lm_ggml_tensor * time_mix_lerp_g = nullptr;
255
+ struct lm_ggml_tensor * time_mix_lerp_fused = nullptr;
256
+
257
+ struct lm_ggml_tensor * time_mix_first = nullptr;
258
+ struct lm_ggml_tensor * time_mix_decay = nullptr;
259
+ struct lm_ggml_tensor * time_mix_decay_w1 = nullptr;
260
+ struct lm_ggml_tensor * time_mix_decay_w2 = nullptr;
261
+ struct lm_ggml_tensor * time_mix_key = nullptr;
262
+ struct lm_ggml_tensor * time_mix_key_b = nullptr;
263
+ struct lm_ggml_tensor * time_mix_value = nullptr;
264
+ struct lm_ggml_tensor * time_mix_value_b = nullptr;
265
+ struct lm_ggml_tensor * time_mix_receptance = nullptr;
266
+ struct lm_ggml_tensor * time_mix_receptance_b = nullptr;
267
+ struct lm_ggml_tensor * time_mix_gate = nullptr;
268
+
269
+ // rwkv7
270
+ struct lm_ggml_tensor * time_mix_w0 = nullptr;
271
+ struct lm_ggml_tensor * time_mix_a0 = nullptr;
272
+ struct lm_ggml_tensor * time_mix_a1 = nullptr;
273
+ struct lm_ggml_tensor * time_mix_a2 = nullptr;
274
+ struct lm_ggml_tensor * time_mix_v0 = nullptr;
275
+ struct lm_ggml_tensor * time_mix_v1 = nullptr;
276
+ struct lm_ggml_tensor * time_mix_v2 = nullptr;
277
+ struct lm_ggml_tensor * time_mix_g1 = nullptr;
278
+ struct lm_ggml_tensor * time_mix_g2 = nullptr;
279
+ struct lm_ggml_tensor * time_mix_k_k = nullptr;
280
+ struct lm_ggml_tensor * time_mix_k_a = nullptr;
281
+ struct lm_ggml_tensor * time_mix_r_k = nullptr;
282
+
283
+ struct lm_ggml_tensor * time_mix_ln = nullptr;
284
+ struct lm_ggml_tensor * time_mix_ln_b = nullptr;
285
+ struct lm_ggml_tensor * time_mix_output = nullptr;
286
+
287
+ struct lm_ggml_tensor * channel_mix_lerp_k = nullptr;
288
+ struct lm_ggml_tensor * channel_mix_lerp_r = nullptr;
289
+
290
+ struct lm_ggml_tensor * channel_mix_key = nullptr;
291
+ struct lm_ggml_tensor * channel_mix_receptance = nullptr;
292
+ struct lm_ggml_tensor * channel_mix_value = nullptr;
293
+
294
+ // long rope factors
295
+ struct lm_ggml_tensor * rope_long = nullptr;
296
+ struct lm_ggml_tensor * rope_short = nullptr;
297
+ struct lm_ggml_tensor * rope_freqs = nullptr;
298
+
299
+ // bitnet scale
300
+ struct lm_ggml_tensor * wq_scale = nullptr;
301
+ struct lm_ggml_tensor * wk_scale = nullptr;
302
+ struct lm_ggml_tensor * wv_scale = nullptr;
303
+ struct lm_ggml_tensor * wo_scale = nullptr;
304
+ struct lm_ggml_tensor * ffn_gate_scale = nullptr;
305
+ struct lm_ggml_tensor * ffn_up_scale = nullptr;
306
+ struct lm_ggml_tensor * ffn_down_scale = nullptr;
307
+
308
+ struct llama_layer_posnet posnet;
309
+
310
+ struct llama_layer_convnext convnext;
311
+ };
312
+
313
+ struct llama_model {
314
+ llm_type type = LLM_TYPE_UNKNOWN;
315
+ llm_arch arch = LLM_ARCH_UNKNOWN;
316
+
317
+ std::string name = "n/a";
318
+
319
+ llama_hparams hparams = {};
320
+ llama_vocab vocab;
321
+
322
+ struct lm_ggml_tensor * tok_embd = nullptr;
323
+ struct lm_ggml_tensor * type_embd = nullptr;
324
+ struct lm_ggml_tensor * pos_embd = nullptr;
325
+ struct lm_ggml_tensor * tok_norm = nullptr;
326
+ struct lm_ggml_tensor * tok_norm_b = nullptr;
327
+
328
+ struct lm_ggml_tensor * output_norm = nullptr;
329
+ struct lm_ggml_tensor * output_norm_b = nullptr;
330
+ struct lm_ggml_tensor * output = nullptr;
331
+ struct lm_ggml_tensor * output_b = nullptr;
332
+ struct lm_ggml_tensor * output_norm_enc = nullptr;
333
+
334
+ // classifier
335
+ struct lm_ggml_tensor * cls = nullptr;
336
+ struct lm_ggml_tensor * cls_b = nullptr;
337
+ struct lm_ggml_tensor * cls_out = nullptr;
338
+ struct lm_ggml_tensor * cls_out_b = nullptr;
339
+
340
+ struct lm_ggml_tensor * conv1d = nullptr;
341
+ struct lm_ggml_tensor * conv1d_b = nullptr;
342
+
343
+ std::vector<llama_layer> layers;
344
+
345
+ llama_model_params params;
346
+
347
+ // gguf metadata
348
+ std::unordered_map<std::string, std::string> lm_gguf_kv;
349
+
350
+ // list of devices used in this model
351
+ std::vector<lm_ggml_backend_dev_t> devices;
352
+
353
+ // for quantize-stats only
354
+ std::vector<std::pair<std::string, struct lm_ggml_tensor *>> tensors_by_name;
355
+
356
+ int64_t t_load_us = 0;
357
+ int64_t t_start_us = 0;
358
+
359
+ explicit llama_model(const struct llama_model_params & params);
360
+ ~llama_model();
361
+
362
+ void load_stats (llama_model_loader & ml);
363
+ void load_arch (llama_model_loader & ml);
364
+ void load_hparams(llama_model_loader & ml);
365
+ void load_vocab (llama_model_loader & ml);
366
+ bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
367
+
368
+ std::string arch_name() const;
369
+ std::string type_name() const;
370
+
371
+ std::string desc() const;
372
+
373
+ size_t size() const;
374
+ size_t n_tensors() const;
375
+ size_t n_devices() const;
376
+
377
+ // total number of parameters in the model
378
+ uint64_t n_elements() const;
379
+
380
+ void print_info() const;
381
+
382
+ lm_ggml_backend_dev_t dev_layer(int il) const;
383
+ lm_ggml_backend_dev_t dev_output() const;
384
+
385
+ lm_ggml_backend_buffer_type_t select_buft(int il) const;
386
+
387
+ bool has_tensor_overrides() const;
388
+
389
+ const struct lm_ggml_tensor * get_tensor(const char * name) const;
390
+
391
+ // TODO: move this to new llm_arch_model_i interface
392
+ llama_memory_i * create_memory() const; // TODO: params
393
+
394
+ // TODO: move this to new llm_arch_model_i interface
395
+ llm_graph_result_ptr build_graph(
396
+ const llm_graph_params & params,
397
+ lm_ggml_cgraph * gf,
398
+ llm_graph_type type) const;
399
+
400
+ private:
401
+ struct impl;
402
+ std::unique_ptr<impl> pimpl;
403
+ };
404
+
405
+ const char * llm_type_name(llm_type type);
406
+
407
+ // For internal test use
408
+ // TODO: remove
409
+ const std::vector<std::pair<std::string, lm_ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);