cui-llama.rn 1.4.3 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
  4. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  5. package/android/src/main/jni-utils.h +6 -0
  6. package/android/src/main/jni.cpp +289 -31
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  16. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  17. package/cpp/chat-template.hpp +529 -0
  18. package/cpp/chat.cpp +1779 -0
  19. package/cpp/chat.h +135 -0
  20. package/cpp/common.cpp +2064 -1873
  21. package/cpp/common.h +700 -699
  22. package/cpp/ggml-alloc.c +1039 -1042
  23. package/cpp/ggml-alloc.h +1 -1
  24. package/cpp/ggml-backend-impl.h +255 -255
  25. package/cpp/ggml-backend-reg.cpp +586 -582
  26. package/cpp/ggml-backend.cpp +2004 -2002
  27. package/cpp/ggml-backend.h +354 -354
  28. package/cpp/ggml-common.h +1851 -1853
  29. package/cpp/ggml-cpp.h +39 -39
  30. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  31. package/cpp/ggml-cpu-aarch64.h +8 -8
  32. package/cpp/ggml-cpu-impl.h +531 -386
  33. package/cpp/ggml-cpu-quants.c +12527 -10920
  34. package/cpp/ggml-cpu-traits.cpp +36 -36
  35. package/cpp/ggml-cpu-traits.h +38 -38
  36. package/cpp/ggml-cpu.c +15766 -14391
  37. package/cpp/ggml-cpu.cpp +655 -635
  38. package/cpp/ggml-cpu.h +138 -135
  39. package/cpp/ggml-impl.h +567 -567
  40. package/cpp/ggml-metal-impl.h +235 -0
  41. package/cpp/ggml-metal.h +1 -1
  42. package/cpp/ggml-metal.m +5146 -4884
  43. package/cpp/ggml-opt.cpp +854 -854
  44. package/cpp/ggml-opt.h +216 -216
  45. package/cpp/ggml-quants.c +5238 -5238
  46. package/cpp/ggml-threading.h +14 -14
  47. package/cpp/ggml.c +6529 -6514
  48. package/cpp/ggml.h +2198 -2194
  49. package/cpp/gguf.cpp +1329 -1329
  50. package/cpp/gguf.h +202 -202
  51. package/cpp/json-schema-to-grammar.cpp +1024 -1045
  52. package/cpp/json-schema-to-grammar.h +21 -8
  53. package/cpp/json.hpp +24766 -24766
  54. package/cpp/llama-adapter.cpp +347 -347
  55. package/cpp/llama-adapter.h +74 -74
  56. package/cpp/llama-arch.cpp +1513 -1487
  57. package/cpp/llama-arch.h +403 -400
  58. package/cpp/llama-batch.cpp +368 -368
  59. package/cpp/llama-batch.h +88 -88
  60. package/cpp/llama-chat.cpp +588 -578
  61. package/cpp/llama-chat.h +53 -52
  62. package/cpp/llama-context.cpp +1775 -1775
  63. package/cpp/llama-context.h +128 -128
  64. package/cpp/llama-cparams.cpp +1 -1
  65. package/cpp/llama-cparams.h +37 -37
  66. package/cpp/llama-cpp.h +30 -30
  67. package/cpp/llama-grammar.cpp +1219 -1139
  68. package/cpp/llama-grammar.h +173 -143
  69. package/cpp/llama-hparams.cpp +71 -71
  70. package/cpp/llama-hparams.h +139 -139
  71. package/cpp/llama-impl.cpp +167 -167
  72. package/cpp/llama-impl.h +61 -61
  73. package/cpp/llama-kv-cache.cpp +718 -718
  74. package/cpp/llama-kv-cache.h +219 -218
  75. package/cpp/llama-mmap.cpp +600 -590
  76. package/cpp/llama-mmap.h +68 -67
  77. package/cpp/llama-model-loader.cpp +1124 -1124
  78. package/cpp/llama-model-loader.h +167 -167
  79. package/cpp/llama-model.cpp +4087 -3997
  80. package/cpp/llama-model.h +370 -370
  81. package/cpp/llama-sampling.cpp +2558 -2408
  82. package/cpp/llama-sampling.h +32 -32
  83. package/cpp/llama-vocab.cpp +3264 -3247
  84. package/cpp/llama-vocab.h +125 -125
  85. package/cpp/llama.cpp +10284 -10077
  86. package/cpp/llama.h +1354 -1323
  87. package/cpp/log.cpp +393 -401
  88. package/cpp/log.h +132 -121
  89. package/cpp/minja/chat-template.hpp +529 -0
  90. package/cpp/minja/minja.hpp +2915 -0
  91. package/cpp/minja.hpp +2915 -0
  92. package/cpp/rn-llama.cpp +66 -6
  93. package/cpp/rn-llama.h +26 -1
  94. package/cpp/sampling.cpp +570 -505
  95. package/cpp/sampling.h +3 -0
  96. package/cpp/sgemm.cpp +2598 -2597
  97. package/cpp/sgemm.h +14 -14
  98. package/cpp/speculative.cpp +278 -277
  99. package/cpp/speculative.h +28 -28
  100. package/cpp/unicode.cpp +9 -2
  101. package/ios/CMakeLists.txt +6 -0
  102. package/ios/RNLlama.h +0 -8
  103. package/ios/RNLlama.mm +27 -3
  104. package/ios/RNLlamaContext.h +10 -1
  105. package/ios/RNLlamaContext.mm +269 -57
  106. package/jest/mock.js +21 -2
  107. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  108. package/lib/commonjs/grammar.js +3 -0
  109. package/lib/commonjs/grammar.js.map +1 -1
  110. package/lib/commonjs/index.js +87 -13
  111. package/lib/commonjs/index.js.map +1 -1
  112. package/lib/module/NativeRNLlama.js.map +1 -1
  113. package/lib/module/grammar.js +3 -0
  114. package/lib/module/grammar.js.map +1 -1
  115. package/lib/module/index.js +86 -13
  116. package/lib/module/index.js.map +1 -1
  117. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  118. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  119. package/lib/typescript/grammar.d.ts.map +1 -1
  120. package/lib/typescript/index.d.ts +32 -7
  121. package/lib/typescript/index.d.ts.map +1 -1
  122. package/llama-rn.podspec +1 -1
  123. package/package.json +3 -2
  124. package/src/NativeRNLlama.ts +115 -3
  125. package/src/grammar.ts +3 -0
  126. package/src/index.ts +138 -21
  127. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  128. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  129. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  130. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  132. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -55
  134. package/cpp/rn-llama.hpp +0 -913
package/cpp/llama-model.h CHANGED
@@ -1,370 +1,370 @@
1
- #pragma once
2
-
3
- #include "llama.h"
4
- #include "llama-arch.h"
5
- #include "llama-hparams.h"
6
- #include "llama-vocab.h"
7
-
8
- #include <memory>
9
- #include <string>
10
- #include <unordered_map>
11
- #include <vector>
12
-
13
- struct llama_model_loader;
14
-
15
- // available models
16
- enum llm_type {
17
- LLM_TYPE_UNKNOWN,
18
- LLM_TYPE_14M,
19
- LLM_TYPE_17M,
20
- LLM_TYPE_22M,
21
- LLM_TYPE_33M,
22
- LLM_TYPE_60M,
23
- LLM_TYPE_70M,
24
- LLM_TYPE_80M,
25
- LLM_TYPE_109M,
26
- LLM_TYPE_137M,
27
- LLM_TYPE_160M,
28
- LLM_TYPE_220M,
29
- LLM_TYPE_250M,
30
- LLM_TYPE_270M,
31
- LLM_TYPE_335M,
32
- LLM_TYPE_410M,
33
- LLM_TYPE_450M,
34
- LLM_TYPE_770M,
35
- LLM_TYPE_780M,
36
- LLM_TYPE_0_5B,
37
- LLM_TYPE_1B,
38
- LLM_TYPE_1_3B,
39
- LLM_TYPE_1_4B,
40
- LLM_TYPE_1_5B,
41
- LLM_TYPE_1_6B,
42
- LLM_TYPE_2B,
43
- LLM_TYPE_2_8B,
44
- LLM_TYPE_3B,
45
- LLM_TYPE_4B,
46
- LLM_TYPE_6B,
47
- LLM_TYPE_6_9B,
48
- LLM_TYPE_7B,
49
- LLM_TYPE_8B,
50
- LLM_TYPE_9B,
51
- LLM_TYPE_11B,
52
- LLM_TYPE_12B,
53
- LLM_TYPE_13B,
54
- LLM_TYPE_14B,
55
- LLM_TYPE_15B,
56
- LLM_TYPE_16B,
57
- LLM_TYPE_20B,
58
- LLM_TYPE_30B,
59
- LLM_TYPE_32B,
60
- LLM_TYPE_34B,
61
- LLM_TYPE_35B,
62
- LLM_TYPE_40B,
63
- LLM_TYPE_65B,
64
- LLM_TYPE_70B,
65
- LLM_TYPE_236B,
66
- LLM_TYPE_314B,
67
- LLM_TYPE_671B,
68
- LLM_TYPE_SMALL,
69
- LLM_TYPE_MEDIUM,
70
- LLM_TYPE_LARGE,
71
- LLM_TYPE_XL,
72
- LLM_TYPE_A1_7B,
73
- LLM_TYPE_A2_7B,
74
- LLM_TYPE_8x7B,
75
- LLM_TYPE_8x22B,
76
- LLM_TYPE_16x12B,
77
- LLM_TYPE_16x3_8B,
78
- LLM_TYPE_10B_128x3_66B,
79
- LLM_TYPE_57B_A14B,
80
- LLM_TYPE_27B,
81
- };
82
-
83
- struct llama_layer_posnet {
84
- // resnet
85
- struct lm_ggml_tensor * norm1 = nullptr;
86
- struct lm_ggml_tensor * norm1_b = nullptr;
87
-
88
- struct lm_ggml_tensor * conv1 = nullptr;
89
- struct lm_ggml_tensor * conv1_b = nullptr;
90
-
91
- struct lm_ggml_tensor * norm2 = nullptr;
92
- struct lm_ggml_tensor * norm2_b = nullptr;
93
-
94
- struct lm_ggml_tensor * conv2 = nullptr;
95
- struct lm_ggml_tensor * conv2_b = nullptr;
96
-
97
- // attention
98
- struct lm_ggml_tensor * attn_norm = nullptr;
99
- struct lm_ggml_tensor * attn_norm_b = nullptr;
100
-
101
- struct lm_ggml_tensor * attn_q = nullptr;
102
- struct lm_ggml_tensor * attn_q_b = nullptr;
103
-
104
- struct lm_ggml_tensor * attn_k = nullptr;
105
- struct lm_ggml_tensor * attn_k_b = nullptr;
106
-
107
- struct lm_ggml_tensor * attn_v = nullptr;
108
- struct lm_ggml_tensor * attn_v_b = nullptr;
109
-
110
- struct lm_ggml_tensor * attn_o = nullptr;
111
- struct lm_ggml_tensor * attn_o_b = nullptr;
112
-
113
- // normalize
114
- struct lm_ggml_tensor * norm = nullptr;
115
- struct lm_ggml_tensor * norm_b = nullptr;
116
- };
117
-
118
- struct llama_layer_convnext {
119
- struct lm_ggml_tensor * dw = nullptr;
120
- struct lm_ggml_tensor * dw_b = nullptr;
121
-
122
- struct lm_ggml_tensor * norm = nullptr;
123
- struct lm_ggml_tensor * norm_b = nullptr;
124
-
125
- struct lm_ggml_tensor * pw1 = nullptr;
126
- struct lm_ggml_tensor * pw1_b = nullptr;
127
-
128
- struct lm_ggml_tensor * pw2 = nullptr;
129
- struct lm_ggml_tensor * pw2_b = nullptr;
130
-
131
- struct lm_ggml_tensor * gamma = nullptr;
132
- };
133
-
134
- struct llama_layer {
135
- // normalization
136
- struct lm_ggml_tensor * attn_norm = nullptr;
137
- struct lm_ggml_tensor * attn_norm_b = nullptr;
138
- struct lm_ggml_tensor * attn_norm_2 = nullptr;
139
- struct lm_ggml_tensor * attn_norm_2_b = nullptr;
140
- struct lm_ggml_tensor * attn_q_norm = nullptr;
141
- struct lm_ggml_tensor * attn_q_norm_b = nullptr;
142
- struct lm_ggml_tensor * attn_k_norm = nullptr;
143
- struct lm_ggml_tensor * attn_k_norm_b = nullptr;
144
- struct lm_ggml_tensor * attn_out_norm = nullptr;
145
- struct lm_ggml_tensor * attn_out_norm_b = nullptr;
146
- struct lm_ggml_tensor * attn_q_a_norm = nullptr;
147
- struct lm_ggml_tensor * attn_kv_a_norm = nullptr;
148
- struct lm_ggml_tensor * attn_sub_norm = nullptr;
149
- struct lm_ggml_tensor * attn_post_norm = nullptr;
150
- struct lm_ggml_tensor * ffn_sub_norm = nullptr;
151
- struct lm_ggml_tensor * attn_norm_cross = nullptr;
152
- struct lm_ggml_tensor * attn_norm_enc = nullptr;
153
-
154
- // attention
155
- struct lm_ggml_tensor * wq = nullptr;
156
- struct lm_ggml_tensor * wk = nullptr;
157
- struct lm_ggml_tensor * wv = nullptr;
158
- struct lm_ggml_tensor * wo = nullptr;
159
- struct lm_ggml_tensor * wqkv = nullptr;
160
- struct lm_ggml_tensor * wq_a = nullptr;
161
- struct lm_ggml_tensor * wq_b = nullptr;
162
- struct lm_ggml_tensor * wkv_a_mqa = nullptr;
163
- struct lm_ggml_tensor * wkv_b = nullptr;
164
- struct lm_ggml_tensor * wq_cross = nullptr;
165
- struct lm_ggml_tensor * wk_cross = nullptr;
166
- struct lm_ggml_tensor * wv_cross = nullptr;
167
- struct lm_ggml_tensor * wo_cross = nullptr;
168
- struct lm_ggml_tensor * wq_enc = nullptr;
169
- struct lm_ggml_tensor * wk_enc = nullptr;
170
- struct lm_ggml_tensor * wv_enc = nullptr;
171
- struct lm_ggml_tensor * wo_enc = nullptr;
172
-
173
- // attention bias
174
- struct lm_ggml_tensor * bq = nullptr;
175
- struct lm_ggml_tensor * bk = nullptr;
176
- struct lm_ggml_tensor * bv = nullptr;
177
- struct lm_ggml_tensor * bo = nullptr;
178
- struct lm_ggml_tensor * bqkv = nullptr;
179
-
180
- // relative position bias
181
- struct lm_ggml_tensor * attn_rel_b = nullptr;
182
- struct lm_ggml_tensor * attn_rel_b_enc = nullptr;
183
- struct lm_ggml_tensor * attn_rel_b_cross = nullptr;
184
-
185
- // normalization
186
- struct lm_ggml_tensor * ffn_norm = nullptr;
187
- struct lm_ggml_tensor * ffn_norm_b = nullptr;
188
- struct lm_ggml_tensor * ffn_post_norm = nullptr;
189
- struct lm_ggml_tensor * layer_out_norm = nullptr;
190
- struct lm_ggml_tensor * layer_out_norm_b = nullptr;
191
- struct lm_ggml_tensor * ffn_norm_exps = nullptr;
192
- struct lm_ggml_tensor * ffn_norm_enc = nullptr;
193
-
194
- // ff
195
- struct lm_ggml_tensor * ffn_gate = nullptr; // w1
196
- struct lm_ggml_tensor * ffn_down = nullptr; // w2
197
- struct lm_ggml_tensor * ffn_up = nullptr; // w3
198
- struct lm_ggml_tensor * ffn_gate_enc = nullptr;
199
- struct lm_ggml_tensor * ffn_down_enc = nullptr;
200
- struct lm_ggml_tensor * ffn_up_enc = nullptr;
201
-
202
- // ff MoE
203
- struct lm_ggml_tensor * ffn_gate_inp = nullptr;
204
- struct lm_ggml_tensor * ffn_gate_exps = nullptr;
205
- struct lm_ggml_tensor * ffn_down_exps = nullptr;
206
- struct lm_ggml_tensor * ffn_up_exps = nullptr;
207
-
208
- // ff shared expert (shexp)
209
- struct lm_ggml_tensor * ffn_gate_inp_shexp = nullptr;
210
- struct lm_ggml_tensor * ffn_gate_shexp = nullptr;
211
- struct lm_ggml_tensor * ffn_down_shexp = nullptr;
212
- struct lm_ggml_tensor * ffn_up_shexp = nullptr;
213
-
214
- // ff bias
215
- struct lm_ggml_tensor * ffn_gate_b = nullptr;
216
- struct lm_ggml_tensor * ffn_down_b = nullptr; // b2
217
- struct lm_ggml_tensor * ffn_up_b = nullptr; // b3
218
- struct lm_ggml_tensor * ffn_act = nullptr;
219
- struct lm_ggml_tensor * ffn_exp_probs_b = nullptr;
220
-
221
- // mamba proj
222
- struct lm_ggml_tensor * ssm_in = nullptr;
223
- struct lm_ggml_tensor * ssm_x = nullptr;
224
- struct lm_ggml_tensor * ssm_dt = nullptr;
225
- struct lm_ggml_tensor * ssm_out = nullptr;
226
-
227
- // mamba
228
- struct lm_ggml_tensor * ssm_conv1d = nullptr;
229
- struct lm_ggml_tensor * ssm_a = nullptr;
230
- struct lm_ggml_tensor * ssm_d = nullptr;
231
-
232
- // mamba bias
233
- struct lm_ggml_tensor * ssm_conv1d_b = nullptr;
234
- struct lm_ggml_tensor * ssm_dt_b = nullptr;
235
-
236
- // rwkv
237
- struct lm_ggml_tensor * time_mix_w1 = nullptr;
238
- struct lm_ggml_tensor * time_mix_w2 = nullptr;
239
- struct lm_ggml_tensor * time_mix_lerp_x = nullptr;
240
- struct lm_ggml_tensor * time_mix_lerp_w = nullptr;
241
- struct lm_ggml_tensor * time_mix_lerp_k = nullptr;
242
- struct lm_ggml_tensor * time_mix_lerp_v = nullptr;
243
- struct lm_ggml_tensor * time_mix_lerp_r = nullptr;
244
- struct lm_ggml_tensor * time_mix_lerp_g = nullptr;
245
- struct lm_ggml_tensor * time_mix_lerp_fused = nullptr;
246
-
247
- struct lm_ggml_tensor * time_mix_first = nullptr;
248
- struct lm_ggml_tensor * time_mix_decay = nullptr;
249
- struct lm_ggml_tensor * time_mix_decay_w1 = nullptr;
250
- struct lm_ggml_tensor * time_mix_decay_w2 = nullptr;
251
- struct lm_ggml_tensor * time_mix_key = nullptr;
252
- struct lm_ggml_tensor * time_mix_key_b = nullptr;
253
- struct lm_ggml_tensor * time_mix_value = nullptr;
254
- struct lm_ggml_tensor * time_mix_value_b = nullptr;
255
- struct lm_ggml_tensor * time_mix_receptance = nullptr;
256
- struct lm_ggml_tensor * time_mix_receptance_b = nullptr;
257
- struct lm_ggml_tensor * time_mix_gate = nullptr;
258
-
259
- struct lm_ggml_tensor * time_mix_ln = nullptr;
260
- struct lm_ggml_tensor * time_mix_ln_b = nullptr;
261
- struct lm_ggml_tensor * time_mix_output = nullptr;
262
-
263
- struct lm_ggml_tensor * channel_mix_lerp_k = nullptr;
264
- struct lm_ggml_tensor * channel_mix_lerp_r = nullptr;
265
-
266
- struct lm_ggml_tensor * channel_mix_key = nullptr;
267
- struct lm_ggml_tensor * channel_mix_receptance = nullptr;
268
- struct lm_ggml_tensor * channel_mix_value = nullptr;
269
-
270
- // long rope factors
271
- struct lm_ggml_tensor * rope_long = nullptr;
272
- struct lm_ggml_tensor * rope_short = nullptr;
273
- struct lm_ggml_tensor * rope_freqs = nullptr;
274
-
275
- // bitnet scale
276
- struct lm_ggml_tensor * wq_scale = nullptr;
277
- struct lm_ggml_tensor * wk_scale = nullptr;
278
- struct lm_ggml_tensor * wv_scale = nullptr;
279
- struct lm_ggml_tensor * wo_scale = nullptr;
280
- struct lm_ggml_tensor * ffn_gate_scale = nullptr;
281
- struct lm_ggml_tensor * ffn_up_scale = nullptr;
282
- struct lm_ggml_tensor * ffn_down_scale = nullptr;
283
-
284
- struct llama_layer_posnet posnet;
285
-
286
- struct llama_layer_convnext convnext;
287
- };
288
-
289
- struct llama_model {
290
- llm_type type = LLM_TYPE_UNKNOWN;
291
- llm_arch arch = LLM_ARCH_UNKNOWN;
292
-
293
- std::string name = "n/a";
294
-
295
- llama_hparams hparams = {};
296
- llama_vocab vocab;
297
-
298
- struct lm_ggml_tensor * tok_embd = nullptr;
299
- struct lm_ggml_tensor * type_embd = nullptr;
300
- struct lm_ggml_tensor * pos_embd = nullptr;
301
- struct lm_ggml_tensor * tok_norm = nullptr;
302
- struct lm_ggml_tensor * tok_norm_b = nullptr;
303
-
304
- struct lm_ggml_tensor * output_norm = nullptr;
305
- struct lm_ggml_tensor * output_norm_b = nullptr;
306
- struct lm_ggml_tensor * output = nullptr;
307
- struct lm_ggml_tensor * output_b = nullptr;
308
- struct lm_ggml_tensor * output_norm_enc = nullptr;
309
-
310
- // classifier
311
- struct lm_ggml_tensor * cls = nullptr;
312
- struct lm_ggml_tensor * cls_b = nullptr;
313
- struct lm_ggml_tensor * cls_out = nullptr;
314
- struct lm_ggml_tensor * cls_out_b = nullptr;
315
-
316
- struct lm_ggml_tensor * conv1d = nullptr;
317
- struct lm_ggml_tensor * conv1d_b = nullptr;
318
-
319
- std::vector<llama_layer> layers;
320
-
321
- llama_model_params params;
322
-
323
- // gguf metadata
324
- std::unordered_map<std::string, std::string> lm_gguf_kv;
325
-
326
- // list of devices used in this model
327
- std::vector<lm_ggml_backend_dev_t> devices;
328
-
329
- // for quantize-stats only
330
- std::vector<std::pair<std::string, struct lm_ggml_tensor *>> tensors_by_name;
331
-
332
- int64_t t_load_us = 0;
333
- int64_t t_start_us = 0;
334
-
335
- explicit llama_model(const struct llama_model_params & params);
336
- ~llama_model();
337
-
338
- void load_stats (llama_model_loader & ml);
339
- void load_arch (llama_model_loader & ml);
340
- void load_hparams(llama_model_loader & ml);
341
- void load_vocab (llama_model_loader & ml);
342
- bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
343
-
344
- std::string arch_name() const;
345
- std::string type_name() const;
346
-
347
- std::string desc() const;
348
-
349
- size_t size() const;
350
- size_t max_nodes() const;
351
- size_t n_devices() const;
352
-
353
- // total number of parameters in the model
354
- uint64_t n_elements() const;
355
-
356
- void print_info() const;
357
-
358
- lm_ggml_backend_dev_t dev_layer(int il) const;
359
- lm_ggml_backend_dev_t dev_output() const;
360
-
361
- lm_ggml_backend_buffer_type_t select_buft(int il) const;
362
-
363
- const struct lm_ggml_tensor * get_tensor(const char * name) const;
364
-
365
- private:
366
- struct impl;
367
- std::unique_ptr<impl> pimpl;
368
- };
369
-
370
- const char * llm_type_name(llm_type type);
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "llama-arch.h"
5
+ #include "llama-hparams.h"
6
+ #include "llama-vocab.h"
7
+
8
+ #include <memory>
9
+ #include <string>
10
+ #include <unordered_map>
11
+ #include <vector>
12
+
13
+ struct llama_model_loader;
14
+
15
+ // available models
16
+ enum llm_type {
17
+ LLM_TYPE_UNKNOWN,
18
+ LLM_TYPE_14M,
19
+ LLM_TYPE_17M,
20
+ LLM_TYPE_22M,
21
+ LLM_TYPE_33M,
22
+ LLM_TYPE_60M,
23
+ LLM_TYPE_70M,
24
+ LLM_TYPE_80M,
25
+ LLM_TYPE_109M,
26
+ LLM_TYPE_137M,
27
+ LLM_TYPE_160M,
28
+ LLM_TYPE_220M,
29
+ LLM_TYPE_250M,
30
+ LLM_TYPE_270M,
31
+ LLM_TYPE_335M,
32
+ LLM_TYPE_410M,
33
+ LLM_TYPE_450M,
34
+ LLM_TYPE_770M,
35
+ LLM_TYPE_780M,
36
+ LLM_TYPE_0_5B,
37
+ LLM_TYPE_1B,
38
+ LLM_TYPE_1_3B,
39
+ LLM_TYPE_1_4B,
40
+ LLM_TYPE_1_5B,
41
+ LLM_TYPE_1_6B,
42
+ LLM_TYPE_2B,
43
+ LLM_TYPE_2_8B,
44
+ LLM_TYPE_3B,
45
+ LLM_TYPE_4B,
46
+ LLM_TYPE_6B,
47
+ LLM_TYPE_6_9B,
48
+ LLM_TYPE_7B,
49
+ LLM_TYPE_8B,
50
+ LLM_TYPE_9B,
51
+ LLM_TYPE_11B,
52
+ LLM_TYPE_12B,
53
+ LLM_TYPE_13B,
54
+ LLM_TYPE_14B,
55
+ LLM_TYPE_15B,
56
+ LLM_TYPE_16B,
57
+ LLM_TYPE_20B,
58
+ LLM_TYPE_30B,
59
+ LLM_TYPE_32B,
60
+ LLM_TYPE_34B,
61
+ LLM_TYPE_35B,
62
+ LLM_TYPE_40B,
63
+ LLM_TYPE_65B,
64
+ LLM_TYPE_70B,
65
+ LLM_TYPE_236B,
66
+ LLM_TYPE_314B,
67
+ LLM_TYPE_671B,
68
+ LLM_TYPE_SMALL,
69
+ LLM_TYPE_MEDIUM,
70
+ LLM_TYPE_LARGE,
71
+ LLM_TYPE_XL,
72
+ LLM_TYPE_A1_7B,
73
+ LLM_TYPE_A2_7B,
74
+ LLM_TYPE_8x7B,
75
+ LLM_TYPE_8x22B,
76
+ LLM_TYPE_16x12B,
77
+ LLM_TYPE_16x3_8B,
78
+ LLM_TYPE_10B_128x3_66B,
79
+ LLM_TYPE_57B_A14B,
80
+ LLM_TYPE_27B,
81
+ };
82
+
83
+ struct llama_layer_posnet {
84
+ // resnet
85
+ struct lm_ggml_tensor * norm1 = nullptr;
86
+ struct lm_ggml_tensor * norm1_b = nullptr;
87
+
88
+ struct lm_ggml_tensor * conv1 = nullptr;
89
+ struct lm_ggml_tensor * conv1_b = nullptr;
90
+
91
+ struct lm_ggml_tensor * norm2 = nullptr;
92
+ struct lm_ggml_tensor * norm2_b = nullptr;
93
+
94
+ struct lm_ggml_tensor * conv2 = nullptr;
95
+ struct lm_ggml_tensor * conv2_b = nullptr;
96
+
97
+ // attention
98
+ struct lm_ggml_tensor * attn_norm = nullptr;
99
+ struct lm_ggml_tensor * attn_norm_b = nullptr;
100
+
101
+ struct lm_ggml_tensor * attn_q = nullptr;
102
+ struct lm_ggml_tensor * attn_q_b = nullptr;
103
+
104
+ struct lm_ggml_tensor * attn_k = nullptr;
105
+ struct lm_ggml_tensor * attn_k_b = nullptr;
106
+
107
+ struct lm_ggml_tensor * attn_v = nullptr;
108
+ struct lm_ggml_tensor * attn_v_b = nullptr;
109
+
110
+ struct lm_ggml_tensor * attn_o = nullptr;
111
+ struct lm_ggml_tensor * attn_o_b = nullptr;
112
+
113
+ // normalize
114
+ struct lm_ggml_tensor * norm = nullptr;
115
+ struct lm_ggml_tensor * norm_b = nullptr;
116
+ };
117
+
118
+ struct llama_layer_convnext {
119
+ struct lm_ggml_tensor * dw = nullptr;
120
+ struct lm_ggml_tensor * dw_b = nullptr;
121
+
122
+ struct lm_ggml_tensor * norm = nullptr;
123
+ struct lm_ggml_tensor * norm_b = nullptr;
124
+
125
+ struct lm_ggml_tensor * pw1 = nullptr;
126
+ struct lm_ggml_tensor * pw1_b = nullptr;
127
+
128
+ struct lm_ggml_tensor * pw2 = nullptr;
129
+ struct lm_ggml_tensor * pw2_b = nullptr;
130
+
131
+ struct lm_ggml_tensor * gamma = nullptr;
132
+ };
133
+
134
+ struct llama_layer {
135
+ // normalization
136
+ struct lm_ggml_tensor * attn_norm = nullptr;
137
+ struct lm_ggml_tensor * attn_norm_b = nullptr;
138
+ struct lm_ggml_tensor * attn_norm_2 = nullptr;
139
+ struct lm_ggml_tensor * attn_norm_2_b = nullptr;
140
+ struct lm_ggml_tensor * attn_q_norm = nullptr;
141
+ struct lm_ggml_tensor * attn_q_norm_b = nullptr;
142
+ struct lm_ggml_tensor * attn_k_norm = nullptr;
143
+ struct lm_ggml_tensor * attn_k_norm_b = nullptr;
144
+ struct lm_ggml_tensor * attn_out_norm = nullptr;
145
+ struct lm_ggml_tensor * attn_out_norm_b = nullptr;
146
+ struct lm_ggml_tensor * attn_q_a_norm = nullptr;
147
+ struct lm_ggml_tensor * attn_kv_a_norm = nullptr;
148
+ struct lm_ggml_tensor * attn_sub_norm = nullptr;
149
+ struct lm_ggml_tensor * attn_post_norm = nullptr;
150
+ struct lm_ggml_tensor * ffn_sub_norm = nullptr;
151
+ struct lm_ggml_tensor * attn_norm_cross = nullptr;
152
+ struct lm_ggml_tensor * attn_norm_enc = nullptr;
153
+
154
+ // attention
155
+ struct lm_ggml_tensor * wq = nullptr;
156
+ struct lm_ggml_tensor * wk = nullptr;
157
+ struct lm_ggml_tensor * wv = nullptr;
158
+ struct lm_ggml_tensor * wo = nullptr;
159
+ struct lm_ggml_tensor * wqkv = nullptr;
160
+ struct lm_ggml_tensor * wq_a = nullptr;
161
+ struct lm_ggml_tensor * wq_b = nullptr;
162
+ struct lm_ggml_tensor * wkv_a_mqa = nullptr;
163
+ struct lm_ggml_tensor * wkv_b = nullptr;
164
+ struct lm_ggml_tensor * wq_cross = nullptr;
165
+ struct lm_ggml_tensor * wk_cross = nullptr;
166
+ struct lm_ggml_tensor * wv_cross = nullptr;
167
+ struct lm_ggml_tensor * wo_cross = nullptr;
168
+ struct lm_ggml_tensor * wq_enc = nullptr;
169
+ struct lm_ggml_tensor * wk_enc = nullptr;
170
+ struct lm_ggml_tensor * wv_enc = nullptr;
171
+ struct lm_ggml_tensor * wo_enc = nullptr;
172
+
173
+ // attention bias
174
+ struct lm_ggml_tensor * bq = nullptr;
175
+ struct lm_ggml_tensor * bk = nullptr;
176
+ struct lm_ggml_tensor * bv = nullptr;
177
+ struct lm_ggml_tensor * bo = nullptr;
178
+ struct lm_ggml_tensor * bqkv = nullptr;
179
+
180
+ // relative position bias
181
+ struct lm_ggml_tensor * attn_rel_b = nullptr;
182
+ struct lm_ggml_tensor * attn_rel_b_enc = nullptr;
183
+ struct lm_ggml_tensor * attn_rel_b_cross = nullptr;
184
+
185
+ // normalization
186
+ struct lm_ggml_tensor * ffn_norm = nullptr;
187
+ struct lm_ggml_tensor * ffn_norm_b = nullptr;
188
+ struct lm_ggml_tensor * ffn_post_norm = nullptr;
189
+ struct lm_ggml_tensor * layer_out_norm = nullptr;
190
+ struct lm_ggml_tensor * layer_out_norm_b = nullptr;
191
+ struct lm_ggml_tensor * ffn_norm_exps = nullptr;
192
+ struct lm_ggml_tensor * ffn_norm_enc = nullptr;
193
+
194
+ // ff
195
+ struct lm_ggml_tensor * ffn_gate = nullptr; // w1
196
+ struct lm_ggml_tensor * ffn_down = nullptr; // w2
197
+ struct lm_ggml_tensor * ffn_up = nullptr; // w3
198
+ struct lm_ggml_tensor * ffn_gate_enc = nullptr;
199
+ struct lm_ggml_tensor * ffn_down_enc = nullptr;
200
+ struct lm_ggml_tensor * ffn_up_enc = nullptr;
201
+
202
+ // ff MoE
203
+ struct lm_ggml_tensor * ffn_gate_inp = nullptr;
204
+ struct lm_ggml_tensor * ffn_gate_exps = nullptr;
205
+ struct lm_ggml_tensor * ffn_down_exps = nullptr;
206
+ struct lm_ggml_tensor * ffn_up_exps = nullptr;
207
+
208
+ // ff shared expert (shexp)
209
+ struct lm_ggml_tensor * ffn_gate_inp_shexp = nullptr;
210
+ struct lm_ggml_tensor * ffn_gate_shexp = nullptr;
211
+ struct lm_ggml_tensor * ffn_down_shexp = nullptr;
212
+ struct lm_ggml_tensor * ffn_up_shexp = nullptr;
213
+
214
+ // ff bias
215
+ struct lm_ggml_tensor * ffn_gate_b = nullptr;
216
+ struct lm_ggml_tensor * ffn_down_b = nullptr; // b2
217
+ struct lm_ggml_tensor * ffn_up_b = nullptr; // b3
218
+ struct lm_ggml_tensor * ffn_act = nullptr;
219
+ struct lm_ggml_tensor * ffn_exp_probs_b = nullptr;
220
+
221
+ // mamba proj
222
+ struct lm_ggml_tensor * ssm_in = nullptr;
223
+ struct lm_ggml_tensor * ssm_x = nullptr;
224
+ struct lm_ggml_tensor * ssm_dt = nullptr;
225
+ struct lm_ggml_tensor * ssm_out = nullptr;
226
+
227
+ // mamba
228
+ struct lm_ggml_tensor * ssm_conv1d = nullptr;
229
+ struct lm_ggml_tensor * ssm_a = nullptr;
230
+ struct lm_ggml_tensor * ssm_d = nullptr;
231
+
232
+ // mamba bias
233
+ struct lm_ggml_tensor * ssm_conv1d_b = nullptr;
234
+ struct lm_ggml_tensor * ssm_dt_b = nullptr;
235
+
236
+ // rwkv
237
+ struct lm_ggml_tensor * time_mix_w1 = nullptr;
238
+ struct lm_ggml_tensor * time_mix_w2 = nullptr;
239
+ struct lm_ggml_tensor * time_mix_lerp_x = nullptr;
240
+ struct lm_ggml_tensor * time_mix_lerp_w = nullptr;
241
+ struct lm_ggml_tensor * time_mix_lerp_k = nullptr;
242
+ struct lm_ggml_tensor * time_mix_lerp_v = nullptr;
243
+ struct lm_ggml_tensor * time_mix_lerp_r = nullptr;
244
+ struct lm_ggml_tensor * time_mix_lerp_g = nullptr;
245
+ struct lm_ggml_tensor * time_mix_lerp_fused = nullptr;
246
+
247
+ struct lm_ggml_tensor * time_mix_first = nullptr;
248
+ struct lm_ggml_tensor * time_mix_decay = nullptr;
249
+ struct lm_ggml_tensor * time_mix_decay_w1 = nullptr;
250
+ struct lm_ggml_tensor * time_mix_decay_w2 = nullptr;
251
+ struct lm_ggml_tensor * time_mix_key = nullptr;
252
+ struct lm_ggml_tensor * time_mix_key_b = nullptr;
253
+ struct lm_ggml_tensor * time_mix_value = nullptr;
254
+ struct lm_ggml_tensor * time_mix_value_b = nullptr;
255
+ struct lm_ggml_tensor * time_mix_receptance = nullptr;
256
+ struct lm_ggml_tensor * time_mix_receptance_b = nullptr;
257
+ struct lm_ggml_tensor * time_mix_gate = nullptr;
258
+
259
+ struct lm_ggml_tensor * time_mix_ln = nullptr;
260
+ struct lm_ggml_tensor * time_mix_ln_b = nullptr;
261
+ struct lm_ggml_tensor * time_mix_output = nullptr;
262
+
263
+ struct lm_ggml_tensor * channel_mix_lerp_k = nullptr;
264
+ struct lm_ggml_tensor * channel_mix_lerp_r = nullptr;
265
+
266
+ struct lm_ggml_tensor * channel_mix_key = nullptr;
267
+ struct lm_ggml_tensor * channel_mix_receptance = nullptr;
268
+ struct lm_ggml_tensor * channel_mix_value = nullptr;
269
+
270
+ // long rope factors
271
+ struct lm_ggml_tensor * rope_long = nullptr;
272
+ struct lm_ggml_tensor * rope_short = nullptr;
273
+ struct lm_ggml_tensor * rope_freqs = nullptr;
274
+
275
+ // bitnet scale
276
+ struct lm_ggml_tensor * wq_scale = nullptr;
277
+ struct lm_ggml_tensor * wk_scale = nullptr;
278
+ struct lm_ggml_tensor * wv_scale = nullptr;
279
+ struct lm_ggml_tensor * wo_scale = nullptr;
280
+ struct lm_ggml_tensor * ffn_gate_scale = nullptr;
281
+ struct lm_ggml_tensor * ffn_up_scale = nullptr;
282
+ struct lm_ggml_tensor * ffn_down_scale = nullptr;
283
+
284
+ struct llama_layer_posnet posnet;
285
+
286
+ struct llama_layer_convnext convnext;
287
+ };
288
+
289
+ struct llama_model {
290
+ llm_type type = LLM_TYPE_UNKNOWN;
291
+ llm_arch arch = LLM_ARCH_UNKNOWN;
292
+
293
+ std::string name = "n/a";
294
+
295
+ llama_hparams hparams = {};
296
+ llama_vocab vocab;
297
+
298
+ struct lm_ggml_tensor * tok_embd = nullptr;
299
+ struct lm_ggml_tensor * type_embd = nullptr;
300
+ struct lm_ggml_tensor * pos_embd = nullptr;
301
+ struct lm_ggml_tensor * tok_norm = nullptr;
302
+ struct lm_ggml_tensor * tok_norm_b = nullptr;
303
+
304
+ struct lm_ggml_tensor * output_norm = nullptr;
305
+ struct lm_ggml_tensor * output_norm_b = nullptr;
306
+ struct lm_ggml_tensor * output = nullptr;
307
+ struct lm_ggml_tensor * output_b = nullptr;
308
+ struct lm_ggml_tensor * output_norm_enc = nullptr;
309
+
310
+ // classifier
311
+ struct lm_ggml_tensor * cls = nullptr;
312
+ struct lm_ggml_tensor * cls_b = nullptr;
313
+ struct lm_ggml_tensor * cls_out = nullptr;
314
+ struct lm_ggml_tensor * cls_out_b = nullptr;
315
+
316
+ struct lm_ggml_tensor * conv1d = nullptr;
317
+ struct lm_ggml_tensor * conv1d_b = nullptr;
318
+
319
+ std::vector<llama_layer> layers;
320
+
321
+ llama_model_params params;
322
+
323
+ // gguf metadata
324
+ std::unordered_map<std::string, std::string> lm_gguf_kv;
325
+
326
+ // list of devices used in this model
327
+ std::vector<lm_ggml_backend_dev_t> devices;
328
+
329
+ // for quantize-stats only
330
+ std::vector<std::pair<std::string, struct lm_ggml_tensor *>> tensors_by_name;
331
+
332
+ int64_t t_load_us = 0;
333
+ int64_t t_start_us = 0;
334
+
335
+ explicit llama_model(const struct llama_model_params & params);
336
+ ~llama_model();
337
+
338
+ void load_stats (llama_model_loader & ml);
339
+ void load_arch (llama_model_loader & ml);
340
+ void load_hparams(llama_model_loader & ml);
341
+ void load_vocab (llama_model_loader & ml);
342
+ bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
343
+
344
+ std::string arch_name() const;
345
+ std::string type_name() const;
346
+
347
+ std::string desc() const;
348
+
349
+ size_t size() const;
350
+ size_t max_nodes() const;
351
+ size_t n_devices() const;
352
+
353
+ // total number of parameters in the model
354
+ uint64_t n_elements() const;
355
+
356
+ void print_info() const;
357
+
358
+ lm_ggml_backend_dev_t dev_layer(int il) const;
359
+ lm_ggml_backend_dev_t dev_output() const;
360
+
361
+ lm_ggml_backend_buffer_type_t select_buft(int il) const;
362
+
363
+ const struct lm_ggml_tensor * get_tensor(const char * name) const;
364
+
365
+ private:
366
+ struct impl;
367
+ std::unique_ptr<impl> pimpl;
368
+ };
369
+
370
+ const char * llm_type_name(llm_type type);