cui-llama.rn 1.6.1 → 1.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
- package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +38 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/llama-graph.h
CHANGED
@@ -19,6 +19,7 @@ struct llama_cparams;
|
|
19
19
|
|
20
20
|
class llama_memory_i;
|
21
21
|
class llama_kv_cache_unified;
|
22
|
+
class llama_kv_cache_unified_iswa;
|
22
23
|
class llama_kv_cache_recurrent;
|
23
24
|
|
24
25
|
// certain models (typically multi-modal) can produce different types of graphs
|
@@ -255,6 +256,31 @@ public:
|
|
255
256
|
|
256
257
|
void set_input(const llama_ubatch * ubatch) override;
|
257
258
|
|
259
|
+
lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
260
|
+
|
261
|
+
lm_ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
262
|
+
lm_ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
263
|
+
|
264
|
+
const llama_hparams & hparams;
|
265
|
+
const llama_cparams & cparams;
|
266
|
+
|
267
|
+
const llama_kv_cache_unified * kv_self;
|
268
|
+
};
|
269
|
+
|
270
|
+
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
271
|
+
public:
|
272
|
+
llm_graph_input_attn_kv_unified_iswa(
|
273
|
+
const llama_hparams & hparams,
|
274
|
+
const llama_cparams & cparams,
|
275
|
+
const llama_kv_cache_unified_iswa * kv_self) :
|
276
|
+
hparams(hparams),
|
277
|
+
cparams(cparams),
|
278
|
+
kv_self(kv_self) {
|
279
|
+
}
|
280
|
+
~llm_graph_input_attn_kv_unified_iswa() = default;
|
281
|
+
|
282
|
+
void set_input(const llama_ubatch * ubatch) override;
|
283
|
+
|
258
284
|
lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
259
285
|
lm_ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
260
286
|
|
@@ -266,7 +292,7 @@ public:
|
|
266
292
|
const llama_hparams & hparams;
|
267
293
|
const llama_cparams & cparams;
|
268
294
|
|
269
|
-
const
|
295
|
+
const llama_kv_cache_unified_iswa * kv_self;
|
270
296
|
};
|
271
297
|
|
272
298
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
@@ -298,6 +324,7 @@ class llm_graph_result_i {
|
|
298
324
|
public:
|
299
325
|
virtual ~llm_graph_result_i() = default;
|
300
326
|
|
327
|
+
virtual lm_ggml_tensor * get_tokens() = 0;
|
301
328
|
virtual lm_ggml_tensor * get_logits() = 0;
|
302
329
|
virtual lm_ggml_tensor * get_embd() = 0;
|
303
330
|
virtual lm_ggml_tensor * get_embd_pooled() = 0;
|
@@ -312,6 +339,7 @@ class llm_graph_result : public llm_graph_result_i {
|
|
312
339
|
public:
|
313
340
|
virtual ~llm_graph_result() = default;
|
314
341
|
|
342
|
+
lm_ggml_tensor * get_tokens() override { return t_tokens; }
|
315
343
|
lm_ggml_tensor * get_logits() override { return t_logits; }
|
316
344
|
lm_ggml_tensor * get_embd() override { return t_embd; }
|
317
345
|
lm_ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
@@ -328,6 +356,7 @@ public:
|
|
328
356
|
}
|
329
357
|
|
330
358
|
// important graph nodes
|
359
|
+
lm_ggml_tensor * t_tokens = nullptr;
|
331
360
|
lm_ggml_tensor * t_logits = nullptr;
|
332
361
|
lm_ggml_tensor * t_embd = nullptr;
|
333
362
|
lm_ggml_tensor * t_embd_pooled = nullptr;
|
@@ -375,7 +404,6 @@ struct llm_graph_context {
|
|
375
404
|
const int64_t n_layer;
|
376
405
|
const int64_t n_rot;
|
377
406
|
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
378
|
-
const int64_t n_ctx_per_seq;
|
379
407
|
const int64_t n_head;
|
380
408
|
const int64_t n_head_kv;
|
381
409
|
const int64_t n_embd_head_k;
|
@@ -504,13 +532,12 @@ struct llm_graph_context {
|
|
504
532
|
|
505
533
|
lm_ggml_tensor * build_attn_mha(
|
506
534
|
lm_ggml_cgraph * gf,
|
507
|
-
lm_ggml_tensor * q,
|
508
|
-
lm_ggml_tensor * k,
|
509
|
-
lm_ggml_tensor * v,
|
535
|
+
lm_ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
536
|
+
lm_ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
537
|
+
lm_ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
510
538
|
lm_ggml_tensor * kq_b,
|
511
539
|
lm_ggml_tensor * kq_mask,
|
512
|
-
lm_ggml_tensor * v_mla,
|
513
|
-
bool v_trans,
|
540
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
514
541
|
float kq_scale) const;
|
515
542
|
|
516
543
|
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
@@ -543,6 +570,21 @@ struct llm_graph_context {
|
|
543
570
|
float kq_scale,
|
544
571
|
int il) const;
|
545
572
|
|
573
|
+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
574
|
+
|
575
|
+
lm_ggml_tensor * build_attn(
|
576
|
+
llm_graph_input_attn_kv_unified_iswa * inp,
|
577
|
+
lm_ggml_cgraph * gf,
|
578
|
+
lm_ggml_tensor * wo,
|
579
|
+
lm_ggml_tensor * wo_b,
|
580
|
+
lm_ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
581
|
+
lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
582
|
+
lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
583
|
+
lm_ggml_tensor * kq_b,
|
584
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
585
|
+
float kq_scale,
|
586
|
+
int il) const;
|
587
|
+
|
546
588
|
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
547
589
|
|
548
590
|
lm_ggml_tensor * build_attn(
|
@@ -593,3 +635,6 @@ struct llm_graph_context {
|
|
593
635
|
lm_ggml_tensor * cls_out,
|
594
636
|
lm_ggml_tensor * cls_out_b) const;
|
595
637
|
};
|
638
|
+
|
639
|
+
// TODO: better name
|
640
|
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
package/cpp/llama-hparams.cpp
CHANGED
@@ -2,6 +2,22 @@
|
|
2
2
|
|
3
3
|
#include "ggml.h"
|
4
4
|
|
5
|
+
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
6
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
7
|
+
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
8
|
+
}
|
9
|
+
}
|
10
|
+
|
11
|
+
bool llama_hparams::is_swa_any() const {
|
12
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
13
|
+
if (swa_layers[il]) {
|
14
|
+
return true;
|
15
|
+
}
|
16
|
+
}
|
17
|
+
|
18
|
+
return false;
|
19
|
+
}
|
20
|
+
|
5
21
|
uint32_t llama_hparams::n_head(uint32_t il) const {
|
6
22
|
if (il < n_layer) {
|
7
23
|
return n_head_arr[il];
|
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
72
88
|
|
73
89
|
bool llama_hparams::is_swa(uint32_t il) const {
|
74
90
|
if (il < n_layer) {
|
75
|
-
return
|
91
|
+
return swa_layers[il];
|
76
92
|
}
|
77
93
|
|
78
94
|
LM_GGML_ABORT("fatal error");
|
package/cpp/llama-hparams.h
CHANGED
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
|
|
14
14
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
15
15
|
};
|
16
16
|
|
17
|
+
enum llama_swa_type {
|
18
|
+
LLAMA_SWA_TYPE_NONE = 0,
|
19
|
+
LLAMA_SWA_TYPE_STANDARD = 1,
|
20
|
+
LLAMA_SWA_TYPE_CHUNKED = 2,
|
21
|
+
};
|
22
|
+
|
17
23
|
struct llama_hparams_posnet {
|
18
24
|
uint32_t n_embd;
|
19
25
|
uint32_t n_layer;
|
@@ -35,8 +41,6 @@ struct llama_hparams {
|
|
35
41
|
uint32_t n_embd_features = 0;
|
36
42
|
uint32_t n_layer;
|
37
43
|
uint32_t n_rot;
|
38
|
-
uint32_t n_swa = 0; // sliding window attention (SWA)
|
39
|
-
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
40
44
|
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
41
45
|
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
42
46
|
uint32_t n_expert = 0;
|
@@ -96,6 +100,15 @@ struct llama_hparams {
|
|
96
100
|
|
97
101
|
std::array<int, 4> rope_sections;
|
98
102
|
|
103
|
+
// Sliding Window Attention (SWA)
|
104
|
+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
105
|
+
// the size of the sliding window (0 - no SWA)
|
106
|
+
uint32_t n_swa = 0;
|
107
|
+
// if swa_layers[il] == true, then layer il is SWA
|
108
|
+
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
109
|
+
// by default, all layers are dense
|
110
|
+
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
111
|
+
|
99
112
|
// for State Space Models
|
100
113
|
uint32_t ssm_d_conv = 0;
|
101
114
|
uint32_t ssm_d_inner = 0;
|
@@ -116,11 +129,10 @@ struct llama_hparams {
|
|
116
129
|
bool causal_attn = true;
|
117
130
|
bool use_alibi = false;
|
118
131
|
bool attn_soft_cap = false;
|
132
|
+
bool use_kq_norm = true;
|
119
133
|
|
134
|
+
// llama4
|
120
135
|
uint32_t n_moe_layer_step = 0;
|
121
|
-
bool use_kq_norm = true;
|
122
|
-
uint32_t n_attn_chunk = 0;
|
123
|
-
// values below seems to be fixed on llama4
|
124
136
|
uint32_t n_no_rope_layer_step = 4;
|
125
137
|
uint32_t n_attn_temp_floor_scale = 8192;
|
126
138
|
float f_attn_temp_scale = 0.1;
|
@@ -133,6 +145,23 @@ struct llama_hparams {
|
|
133
145
|
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
134
146
|
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
135
147
|
|
148
|
+
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
149
|
+
// note that if n_pattern == 0, all layers are SWA
|
150
|
+
// if n_pattern == 1, all layers are dense
|
151
|
+
// example: n_pattern = 3
|
152
|
+
// il == 0: swa
|
153
|
+
// il == 1: swa
|
154
|
+
// il == 2: dense
|
155
|
+
// il == 3: swa
|
156
|
+
// il == 4: swa
|
157
|
+
// il == 5: dense
|
158
|
+
// il == 6: swa
|
159
|
+
// etc ...
|
160
|
+
void set_swa_pattern(uint32_t n_pattern);
|
161
|
+
|
162
|
+
// return true if one of the layers is SWA
|
163
|
+
bool is_swa_any() const;
|
164
|
+
|
136
165
|
uint32_t n_head(uint32_t il = 0) const;
|
137
166
|
|
138
167
|
uint32_t n_head_kv(uint32_t il = 0) const;
|