@shipworthy/ai-sdk-llama-cpp 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -0
- package/LICENSE +21 -0
- package/README.md +274 -0
- package/dist/binding-bun.d.ts +7 -0
- package/dist/binding-bun.d.ts.map +1 -0
- package/dist/binding-bun.js +354 -0
- package/dist/binding-bun.js.map +1 -0
- package/dist/binding-node.d.ts +7 -0
- package/dist/binding-node.d.ts.map +1 -0
- package/dist/binding-node.js +59 -0
- package/dist/binding-node.js.map +1 -0
- package/dist/binding.d.ts +67 -0
- package/dist/binding.d.ts.map +1 -0
- package/dist/binding.js +105 -0
- package/dist/binding.js.map +1 -0
- package/dist/index.d.ts +5 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +8 -0
- package/dist/index.js.map +1 -0
- package/dist/llama-cpp-embedding-model.d.ts +28 -0
- package/dist/llama-cpp-embedding-model.d.ts.map +1 -0
- package/dist/llama-cpp-embedding-model.js +78 -0
- package/dist/llama-cpp-embedding-model.js.map +1 -0
- package/dist/llama-cpp-language-model.d.ts +55 -0
- package/dist/llama-cpp-language-model.d.ts.map +1 -0
- package/dist/llama-cpp-language-model.js +221 -0
- package/dist/llama-cpp-language-model.js.map +1 -0
- package/dist/llama-cpp-provider.d.ts +82 -0
- package/dist/llama-cpp-provider.d.ts.map +1 -0
- package/dist/llama-cpp-provider.js +71 -0
- package/dist/llama-cpp-provider.js.map +1 -0
- package/dist/native-binding.d.ts +51 -0
- package/dist/native-binding.d.ts.map +1 -0
- package/dist/native-binding.js +74 -0
- package/dist/native-binding.js.map +1 -0
- package/native/CMakeLists.txt +74 -0
- package/native/binding.cpp +522 -0
- package/native/llama-wrapper.cpp +519 -0
- package/native/llama-wrapper.h +131 -0
- package/package.json +79 -0
- package/scripts/postinstall.cjs +74 -0
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
#include "llama-wrapper.h"
|
|
2
|
+
#include "llama.h"
|
|
3
|
+
#include <algorithm>
|
|
4
|
+
#include <cstring>
|
|
5
|
+
#include <stdexcept>
|
|
6
|
+
#include <cstdio>
|
|
7
|
+
#include <cmath>
|
|
8
|
+
|
|
9
|
+
namespace llama_wrapper {
|
|
10
|
+
|
|
11
|
+
// Global debug flag for log callback
|
|
12
|
+
static bool g_debug_mode = false;
|
|
13
|
+
|
|
14
|
+
// Custom log callback that respects debug mode
|
|
15
|
+
static void llama_log_callback(ggml_log_level level, const char* text, void* user_data) {
|
|
16
|
+
(void)level;
|
|
17
|
+
(void)user_data;
|
|
18
|
+
if (g_debug_mode) {
|
|
19
|
+
fprintf(stderr, "%s", text);
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
LlamaModel::LlamaModel() = default;
|
|
24
|
+
|
|
25
|
+
LlamaModel::~LlamaModel() {
|
|
26
|
+
unload();
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
LlamaModel::LlamaModel(LlamaModel&& other) noexcept
|
|
30
|
+
: model_(other.model_)
|
|
31
|
+
, ctx_(other.ctx_)
|
|
32
|
+
, sampler_(other.sampler_)
|
|
33
|
+
, model_path_(std::move(other.model_path_))
|
|
34
|
+
, chat_template_(std::move(other.chat_template_)) {
|
|
35
|
+
other.model_ = nullptr;
|
|
36
|
+
other.ctx_ = nullptr;
|
|
37
|
+
other.sampler_ = nullptr;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
LlamaModel& LlamaModel::operator=(LlamaModel&& other) noexcept {
|
|
41
|
+
if (this != &other) {
|
|
42
|
+
unload();
|
|
43
|
+
model_ = other.model_;
|
|
44
|
+
ctx_ = other.ctx_;
|
|
45
|
+
sampler_ = other.sampler_;
|
|
46
|
+
model_path_ = std::move(other.model_path_);
|
|
47
|
+
chat_template_ = std::move(other.chat_template_);
|
|
48
|
+
other.model_ = nullptr;
|
|
49
|
+
other.ctx_ = nullptr;
|
|
50
|
+
other.sampler_ = nullptr;
|
|
51
|
+
}
|
|
52
|
+
return *this;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
bool LlamaModel::load(const ModelParams& params) {
|
|
56
|
+
if (model_) {
|
|
57
|
+
unload();
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Set debug mode and install log callback
|
|
61
|
+
g_debug_mode = params.debug;
|
|
62
|
+
llama_log_set(llama_log_callback, nullptr);
|
|
63
|
+
|
|
64
|
+
// Initialize llama backend
|
|
65
|
+
llama_backend_init();
|
|
66
|
+
|
|
67
|
+
// Set up model parameters
|
|
68
|
+
llama_model_params model_params = llama_model_default_params();
|
|
69
|
+
model_params.n_gpu_layers = params.n_gpu_layers;
|
|
70
|
+
model_params.use_mmap = params.use_mmap;
|
|
71
|
+
model_params.use_mlock = params.use_mlock;
|
|
72
|
+
|
|
73
|
+
// Load the model
|
|
74
|
+
model_ = llama_model_load_from_file(params.model_path.c_str(), model_params);
|
|
75
|
+
if (!model_) {
|
|
76
|
+
return false;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
model_path_ = params.model_path;
|
|
80
|
+
chat_template_ = params.chat_template;
|
|
81
|
+
return true;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
bool LlamaModel::is_loaded() const {
|
|
85
|
+
return model_ != nullptr;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
void LlamaModel::unload() {
|
|
89
|
+
if (sampler_) {
|
|
90
|
+
llama_sampler_free(sampler_);
|
|
91
|
+
sampler_ = nullptr;
|
|
92
|
+
}
|
|
93
|
+
if (ctx_) {
|
|
94
|
+
llama_free(ctx_);
|
|
95
|
+
ctx_ = nullptr;
|
|
96
|
+
}
|
|
97
|
+
if (model_) {
|
|
98
|
+
llama_model_free(model_);
|
|
99
|
+
model_ = nullptr;
|
|
100
|
+
llama_backend_free();
|
|
101
|
+
}
|
|
102
|
+
model_path_.clear();
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
bool LlamaModel::create_context(const ContextParams& params) {
|
|
106
|
+
if (!model_) {
|
|
107
|
+
return false;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
if (ctx_) {
|
|
111
|
+
llama_free(ctx_);
|
|
112
|
+
ctx_ = nullptr;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
llama_context_params ctx_params = llama_context_default_params();
|
|
116
|
+
ctx_params.n_ctx = params.n_ctx;
|
|
117
|
+
ctx_params.n_batch = params.n_batch;
|
|
118
|
+
ctx_params.n_threads = params.n_threads;
|
|
119
|
+
ctx_params.n_threads_batch = params.n_threads;
|
|
120
|
+
|
|
121
|
+
if (params.embedding) {
|
|
122
|
+
ctx_params.embeddings = true;
|
|
123
|
+
ctx_params.pooling_type = LLAMA_POOLING_TYPE_MEAN;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
ctx_ = llama_init_from_model(model_, ctx_params);
|
|
127
|
+
return ctx_ != nullptr;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
void LlamaModel::normalize_embedding(float* embedding, int n_embd) {
|
|
131
|
+
float sum = 0.0f;
|
|
132
|
+
for (int i = 0; i < n_embd; i++) {
|
|
133
|
+
sum += embedding[i] * embedding[i];
|
|
134
|
+
}
|
|
135
|
+
float norm = std::sqrt(sum);
|
|
136
|
+
if (norm > 0.0f) {
|
|
137
|
+
for (int i = 0; i < n_embd; i++) {
|
|
138
|
+
embedding[i] /= norm;
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
EmbeddingResult LlamaModel::embed(const std::vector<std::string>& texts) {
|
|
144
|
+
EmbeddingResult result;
|
|
145
|
+
result.total_tokens = 0;
|
|
146
|
+
|
|
147
|
+
if (!ctx_ || !model_) {
|
|
148
|
+
return result;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
const int n_embd = llama_model_n_embd(model_);
|
|
152
|
+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx_);
|
|
153
|
+
|
|
154
|
+
// Process each text
|
|
155
|
+
for (size_t seq_id = 0; seq_id < texts.size(); seq_id++) {
|
|
156
|
+
const std::string& text = texts[seq_id];
|
|
157
|
+
|
|
158
|
+
// Tokenize the text
|
|
159
|
+
std::vector<int32_t> tokens = tokenize(text, true);
|
|
160
|
+
result.total_tokens += tokens.size();
|
|
161
|
+
|
|
162
|
+
if (tokens.empty()) {
|
|
163
|
+
// Return zero embedding for empty text
|
|
164
|
+
result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
|
|
165
|
+
continue;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
// Clear the memory/KV cache
|
|
169
|
+
llama_memory_t mem = llama_get_memory(ctx_);
|
|
170
|
+
if (mem) {
|
|
171
|
+
llama_memory_clear(mem, true);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// Create batch with sequence ID
|
|
175
|
+
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
|
176
|
+
for (size_t i = 0; i < tokens.size(); i++) {
|
|
177
|
+
batch.token[i] = tokens[i];
|
|
178
|
+
batch.pos[i] = i;
|
|
179
|
+
batch.n_seq_id[i] = 1;
|
|
180
|
+
batch.seq_id[i][0] = seq_id;
|
|
181
|
+
batch.logits[i] = true; // We want embeddings for all tokens
|
|
182
|
+
}
|
|
183
|
+
batch.n_tokens = tokens.size();
|
|
184
|
+
|
|
185
|
+
// Decode to get embeddings
|
|
186
|
+
if (llama_decode(ctx_, batch) != 0) {
|
|
187
|
+
llama_batch_free(batch);
|
|
188
|
+
result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
|
|
189
|
+
continue;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// Extract embedding based on pooling type
|
|
193
|
+
const float* embd = nullptr;
|
|
194
|
+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
|
195
|
+
// Get embedding for last token
|
|
196
|
+
embd = llama_get_embeddings_ith(ctx_, tokens.size() - 1);
|
|
197
|
+
} else {
|
|
198
|
+
// Get pooled embedding for the sequence
|
|
199
|
+
embd = llama_get_embeddings_seq(ctx_, seq_id);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if (embd) {
|
|
203
|
+
std::vector<float> embedding(n_embd);
|
|
204
|
+
std::copy(embd, embd + n_embd, embedding.begin());
|
|
205
|
+
// Normalize the embedding (L2 normalization)
|
|
206
|
+
normalize_embedding(embedding.data(), n_embd);
|
|
207
|
+
result.embeddings.push_back(std::move(embedding));
|
|
208
|
+
} else {
|
|
209
|
+
result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
llama_batch_free(batch);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
return result;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage>& messages) {
|
|
219
|
+
if (!model_) {
|
|
220
|
+
return "";
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
// Determine which template to use
|
|
224
|
+
const char* tmpl = nullptr;
|
|
225
|
+
if (chat_template_ == "auto") {
|
|
226
|
+
// Use the template embedded in the model
|
|
227
|
+
tmpl = llama_model_chat_template(model_, nullptr);
|
|
228
|
+
} else {
|
|
229
|
+
// Use the specified template name
|
|
230
|
+
tmpl = chat_template_.c_str();
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
// Convert messages to llama_chat_message format
|
|
234
|
+
std::vector<llama_chat_message> chat_messages;
|
|
235
|
+
chat_messages.reserve(messages.size());
|
|
236
|
+
for (const auto& msg : messages) {
|
|
237
|
+
llama_chat_message chat_msg;
|
|
238
|
+
chat_msg.role = msg.role.c_str();
|
|
239
|
+
chat_msg.content = msg.content.c_str();
|
|
240
|
+
chat_messages.push_back(chat_msg);
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// First call to get required buffer size
|
|
244
|
+
int32_t result_size = llama_chat_apply_template(
|
|
245
|
+
tmpl,
|
|
246
|
+
chat_messages.data(),
|
|
247
|
+
chat_messages.size(),
|
|
248
|
+
true, // add_ass: add assistant prompt
|
|
249
|
+
nullptr,
|
|
250
|
+
0
|
|
251
|
+
);
|
|
252
|
+
|
|
253
|
+
if (result_size < 0) {
|
|
254
|
+
// Template not supported, return empty string
|
|
255
|
+
return "";
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
// Allocate buffer and apply template
|
|
259
|
+
std::vector<char> buffer(result_size + 1);
|
|
260
|
+
llama_chat_apply_template(
|
|
261
|
+
tmpl,
|
|
262
|
+
chat_messages.data(),
|
|
263
|
+
chat_messages.size(),
|
|
264
|
+
true,
|
|
265
|
+
buffer.data(),
|
|
266
|
+
buffer.size()
|
|
267
|
+
);
|
|
268
|
+
|
|
269
|
+
return std::string(buffer.data(), result_size);
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
void LlamaModel::create_sampler(const GenerationParams& params) {
|
|
273
|
+
if (sampler_) {
|
|
274
|
+
llama_sampler_free(sampler_);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
// Create a sampler chain
|
|
278
|
+
sampler_ = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
|
279
|
+
|
|
280
|
+
// Add samplers to the chain
|
|
281
|
+
llama_sampler_chain_add(sampler_, llama_sampler_init_top_k(params.top_k));
|
|
282
|
+
llama_sampler_chain_add(sampler_, llama_sampler_init_top_p(params.top_p, 1));
|
|
283
|
+
llama_sampler_chain_add(sampler_, llama_sampler_init_temp(params.temperature));
|
|
284
|
+
llama_sampler_chain_add(sampler_, llama_sampler_init_dist(42)); // Random seed
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
std::vector<int32_t> LlamaModel::tokenize(const std::string& text, bool add_bos) {
|
|
288
|
+
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
289
|
+
|
|
290
|
+
// First, get the number of tokens needed
|
|
291
|
+
// When passing 0 for n_tokens_max, llama_tokenize returns negative of required size
|
|
292
|
+
int n_tokens = llama_tokenize(vocab, text.c_str(), text.length(), nullptr, 0, add_bos, true);
|
|
293
|
+
|
|
294
|
+
if (n_tokens < 0) {
|
|
295
|
+
n_tokens = -n_tokens; // Convert to positive size
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
if (n_tokens == 0) {
|
|
299
|
+
return {}; // Empty input
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
std::vector<int32_t> tokens(n_tokens);
|
|
303
|
+
int actual_tokens = llama_tokenize(vocab, text.c_str(), text.length(), tokens.data(), tokens.size(), add_bos, true);
|
|
304
|
+
|
|
305
|
+
if (actual_tokens < 0) {
|
|
306
|
+
// Buffer still too small, resize and try again
|
|
307
|
+
tokens.resize(-actual_tokens);
|
|
308
|
+
actual_tokens = llama_tokenize(vocab, text.c_str(), text.length(), tokens.data(), tokens.size(), add_bos, true);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
if (actual_tokens > 0) {
|
|
312
|
+
tokens.resize(actual_tokens);
|
|
313
|
+
} else {
|
|
314
|
+
tokens.clear();
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
return tokens;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
std::string LlamaModel::detokenize(int32_t token) {
|
|
321
|
+
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
322
|
+
|
|
323
|
+
char buf[256];
|
|
324
|
+
int n = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true);
|
|
325
|
+
if (n < 0) {
|
|
326
|
+
return "";
|
|
327
|
+
}
|
|
328
|
+
return std::string(buf, n);
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
bool LlamaModel::is_eos_token(int32_t token) {
|
|
332
|
+
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
333
|
+
return llama_vocab_is_eog(vocab, token);
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages, const GenerationParams& params) {
|
|
337
|
+
GenerationResult result;
|
|
338
|
+
result.finish_reason = "error";
|
|
339
|
+
|
|
340
|
+
if (!ctx_ || !model_) {
|
|
341
|
+
return result;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
// Apply chat template to get the prompt
|
|
345
|
+
std::string prompt = apply_chat_template(messages);
|
|
346
|
+
if (prompt.empty()) {
|
|
347
|
+
return result;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
// Tokenize the prompt
|
|
351
|
+
std::vector<int32_t> prompt_tokens = tokenize(prompt, true);
|
|
352
|
+
result.prompt_tokens = prompt_tokens.size();
|
|
353
|
+
result.completion_tokens = 0;
|
|
354
|
+
|
|
355
|
+
// Clear the memory/KV cache
|
|
356
|
+
llama_memory_t mem = llama_get_memory(ctx_);
|
|
357
|
+
if (mem) {
|
|
358
|
+
llama_memory_clear(mem, true);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
// Create sampler
|
|
362
|
+
create_sampler(params);
|
|
363
|
+
|
|
364
|
+
// Create batch for prompt processing
|
|
365
|
+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
|
366
|
+
|
|
367
|
+
if (llama_decode(ctx_, batch) != 0) {
|
|
368
|
+
return result;
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
// Generate tokens
|
|
372
|
+
std::string generated_text;
|
|
373
|
+
int n_cur = prompt_tokens.size();
|
|
374
|
+
|
|
375
|
+
for (int i = 0; i < params.max_tokens; i++) {
|
|
376
|
+
// Sample the next token
|
|
377
|
+
int32_t new_token = llama_sampler_sample(sampler_, ctx_, -1);
|
|
378
|
+
|
|
379
|
+
// Check for end of sequence
|
|
380
|
+
if (is_eos_token(new_token)) {
|
|
381
|
+
result.finish_reason = "stop";
|
|
382
|
+
break;
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
// Convert token to string
|
|
386
|
+
std::string token_str = detokenize(new_token);
|
|
387
|
+
generated_text += token_str;
|
|
388
|
+
result.completion_tokens++;
|
|
389
|
+
|
|
390
|
+
// Check for stop sequences
|
|
391
|
+
bool should_stop = false;
|
|
392
|
+
for (const auto& stop_seq : params.stop_sequences) {
|
|
393
|
+
if (generated_text.length() >= stop_seq.length()) {
|
|
394
|
+
if (generated_text.substr(generated_text.length() - stop_seq.length()) == stop_seq) {
|
|
395
|
+
// Remove the stop sequence from output
|
|
396
|
+
generated_text = generated_text.substr(0, generated_text.length() - stop_seq.length());
|
|
397
|
+
should_stop = true;
|
|
398
|
+
result.finish_reason = "stop";
|
|
399
|
+
break;
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
if (should_stop) break;
|
|
404
|
+
|
|
405
|
+
// Prepare for next iteration
|
|
406
|
+
batch = llama_batch_get_one(&new_token, 1);
|
|
407
|
+
if (llama_decode(ctx_, batch) != 0) {
|
|
408
|
+
break;
|
|
409
|
+
}
|
|
410
|
+
n_cur++;
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
if (result.finish_reason == "error" && result.completion_tokens >= params.max_tokens) {
|
|
414
|
+
result.finish_reason = "length";
|
|
415
|
+
} else if (result.finish_reason == "error") {
|
|
416
|
+
result.finish_reason = "stop";
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
result.text = generated_text;
|
|
420
|
+
return result;
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
GenerationResult LlamaModel::generate_streaming(
|
|
424
|
+
const std::vector<ChatMessage>& messages,
|
|
425
|
+
const GenerationParams& params,
|
|
426
|
+
TokenCallback callback
|
|
427
|
+
) {
|
|
428
|
+
GenerationResult result;
|
|
429
|
+
result.finish_reason = "error";
|
|
430
|
+
|
|
431
|
+
if (!ctx_ || !model_) {
|
|
432
|
+
return result;
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
// Apply chat template to get the prompt
|
|
436
|
+
std::string prompt = apply_chat_template(messages);
|
|
437
|
+
if (prompt.empty()) {
|
|
438
|
+
return result;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
// Tokenize the prompt
|
|
442
|
+
std::vector<int32_t> prompt_tokens = tokenize(prompt, true);
|
|
443
|
+
result.prompt_tokens = prompt_tokens.size();
|
|
444
|
+
result.completion_tokens = 0;
|
|
445
|
+
|
|
446
|
+
// Clear the memory/KV cache
|
|
447
|
+
llama_memory_t mem = llama_get_memory(ctx_);
|
|
448
|
+
if (mem) {
|
|
449
|
+
llama_memory_clear(mem, true);
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
// Create sampler
|
|
453
|
+
create_sampler(params);
|
|
454
|
+
|
|
455
|
+
// Create batch for prompt processing
|
|
456
|
+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
|
457
|
+
|
|
458
|
+
if (llama_decode(ctx_, batch) != 0) {
|
|
459
|
+
return result;
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
// Generate tokens
|
|
463
|
+
std::string generated_text;
|
|
464
|
+
int n_cur = prompt_tokens.size();
|
|
465
|
+
|
|
466
|
+
for (int i = 0; i < params.max_tokens; i++) {
|
|
467
|
+
// Sample the next token
|
|
468
|
+
int32_t new_token = llama_sampler_sample(sampler_, ctx_, -1);
|
|
469
|
+
|
|
470
|
+
// Check for end of sequence
|
|
471
|
+
if (is_eos_token(new_token)) {
|
|
472
|
+
result.finish_reason = "stop";
|
|
473
|
+
break;
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
// Convert token to string
|
|
477
|
+
std::string token_str = detokenize(new_token);
|
|
478
|
+
generated_text += token_str;
|
|
479
|
+
result.completion_tokens++;
|
|
480
|
+
|
|
481
|
+
// Call the callback with the new token
|
|
482
|
+
if (!callback(token_str)) {
|
|
483
|
+
result.finish_reason = "stop";
|
|
484
|
+
break;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
// Check for stop sequences
|
|
488
|
+
bool should_stop = false;
|
|
489
|
+
for (const auto& stop_seq : params.stop_sequences) {
|
|
490
|
+
if (generated_text.length() >= stop_seq.length()) {
|
|
491
|
+
if (generated_text.substr(generated_text.length() - stop_seq.length()) == stop_seq) {
|
|
492
|
+
should_stop = true;
|
|
493
|
+
result.finish_reason = "stop";
|
|
494
|
+
break;
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
if (should_stop) break;
|
|
499
|
+
|
|
500
|
+
// Prepare for next iteration
|
|
501
|
+
batch = llama_batch_get_one(&new_token, 1);
|
|
502
|
+
if (llama_decode(ctx_, batch) != 0) {
|
|
503
|
+
break;
|
|
504
|
+
}
|
|
505
|
+
n_cur++;
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
if (result.finish_reason == "error" && result.completion_tokens >= params.max_tokens) {
|
|
509
|
+
result.finish_reason = "length";
|
|
510
|
+
} else if (result.finish_reason == "error") {
|
|
511
|
+
result.finish_reason = "stop";
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
result.text = generated_text;
|
|
515
|
+
return result;
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
} // namespace llama_wrapper
|
|
519
|
+
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
#ifndef LLAMA_WRAPPER_H
|
|
2
|
+
#define LLAMA_WRAPPER_H
|
|
3
|
+
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <vector>
|
|
6
|
+
#include <functional>
|
|
7
|
+
#include <memory>
|
|
8
|
+
|
|
9
|
+
// Forward declarations for llama.cpp types
|
|
10
|
+
struct llama_model;
|
|
11
|
+
struct llama_context;
|
|
12
|
+
struct llama_sampler;
|
|
13
|
+
|
|
14
|
+
namespace llama_wrapper {
|
|
15
|
+
|
|
16
|
+
struct ModelParams {
|
|
17
|
+
std::string model_path;
|
|
18
|
+
int n_gpu_layers = 99; // Use GPU by default if available
|
|
19
|
+
bool use_mmap = true;
|
|
20
|
+
bool use_mlock = false;
|
|
21
|
+
bool debug = false; // Show verbose llama.cpp output
|
|
22
|
+
std::string chat_template = "auto"; // "auto" uses template from model, or specify a built-in template
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
struct ChatMessage {
|
|
26
|
+
std::string role;
|
|
27
|
+
std::string content;
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
struct ContextParams {
|
|
31
|
+
int n_ctx = 2048; // Context size
|
|
32
|
+
int n_batch = 512; // Batch size for prompt processing
|
|
33
|
+
int n_threads = 4; // Number of threads
|
|
34
|
+
bool embedding = false; // Enable embedding mode with mean pooling
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
struct GenerationParams {
|
|
38
|
+
int max_tokens = 256;
|
|
39
|
+
float temperature = 0.7f;
|
|
40
|
+
float top_p = 0.9f;
|
|
41
|
+
int top_k = 40;
|
|
42
|
+
float repeat_penalty = 1.1f;
|
|
43
|
+
std::vector<std::string> stop_sequences;
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
struct GenerationResult {
|
|
47
|
+
std::string text;
|
|
48
|
+
int prompt_tokens;
|
|
49
|
+
int completion_tokens;
|
|
50
|
+
std::string finish_reason; // "stop", "length", or "error"
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
struct EmbeddingResult {
|
|
54
|
+
std::vector<std::vector<float>> embeddings; // One embedding vector per input text
|
|
55
|
+
int total_tokens;
|
|
56
|
+
};
|
|
57
|
+
|
|
58
|
+
// Token callback for streaming: returns false to stop generation
|
|
59
|
+
using TokenCallback = std::function<bool(const std::string& token)>;
|
|
60
|
+
|
|
61
|
+
class LlamaModel {
|
|
62
|
+
public:
|
|
63
|
+
LlamaModel();
|
|
64
|
+
~LlamaModel();
|
|
65
|
+
|
|
66
|
+
// Disable copy
|
|
67
|
+
LlamaModel(const LlamaModel&) = delete;
|
|
68
|
+
LlamaModel& operator=(const LlamaModel&) = delete;
|
|
69
|
+
|
|
70
|
+
// Enable move
|
|
71
|
+
LlamaModel(LlamaModel&& other) noexcept;
|
|
72
|
+
LlamaModel& operator=(LlamaModel&& other) noexcept;
|
|
73
|
+
|
|
74
|
+
// Load a model from a GGUF file
|
|
75
|
+
bool load(const ModelParams& params);
|
|
76
|
+
|
|
77
|
+
// Check if model is loaded
|
|
78
|
+
bool is_loaded() const;
|
|
79
|
+
|
|
80
|
+
// Unload the model
|
|
81
|
+
void unload();
|
|
82
|
+
|
|
83
|
+
// Get the model path
|
|
84
|
+
const std::string& get_model_path() const { return model_path_; }
|
|
85
|
+
|
|
86
|
+
// Create a context for inference (or embeddings if params.embedding is true)
|
|
87
|
+
bool create_context(const ContextParams& params);
|
|
88
|
+
|
|
89
|
+
// Apply chat template to messages and return formatted prompt
|
|
90
|
+
std::string apply_chat_template(const std::vector<ChatMessage>& messages);
|
|
91
|
+
|
|
92
|
+
// Generate text from messages (non-streaming)
|
|
93
|
+
GenerationResult generate(const std::vector<ChatMessage>& messages, const GenerationParams& params);
|
|
94
|
+
|
|
95
|
+
// Generate text from messages (streaming)
|
|
96
|
+
GenerationResult generate_streaming(
|
|
97
|
+
const std::vector<ChatMessage>& messages,
|
|
98
|
+
const GenerationParams& params,
|
|
99
|
+
TokenCallback callback
|
|
100
|
+
);
|
|
101
|
+
|
|
102
|
+
// Generate embeddings for multiple texts
|
|
103
|
+
EmbeddingResult embed(const std::vector<std::string>& texts);
|
|
104
|
+
|
|
105
|
+
private:
|
|
106
|
+
llama_model* model_ = nullptr;
|
|
107
|
+
llama_context* ctx_ = nullptr;
|
|
108
|
+
llama_sampler* sampler_ = nullptr;
|
|
109
|
+
std::string model_path_;
|
|
110
|
+
std::string chat_template_;
|
|
111
|
+
|
|
112
|
+
// Tokenize a string
|
|
113
|
+
std::vector<int32_t> tokenize(const std::string& text, bool add_bos);
|
|
114
|
+
|
|
115
|
+
// Normalize an embedding vector (L2 normalization)
|
|
116
|
+
static void normalize_embedding(float* embedding, int n_embd);
|
|
117
|
+
|
|
118
|
+
// Detokenize a single token
|
|
119
|
+
std::string detokenize(int32_t token);
|
|
120
|
+
|
|
121
|
+
// Create sampler with given params
|
|
122
|
+
void create_sampler(const GenerationParams& params);
|
|
123
|
+
|
|
124
|
+
// Check if token is end-of-sequence
|
|
125
|
+
bool is_eos_token(int32_t token);
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
} // namespace llama_wrapper
|
|
129
|
+
|
|
130
|
+
#endif // LLAMA_WRAPPER_H
|
|
131
|
+
|