@shipworthy/ai-sdk-llama-cpp 0.2.3 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/llama-cpp-embedding-model.d.ts +7 -0
- package/dist/llama-cpp-embedding-model.d.ts.map +1 -1
- package/dist/llama-cpp-embedding-model.js +12 -2
- package/dist/llama-cpp-embedding-model.js.map +1 -1
- package/dist/llama-cpp-language-model.d.ts +7 -0
- package/dist/llama-cpp-language-model.d.ts.map +1 -1
- package/dist/llama-cpp-language-model.js +12 -2
- package/dist/llama-cpp-language-model.js.map +1 -1
- package/dist/native-binding.d.ts +5 -0
- package/dist/native-binding.d.ts.map +1 -1
- package/dist/native-binding.js +3 -0
- package/dist/native-binding.js.map +1 -1
- package/native/binding.cpp +187 -183
- package/native/llama-wrapper.cpp +185 -124
- package/native/llama-wrapper.h +48 -48
- package/package.json +1 -1
package/native/llama-wrapper.cpp
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
#include "llama-wrapper.h"
|
|
2
|
+
|
|
2
3
|
#include "llama.h"
|
|
4
|
+
|
|
3
5
|
#include <algorithm>
|
|
6
|
+
#include <cmath>
|
|
7
|
+
#include <cstdio>
|
|
4
8
|
#include <cstring>
|
|
5
9
|
#include <stdexcept>
|
|
6
|
-
#include <cstdio>
|
|
7
|
-
#include <cmath>
|
|
8
10
|
|
|
9
11
|
namespace llama_wrapper {
|
|
10
12
|
|
|
@@ -12,9 +14,9 @@ namespace llama_wrapper {
|
|
|
12
14
|
static bool g_debug_mode = false;
|
|
13
15
|
|
|
14
16
|
// Custom log callback that respects debug mode
|
|
15
|
-
static void llama_log_callback(ggml_log_level level, const char* text, void* user_data) {
|
|
16
|
-
(void)level;
|
|
17
|
-
(void)user_data;
|
|
17
|
+
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
|
|
18
|
+
(void) level;
|
|
19
|
+
(void) user_data;
|
|
18
20
|
if (g_debug_mode) {
|
|
19
21
|
fprintf(stderr, "%s", text);
|
|
20
22
|
}
|
|
@@ -26,33 +28,33 @@ LlamaModel::~LlamaModel() {
|
|
|
26
28
|
unload();
|
|
27
29
|
}
|
|
28
30
|
|
|
29
|
-
LlamaModel::LlamaModel(LlamaModel&& other) noexcept
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
other.model_
|
|
36
|
-
other.ctx_
|
|
31
|
+
LlamaModel::LlamaModel(LlamaModel && other) noexcept :
|
|
32
|
+
model_(other.model_),
|
|
33
|
+
ctx_(other.ctx_),
|
|
34
|
+
sampler_(other.sampler_),
|
|
35
|
+
model_path_(std::move(other.model_path_)),
|
|
36
|
+
chat_template_(std::move(other.chat_template_)) {
|
|
37
|
+
other.model_ = nullptr;
|
|
38
|
+
other.ctx_ = nullptr;
|
|
37
39
|
other.sampler_ = nullptr;
|
|
38
40
|
}
|
|
39
41
|
|
|
40
|
-
LlamaModel& LlamaModel::operator=(LlamaModel&& other) noexcept {
|
|
42
|
+
LlamaModel & LlamaModel::operator=(LlamaModel && other) noexcept {
|
|
41
43
|
if (this != &other) {
|
|
42
44
|
unload();
|
|
43
|
-
model_
|
|
44
|
-
ctx_
|
|
45
|
-
sampler_
|
|
46
|
-
model_path_
|
|
45
|
+
model_ = other.model_;
|
|
46
|
+
ctx_ = other.ctx_;
|
|
47
|
+
sampler_ = other.sampler_;
|
|
48
|
+
model_path_ = std::move(other.model_path_);
|
|
47
49
|
chat_template_ = std::move(other.chat_template_);
|
|
48
|
-
other.model_
|
|
49
|
-
other.ctx_
|
|
50
|
+
other.model_ = nullptr;
|
|
51
|
+
other.ctx_ = nullptr;
|
|
50
52
|
other.sampler_ = nullptr;
|
|
51
53
|
}
|
|
52
54
|
return *this;
|
|
53
55
|
}
|
|
54
56
|
|
|
55
|
-
bool LlamaModel::load(const ModelParams& params) {
|
|
57
|
+
bool LlamaModel::load(const ModelParams & params) {
|
|
56
58
|
if (model_) {
|
|
57
59
|
unload();
|
|
58
60
|
}
|
|
@@ -66,9 +68,9 @@ bool LlamaModel::load(const ModelParams& params) {
|
|
|
66
68
|
|
|
67
69
|
// Set up model parameters
|
|
68
70
|
llama_model_params model_params = llama_model_default_params();
|
|
69
|
-
model_params.n_gpu_layers
|
|
70
|
-
model_params.use_mmap
|
|
71
|
-
model_params.use_mlock
|
|
71
|
+
model_params.n_gpu_layers = params.n_gpu_layers;
|
|
72
|
+
model_params.use_mmap = params.use_mmap;
|
|
73
|
+
model_params.use_mlock = params.use_mlock;
|
|
72
74
|
|
|
73
75
|
// Load the model
|
|
74
76
|
model_ = llama_model_load_from_file(params.model_path.c_str(), model_params);
|
|
@@ -76,7 +78,7 @@ bool LlamaModel::load(const ModelParams& params) {
|
|
|
76
78
|
return false;
|
|
77
79
|
}
|
|
78
80
|
|
|
79
|
-
model_path_
|
|
81
|
+
model_path_ = params.model_path;
|
|
80
82
|
chat_template_ = params.chat_template;
|
|
81
83
|
return true;
|
|
82
84
|
}
|
|
@@ -102,7 +104,7 @@ void LlamaModel::unload() {
|
|
|
102
104
|
model_path_.clear();
|
|
103
105
|
}
|
|
104
106
|
|
|
105
|
-
bool LlamaModel::create_context(const ContextParams& params) {
|
|
107
|
+
bool LlamaModel::create_context(const ContextParams & params) {
|
|
106
108
|
if (!model_) {
|
|
107
109
|
return false;
|
|
108
110
|
}
|
|
@@ -113,21 +115,33 @@ bool LlamaModel::create_context(const ContextParams& params) {
|
|
|
113
115
|
}
|
|
114
116
|
|
|
115
117
|
llama_context_params ctx_params = llama_context_default_params();
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
118
|
+
|
|
119
|
+
// Only override defaults if non-zero values are provided
|
|
120
|
+
if (params.n_ctx > 0) {
|
|
121
|
+
ctx_params.n_ctx = params.n_ctx;
|
|
122
|
+
}
|
|
123
|
+
if (params.n_batch > 0) {
|
|
124
|
+
ctx_params.n_batch = params.n_batch;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
ctx_params.n_threads = params.n_threads;
|
|
119
128
|
ctx_params.n_threads_batch = params.n_threads;
|
|
120
|
-
|
|
129
|
+
|
|
121
130
|
if (params.embedding) {
|
|
122
|
-
ctx_params.embeddings
|
|
131
|
+
ctx_params.embeddings = true;
|
|
123
132
|
ctx_params.pooling_type = LLAMA_POOLING_TYPE_MEAN;
|
|
133
|
+
// For embeddings, batch size must be at least as large as context size
|
|
134
|
+
// (see llama.cpp/examples/embedding/embedding.cpp)
|
|
135
|
+
if (ctx_params.n_batch < ctx_params.n_ctx) {
|
|
136
|
+
ctx_params.n_batch = ctx_params.n_ctx;
|
|
137
|
+
}
|
|
124
138
|
}
|
|
125
139
|
|
|
126
140
|
ctx_ = llama_init_from_model(model_, ctx_params);
|
|
127
141
|
return ctx_ != nullptr;
|
|
128
142
|
}
|
|
129
143
|
|
|
130
|
-
void LlamaModel::normalize_embedding(float* embedding, int n_embd) {
|
|
144
|
+
void LlamaModel::normalize_embedding(float * embedding, int n_embd) {
|
|
131
145
|
float sum = 0.0f;
|
|
132
146
|
for (int i = 0; i < n_embd; i++) {
|
|
133
147
|
sum += embedding[i] * embedding[i];
|
|
@@ -140,7 +154,58 @@ void LlamaModel::normalize_embedding(float* embedding, int n_embd) {
|
|
|
140
154
|
}
|
|
141
155
|
}
|
|
142
156
|
|
|
143
|
-
|
|
157
|
+
std::vector<float> LlamaModel::embed_chunk(const std::vector<int32_t> & tokens,
|
|
158
|
+
int seq_id,
|
|
159
|
+
int n_embd,
|
|
160
|
+
int pooling_type) {
|
|
161
|
+
std::vector<float> embedding(n_embd, 0.0f);
|
|
162
|
+
|
|
163
|
+
if (tokens.empty()) {
|
|
164
|
+
return embedding;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
// Clear the memory/KV cache
|
|
168
|
+
llama_memory_t mem = llama_get_memory(ctx_);
|
|
169
|
+
if (mem) {
|
|
170
|
+
llama_memory_clear(mem, true);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
// Create batch with sequence ID
|
|
174
|
+
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
|
175
|
+
for (size_t i = 0; i < tokens.size(); i++) {
|
|
176
|
+
batch.token[i] = tokens[i];
|
|
177
|
+
batch.pos[i] = i;
|
|
178
|
+
batch.n_seq_id[i] = 1;
|
|
179
|
+
batch.seq_id[i][0] = seq_id;
|
|
180
|
+
batch.logits[i] = true; // We want embeddings for all tokens
|
|
181
|
+
}
|
|
182
|
+
batch.n_tokens = tokens.size();
|
|
183
|
+
|
|
184
|
+
// Decode to get embeddings
|
|
185
|
+
if (llama_decode(ctx_, batch) != 0) {
|
|
186
|
+
llama_batch_free(batch);
|
|
187
|
+
return embedding; // Return zero embedding on failure
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
// Extract embedding based on pooling type
|
|
191
|
+
const float * embd = nullptr;
|
|
192
|
+
if (static_cast<enum llama_pooling_type>(pooling_type) == LLAMA_POOLING_TYPE_NONE) {
|
|
193
|
+
// Get embedding for last token
|
|
194
|
+
embd = llama_get_embeddings_ith(ctx_, tokens.size() - 1);
|
|
195
|
+
} else {
|
|
196
|
+
// Get pooled embedding for the sequence
|
|
197
|
+
embd = llama_get_embeddings_seq(ctx_, seq_id);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
if (embd) {
|
|
201
|
+
std::copy(embd, embd + n_embd, embedding.begin());
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
llama_batch_free(batch);
|
|
205
|
+
return embedding;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
EmbeddingResult LlamaModel::embed(const std::vector<std::string> & texts) {
|
|
144
209
|
EmbeddingResult result;
|
|
145
210
|
result.total_tokens = 0;
|
|
146
211
|
|
|
@@ -148,12 +213,17 @@ EmbeddingResult LlamaModel::embed(const std::vector<std::string>& texts) {
|
|
|
148
213
|
return result;
|
|
149
214
|
}
|
|
150
215
|
|
|
151
|
-
const int
|
|
216
|
+
const int n_embd = llama_model_n_embd(model_);
|
|
152
217
|
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx_);
|
|
153
218
|
|
|
219
|
+
// Get context size for chunking
|
|
220
|
+
const int n_ctx = llama_n_ctx(ctx_);
|
|
221
|
+
const int overlap = n_ctx / 10; // 10% overlap between chunks
|
|
222
|
+
const int step = n_ctx - overlap;
|
|
223
|
+
|
|
154
224
|
// Process each text
|
|
155
225
|
for (size_t seq_id = 0; seq_id < texts.size(); seq_id++) {
|
|
156
|
-
const std::string& text = texts[seq_id];
|
|
226
|
+
const std::string & text = texts[seq_id];
|
|
157
227
|
|
|
158
228
|
// Tokenize the text
|
|
159
229
|
std::vector<int32_t> tokens = tokenize(text, true);
|
|
@@ -165,63 +235,64 @@ EmbeddingResult LlamaModel::embed(const std::vector<std::string>& texts) {
|
|
|
165
235
|
continue;
|
|
166
236
|
}
|
|
167
237
|
|
|
168
|
-
//
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
238
|
+
// Check if text fits in context (no chunking needed)
|
|
239
|
+
if (static_cast<int>(tokens.size()) <= n_ctx) {
|
|
240
|
+
// Process single chunk
|
|
241
|
+
std::vector<float> embedding = embed_chunk(tokens, seq_id, n_embd, pooling_type);
|
|
242
|
+
normalize_embedding(embedding.data(), n_embd);
|
|
243
|
+
result.embeddings.push_back(std::move(embedding));
|
|
244
|
+
} else {
|
|
245
|
+
// Text exceeds context size - split into overlapping chunks
|
|
246
|
+
std::vector<std::vector<float>> chunk_embeddings;
|
|
173
247
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
batch.token[i] = tokens[i];
|
|
178
|
-
batch.pos[i] = i;
|
|
179
|
-
batch.n_seq_id[i] = 1;
|
|
180
|
-
batch.seq_id[i][0] = seq_id;
|
|
181
|
-
batch.logits[i] = true; // We want embeddings for all tokens
|
|
182
|
-
}
|
|
183
|
-
batch.n_tokens = tokens.size();
|
|
248
|
+
for (size_t start = 0; start < tokens.size(); start += step) {
|
|
249
|
+
// Calculate chunk end position
|
|
250
|
+
size_t end = std::min(start + n_ctx, tokens.size());
|
|
184
251
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
llama_batch_free(batch);
|
|
188
|
-
result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
|
|
189
|
-
continue;
|
|
190
|
-
}
|
|
252
|
+
// Extract chunk tokens
|
|
253
|
+
std::vector<int32_t> chunk_tokens(tokens.begin() + start, tokens.begin() + end);
|
|
191
254
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
// Get embedding for last token
|
|
196
|
-
embd = llama_get_embeddings_ith(ctx_, tokens.size() - 1);
|
|
197
|
-
} else {
|
|
198
|
-
// Get pooled embedding for the sequence
|
|
199
|
-
embd = llama_get_embeddings_seq(ctx_, seq_id);
|
|
200
|
-
}
|
|
255
|
+
// Get embedding for this chunk
|
|
256
|
+
std::vector<float> chunk_emb = embed_chunk(chunk_tokens, seq_id, n_embd, pooling_type);
|
|
257
|
+
chunk_embeddings.push_back(std::move(chunk_emb));
|
|
201
258
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
result.embeddings.push_back(std::move(embedding));
|
|
208
|
-
} else {
|
|
209
|
-
result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
|
|
210
|
-
}
|
|
259
|
+
// If this chunk reached the end, we're done
|
|
260
|
+
if (end == tokens.size()) {
|
|
261
|
+
break;
|
|
262
|
+
}
|
|
263
|
+
}
|
|
211
264
|
|
|
212
|
-
|
|
265
|
+
// Mean-pool all chunk embeddings
|
|
266
|
+
std::vector<float> final_embedding(n_embd, 0.0f);
|
|
267
|
+
if (!chunk_embeddings.empty()) {
|
|
268
|
+
for (const auto & chunk_emb : chunk_embeddings) {
|
|
269
|
+
for (int i = 0; i < n_embd; i++) {
|
|
270
|
+
final_embedding[i] += chunk_emb[i];
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
// Divide by number of chunks to get mean
|
|
274
|
+
float num_chunks = static_cast<float>(chunk_embeddings.size());
|
|
275
|
+
for (int i = 0; i < n_embd; i++) {
|
|
276
|
+
final_embedding[i] /= num_chunks;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Normalize the final averaged embedding
|
|
281
|
+
normalize_embedding(final_embedding.data(), n_embd);
|
|
282
|
+
result.embeddings.push_back(std::move(final_embedding));
|
|
283
|
+
}
|
|
213
284
|
}
|
|
214
285
|
|
|
215
286
|
return result;
|
|
216
287
|
}
|
|
217
288
|
|
|
218
|
-
std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage
|
|
289
|
+
std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage> & messages) {
|
|
219
290
|
if (!model_) {
|
|
220
291
|
return "";
|
|
221
292
|
}
|
|
222
293
|
|
|
223
294
|
// Determine which template to use
|
|
224
|
-
const char* tmpl = nullptr;
|
|
295
|
+
const char * tmpl = nullptr;
|
|
225
296
|
if (chat_template_ == "auto") {
|
|
226
297
|
// Use the template embedded in the model
|
|
227
298
|
tmpl = llama_model_chat_template(model_, nullptr);
|
|
@@ -233,22 +304,17 @@ std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage>& mess
|
|
|
233
304
|
// Convert messages to llama_chat_message format
|
|
234
305
|
std::vector<llama_chat_message> chat_messages;
|
|
235
306
|
chat_messages.reserve(messages.size());
|
|
236
|
-
for (const auto& msg : messages) {
|
|
307
|
+
for (const auto & msg : messages) {
|
|
237
308
|
llama_chat_message chat_msg;
|
|
238
|
-
chat_msg.role
|
|
309
|
+
chat_msg.role = msg.role.c_str();
|
|
239
310
|
chat_msg.content = msg.content.c_str();
|
|
240
311
|
chat_messages.push_back(chat_msg);
|
|
241
312
|
}
|
|
242
313
|
|
|
243
314
|
// First call to get required buffer size
|
|
244
|
-
int32_t result_size = llama_chat_apply_template(
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
chat_messages.size(),
|
|
248
|
-
true, // add_ass: add assistant prompt
|
|
249
|
-
nullptr,
|
|
250
|
-
0
|
|
251
|
-
);
|
|
315
|
+
int32_t result_size = llama_chat_apply_template(tmpl, chat_messages.data(), chat_messages.size(),
|
|
316
|
+
true, // add_ass: add assistant prompt
|
|
317
|
+
nullptr, 0);
|
|
252
318
|
|
|
253
319
|
if (result_size < 0) {
|
|
254
320
|
// Template not supported, return empty string
|
|
@@ -257,19 +323,12 @@ std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage>& mess
|
|
|
257
323
|
|
|
258
324
|
// Allocate buffer and apply template
|
|
259
325
|
std::vector<char> buffer(result_size + 1);
|
|
260
|
-
llama_chat_apply_template(
|
|
261
|
-
tmpl,
|
|
262
|
-
chat_messages.data(),
|
|
263
|
-
chat_messages.size(),
|
|
264
|
-
true,
|
|
265
|
-
buffer.data(),
|
|
266
|
-
buffer.size()
|
|
267
|
-
);
|
|
326
|
+
llama_chat_apply_template(tmpl, chat_messages.data(), chat_messages.size(), true, buffer.data(), buffer.size());
|
|
268
327
|
|
|
269
328
|
return std::string(buffer.data(), result_size);
|
|
270
329
|
}
|
|
271
330
|
|
|
272
|
-
void LlamaModel::create_sampler(const GenerationParams& params) {
|
|
331
|
+
void LlamaModel::create_sampler(const GenerationParams & params) {
|
|
273
332
|
if (sampler_) {
|
|
274
333
|
llama_sampler_free(sampler_);
|
|
275
334
|
}
|
|
@@ -281,14 +340,15 @@ void LlamaModel::create_sampler(const GenerationParams& params) {
|
|
|
281
340
|
llama_sampler_chain_add(sampler_, llama_sampler_init_top_k(params.top_k));
|
|
282
341
|
llama_sampler_chain_add(sampler_, llama_sampler_init_top_p(params.top_p, 1));
|
|
283
342
|
llama_sampler_chain_add(sampler_, llama_sampler_init_temp(params.temperature));
|
|
284
|
-
llama_sampler_chain_add(sampler_, llama_sampler_init_dist(42));
|
|
343
|
+
llama_sampler_chain_add(sampler_, llama_sampler_init_dist(42)); // Random seed
|
|
285
344
|
}
|
|
286
345
|
|
|
287
|
-
std::vector<int32_t> LlamaModel::tokenize(const std::string& text, bool add_bos) {
|
|
288
|
-
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
346
|
+
std::vector<int32_t> LlamaModel::tokenize(const std::string & text, bool add_bos) {
|
|
347
|
+
const llama_vocab * vocab = llama_model_get_vocab(model_);
|
|
289
348
|
|
|
290
349
|
// First, get the number of tokens needed
|
|
291
|
-
// When passing 0 for n_tokens_max, llama_tokenize returns negative of
|
|
350
|
+
// When passing 0 for n_tokens_max, llama_tokenize returns negative of
|
|
351
|
+
// required size
|
|
292
352
|
int n_tokens = llama_tokenize(vocab, text.c_str(), text.length(), nullptr, 0, add_bos, true);
|
|
293
353
|
|
|
294
354
|
if (n_tokens < 0) {
|
|
@@ -318,10 +378,10 @@ std::vector<int32_t> LlamaModel::tokenize(const std::string& text, bool add_bos)
|
|
|
318
378
|
}
|
|
319
379
|
|
|
320
380
|
std::string LlamaModel::detokenize(int32_t token) {
|
|
321
|
-
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
381
|
+
const llama_vocab * vocab = llama_model_get_vocab(model_);
|
|
322
382
|
|
|
323
383
|
char buf[256];
|
|
324
|
-
int
|
|
384
|
+
int n = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true);
|
|
325
385
|
if (n < 0) {
|
|
326
386
|
return "";
|
|
327
387
|
}
|
|
@@ -329,11 +389,11 @@ std::string LlamaModel::detokenize(int32_t token) {
|
|
|
329
389
|
}
|
|
330
390
|
|
|
331
391
|
bool LlamaModel::is_eos_token(int32_t token) {
|
|
332
|
-
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
392
|
+
const llama_vocab * vocab = llama_model_get_vocab(model_);
|
|
333
393
|
return llama_vocab_is_eog(vocab, token);
|
|
334
394
|
}
|
|
335
395
|
|
|
336
|
-
GenerationResult LlamaModel::generate(const std::vector<ChatMessage
|
|
396
|
+
GenerationResult LlamaModel::generate(const std::vector<ChatMessage> & messages, const GenerationParams & params) {
|
|
337
397
|
GenerationResult result;
|
|
338
398
|
result.finish_reason = "error";
|
|
339
399
|
|
|
@@ -349,8 +409,8 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
|
|
|
349
409
|
|
|
350
410
|
// Tokenize the prompt
|
|
351
411
|
std::vector<int32_t> prompt_tokens = tokenize(prompt, true);
|
|
352
|
-
result.prompt_tokens
|
|
353
|
-
result.completion_tokens
|
|
412
|
+
result.prompt_tokens = prompt_tokens.size();
|
|
413
|
+
result.completion_tokens = 0;
|
|
354
414
|
|
|
355
415
|
// Clear the memory/KV cache
|
|
356
416
|
llama_memory_t mem = llama_get_memory(ctx_);
|
|
@@ -370,7 +430,7 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
|
|
|
370
430
|
|
|
371
431
|
// Generate tokens
|
|
372
432
|
std::string generated_text;
|
|
373
|
-
int
|
|
433
|
+
int n_cur = prompt_tokens.size();
|
|
374
434
|
|
|
375
435
|
for (int i = 0; i < params.max_tokens; i++) {
|
|
376
436
|
// Sample the next token
|
|
@@ -389,18 +449,20 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
|
|
|
389
449
|
|
|
390
450
|
// Check for stop sequences
|
|
391
451
|
bool should_stop = false;
|
|
392
|
-
for (const auto& stop_seq : params.stop_sequences) {
|
|
452
|
+
for (const auto & stop_seq : params.stop_sequences) {
|
|
393
453
|
if (generated_text.length() >= stop_seq.length()) {
|
|
394
454
|
if (generated_text.substr(generated_text.length() - stop_seq.length()) == stop_seq) {
|
|
395
455
|
// Remove the stop sequence from output
|
|
396
|
-
generated_text
|
|
397
|
-
should_stop
|
|
456
|
+
generated_text = generated_text.substr(0, generated_text.length() - stop_seq.length());
|
|
457
|
+
should_stop = true;
|
|
398
458
|
result.finish_reason = "stop";
|
|
399
459
|
break;
|
|
400
460
|
}
|
|
401
461
|
}
|
|
402
462
|
}
|
|
403
|
-
if (should_stop)
|
|
463
|
+
if (should_stop) {
|
|
464
|
+
break;
|
|
465
|
+
}
|
|
404
466
|
|
|
405
467
|
// Prepare for next iteration
|
|
406
468
|
batch = llama_batch_get_one(&new_token, 1);
|
|
@@ -420,11 +482,9 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
|
|
|
420
482
|
return result;
|
|
421
483
|
}
|
|
422
484
|
|
|
423
|
-
GenerationResult LlamaModel::generate_streaming(
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
TokenCallback callback
|
|
427
|
-
) {
|
|
485
|
+
GenerationResult LlamaModel::generate_streaming(const std::vector<ChatMessage> & messages,
|
|
486
|
+
const GenerationParams & params,
|
|
487
|
+
TokenCallback callback) {
|
|
428
488
|
GenerationResult result;
|
|
429
489
|
result.finish_reason = "error";
|
|
430
490
|
|
|
@@ -440,8 +500,8 @@ GenerationResult LlamaModel::generate_streaming(
|
|
|
440
500
|
|
|
441
501
|
// Tokenize the prompt
|
|
442
502
|
std::vector<int32_t> prompt_tokens = tokenize(prompt, true);
|
|
443
|
-
result.prompt_tokens
|
|
444
|
-
result.completion_tokens
|
|
503
|
+
result.prompt_tokens = prompt_tokens.size();
|
|
504
|
+
result.completion_tokens = 0;
|
|
445
505
|
|
|
446
506
|
// Clear the memory/KV cache
|
|
447
507
|
llama_memory_t mem = llama_get_memory(ctx_);
|
|
@@ -461,7 +521,7 @@ GenerationResult LlamaModel::generate_streaming(
|
|
|
461
521
|
|
|
462
522
|
// Generate tokens
|
|
463
523
|
std::string generated_text;
|
|
464
|
-
int
|
|
524
|
+
int n_cur = prompt_tokens.size();
|
|
465
525
|
|
|
466
526
|
for (int i = 0; i < params.max_tokens; i++) {
|
|
467
527
|
// Sample the next token
|
|
@@ -486,16 +546,18 @@ GenerationResult LlamaModel::generate_streaming(
|
|
|
486
546
|
|
|
487
547
|
// Check for stop sequences
|
|
488
548
|
bool should_stop = false;
|
|
489
|
-
for (const auto& stop_seq : params.stop_sequences) {
|
|
549
|
+
for (const auto & stop_seq : params.stop_sequences) {
|
|
490
550
|
if (generated_text.length() >= stop_seq.length()) {
|
|
491
551
|
if (generated_text.substr(generated_text.length() - stop_seq.length()) == stop_seq) {
|
|
492
|
-
should_stop
|
|
552
|
+
should_stop = true;
|
|
493
553
|
result.finish_reason = "stop";
|
|
494
554
|
break;
|
|
495
555
|
}
|
|
496
556
|
}
|
|
497
557
|
}
|
|
498
|
-
if (should_stop)
|
|
558
|
+
if (should_stop) {
|
|
559
|
+
break;
|
|
560
|
+
}
|
|
499
561
|
|
|
500
562
|
// Prepare for next iteration
|
|
501
563
|
batch = llama_batch_get_one(&new_token, 1);
|
|
@@ -515,5 +577,4 @@ GenerationResult LlamaModel::generate_streaming(
|
|
|
515
577
|
return result;
|
|
516
578
|
}
|
|
517
579
|
|
|
518
|
-
}
|
|
519
|
-
|
|
580
|
+
} // namespace llama_wrapper
|