@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include "arg.h"
|
|
1
2
|
#include "common.h"
|
|
2
3
|
#include "ggml.h"
|
|
3
4
|
#include "ggml-alloc.h"
|
|
@@ -10,6 +11,12 @@
|
|
|
10
11
|
|
|
11
12
|
static bool g_verbose = false;
|
|
12
13
|
|
|
14
|
+
struct tensor_transformation {
|
|
15
|
+
struct ggml_tensor * in;
|
|
16
|
+
struct ggml_tensor * out;
|
|
17
|
+
bool is_copy;
|
|
18
|
+
};
|
|
19
|
+
|
|
13
20
|
static std::string get_kv_str(struct gguf_context * ctx_gguf, const std::string & key){
|
|
14
21
|
int id = gguf_find_key(ctx_gguf, key.c_str());
|
|
15
22
|
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
|
|
@@ -50,20 +57,6 @@ static struct gguf_context * load_gguf(std::string & fname, struct ggml_context
|
|
|
50
57
|
return ctx_gguf;
|
|
51
58
|
}
|
|
52
59
|
|
|
53
|
-
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
|
54
|
-
std::string result;
|
|
55
|
-
for (size_t pos = 0; ; pos += search.length()) {
|
|
56
|
-
auto new_pos = s.find(search, pos);
|
|
57
|
-
if (new_pos == std::string::npos) {
|
|
58
|
-
result += s.substr(pos, s.size() - pos);
|
|
59
|
-
break;
|
|
60
|
-
}
|
|
61
|
-
result += s.substr(pos, new_pos - pos) + replace;
|
|
62
|
-
pos = new_pos;
|
|
63
|
-
}
|
|
64
|
-
s = std::move(result);
|
|
65
|
-
}
|
|
66
|
-
|
|
67
60
|
struct file_input {
|
|
68
61
|
struct ggml_context * ctx_meta = nullptr;
|
|
69
62
|
struct gguf_context * ctx_gguf = nullptr;
|
|
@@ -135,7 +128,7 @@ struct lora_merge_ctx {
|
|
|
135
128
|
|
|
136
129
|
lora_merge_ctx(
|
|
137
130
|
std::string & base_fname,
|
|
138
|
-
std::vector<
|
|
131
|
+
std::vector<llama_lora_adapter_info> & lora_files,
|
|
139
132
|
std::string & outfile,
|
|
140
133
|
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
|
|
141
134
|
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
|
|
@@ -144,9 +137,9 @@ struct lora_merge_ctx {
|
|
|
144
137
|
throw std::runtime_error("split model is not yet supported");
|
|
145
138
|
}
|
|
146
139
|
|
|
147
|
-
for (auto lora_inp : lora_files) {
|
|
148
|
-
auto fname =
|
|
149
|
-
auto scale =
|
|
140
|
+
for (auto & lora_inp : lora_files) {
|
|
141
|
+
auto fname = lora_inp.path;
|
|
142
|
+
auto scale = lora_inp.scale;
|
|
150
143
|
std::unique_ptr<file_input> adapter(new file_input(fname, scale));
|
|
151
144
|
check_metadata_lora(adapter.get());
|
|
152
145
|
adapters.push_back(std::move(adapter));
|
|
@@ -212,8 +205,7 @@ struct lora_merge_ctx {
|
|
|
212
205
|
}
|
|
213
206
|
|
|
214
207
|
// mapping base tensor to out tensor (same shape with base, but different type)
|
|
215
|
-
|
|
216
|
-
std::vector<std::pair<struct ggml_tensor *, struct ggml_tensor *>> base_to_out_tensors;
|
|
208
|
+
std::vector<tensor_transformation> trans;
|
|
217
209
|
for (auto & it : base_model.tensors) {
|
|
218
210
|
bool t_a = true;
|
|
219
211
|
bool t_b = true;
|
|
@@ -226,14 +218,22 @@ struct lora_merge_ctx {
|
|
|
226
218
|
// only copy
|
|
227
219
|
struct ggml_tensor * cpy_tensor = ggml_dup_tensor(ctx_out_ggml, base_tensor);
|
|
228
220
|
ggml_set_name(cpy_tensor, base_tensor->name);
|
|
229
|
-
|
|
221
|
+
trans.push_back({
|
|
222
|
+
cpy_tensor,
|
|
223
|
+
cpy_tensor,
|
|
224
|
+
true,
|
|
225
|
+
});
|
|
230
226
|
gguf_add_tensor(ctx_out, cpy_tensor);
|
|
231
227
|
} else if (t_a && t_b) {
|
|
232
228
|
// need merging
|
|
233
229
|
struct ggml_tensor * out_tensor = ggml_new_tensor(
|
|
234
230
|
ctx_out_ggml, get_out_tensor_type(base_tensor), GGML_MAX_DIMS, base_tensor->ne);
|
|
235
231
|
ggml_set_name(out_tensor, base_tensor->name);
|
|
236
|
-
|
|
232
|
+
trans.push_back({
|
|
233
|
+
base_tensor,
|
|
234
|
+
out_tensor,
|
|
235
|
+
false,
|
|
236
|
+
});
|
|
237
237
|
gguf_add_tensor(ctx_out, out_tensor);
|
|
238
238
|
} else {
|
|
239
239
|
throw std::runtime_error("tensor " + it.first + " missing either lora_a or lora_b");
|
|
@@ -248,12 +248,12 @@ struct lora_merge_ctx {
|
|
|
248
248
|
|
|
249
249
|
// process base model tensors
|
|
250
250
|
size_t n_merged = 0;
|
|
251
|
-
for (auto & it :
|
|
252
|
-
if (it.
|
|
253
|
-
merge_tensor(it.
|
|
251
|
+
for (auto & it : trans) {
|
|
252
|
+
if (!it.is_copy) {
|
|
253
|
+
merge_tensor(it.in, it.out);
|
|
254
254
|
n_merged++;
|
|
255
255
|
} else {
|
|
256
|
-
copy_tensor(it.
|
|
256
|
+
copy_tensor(it.in);
|
|
257
257
|
}
|
|
258
258
|
}
|
|
259
259
|
|
|
@@ -266,7 +266,7 @@ struct lora_merge_ctx {
|
|
|
266
266
|
}
|
|
267
267
|
|
|
268
268
|
printf("%s : merged %ld tensors with lora adapters\n", __func__, n_merged);
|
|
269
|
-
printf("%s : wrote %ld tensors to output file\n", __func__,
|
|
269
|
+
printf("%s : wrote %ld tensors to output file\n", __func__, trans.size());
|
|
270
270
|
}
|
|
271
271
|
|
|
272
272
|
void copy_tensor(struct ggml_tensor * base) {
|
|
@@ -299,6 +299,10 @@ struct lora_merge_ctx {
|
|
|
299
299
|
for (size_t i = 0; i < adapters.size(); ++i) {
|
|
300
300
|
auto t_a = adapters[i]->get_tensor(name_lora_a);
|
|
301
301
|
auto t_b = adapters[i]->get_tensor(name_lora_b);
|
|
302
|
+
// TODO: add support for quantized lora
|
|
303
|
+
if (ggml_is_quantized(t_a->type) || ggml_is_quantized(t_b->type)) {
|
|
304
|
+
throw std::runtime_error("quantized LoRA adapters is not supported, please retry with f16 or f32");
|
|
305
|
+
}
|
|
302
306
|
inp_a[i] = ggml_dup_tensor(ctx, t_a);
|
|
303
307
|
inp_b[i] = ggml_dup_tensor(ctx, t_b);
|
|
304
308
|
}
|
|
@@ -366,7 +370,7 @@ struct lora_merge_ctx {
|
|
|
366
370
|
|
|
367
371
|
// write data to output file
|
|
368
372
|
{
|
|
369
|
-
auto result = gf
|
|
373
|
+
auto * result = ggml_graph_node(gf, -1);
|
|
370
374
|
size_t len = ggml_nbytes(result);
|
|
371
375
|
if (read_buf.size() < len) {
|
|
372
376
|
read_buf.resize(len);
|
|
@@ -388,9 +392,7 @@ struct lora_merge_ctx {
|
|
|
388
392
|
}
|
|
389
393
|
};
|
|
390
394
|
|
|
391
|
-
static void print_usage(int
|
|
392
|
-
gpt_params_print_usage(argc, argv, params);
|
|
393
|
-
|
|
395
|
+
static void print_usage(int, char ** argv) {
|
|
394
396
|
printf("\nexample usage:\n");
|
|
395
397
|
printf("\n %s -m base-model.gguf --lora lora-file.gguf -o merged-model-f16.gguf\n", argv[0]);
|
|
396
398
|
printf("\nNOTE: output model is F16\n");
|
|
@@ -400,14 +402,13 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) {
|
|
|
400
402
|
int main(int argc, char ** argv) {
|
|
401
403
|
gpt_params params;
|
|
402
404
|
|
|
403
|
-
if (!gpt_params_parse(argc, argv, params)) {
|
|
404
|
-
print_usage(argc, argv, params);
|
|
405
|
+
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) {
|
|
405
406
|
return 1;
|
|
406
407
|
}
|
|
407
408
|
|
|
408
|
-
g_verbose = (params.verbosity
|
|
409
|
+
g_verbose = (params.verbosity > 1);
|
|
409
410
|
try {
|
|
410
|
-
lora_merge_ctx ctx(params.model, params.
|
|
411
|
+
lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.cpuparams.n_threads);
|
|
411
412
|
ctx.run_merge();
|
|
412
413
|
} catch (const std::exception & err) {
|
|
413
414
|
fprintf(stderr, "%s\n", err.what());
|
|
@@ -1,9 +1,5 @@
|
|
|
1
|
-
#define LLAMA_API_INTERNAL
|
|
2
|
-
|
|
3
|
-
#include "grammar-parser.h"
|
|
4
|
-
#include "ggml.h"
|
|
5
|
-
#include "llama.h"
|
|
6
1
|
#include "unicode.h"
|
|
2
|
+
#include "llama-grammar.h"
|
|
7
3
|
|
|
8
4
|
#include <cstdio>
|
|
9
5
|
#include <cstdlib>
|
|
@@ -12,29 +8,28 @@
|
|
|
12
8
|
#include <string>
|
|
13
9
|
#include <vector>
|
|
14
10
|
|
|
15
|
-
static bool
|
|
16
|
-
auto
|
|
17
|
-
const auto & code_points = decoded.first;
|
|
11
|
+
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
|
12
|
+
const auto cpts = unicode_cpts_from_utf8(input_str);
|
|
18
13
|
|
|
19
14
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
|
20
|
-
llama_grammar_stacks &
|
|
15
|
+
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
|
21
16
|
|
|
22
17
|
size_t pos = 0;
|
|
23
|
-
for (auto
|
|
24
|
-
const llama_grammar_stacks
|
|
18
|
+
for (const auto & cpt : cpts) {
|
|
19
|
+
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
|
25
20
|
|
|
26
|
-
llama_grammar_accept(rules,
|
|
21
|
+
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
|
27
22
|
|
|
28
|
-
if (
|
|
23
|
+
if (stacks_cur.empty()) {
|
|
29
24
|
error_pos = pos;
|
|
30
|
-
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(
|
|
31
|
-
|
|
25
|
+
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
|
26
|
+
stacks_cur = stacks_prev;
|
|
32
27
|
return false;
|
|
33
28
|
}
|
|
34
29
|
++pos;
|
|
35
30
|
}
|
|
36
31
|
|
|
37
|
-
for (const auto & stack :
|
|
32
|
+
for (const auto & stack : stacks_cur) {
|
|
38
33
|
if (stack.empty()) {
|
|
39
34
|
return true;
|
|
40
35
|
}
|
|
@@ -85,27 +80,7 @@ int main(int argc, char** argv) {
|
|
|
85
80
|
grammar_str = buffer.str();
|
|
86
81
|
}
|
|
87
82
|
|
|
88
|
-
|
|
89
|
-
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
|
90
|
-
|
|
91
|
-
// will be empty (default) if there are parse errors
|
|
92
|
-
if (parsed_grammar.rules.empty()) {
|
|
93
|
-
fprintf(stdout, "%s: failed to parse grammar\n", __func__);
|
|
94
|
-
return 1;
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
// Ensure that there is a "root" node.
|
|
98
|
-
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
|
|
99
|
-
fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
|
100
|
-
return 1;
|
|
101
|
-
}
|
|
102
|
-
|
|
103
|
-
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
104
|
-
|
|
105
|
-
// Create the LLAMA grammar
|
|
106
|
-
auto grammar = llama_grammar_init(
|
|
107
|
-
grammar_rules.data(),
|
|
108
|
-
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
83
|
+
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
|
109
84
|
if (grammar == nullptr) {
|
|
110
85
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
111
86
|
}
|
|
@@ -122,7 +97,7 @@ int main(int argc, char** argv) {
|
|
|
122
97
|
// Validate the input string against the grammar
|
|
123
98
|
size_t error_pos;
|
|
124
99
|
std::string error_msg;
|
|
125
|
-
bool is_valid =
|
|
100
|
+
bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
|
|
126
101
|
|
|
127
102
|
if (is_valid) {
|
|
128
103
|
fprintf(stdout, "Input string is valid according to the grammar.\n");
|
|
@@ -131,7 +106,7 @@ int main(int argc, char** argv) {
|
|
|
131
106
|
}
|
|
132
107
|
|
|
133
108
|
// Clean up
|
|
134
|
-
|
|
109
|
+
llama_grammar_free_impl(grammar);
|
|
135
110
|
|
|
136
111
|
return 0;
|
|
137
112
|
}
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
#include "arg.h"
|
|
2
|
+
#include "common.h"
|
|
3
|
+
|
|
4
|
+
#include <fstream>
|
|
5
|
+
#include <string>
|
|
6
|
+
|
|
7
|
+
// Export usage message (-h) to markdown format
|
|
8
|
+
|
|
9
|
+
static void write_table_header(std::ofstream & file) {
|
|
10
|
+
file << "| Argument | Explanation |\n";
|
|
11
|
+
file << "| -------- | ----------- |\n";
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
static void write_table_entry(std::ofstream & file, const llama_arg & opt) {
|
|
15
|
+
file << "| `";
|
|
16
|
+
// args
|
|
17
|
+
for (const auto & arg : opt.args) {
|
|
18
|
+
if (arg == opt.args.front()) {
|
|
19
|
+
file << arg;
|
|
20
|
+
if (opt.args.size() > 1) file << ", ";
|
|
21
|
+
} else {
|
|
22
|
+
file << arg << (arg != opt.args.back() ? ", " : "");
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
// value hint
|
|
26
|
+
if (opt.value_hint) {
|
|
27
|
+
std::string md_value_hint(opt.value_hint);
|
|
28
|
+
string_replace_all(md_value_hint, "|", "\\|");
|
|
29
|
+
file << " " << md_value_hint;
|
|
30
|
+
}
|
|
31
|
+
if (opt.value_hint_2) {
|
|
32
|
+
std::string md_value_hint_2(opt.value_hint_2);
|
|
33
|
+
string_replace_all(md_value_hint_2, "|", "\\|");
|
|
34
|
+
file << " " << md_value_hint_2;
|
|
35
|
+
}
|
|
36
|
+
// help text
|
|
37
|
+
std::string md_help(opt.help);
|
|
38
|
+
string_replace_all(md_help, "\n", "<br/>");
|
|
39
|
+
string_replace_all(md_help, "|", "\\|");
|
|
40
|
+
file << "` | " << md_help << " |\n";
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
static void write_table(std::ofstream & file, std::vector<llama_arg *> & opts) {
|
|
44
|
+
write_table_header(file);
|
|
45
|
+
for (const auto & opt : opts) {
|
|
46
|
+
write_table_entry(file, *opt);
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
static void export_md(std::string fname, llama_example ex) {
|
|
51
|
+
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
|
|
52
|
+
|
|
53
|
+
gpt_params params;
|
|
54
|
+
auto ctx_arg = gpt_params_parser_init(params, ex);
|
|
55
|
+
|
|
56
|
+
std::vector<llama_arg *> common_options;
|
|
57
|
+
std::vector<llama_arg *> sparam_options;
|
|
58
|
+
std::vector<llama_arg *> specific_options;
|
|
59
|
+
for (auto & opt : ctx_arg.options) {
|
|
60
|
+
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
|
|
61
|
+
if (opt.is_sparam) {
|
|
62
|
+
sparam_options.push_back(&opt);
|
|
63
|
+
} else if (opt.in_example(ctx_arg.ex)) {
|
|
64
|
+
specific_options.push_back(&opt);
|
|
65
|
+
} else {
|
|
66
|
+
common_options.push_back(&opt);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
file << "**Common params**\n\n";
|
|
71
|
+
write_table(file, common_options);
|
|
72
|
+
file << "\n\n**Sampling params**\n\n";
|
|
73
|
+
write_table(file, sparam_options);
|
|
74
|
+
file << "\n\n**Example-specific params**\n\n";
|
|
75
|
+
write_table(file, specific_options);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
int main(int, char **) {
|
|
79
|
+
export_md("autogen-main.md", LLAMA_EXAMPLE_MAIN);
|
|
80
|
+
export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER);
|
|
81
|
+
|
|
82
|
+
return 0;
|
|
83
|
+
}
|
|
@@ -22,12 +22,20 @@
|
|
|
22
22
|
#endif
|
|
23
23
|
|
|
24
24
|
enum split_operation : uint8_t {
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
OP_NONE,
|
|
26
|
+
OP_SPLIT,
|
|
27
|
+
OP_MERGE,
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
enum split_mode : uint8_t {
|
|
31
|
+
MODE_NONE,
|
|
32
|
+
MODE_TENSOR,
|
|
33
|
+
MODE_SIZE,
|
|
27
34
|
};
|
|
28
35
|
|
|
29
36
|
struct split_params {
|
|
30
|
-
split_operation operation =
|
|
37
|
+
split_operation operation = OP_NONE;
|
|
38
|
+
split_mode mode = MODE_NONE;
|
|
31
39
|
size_t n_bytes_split = 0;
|
|
32
40
|
int n_split_tensors = 128;
|
|
33
41
|
std::string input;
|
|
@@ -87,59 +95,52 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
|
|
|
87
95
|
}
|
|
88
96
|
|
|
89
97
|
bool arg_found = false;
|
|
90
|
-
bool is_op_set = false;
|
|
91
|
-
bool is_mode_set = false;
|
|
92
98
|
if (arg == "-h" || arg == "--help") {
|
|
93
99
|
split_print_usage(argv[0]);
|
|
94
100
|
exit(0);
|
|
95
|
-
}
|
|
96
|
-
if (arg == "--version") {
|
|
101
|
+
} else if (arg == "--version") {
|
|
97
102
|
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
|
98
103
|
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
|
99
104
|
exit(0);
|
|
100
|
-
}
|
|
101
|
-
if (arg == "--dry-run") {
|
|
105
|
+
} else if (arg == "--dry-run") {
|
|
102
106
|
arg_found = true;
|
|
103
107
|
params.dry_run = true;
|
|
104
|
-
}
|
|
105
|
-
if (arg == "--no-tensor-first-split") {
|
|
108
|
+
} else if (arg == "--no-tensor-first-split") {
|
|
106
109
|
arg_found = true;
|
|
107
110
|
params.no_tensor_first_split = true;
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
if (is_op_set) {
|
|
111
|
-
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
|
|
112
|
-
}
|
|
113
|
-
if (arg == "--merge") {
|
|
111
|
+
} else if (arg == "--merge") {
|
|
114
112
|
arg_found = true;
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
113
|
+
if (params.operation != OP_NONE && params.operation != OP_MERGE) {
|
|
114
|
+
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
|
|
115
|
+
}
|
|
116
|
+
params.operation = OP_MERGE;
|
|
117
|
+
} else if (arg == "--split") {
|
|
119
118
|
arg_found = true;
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if (
|
|
125
|
-
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
|
|
126
|
-
}
|
|
127
|
-
if (arg == "--split-max-tensors") {
|
|
119
|
+
if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
|
|
120
|
+
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
|
|
121
|
+
}
|
|
122
|
+
params.operation = OP_SPLIT;
|
|
123
|
+
} else if (arg == "--split-max-tensors") {
|
|
128
124
|
if (++arg_idx >= argc) {
|
|
129
125
|
invalid_param = true;
|
|
130
126
|
break;
|
|
131
127
|
}
|
|
132
128
|
arg_found = true;
|
|
133
|
-
|
|
129
|
+
if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
|
|
130
|
+
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
|
|
131
|
+
}
|
|
132
|
+
params.mode = MODE_TENSOR;
|
|
134
133
|
params.n_split_tensors = atoi(argv[arg_idx]);
|
|
135
|
-
}
|
|
136
|
-
if (arg == "--split-max-size") {
|
|
134
|
+
} else if (arg == "--split-max-size") {
|
|
137
135
|
if (++arg_idx >= argc) {
|
|
138
136
|
invalid_param = true;
|
|
139
137
|
break;
|
|
140
138
|
}
|
|
141
139
|
arg_found = true;
|
|
142
|
-
|
|
140
|
+
if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
|
|
141
|
+
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
|
|
142
|
+
}
|
|
143
|
+
params.mode = MODE_SIZE;
|
|
143
144
|
params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
|
|
144
145
|
}
|
|
145
146
|
|
|
@@ -148,11 +149,20 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
|
|
|
148
149
|
}
|
|
149
150
|
}
|
|
150
151
|
|
|
152
|
+
// the operation is split if not specified
|
|
153
|
+
if (params.operation == OP_NONE) {
|
|
154
|
+
params.operation = OP_SPLIT;
|
|
155
|
+
}
|
|
156
|
+
// the split mode is by tensor if not specified
|
|
157
|
+
if (params.mode == MODE_NONE) {
|
|
158
|
+
params.mode = MODE_TENSOR;
|
|
159
|
+
}
|
|
160
|
+
|
|
151
161
|
if (invalid_param) {
|
|
152
162
|
throw std::invalid_argument("error: invalid parameter for argument: " + arg);
|
|
153
163
|
}
|
|
154
164
|
|
|
155
|
-
if (argc - arg_idx
|
|
165
|
+
if (argc - arg_idx != 2) {
|
|
156
166
|
throw std::invalid_argument("error: bad arguments");
|
|
157
167
|
}
|
|
158
168
|
|
|
@@ -265,13 +275,15 @@ struct split_strategy {
|
|
|
265
275
|
}
|
|
266
276
|
|
|
267
277
|
bool should_split(int i_tensor, size_t next_size) {
|
|
268
|
-
if (params.
|
|
278
|
+
if (params.mode == MODE_SIZE) {
|
|
269
279
|
// split by max size per file
|
|
270
280
|
return next_size > params.n_bytes_split;
|
|
271
|
-
} else {
|
|
281
|
+
} else if (params.mode == MODE_TENSOR) {
|
|
272
282
|
// split by number of tensors per file
|
|
273
283
|
return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
|
|
274
284
|
}
|
|
285
|
+
// should never happen
|
|
286
|
+
GGML_ABORT("invalid mode");
|
|
275
287
|
}
|
|
276
288
|
|
|
277
289
|
void print_info() {
|
|
@@ -389,10 +401,17 @@ static void gguf_merge(const split_params & split_params) {
|
|
|
389
401
|
int n_split = 1;
|
|
390
402
|
int total_tensors = 0;
|
|
391
403
|
|
|
392
|
-
|
|
404
|
+
// avoid overwriting existing output file
|
|
405
|
+
if (std::ifstream(split_params.output.c_str())) {
|
|
406
|
+
fprintf(stderr, "%s: output file %s already exists\n", __func__, split_params.output.c_str());
|
|
407
|
+
exit(EXIT_FAILURE);
|
|
408
|
+
}
|
|
409
|
+
|
|
393
410
|
std::ofstream fout(split_params.output.c_str(), std::ios::binary);
|
|
394
411
|
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
|
|
395
412
|
|
|
413
|
+
auto * ctx_out = gguf_init_empty();
|
|
414
|
+
|
|
396
415
|
std::vector<uint8_t> read_data;
|
|
397
416
|
std::vector<ggml_context *> ctx_metas;
|
|
398
417
|
std::vector<gguf_context *> ctx_ggufs;
|
|
@@ -552,9 +571,9 @@ int main(int argc, const char ** argv) {
|
|
|
552
571
|
split_params_parse(argc, argv, params);
|
|
553
572
|
|
|
554
573
|
switch (params.operation) {
|
|
555
|
-
case
|
|
574
|
+
case OP_SPLIT: gguf_split(params);
|
|
556
575
|
break;
|
|
557
|
-
case
|
|
576
|
+
case OP_MERGE: gguf_merge(params);
|
|
558
577
|
break;
|
|
559
578
|
default: split_print_usage(argv[0]);
|
|
560
579
|
exit(EXIT_FAILURE);
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include "arg.h"
|
|
1
2
|
#include "common.h"
|
|
2
3
|
#include "llama.h"
|
|
3
4
|
|
|
@@ -9,7 +10,7 @@
|
|
|
9
10
|
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
|
10
11
|
std::vector<std::vector<float>> result;
|
|
11
12
|
|
|
12
|
-
const llama_model *
|
|
13
|
+
const llama_model * model = llama_get_model(ctx);
|
|
13
14
|
|
|
14
15
|
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
|
15
16
|
|
|
@@ -18,16 +19,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|
|
18
19
|
|
|
19
20
|
const std::string input_string = instruction + sentences[i];
|
|
20
21
|
|
|
21
|
-
std::vector<llama_token> inputs = llama_tokenize(
|
|
22
|
+
std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
|
|
22
23
|
|
|
23
24
|
const int32_t n_toks = inputs.size();
|
|
24
25
|
|
|
25
26
|
// GritLM seems to have EOS = ""
|
|
26
27
|
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
|
|
27
|
-
// inputs.push_back(llama_token_eos(
|
|
28
|
+
// inputs.push_back(llama_token_eos(model));
|
|
28
29
|
|
|
29
30
|
// we want to ignore instruction tokens for mean pooling
|
|
30
|
-
const int32_t n_inst = llama_tokenize(
|
|
31
|
+
const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
|
|
31
32
|
|
|
32
33
|
#ifdef GRIT_DEBUG
|
|
33
34
|
// debug tokens - should be matching as referenced in the GritLM sample
|
|
@@ -51,7 +52,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|
|
51
52
|
llama_decode(ctx, batch);
|
|
52
53
|
|
|
53
54
|
// get embedding dimensions
|
|
54
|
-
uint64_t n_embd = llama_n_embd(
|
|
55
|
+
uint64_t n_embd = llama_n_embd(model);
|
|
55
56
|
|
|
56
57
|
// allocate embedding output
|
|
57
58
|
std::vector<float> emb_unorm(n_embd, 0.0f);
|
|
@@ -92,11 +93,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|
|
92
93
|
return result;
|
|
93
94
|
}
|
|
94
95
|
|
|
95
|
-
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
|
|
96
|
+
static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
|
|
96
97
|
std::string result;
|
|
97
98
|
|
|
98
|
-
const llama_model *
|
|
99
|
-
llama_token eos_token = llama_token_eos(
|
|
99
|
+
const llama_model * model = llama_get_model(ctx);
|
|
100
|
+
llama_token eos_token = llama_token_eos(model);
|
|
100
101
|
|
|
101
102
|
llama_kv_cache_clear(ctx);
|
|
102
103
|
llama_set_embeddings(ctx, false);
|
|
@@ -104,28 +105,24 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|
|
104
105
|
|
|
105
106
|
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
|
106
107
|
|
|
107
|
-
std::vector<llama_token> inputs = llama_tokenize(
|
|
108
|
+
std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
|
|
108
109
|
int32_t i_current_token = 0;
|
|
109
110
|
|
|
110
111
|
while (true) {
|
|
111
112
|
llama_batch_clear(bat);
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
113
|
+
{
|
|
114
|
+
const int32_t n_inputs = inputs.size();
|
|
115
|
+
|
|
116
|
+
for (int32_t i = 0; i < n_inputs; i++) {
|
|
117
|
+
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
|
118
|
+
}
|
|
115
119
|
}
|
|
116
120
|
inputs.clear();
|
|
117
121
|
|
|
118
122
|
llama_decode(ctx, bat);
|
|
119
|
-
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
|
120
123
|
|
|
121
|
-
|
|
122
|
-
auto n_candidates = (int32_t)candidates.size();
|
|
123
|
-
for (int32_t token = 0; token < n_candidates; token++) {
|
|
124
|
-
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
|
|
125
|
-
}
|
|
126
|
-
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
|
|
124
|
+
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
|
|
127
125
|
|
|
128
|
-
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
|
|
129
126
|
if (token == eos_token) {
|
|
130
127
|
break;
|
|
131
128
|
}
|
|
@@ -157,20 +154,29 @@ static std::string gritlm_instruction(const std::string & instruction) {
|
|
|
157
154
|
int main(int argc, char * argv[]) {
|
|
158
155
|
gpt_params params;
|
|
159
156
|
|
|
160
|
-
if (!gpt_params_parse(argc, argv, params)) {
|
|
161
|
-
gpt_params_print_usage(argc, argv, params);
|
|
157
|
+
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
|
162
158
|
return 1;
|
|
163
159
|
}
|
|
164
160
|
|
|
161
|
+
gpt_init();
|
|
162
|
+
|
|
165
163
|
llama_model_params mparams = llama_model_params_from_gpt_params(params);
|
|
166
164
|
llama_context_params cparams = llama_context_params_from_gpt_params(params);
|
|
167
165
|
|
|
168
166
|
llama_backend_init();
|
|
169
167
|
|
|
170
|
-
llama_model *
|
|
168
|
+
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
|
|
171
169
|
|
|
172
170
|
// create generation context
|
|
173
|
-
llama_context * ctx = llama_new_context_with_model(
|
|
171
|
+
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
|
172
|
+
|
|
173
|
+
auto sparams = llama_sampler_chain_default_params();
|
|
174
|
+
|
|
175
|
+
sparams.no_perf = false;
|
|
176
|
+
|
|
177
|
+
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
|
178
|
+
|
|
179
|
+
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
|
174
180
|
|
|
175
181
|
// ### Embedding/Representation ###
|
|
176
182
|
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
|
@@ -191,7 +197,7 @@ int main(int argc, char * argv[]) {
|
|
|
191
197
|
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
|
192
198
|
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
|
193
199
|
|
|
194
|
-
const int n_embd = llama_n_embd(
|
|
200
|
+
const int n_embd = llama_n_embd(model);
|
|
195
201
|
|
|
196
202
|
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
|
197
203
|
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
|
@@ -208,11 +214,12 @@ int main(int argc, char * argv[]) {
|
|
|
208
214
|
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
|
209
215
|
{
|
|
210
216
|
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
|
211
|
-
std::string response = generate(ctx, prompt, true);
|
|
217
|
+
std::string response = generate(ctx, smpl, prompt, true);
|
|
212
218
|
}
|
|
213
219
|
|
|
220
|
+
llama_sampler_free(smpl);
|
|
214
221
|
llama_free(ctx);
|
|
215
|
-
llama_free_model(
|
|
222
|
+
llama_free_model(model);
|
|
216
223
|
llama_backend_free();
|
|
217
224
|
|
|
218
225
|
return 0;
|