cactus-react-native 0.0.1 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE.txt +20 -0
- package/README.md +3 -1
- package/android/src/main/CMakeLists.txt +58 -23
- package/android/src/main/java/com/cactus/Cactus.java +484 -16
- package/android/src/main/java/com/cactus/LlamaContext.java +199 -0
- package/android/src/main/jni.cpp +325 -10
- package/android/src/main/jniLibs/arm64-v8a/libcactus.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcactus_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/libcactus.so +0 -0
- package/android/src/main/jniLibs/x86_64/libcactus_x86_64.so +0 -0
- package/android/src/newarch/java/com/cactus/CactusModule.java +79 -7
- package/android/src/oldarch/java/com/cactus/CactusModule.java +70 -0
- package/cactus-react-native.podspec +0 -3
- package/ios/CMakeLists.txt +58 -36
- package/ios/Cactus.mm +243 -2
- package/ios/CactusContext.h +22 -0
- package/ios/CactusContext.mm +176 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +92 -5
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +268 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/chat.h +2 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/common.h +42 -51
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-backend.h +4 -4
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-common.h +12 -6
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpp.h +1 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu.h +5 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-impl.h +52 -18
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-metal-impl.h +106 -14
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-opt.h +49 -28
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml.h +87 -106
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-arch.h +16 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-batch.h +2 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-chat.h +7 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-context.h +44 -33
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-cparams.h +1 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-graph.h +83 -17
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-hparams.h +44 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-kv-cache.h +407 -179
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-memory.h +13 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model-loader.h +5 -3
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model-saver.h +37 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-model.h +24 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama-vocab.h +6 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/llama.h +102 -142
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/minja/chat-template.hpp +23 -11
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/minja/minja.hpp +186 -127
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/ggml-llama.metallib +0 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus.h +92 -5
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +268 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/chat.h +2 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/common.h +42 -51
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-backend.h +4 -4
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-common.h +12 -6
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpp.h +1 -1
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu.h +5 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-impl.h +52 -18
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-metal-impl.h +106 -14
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-opt.h +49 -28
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml.h +87 -106
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-arch.h +16 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-batch.h +2 -1
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-chat.h +7 -2
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h +44 -33
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-cparams.h +1 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h +83 -17
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h +44 -2
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-kv-cache.h +407 -179
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-memory.h +13 -2
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-loader.h +5 -3
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-saver.h +37 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-model.h +24 -2
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama-vocab.h +6 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/llama.h +102 -142
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/minja/chat-template.hpp +23 -11
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/minja/minja.hpp +186 -127
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Info.plist +0 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/ggml-llama-sim.metallib +0 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus.h +92 -5
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus_ffi.h +268 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/chat.h +2 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/common.h +42 -51
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-backend.h +4 -4
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-common.h +12 -6
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpp.h +1 -1
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu.h +5 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-impl.h +52 -18
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-metal-impl.h +106 -14
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-opt.h +49 -28
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml.h +87 -106
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-arch.h +16 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-batch.h +2 -1
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-chat.h +7 -2
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-context.h +44 -33
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-cparams.h +1 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-graph.h +83 -17
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-hparams.h +44 -2
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-kv-cache.h +407 -179
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-memory.h +13 -2
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model-loader.h +5 -3
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model-saver.h +37 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-model.h +24 -2
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama-vocab.h +6 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/llama.h +102 -142
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/minja/chat-template.hpp +23 -11
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/minja/minja.hpp +186 -127
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Info.plist +0 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/ggml-llama.metallib +0 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus.h +92 -5
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +268 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/chat.h +2 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/common.h +42 -51
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-backend.h +4 -4
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-common.h +12 -6
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpp.h +1 -1
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu.h +5 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-impl.h +52 -18
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-metal-impl.h +106 -14
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-opt.h +49 -28
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml.h +87 -106
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-arch.h +16 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-batch.h +2 -1
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-chat.h +7 -2
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h +44 -33
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-cparams.h +1 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h +83 -17
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h +44 -2
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-kv-cache.h +407 -179
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-memory.h +13 -2
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-loader.h +5 -3
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model-saver.h +37 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-model.h +24 -2
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-vocab.h +6 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama.h +102 -142
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/minja/chat-template.hpp +23 -11
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/minja/minja.hpp +186 -127
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Info.plist +0 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/ggml-llama-sim.metallib +0 -0
- package/lib/commonjs/NativeCactus.js +1 -0
- package/lib/commonjs/NativeCactus.js.map +1 -1
- package/lib/commonjs/index.js +112 -0
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/tools.js +118 -0
- package/lib/commonjs/tools.js.map +1 -0
- package/lib/module/NativeCactus.js +3 -0
- package/lib/module/NativeCactus.js.map +1 -1
- package/lib/module/index.js +87 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/tools.js +110 -0
- package/lib/module/tools.js.map +1 -0
- package/lib/typescript/NativeCactus.d.ts +30 -1
- package/lib/typescript/NativeCactus.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +21 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/tools.d.ts +38 -0
- package/lib/typescript/tools.d.ts.map +1 -0
- package/package.json +6 -3
- package/src/NativeCactus.ts +62 -1
- package/src/index.ts +113 -2
- package/src/tools.ts +127 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/sgemm.h +0 -14
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/sgemm.h +0 -14
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/sgemm.h +0 -14
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-impl.h +0 -531
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/sgemm.h +0 -14
package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-context.h
CHANGED
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#include "llama-adapter.h"
|
|
8
8
|
|
|
9
9
|
#include "ggml-cpp.h"
|
|
10
|
+
#include "ggml-opt.h"
|
|
10
11
|
|
|
11
12
|
#include <map>
|
|
12
13
|
#include <vector>
|
|
@@ -27,7 +28,12 @@ struct llama_context {
|
|
|
27
28
|
|
|
28
29
|
void synchronize();
|
|
29
30
|
|
|
30
|
-
const llama_model
|
|
31
|
+
const llama_model & get_model() const;
|
|
32
|
+
const llama_cparams & get_cparams() const;
|
|
33
|
+
|
|
34
|
+
lm_ggml_backend_sched_t get_sched() const;
|
|
35
|
+
|
|
36
|
+
lm_ggml_context * get_ctx_compute() const;
|
|
31
37
|
|
|
32
38
|
uint32_t n_ctx() const;
|
|
33
39
|
uint32_t n_ctx_per_seq() const;
|
|
@@ -128,6 +134,32 @@ struct llama_context {
|
|
|
128
134
|
llama_perf_context_data perf_get_data() const;
|
|
129
135
|
void perf_reset();
|
|
130
136
|
|
|
137
|
+
//
|
|
138
|
+
// training
|
|
139
|
+
//
|
|
140
|
+
|
|
141
|
+
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
|
142
|
+
|
|
143
|
+
void opt_epoch(
|
|
144
|
+
lm_ggml_opt_dataset_t dataset,
|
|
145
|
+
lm_ggml_opt_result_t result_train,
|
|
146
|
+
lm_ggml_opt_result_t result_eval,
|
|
147
|
+
int64_t idata_split,
|
|
148
|
+
lm_ggml_opt_epoch_callback callback_train,
|
|
149
|
+
lm_ggml_opt_epoch_callback callback_eval);
|
|
150
|
+
|
|
151
|
+
void opt_epoch_iter(
|
|
152
|
+
lm_ggml_opt_dataset_t dataset,
|
|
153
|
+
lm_ggml_opt_result_t result,
|
|
154
|
+
const std::vector<llama_token> & tokens,
|
|
155
|
+
const std::vector<llama_token> & labels_sparse,
|
|
156
|
+
llama_batch & batch,
|
|
157
|
+
lm_ggml_opt_epoch_callback callback,
|
|
158
|
+
bool train,
|
|
159
|
+
int64_t idata_in_loop,
|
|
160
|
+
int64_t ndata_in_loop,
|
|
161
|
+
int64_t t_loop_start);
|
|
162
|
+
|
|
131
163
|
private:
|
|
132
164
|
//
|
|
133
165
|
// output
|
|
@@ -137,50 +169,30 @@ private:
|
|
|
137
169
|
// Returns max number of outputs for which space was reserved.
|
|
138
170
|
int32_t output_reserve(int32_t n_outputs);
|
|
139
171
|
|
|
140
|
-
// make the outputs have the same order they had in the user-provided batch
|
|
141
|
-
// TODO: maybe remove this
|
|
142
|
-
void output_reorder();
|
|
143
|
-
|
|
144
172
|
//
|
|
145
173
|
// graph
|
|
146
174
|
//
|
|
147
175
|
|
|
176
|
+
public:
|
|
148
177
|
int32_t graph_max_nodes() const;
|
|
149
178
|
|
|
150
179
|
// zero-out inputs and create the ctx_compute for the compute graph
|
|
151
180
|
lm_ggml_cgraph * graph_init();
|
|
152
181
|
|
|
182
|
+
// returns the result of lm_ggml_backend_sched_graph_compute_async execution
|
|
183
|
+
lm_ggml_status graph_compute(
|
|
184
|
+
lm_ggml_cgraph * gf,
|
|
185
|
+
bool batched);
|
|
186
|
+
|
|
187
|
+
private:
|
|
153
188
|
llm_graph_result_ptr graph_build(
|
|
154
189
|
lm_ggml_context * ctx,
|
|
155
190
|
lm_ggml_cgraph * gf,
|
|
156
191
|
const llama_ubatch & ubatch,
|
|
157
192
|
llm_graph_type gtype);
|
|
158
193
|
|
|
159
|
-
// returns the result of lm_ggml_backend_sched_graph_compute_async execution
|
|
160
|
-
lm_ggml_status graph_compute(
|
|
161
|
-
lm_ggml_cgraph * gf,
|
|
162
|
-
bool batched);
|
|
163
|
-
|
|
164
194
|
llm_graph_cb graph_get_cb() const;
|
|
165
195
|
|
|
166
|
-
// used by kv_self_update()
|
|
167
|
-
lm_ggml_tensor * build_rope_shift(
|
|
168
|
-
lm_ggml_context * ctx0,
|
|
169
|
-
lm_ggml_tensor * cur,
|
|
170
|
-
lm_ggml_tensor * shift,
|
|
171
|
-
lm_ggml_tensor * factors,
|
|
172
|
-
float freq_base,
|
|
173
|
-
float freq_scale,
|
|
174
|
-
lm_ggml_backend_buffer * bbuf) const;
|
|
175
|
-
|
|
176
|
-
llm_graph_result_ptr build_kv_self_shift(
|
|
177
|
-
lm_ggml_context * ctx0,
|
|
178
|
-
lm_ggml_cgraph * gf) const;
|
|
179
|
-
|
|
180
|
-
llm_graph_result_ptr build_kv_self_defrag(
|
|
181
|
-
lm_ggml_context * ctx0,
|
|
182
|
-
lm_ggml_cgraph * gf) const;
|
|
183
|
-
|
|
184
196
|
// TODO: read/write lora adapters and cvec
|
|
185
197
|
size_t state_write_data(llama_io_write_i & io);
|
|
186
198
|
size_t state_read_data (llama_io_read_i & io);
|
|
@@ -197,14 +209,10 @@ private:
|
|
|
197
209
|
llama_cparams cparams;
|
|
198
210
|
llama_adapter_cvec cvec;
|
|
199
211
|
llama_adapter_loras loras;
|
|
200
|
-
llama_sbatch sbatch;
|
|
201
212
|
|
|
202
213
|
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
|
203
214
|
|
|
204
|
-
std::unique_ptr<
|
|
205
|
-
|
|
206
|
-
// TODO: remove
|
|
207
|
-
bool logits_all = false;
|
|
215
|
+
std::unique_ptr<llama_memory_i> memory;
|
|
208
216
|
|
|
209
217
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
|
210
218
|
size_t logits_size = 0; // capacity (of floats) for logits
|
|
@@ -231,6 +239,9 @@ private:
|
|
|
231
239
|
|
|
232
240
|
lm_ggml_context_ptr ctx_compute;
|
|
233
241
|
|
|
242
|
+
// training
|
|
243
|
+
lm_ggml_opt_context_t opt_ctx = nullptr;
|
|
244
|
+
|
|
234
245
|
lm_ggml_threadpool_t threadpool = nullptr;
|
|
235
246
|
lm_ggml_threadpool_t threadpool_batch = nullptr;
|
|
236
247
|
|
package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-graph.h
CHANGED
|
@@ -19,6 +19,8 @@ struct llama_cparams;
|
|
|
19
19
|
|
|
20
20
|
class llama_memory_i;
|
|
21
21
|
class llama_kv_cache_unified;
|
|
22
|
+
class llama_kv_cache_unified_iswa;
|
|
23
|
+
class llama_kv_cache_recurrent;
|
|
22
24
|
|
|
23
25
|
// certain models (typically multi-modal) can produce different types of graphs
|
|
24
26
|
enum llm_graph_type {
|
|
@@ -90,14 +92,29 @@ public:
|
|
|
90
92
|
|
|
91
93
|
class llm_graph_input_pos : public llm_graph_input_i {
|
|
92
94
|
public:
|
|
93
|
-
llm_graph_input_pos(int64_t
|
|
95
|
+
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
|
94
96
|
virtual ~llm_graph_input_pos() = default;
|
|
95
97
|
|
|
96
98
|
void set_input(const llama_ubatch * ubatch) override;
|
|
97
99
|
|
|
98
100
|
lm_ggml_tensor * pos = nullptr; // I32 [n_batch]
|
|
99
101
|
|
|
100
|
-
const int64_t
|
|
102
|
+
const int64_t n_pos_per_embd = 1;
|
|
103
|
+
};
|
|
104
|
+
|
|
105
|
+
// temperature tuning, used by llama4
|
|
106
|
+
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
|
107
|
+
public:
|
|
108
|
+
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
|
109
|
+
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
|
110
|
+
virtual ~llm_graph_input_attn_temp() = default;
|
|
111
|
+
|
|
112
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
113
|
+
|
|
114
|
+
lm_ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
|
|
115
|
+
|
|
116
|
+
const uint32_t n_attn_temp_floor_scale;
|
|
117
|
+
const float f_attn_temp_scale;
|
|
101
118
|
};
|
|
102
119
|
|
|
103
120
|
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
|
@@ -171,26 +188,26 @@ public:
|
|
|
171
188
|
|
|
172
189
|
class llm_graph_input_s_copy : public llm_graph_input_i {
|
|
173
190
|
public:
|
|
174
|
-
llm_graph_input_s_copy(const
|
|
191
|
+
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
|
175
192
|
virtual ~llm_graph_input_s_copy() = default;
|
|
176
193
|
|
|
177
194
|
void set_input(const llama_ubatch * ubatch) override;
|
|
178
195
|
|
|
179
196
|
lm_ggml_tensor * s_copy; // I32 [kv_size]
|
|
180
197
|
|
|
181
|
-
const
|
|
198
|
+
const llama_kv_cache_recurrent * kv_self;
|
|
182
199
|
};
|
|
183
200
|
|
|
184
201
|
class llm_graph_input_s_mask : public llm_graph_input_i {
|
|
185
202
|
public:
|
|
186
|
-
llm_graph_input_s_mask(const
|
|
203
|
+
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
|
187
204
|
virtual ~llm_graph_input_s_mask() = default;
|
|
188
205
|
|
|
189
206
|
void set_input(const llama_ubatch * ubatch) override;
|
|
190
207
|
|
|
191
208
|
lm_ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
192
209
|
|
|
193
|
-
const
|
|
210
|
+
const llama_kv_cache_recurrent * kv_self;
|
|
194
211
|
};
|
|
195
212
|
|
|
196
213
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -239,6 +256,31 @@ public:
|
|
|
239
256
|
|
|
240
257
|
void set_input(const llama_ubatch * ubatch) override;
|
|
241
258
|
|
|
259
|
+
lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
260
|
+
|
|
261
|
+
lm_ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
|
262
|
+
lm_ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
|
263
|
+
|
|
264
|
+
const llama_hparams & hparams;
|
|
265
|
+
const llama_cparams & cparams;
|
|
266
|
+
|
|
267
|
+
const llama_kv_cache_unified * kv_self;
|
|
268
|
+
};
|
|
269
|
+
|
|
270
|
+
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
271
|
+
public:
|
|
272
|
+
llm_graph_input_attn_kv_unified_iswa(
|
|
273
|
+
const llama_hparams & hparams,
|
|
274
|
+
const llama_cparams & cparams,
|
|
275
|
+
const llama_kv_cache_unified_iswa * kv_self) :
|
|
276
|
+
hparams(hparams),
|
|
277
|
+
cparams(cparams),
|
|
278
|
+
kv_self(kv_self) {
|
|
279
|
+
}
|
|
280
|
+
~llm_graph_input_attn_kv_unified_iswa() = default;
|
|
281
|
+
|
|
282
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
283
|
+
|
|
242
284
|
lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
243
285
|
lm_ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
|
244
286
|
|
|
@@ -250,7 +292,7 @@ public:
|
|
|
250
292
|
const llama_hparams & hparams;
|
|
251
293
|
const llama_cparams & cparams;
|
|
252
294
|
|
|
253
|
-
const
|
|
295
|
+
const llama_kv_cache_unified_iswa * kv_self;
|
|
254
296
|
};
|
|
255
297
|
|
|
256
298
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -282,6 +324,7 @@ class llm_graph_result_i {
|
|
|
282
324
|
public:
|
|
283
325
|
virtual ~llm_graph_result_i() = default;
|
|
284
326
|
|
|
327
|
+
virtual lm_ggml_tensor * get_tokens() = 0;
|
|
285
328
|
virtual lm_ggml_tensor * get_logits() = 0;
|
|
286
329
|
virtual lm_ggml_tensor * get_embd() = 0;
|
|
287
330
|
virtual lm_ggml_tensor * get_embd_pooled() = 0;
|
|
@@ -296,6 +339,7 @@ class llm_graph_result : public llm_graph_result_i {
|
|
|
296
339
|
public:
|
|
297
340
|
virtual ~llm_graph_result() = default;
|
|
298
341
|
|
|
342
|
+
lm_ggml_tensor * get_tokens() override { return t_tokens; }
|
|
299
343
|
lm_ggml_tensor * get_logits() override { return t_logits; }
|
|
300
344
|
lm_ggml_tensor * get_embd() override { return t_embd; }
|
|
301
345
|
lm_ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
|
@@ -312,6 +356,7 @@ public:
|
|
|
312
356
|
}
|
|
313
357
|
|
|
314
358
|
// important graph nodes
|
|
359
|
+
lm_ggml_tensor * t_tokens = nullptr;
|
|
315
360
|
lm_ggml_tensor * t_logits = nullptr;
|
|
316
361
|
lm_ggml_tensor * t_embd = nullptr;
|
|
317
362
|
lm_ggml_tensor * t_embd_pooled = nullptr;
|
|
@@ -335,8 +380,8 @@ struct llm_graph_params {
|
|
|
335
380
|
const llama_cparams & cparams;
|
|
336
381
|
const llama_ubatch & ubatch;
|
|
337
382
|
|
|
338
|
-
|
|
339
|
-
|
|
383
|
+
lm_ggml_backend_sched_t sched;
|
|
384
|
+
lm_ggml_backend_t backend_cpu;
|
|
340
385
|
|
|
341
386
|
const llama_adapter_cvec * cvec;
|
|
342
387
|
const llama_adapter_loras * loras;
|
|
@@ -359,7 +404,6 @@ struct llm_graph_context {
|
|
|
359
404
|
const int64_t n_layer;
|
|
360
405
|
const int64_t n_rot;
|
|
361
406
|
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
|
362
|
-
const int64_t n_ctx_per_seq;
|
|
363
407
|
const int64_t n_head;
|
|
364
408
|
const int64_t n_head_kv;
|
|
365
409
|
const int64_t n_embd_head_k;
|
|
@@ -387,9 +431,9 @@ struct llm_graph_context {
|
|
|
387
431
|
|
|
388
432
|
lm_ggml_context * ctx0 = nullptr;
|
|
389
433
|
|
|
390
|
-
|
|
434
|
+
lm_ggml_backend_sched_t sched;
|
|
391
435
|
|
|
392
|
-
|
|
436
|
+
lm_ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
393
437
|
|
|
394
438
|
const llama_adapter_cvec * cvec;
|
|
395
439
|
const llama_adapter_loras * loras;
|
|
@@ -402,7 +446,7 @@ struct llm_graph_context {
|
|
|
402
446
|
|
|
403
447
|
llm_graph_context(const llm_graph_params & params);
|
|
404
448
|
|
|
405
|
-
int64_t
|
|
449
|
+
int64_t n_pos_per_embd() const;
|
|
406
450
|
|
|
407
451
|
void cb(lm_ggml_tensor * cur, const char * name, int il) const;
|
|
408
452
|
|
|
@@ -470,6 +514,7 @@ struct llm_graph_context {
|
|
|
470
514
|
|
|
471
515
|
lm_ggml_tensor * build_inp_embd(lm_ggml_tensor * tok_embd) const;
|
|
472
516
|
lm_ggml_tensor * build_inp_pos() const;
|
|
517
|
+
lm_ggml_tensor * build_inp_attn_scale() const;
|
|
473
518
|
lm_ggml_tensor * build_inp_out_ids() const;
|
|
474
519
|
lm_ggml_tensor * build_inp_mean() const;
|
|
475
520
|
lm_ggml_tensor * build_inp_cls() const;
|
|
@@ -487,12 +532,12 @@ struct llm_graph_context {
|
|
|
487
532
|
|
|
488
533
|
lm_ggml_tensor * build_attn_mha(
|
|
489
534
|
lm_ggml_cgraph * gf,
|
|
490
|
-
lm_ggml_tensor * q,
|
|
491
|
-
lm_ggml_tensor * k,
|
|
492
|
-
lm_ggml_tensor * v,
|
|
535
|
+
lm_ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
|
536
|
+
lm_ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
|
537
|
+
lm_ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
|
493
538
|
lm_ggml_tensor * kq_b,
|
|
494
539
|
lm_ggml_tensor * kq_mask,
|
|
495
|
-
|
|
540
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
496
541
|
float kq_scale) const;
|
|
497
542
|
|
|
498
543
|
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
|
@@ -506,6 +551,7 @@ struct llm_graph_context {
|
|
|
506
551
|
lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
507
552
|
lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
508
553
|
lm_ggml_tensor * kq_b,
|
|
554
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
509
555
|
float kq_scale,
|
|
510
556
|
int il) const;
|
|
511
557
|
|
|
@@ -520,6 +566,22 @@ struct llm_graph_context {
|
|
|
520
566
|
lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
521
567
|
lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
522
568
|
lm_ggml_tensor * kq_b,
|
|
569
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
570
|
+
float kq_scale,
|
|
571
|
+
int il) const;
|
|
572
|
+
|
|
573
|
+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
|
574
|
+
|
|
575
|
+
lm_ggml_tensor * build_attn(
|
|
576
|
+
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
577
|
+
lm_ggml_cgraph * gf,
|
|
578
|
+
lm_ggml_tensor * wo,
|
|
579
|
+
lm_ggml_tensor * wo_b,
|
|
580
|
+
lm_ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
581
|
+
lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
582
|
+
lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
583
|
+
lm_ggml_tensor * kq_b,
|
|
584
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
523
585
|
float kq_scale,
|
|
524
586
|
int il) const;
|
|
525
587
|
|
|
@@ -534,6 +596,7 @@ struct llm_graph_context {
|
|
|
534
596
|
lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
535
597
|
lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
536
598
|
lm_ggml_tensor * kq_b,
|
|
599
|
+
lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
537
600
|
float kq_scale,
|
|
538
601
|
int il) const;
|
|
539
602
|
|
|
@@ -572,3 +635,6 @@ struct llm_graph_context {
|
|
|
572
635
|
lm_ggml_tensor * cls_out,
|
|
573
636
|
lm_ggml_tensor * cls_out_b) const;
|
|
574
637
|
};
|
|
638
|
+
|
|
639
|
+
// TODO: better name
|
|
640
|
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/llama-hparams.h
CHANGED
|
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
|
|
|
14
14
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
|
15
15
|
};
|
|
16
16
|
|
|
17
|
+
enum llama_swa_type {
|
|
18
|
+
LLAMA_SWA_TYPE_NONE = 0,
|
|
19
|
+
LLAMA_SWA_TYPE_STANDARD = 1,
|
|
20
|
+
LLAMA_SWA_TYPE_CHUNKED = 2,
|
|
21
|
+
};
|
|
22
|
+
|
|
17
23
|
struct llama_hparams_posnet {
|
|
18
24
|
uint32_t n_embd;
|
|
19
25
|
uint32_t n_layer;
|
|
@@ -35,14 +41,16 @@ struct llama_hparams {
|
|
|
35
41
|
uint32_t n_embd_features = 0;
|
|
36
42
|
uint32_t n_layer;
|
|
37
43
|
uint32_t n_rot;
|
|
38
|
-
uint32_t n_swa = 0; // sliding window attention (SWA)
|
|
39
|
-
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
|
40
44
|
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
|
41
45
|
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
|
42
46
|
uint32_t n_expert = 0;
|
|
43
47
|
uint32_t n_expert_used = 0;
|
|
44
48
|
uint32_t n_rel_attn_bkts = 0;
|
|
45
49
|
|
|
50
|
+
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
|
51
|
+
uint32_t n_embd_head_k_mla = 0;
|
|
52
|
+
uint32_t n_embd_head_v_mla = 0;
|
|
53
|
+
|
|
46
54
|
// for WavTokenizer
|
|
47
55
|
struct llama_hparams_posnet posnet;
|
|
48
56
|
struct llama_hparams_convnext convnext;
|
|
@@ -62,6 +70,7 @@ struct llama_hparams {
|
|
|
62
70
|
float expert_weights_scale = 0.0;
|
|
63
71
|
bool expert_weights_norm = false;
|
|
64
72
|
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
|
73
|
+
uint32_t moe_every_n_layers = 0;
|
|
65
74
|
|
|
66
75
|
float f_norm_eps;
|
|
67
76
|
float f_norm_rms_eps;
|
|
@@ -91,6 +100,15 @@ struct llama_hparams {
|
|
|
91
100
|
|
|
92
101
|
std::array<int, 4> rope_sections;
|
|
93
102
|
|
|
103
|
+
// Sliding Window Attention (SWA)
|
|
104
|
+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
|
105
|
+
// the size of the sliding window (0 - no SWA)
|
|
106
|
+
uint32_t n_swa = 0;
|
|
107
|
+
// if swa_layers[il] == true, then layer il is SWA
|
|
108
|
+
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
|
109
|
+
// by default, all layers are dense
|
|
110
|
+
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
|
111
|
+
|
|
94
112
|
// for State Space Models
|
|
95
113
|
uint32_t ssm_d_conv = 0;
|
|
96
114
|
uint32_t ssm_d_inner = 0;
|
|
@@ -111,6 +129,13 @@ struct llama_hparams {
|
|
|
111
129
|
bool causal_attn = true;
|
|
112
130
|
bool use_alibi = false;
|
|
113
131
|
bool attn_soft_cap = false;
|
|
132
|
+
bool use_kq_norm = true;
|
|
133
|
+
|
|
134
|
+
// llama4
|
|
135
|
+
uint32_t n_moe_layer_step = 0;
|
|
136
|
+
uint32_t n_no_rope_layer_step = 4;
|
|
137
|
+
uint32_t n_attn_temp_floor_scale = 8192;
|
|
138
|
+
float f_attn_temp_scale = 0.1;
|
|
114
139
|
|
|
115
140
|
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
|
116
141
|
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
|
@@ -120,6 +145,23 @@ struct llama_hparams {
|
|
|
120
145
|
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
|
121
146
|
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
|
122
147
|
|
|
148
|
+
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
|
149
|
+
// note that if n_pattern == 0, all layers are SWA
|
|
150
|
+
// if n_pattern == 1, all layers are dense
|
|
151
|
+
// example: n_pattern = 3
|
|
152
|
+
// il == 0: swa
|
|
153
|
+
// il == 1: swa
|
|
154
|
+
// il == 2: dense
|
|
155
|
+
// il == 3: swa
|
|
156
|
+
// il == 4: swa
|
|
157
|
+
// il == 5: dense
|
|
158
|
+
// il == 6: swa
|
|
159
|
+
// etc ...
|
|
160
|
+
void set_swa_pattern(uint32_t n_pattern);
|
|
161
|
+
|
|
162
|
+
// return true if one of the layers is SWA
|
|
163
|
+
bool is_swa_any() const;
|
|
164
|
+
|
|
123
165
|
uint32_t n_head(uint32_t il = 0) const;
|
|
124
166
|
|
|
125
167
|
uint32_t n_head_kv(uint32_t il = 0) const;
|