@fugood/llama.node 0.2.0 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +9 -0
- package/README.md +1 -1
- package/bin/darwin/arm64/default.metallib +0 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/default.metallib +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +1 -1
- package/package.json +2 -1
- package/patches/llama.patch +22 -0
- package/src/LlamaContext.cpp +2 -2
- package/src/TokenizeWorker.cpp +1 -1
- package/src/llama.cpp/CMakeLists.txt +82 -54
- package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +16 -0
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +6 -0
- package/src/llama.cpp/common/common.cpp +748 -754
- package/src/llama.cpp/common/common.h +49 -41
- package/src/llama.cpp/common/grammar-parser.cpp +10 -1
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +6 -6
- package/src/llama.cpp/common/log.h +5 -5
- package/src/llama.cpp/common/sampling.cpp +92 -10
- package/src/llama.cpp/common/sampling.h +6 -1
- package/src/llama.cpp/common/train.cpp +2 -2
- package/src/llama.cpp/examples/CMakeLists.txt +3 -0
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +13 -4
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +2 -2
- package/src/llama.cpp/examples/finetune/finetune.cpp +4 -3
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -2
- package/src/llama.cpp/examples/infill/infill.cpp +8 -8
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +57 -8
- package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +55 -0
- package/src/llama.cpp/examples/llama.android/{app → llama}/src/main/cpp/CMakeLists.txt +7 -8
- package/src/llama.cpp/examples/llama.android/{app → llama}/src/main/cpp/llama-android.cpp +14 -14
- package/src/llama.cpp/examples/llava/clip.h +1 -1
- package/src/llama.cpp/examples/llava/llava-cli.cpp +27 -7
- package/src/llama.cpp/examples/llava/llava.cpp +0 -15
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +29 -17
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +9 -9
- package/src/llama.cpp/examples/quantize/quantize.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +2 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +2 -0
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +134 -0
- package/src/llama.cpp/examples/server/server.cpp +33 -25
- package/src/llama.cpp/examples/server/utils.hpp +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +359 -9
- package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +4 -3
- package/src/llama.cpp/ggml-backend.c +2 -3
- package/src/llama.cpp/ggml-common.h +0 -54
- package/src/llama.cpp/ggml-cuda.h +1 -0
- package/src/llama.cpp/ggml-impl.h +51 -0
- package/src/llama.cpp/ggml-kompute.cpp +13 -3
- package/src/llama.cpp/ggml-opencl.cpp +4 -1
- package/src/llama.cpp/ggml-quants.c +3715 -2050
- package/src/llama.cpp/ggml-rpc.cpp +1155 -0
- package/src/llama.cpp/ggml-rpc.h +24 -0
- package/src/llama.cpp/ggml-sycl.cpp +119 -673
- package/src/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- package/src/llama.cpp/ggml-vulkan.cpp +203 -224
- package/src/llama.cpp/ggml.c +1208 -1483
- package/src/llama.cpp/ggml.h +71 -46
- package/src/llama.cpp/llama.cpp +1374 -938
- package/src/llama.cpp/llama.h +22 -6
- package/src/llama.cpp/requirements.txt +0 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -1
- package/src/llama.cpp/tests/test-backend-ops.cpp +120 -57
- package/src/llama.cpp/tests/test-chat-template.cpp +16 -4
- package/src/llama.cpp/tests/test-grad0.cpp +43 -83
- package/src/llama.cpp/tests/test-grammar-integration.cpp +46 -0
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +27 -3
- package/src/llama.cpp/unicode-data.cpp +6969 -2169
- package/src/llama.cpp/unicode-data.h +15 -12
- package/src/llama.cpp/unicode.cpp +89 -111
- package/src/llama.cpp/unicode.h +44 -12
- package/src/llama.cpp/build.zig +0 -172
- package/src/llama.cpp/ggml-mpi.c +0 -216
- package/src/llama.cpp/ggml-mpi.h +0 -39
- package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +0 -2
- package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +0 -2
package/src/llama.cpp/llama.h
CHANGED
|
@@ -81,9 +81,11 @@ extern "C" {
|
|
|
81
81
|
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
|
82
82
|
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
|
83
83
|
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
84
|
+
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
|
85
|
+
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
|
86
|
+
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
|
87
|
+
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
|
88
|
+
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
|
87
89
|
};
|
|
88
90
|
|
|
89
91
|
// note: these values should be synchronized with ggml_rope
|
|
@@ -242,6 +244,9 @@ extern "C" {
|
|
|
242
244
|
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
|
243
245
|
const float * tensor_split;
|
|
244
246
|
|
|
247
|
+
// comma separated list of RPC servers to use for offloading
|
|
248
|
+
const char * rpc_servers;
|
|
249
|
+
|
|
245
250
|
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
|
246
251
|
// If the provided progress_callback returns true, model loading continues.
|
|
247
252
|
// If it returns false, model loading is immediately aborted.
|
|
@@ -260,6 +265,8 @@ extern "C" {
|
|
|
260
265
|
bool check_tensors; // validate model tensor data
|
|
261
266
|
};
|
|
262
267
|
|
|
268
|
+
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
|
269
|
+
// https://github.com/ggerganov/llama.cpp/pull/7544
|
|
263
270
|
struct llama_context_params {
|
|
264
271
|
uint32_t seed; // RNG seed, -1 for random
|
|
265
272
|
uint32_t n_ctx; // text context, 0 = from model
|
|
@@ -286,14 +293,14 @@ extern "C" {
|
|
|
286
293
|
ggml_backend_sched_eval_callback cb_eval;
|
|
287
294
|
void * cb_eval_user_data;
|
|
288
295
|
|
|
289
|
-
enum ggml_type type_k; // data type for K cache
|
|
290
|
-
enum ggml_type type_v; // data type for V cache
|
|
296
|
+
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
|
297
|
+
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
|
291
298
|
|
|
292
299
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
|
293
300
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
|
294
301
|
bool embeddings; // if true, extract embeddings (together with logits)
|
|
295
302
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
|
296
|
-
bool flash_attn; // whether to use flash attention
|
|
303
|
+
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
|
297
304
|
|
|
298
305
|
// Abort callback
|
|
299
306
|
// if it returns true, execution of llama_decode() will be aborted
|
|
@@ -755,6 +762,12 @@ extern "C" {
|
|
|
755
762
|
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
|
756
763
|
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
|
757
764
|
|
|
765
|
+
// Get the number of threads used for generation of a single token.
|
|
766
|
+
LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
|
|
767
|
+
|
|
768
|
+
// Get the number of threads used for prompt and batch processing (multiple token).
|
|
769
|
+
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
|
|
770
|
+
|
|
758
771
|
// Set whether to use causal attention or not
|
|
759
772
|
// If set to true, the model will only attend to the past tokens
|
|
760
773
|
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
|
@@ -813,6 +826,9 @@ extern "C" {
|
|
|
813
826
|
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
|
|
814
827
|
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
|
|
815
828
|
|
|
829
|
+
// Identify if Token Id is a control token or a render-able token
|
|
830
|
+
LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
|
|
831
|
+
|
|
816
832
|
// Special tokens
|
|
817
833
|
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
|
818
834
|
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
|
@@ -9,5 +9,3 @@
|
|
|
9
9
|
-r ./requirements/requirements-convert-hf-to-gguf.txt
|
|
10
10
|
-r ./requirements/requirements-convert-hf-to-gguf-update.txt
|
|
11
11
|
-r ./requirements/requirements-convert-llama-ggml-to-gguf.txt
|
|
12
|
-
-r ./requirements/requirements-convert-lora-to-ggml.txt
|
|
13
|
-
-r ./requirements/requirements-convert-persimmon-to-gguf.txt
|
|
@@ -92,7 +92,7 @@ target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
|
|
|
92
92
|
install(TARGETS test-tokenizer-1-bpe RUNTIME)
|
|
93
93
|
|
|
94
94
|
# TODO: disabled due to slowness
|
|
95
|
-
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
|
|
95
|
+
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
|
|
96
96
|
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
|
97
97
|
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
|
98
98
|
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
#include <ggml-alloc.h>
|
|
3
3
|
#include <ggml-backend.h>
|
|
4
4
|
#include <ggml-backend-impl.h>
|
|
5
|
+
|
|
5
6
|
#include <algorithm>
|
|
6
7
|
#include <array>
|
|
7
8
|
#include <cfloat>
|
|
@@ -15,6 +16,7 @@
|
|
|
15
16
|
#include <thread>
|
|
16
17
|
#include <vector>
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
|
19
21
|
// static RNG initialization (revisit if n_threads stops being constant)
|
|
20
22
|
static const size_t n_threads = std::thread::hardware_concurrency();
|
|
@@ -48,6 +50,22 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
|
|
48
50
|
t.join();
|
|
49
51
|
}
|
|
50
52
|
|
|
53
|
+
#if 0
|
|
54
|
+
const char * val_str = getenv("GGML_TEST_EPS");
|
|
55
|
+
float val = 1e-9f;
|
|
56
|
+
if (val_str != nullptr) {
|
|
57
|
+
val = std::stof(val_str);
|
|
58
|
+
printf("GGML_TEST_EPS=%e\n", val);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// test quantization with very small values that may result in nan scales due to division by zero
|
|
62
|
+
if (ggml_is_quantized(tensor->type)) {
|
|
63
|
+
for (int i = 0; i < 256; i++) {
|
|
64
|
+
data[i] = val;
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
#endif
|
|
68
|
+
|
|
51
69
|
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
|
|
52
70
|
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
|
53
71
|
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
|
|
@@ -63,6 +81,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
|
|
63
81
|
}
|
|
64
82
|
}
|
|
65
83
|
ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
|
|
84
|
+
GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
|
|
66
85
|
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
|
|
67
86
|
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
|
|
68
87
|
// This is going to create some weird integers though.
|
|
@@ -1111,11 +1130,7 @@ struct test_soft_max : public test_case {
|
|
|
1111
1130
|
if (this->mask) {
|
|
1112
1131
|
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
|
1113
1132
|
}
|
|
1114
|
-
ggml_tensor *
|
|
1115
|
-
if (max_bias > 0.0f) {
|
|
1116
|
-
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
|
1117
|
-
}
|
|
1118
|
-
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
|
1133
|
+
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
|
|
1119
1134
|
return out;
|
|
1120
1135
|
}
|
|
1121
1136
|
};
|
|
@@ -1127,20 +1142,22 @@ struct test_rope : public test_case {
|
|
|
1127
1142
|
int n_dims;
|
|
1128
1143
|
int mode;
|
|
1129
1144
|
int n_ctx;
|
|
1145
|
+
bool ff;
|
|
1130
1146
|
|
|
1131
1147
|
std::string vars() override {
|
|
1132
|
-
return
|
|
1148
|
+
return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
|
|
1133
1149
|
}
|
|
1134
1150
|
|
|
1135
1151
|
test_rope(ggml_type type = GGML_TYPE_F32,
|
|
1136
1152
|
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
|
1137
|
-
int n_dims = 10, int mode = 0, int n_ctx = 512)
|
|
1138
|
-
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
|
|
1153
|
+
int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
|
|
1154
|
+
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
|
|
1139
1155
|
|
|
1140
1156
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1141
1157
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
1142
1158
|
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
|
1143
|
-
ggml_tensor *
|
|
1159
|
+
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
|
1160
|
+
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
|
1144
1161
|
return out;
|
|
1145
1162
|
}
|
|
1146
1163
|
|
|
@@ -1154,7 +1171,12 @@ struct test_rope : public test_case {
|
|
|
1154
1171
|
}
|
|
1155
1172
|
ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
|
|
1156
1173
|
} else {
|
|
1157
|
-
|
|
1174
|
+
if (t->ne[0] == n_dims/2) {
|
|
1175
|
+
// frequency factors in the range [0.9f, 1.1f]
|
|
1176
|
+
init_tensor_uniform(t, 0.9f, 1.1f);
|
|
1177
|
+
} else {
|
|
1178
|
+
init_tensor_uniform(t);
|
|
1179
|
+
}
|
|
1158
1180
|
}
|
|
1159
1181
|
}
|
|
1160
1182
|
}
|
|
@@ -1237,22 +1259,26 @@ struct test_im2col : public test_case {
|
|
|
1237
1259
|
// GGML_OP_CONCAT
|
|
1238
1260
|
struct test_concat : public test_case {
|
|
1239
1261
|
const ggml_type type;
|
|
1240
|
-
const std::array<int64_t, 4>
|
|
1241
|
-
const int64_t
|
|
1262
|
+
const std::array<int64_t, 4> ne_a;
|
|
1263
|
+
const int64_t ne_b_d;
|
|
1264
|
+
const int dim;
|
|
1242
1265
|
|
|
1243
1266
|
std::string vars() override {
|
|
1244
|
-
return
|
|
1267
|
+
return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
|
|
1245
1268
|
}
|
|
1246
1269
|
|
|
1247
1270
|
test_concat(ggml_type type = GGML_TYPE_F32,
|
|
1248
|
-
std::array<int64_t, 4>
|
|
1249
|
-
int64_t
|
|
1250
|
-
|
|
1271
|
+
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
|
|
1272
|
+
int64_t ne_b_d = 10,
|
|
1273
|
+
int dim = 2)
|
|
1274
|
+
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
|
|
1251
1275
|
|
|
1252
1276
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
ggml_tensor *
|
|
1277
|
+
auto ne_b = ne_a;
|
|
1278
|
+
ne_b[dim] = ne_b_d;
|
|
1279
|
+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
|
1280
|
+
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
|
1281
|
+
ggml_tensor * out = ggml_concat(ctx, a, b, dim);
|
|
1256
1282
|
return out;
|
|
1257
1283
|
}
|
|
1258
1284
|
};
|
|
@@ -1332,23 +1358,47 @@ struct test_upscale : public test_case {
|
|
|
1332
1358
|
const ggml_type type;
|
|
1333
1359
|
const std::array<int64_t, 4> ne;
|
|
1334
1360
|
const int32_t scale_factor;
|
|
1361
|
+
const bool transpose;
|
|
1335
1362
|
|
|
1336
1363
|
std::string vars() override {
|
|
1337
|
-
return
|
|
1364
|
+
return VARS_TO_STR4(type, ne, scale_factor, transpose);
|
|
1338
1365
|
}
|
|
1339
1366
|
|
|
1340
1367
|
test_upscale(ggml_type type = GGML_TYPE_F32,
|
|
1341
1368
|
std::array<int64_t, 4> ne = {512, 512, 3, 1},
|
|
1342
|
-
int32_t scale_factor = 2)
|
|
1343
|
-
: type(type), ne(ne), scale_factor(scale_factor) {}
|
|
1369
|
+
int32_t scale_factor = 2, bool transpose = false)
|
|
1370
|
+
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
|
|
1344
1371
|
|
|
1345
1372
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1346
1373
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
1374
|
+
if (transpose) a = ggml_transpose(ctx, a);
|
|
1347
1375
|
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
|
|
1348
1376
|
return out;
|
|
1349
1377
|
}
|
|
1350
1378
|
};
|
|
1351
1379
|
|
|
1380
|
+
// GGML_OP_UPSCALE (ext)
|
|
1381
|
+
struct test_upscale_ext : public test_case {
|
|
1382
|
+
const ggml_type type;
|
|
1383
|
+
const std::array<int64_t, 4> ne;
|
|
1384
|
+
const std::array<int64_t, 4> ne_tgt;
|
|
1385
|
+
|
|
1386
|
+
std::string vars() override {
|
|
1387
|
+
return VARS_TO_STR3(type, ne, ne_tgt);
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
test_upscale_ext(ggml_type type = GGML_TYPE_F32,
|
|
1391
|
+
std::array<int64_t, 4> ne = {2, 5, 7, 11},
|
|
1392
|
+
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
|
|
1393
|
+
: type(type), ne(ne), ne_tgt(ne_tgt) {}
|
|
1394
|
+
|
|
1395
|
+
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1396
|
+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
1397
|
+
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
|
|
1398
|
+
return out;
|
|
1399
|
+
}
|
|
1400
|
+
};
|
|
1401
|
+
|
|
1352
1402
|
// GGML_OP_GROUP_NORM
|
|
1353
1403
|
struct test_group_norm : public test_case {
|
|
1354
1404
|
const ggml_type type;
|
|
@@ -1490,23 +1540,27 @@ struct test_flash_attn_ext : public test_case {
|
|
|
1490
1540
|
const int64_t kv; // kv size
|
|
1491
1541
|
const int64_t nb; // batch size
|
|
1492
1542
|
|
|
1543
|
+
const bool mask; // use mask
|
|
1544
|
+
|
|
1545
|
+
const float max_bias; // ALiBi
|
|
1546
|
+
|
|
1493
1547
|
std::string vars() override {
|
|
1494
|
-
return
|
|
1548
|
+
return VARS_TO_STR6(hs, nh, kv, nb, mask, max_bias);
|
|
1495
1549
|
}
|
|
1496
1550
|
|
|
1497
1551
|
double max_nmse_err() override {
|
|
1498
1552
|
return 5e-4;
|
|
1499
1553
|
}
|
|
1500
1554
|
|
|
1501
|
-
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
|
1502
|
-
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
|
1555
|
+
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f)
|
|
1556
|
+
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {}
|
|
1503
1557
|
|
|
1504
1558
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1505
1559
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
|
1506
1560
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
|
1507
1561
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
|
1508
|
-
ggml_tensor *
|
|
1509
|
-
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v,
|
|
1562
|
+
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
|
1563
|
+
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
|
|
1510
1564
|
return out;
|
|
1511
1565
|
}
|
|
1512
1566
|
};
|
|
@@ -1611,7 +1665,7 @@ public:
|
|
|
1611
1665
|
|
|
1612
1666
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
|
1613
1667
|
|
|
1614
|
-
kq = ggml_soft_max_ext(ctx, kq, kq_mask,
|
|
1668
|
+
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
|
|
1615
1669
|
|
|
1616
1670
|
// split cached v into n_head heads
|
|
1617
1671
|
struct ggml_tensor * v =
|
|
@@ -1720,14 +1774,14 @@ struct test_llama : public test_llm {
|
|
|
1720
1774
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
|
|
1721
1775
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
|
|
1722
1776
|
|
|
1723
|
-
Qcur =
|
|
1724
|
-
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
|
|
1777
|
+
Qcur = ggml_rope_ext(
|
|
1778
|
+
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
|
|
1725
1779
|
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
|
1726
1780
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
1727
1781
|
);
|
|
1728
1782
|
|
|
1729
|
-
Kcur =
|
|
1730
|
-
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
|
|
1783
|
+
Kcur = ggml_rope_ext(
|
|
1784
|
+
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
|
|
1731
1785
|
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
|
1732
1786
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
1733
1787
|
);
|
|
@@ -1846,13 +1900,13 @@ struct test_falcon : public test_llm {
|
|
|
1846
1900
|
Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
|
|
1847
1901
|
|
|
1848
1902
|
// using mode = 2 for neox mode
|
|
1849
|
-
Qcur =
|
|
1850
|
-
ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
|
1903
|
+
Qcur = ggml_rope_ext(
|
|
1904
|
+
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
|
1851
1905
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
1852
1906
|
);
|
|
1853
1907
|
|
|
1854
|
-
Kcur =
|
|
1855
|
-
ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
|
1908
|
+
Kcur = ggml_rope_ext(
|
|
1909
|
+
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
|
1856
1910
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
1857
1911
|
);
|
|
1858
1912
|
|
|
@@ -2128,6 +2182,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
2128
2182
|
#endif
|
|
2129
2183
|
for (bool mask : {false, true}) {
|
|
2130
2184
|
for (float max_bias : {0.0f, 8.0f}) {
|
|
2185
|
+
if (!mask && max_bias > 0.0f) continue;
|
|
2131
2186
|
for (float scale : {1.0f, 0.1f}) {
|
|
2132
2187
|
for (int64_t ne0 : {16, 1024}) {
|
|
2133
2188
|
for (int64_t ne1 : {16, 1024}) {
|
|
@@ -2141,24 +2196,29 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
2141
2196
|
|
|
2142
2197
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
|
|
2143
2198
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
|
2144
|
-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
|
|
2145
2199
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
|
2146
2200
|
|
|
2147
2201
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
|
2148
|
-
|
|
2149
|
-
test_cases.emplace_back(new test_rope(type, {128,
|
|
2150
|
-
test_cases.emplace_back(new test_rope(type, {128,
|
|
2151
|
-
test_cases.emplace_back(new test_rope(type, {128,
|
|
2152
|
-
test_cases.emplace_back(new test_rope(type, {
|
|
2153
|
-
|
|
2154
|
-
|
|
2155
|
-
|
|
2156
|
-
|
|
2157
|
-
|
|
2202
|
+
// TODO: ff not supported yet for !neox
|
|
2203
|
+
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
|
|
2204
|
+
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
|
|
2205
|
+
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
|
|
2206
|
+
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
|
|
2207
|
+
|
|
2208
|
+
for (bool ff : {false, true}) { // freq_factors
|
|
2209
|
+
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
|
2210
|
+
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
|
2211
|
+
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
|
2212
|
+
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
|
2213
|
+
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
|
|
2214
|
+
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
|
|
2215
|
+
}
|
|
2158
2216
|
}
|
|
2159
2217
|
|
|
2160
|
-
|
|
2161
|
-
|
|
2218
|
+
for (int dim : { 0, 1, 2, 3, }) {
|
|
2219
|
+
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
|
|
2220
|
+
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
|
|
2221
|
+
}
|
|
2162
2222
|
|
|
2163
2223
|
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
|
2164
2224
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
|
@@ -2168,6 +2228,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
2168
2228
|
|
|
2169
2229
|
test_cases.emplace_back(new test_sum_rows());
|
|
2170
2230
|
test_cases.emplace_back(new test_upscale());
|
|
2231
|
+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
|
|
2232
|
+
test_cases.emplace_back(new test_upscale_ext());
|
|
2171
2233
|
test_cases.emplace_back(new test_group_norm());
|
|
2172
2234
|
test_cases.emplace_back(new test_acc());
|
|
2173
2235
|
test_cases.emplace_back(new test_pad());
|
|
@@ -2175,15 +2237,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
2175
2237
|
test_cases.emplace_back(new test_timestep_embedding());
|
|
2176
2238
|
test_cases.emplace_back(new test_leaky_relu());
|
|
2177
2239
|
|
|
2178
|
-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
2179
|
-
for (int hs : { 64, 128, }) { // other head sizes not implemented
|
|
2180
|
-
#else
|
|
2181
2240
|
for (int hs : { 64, 80, 128, 256, }) {
|
|
2182
|
-
|
|
2183
|
-
|
|
2184
|
-
|
|
2185
|
-
for (int
|
|
2186
|
-
|
|
2241
|
+
for (bool mask : { true, false } ) {
|
|
2242
|
+
for (float max_bias : { 0.0f, 8.0f }) {
|
|
2243
|
+
if (!mask && max_bias > 0.0f) continue;
|
|
2244
|
+
for (int nh : { 32, }) {
|
|
2245
|
+
for (int kv : { 512, 1024, }) {
|
|
2246
|
+
for (int nb : { 1, 2, 4, 8, }) {
|
|
2247
|
+
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias));
|
|
2248
|
+
}
|
|
2249
|
+
}
|
|
2187
2250
|
}
|
|
2188
2251
|
}
|
|
2189
2252
|
}
|
|
@@ -49,8 +49,14 @@ int main(void) {
|
|
|
49
49
|
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
|
50
50
|
// Llama-3
|
|
51
51
|
"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
|
52
|
-
//
|
|
53
|
-
"{{ bos_token }}{% for message in messages %}{
|
|
52
|
+
//Phi-3-mini
|
|
53
|
+
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
|
54
|
+
//Phi-3-small
|
|
55
|
+
"{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
|
56
|
+
//Phi-3-medium
|
|
57
|
+
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
|
58
|
+
//Phi-3-vision
|
|
59
|
+
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
|
|
54
60
|
};
|
|
55
61
|
std::vector<std::string> expected_output = {
|
|
56
62
|
// teknium/OpenHermes-2.5-Mistral-7B
|
|
@@ -79,8 +85,14 @@ int main(void) {
|
|
|
79
85
|
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
|
80
86
|
// Llama 3
|
|
81
87
|
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
|
82
|
-
//
|
|
83
|
-
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\
|
|
88
|
+
//Phi-3-mini
|
|
89
|
+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
|
90
|
+
//Phi-3-small
|
|
91
|
+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
|
92
|
+
//Phi-3-medium
|
|
93
|
+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
|
94
|
+
//Phi-3-vision
|
|
95
|
+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
|
84
96
|
};
|
|
85
97
|
std::vector<char> formatted_chat(1024);
|
|
86
98
|
int32_t res;
|
|
@@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) {
|
|
|
1515
1515
|
}
|
|
1516
1516
|
|
|
1517
1517
|
// flash_attn f32
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
}
|
|
1560
|
-
|
|
1561
|
-
// flash_attn f16, not yet fully implemented
|
|
1562
|
-
if(0)
|
|
1563
|
-
{
|
|
1564
|
-
srand(seed);
|
|
1565
|
-
const int nargs = 3;
|
|
1566
|
-
|
|
1567
|
-
int64_t ne2[4];
|
|
1568
|
-
|
|
1569
|
-
get_random_dims(ne2, 4);
|
|
1570
|
-
int64_t D = ne2[0];
|
|
1571
|
-
int64_t N = ne2[1];
|
|
1572
|
-
int64_t M = ne2[2] + N;
|
|
1573
|
-
int64_t B = ne2[3];
|
|
1574
|
-
|
|
1575
|
-
for (int masked = 0; masked <= 1; ++masked) {
|
|
1576
|
-
for (int ndims = 2; ndims <= 4; ++ndims) {
|
|
1577
|
-
int64_t neq[4] = { D, N, B, ne[3] };
|
|
1578
|
-
int64_t nek[4] = { D, M, B, ne[3] };
|
|
1579
|
-
int64_t nev[4] = { M, D, B, ne[3] };
|
|
1580
|
-
if (ndims == 2) {
|
|
1581
|
-
neq[2] = 1; neq[3] = 1;
|
|
1582
|
-
nek[2] = 1; nek[3] = 1;
|
|
1583
|
-
nev[2] = 1; nev[3] = 1;
|
|
1584
|
-
} else if (ndims == 3) {
|
|
1585
|
-
neq[3] = 1;
|
|
1586
|
-
nek[3] = 1;
|
|
1587
|
-
nev[3] = 1;
|
|
1588
|
-
}
|
|
1589
|
-
x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
|
1590
|
-
x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
|
1591
|
-
x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
|
1592
|
-
ggml_set_param(ctx0, x[0]);
|
|
1593
|
-
ggml_set_param(ctx0, x[1]);
|
|
1594
|
-
ggml_set_param(ctx0, x[2]);
|
|
1595
|
-
|
|
1596
|
-
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
|
1518
|
+
// TODO: adapt to ggml_flash_attn_ext() changes
|
|
1519
|
+
//{
|
|
1520
|
+
// srand(seed);
|
|
1521
|
+
// const int nargs = 3;
|
|
1522
|
+
|
|
1523
|
+
// int64_t ne2[4];
|
|
1524
|
+
|
|
1525
|
+
// get_random_dims(ne2, 4);
|
|
1526
|
+
// int64_t D = ne2[0];
|
|
1527
|
+
// int64_t N = ne2[1];
|
|
1528
|
+
// int64_t M = ne2[2] + N;
|
|
1529
|
+
// int64_t B = ne2[3];
|
|
1530
|
+
|
|
1531
|
+
// for (int masked = 0; masked <= 1; ++masked) {
|
|
1532
|
+
// for (int ndims = 2; ndims <= 4; ++ndims) {
|
|
1533
|
+
// int max_nrep = (ndims >= 3) ? 2 : 1;
|
|
1534
|
+
// for (int nrep = 1; nrep < max_nrep; ++nrep) {
|
|
1535
|
+
// int64_t neq[4] = { D, N, B*nrep, ne[3] };
|
|
1536
|
+
// int64_t nek[4] = { D, M, B, ne[3] };
|
|
1537
|
+
// int64_t nev[4] = { M, D, B, ne[3] };
|
|
1538
|
+
// if (ndims == 2) {
|
|
1539
|
+
// neq[2] = 1; neq[3] = 1;
|
|
1540
|
+
// nek[2] = 1; nek[3] = 1;
|
|
1541
|
+
// nev[2] = 1; nev[3] = 1;
|
|
1542
|
+
// } else if (ndims == 3) {
|
|
1543
|
+
// neq[3] = 1;
|
|
1544
|
+
// nek[3] = 1;
|
|
1545
|
+
// nev[3] = 1;
|
|
1546
|
+
// }
|
|
1547
|
+
// x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
|
1548
|
+
// x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
|
1549
|
+
// x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
|
1550
|
+
// ggml_set_param(ctx0, x[0]);
|
|
1551
|
+
// ggml_set_param(ctx0, x[1]);
|
|
1552
|
+
// ggml_set_param(ctx0, x[2]);
|
|
1553
|
+
|
|
1554
|
+
// struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
|
1555
|
+
|
|
1556
|
+
// check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
|
1557
|
+
// }
|
|
1558
|
+
// }
|
|
1559
|
+
// }
|
|
1560
|
+
//}
|
|
1597
1561
|
|
|
1598
|
-
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
|
1599
|
-
}
|
|
1600
|
-
}
|
|
1601
|
-
}
|
|
1602
1562
|
ggml_free(ctx0);
|
|
1603
1563
|
}
|
|
1604
1564
|
|