@shipworthy/ai-sdk-llama-cpp 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,12 @@
1
1
  #include "llama-wrapper.h"
2
+
2
3
  #include "llama.h"
4
+
3
5
  #include <algorithm>
6
+ #include <cmath>
7
+ #include <cstdio>
4
8
  #include <cstring>
5
9
  #include <stdexcept>
6
- #include <cstdio>
7
- #include <cmath>
8
10
 
9
11
  namespace llama_wrapper {
10
12
 
@@ -12,9 +14,9 @@ namespace llama_wrapper {
12
14
  static bool g_debug_mode = false;
13
15
 
14
16
  // Custom log callback that respects debug mode
15
- static void llama_log_callback(ggml_log_level level, const char* text, void* user_data) {
16
- (void)level;
17
- (void)user_data;
17
+ static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
18
+ (void) level;
19
+ (void) user_data;
18
20
  if (g_debug_mode) {
19
21
  fprintf(stderr, "%s", text);
20
22
  }
@@ -26,33 +28,33 @@ LlamaModel::~LlamaModel() {
26
28
  unload();
27
29
  }
28
30
 
29
- LlamaModel::LlamaModel(LlamaModel&& other) noexcept
30
- : model_(other.model_)
31
- , ctx_(other.ctx_)
32
- , sampler_(other.sampler_)
33
- , model_path_(std::move(other.model_path_))
34
- , chat_template_(std::move(other.chat_template_)) {
35
- other.model_ = nullptr;
36
- other.ctx_ = nullptr;
31
+ LlamaModel::LlamaModel(LlamaModel && other) noexcept :
32
+ model_(other.model_),
33
+ ctx_(other.ctx_),
34
+ sampler_(other.sampler_),
35
+ model_path_(std::move(other.model_path_)),
36
+ chat_template_(std::move(other.chat_template_)) {
37
+ other.model_ = nullptr;
38
+ other.ctx_ = nullptr;
37
39
  other.sampler_ = nullptr;
38
40
  }
39
41
 
40
- LlamaModel& LlamaModel::operator=(LlamaModel&& other) noexcept {
42
+ LlamaModel & LlamaModel::operator=(LlamaModel && other) noexcept {
41
43
  if (this != &other) {
42
44
  unload();
43
- model_ = other.model_;
44
- ctx_ = other.ctx_;
45
- sampler_ = other.sampler_;
46
- model_path_ = std::move(other.model_path_);
45
+ model_ = other.model_;
46
+ ctx_ = other.ctx_;
47
+ sampler_ = other.sampler_;
48
+ model_path_ = std::move(other.model_path_);
47
49
  chat_template_ = std::move(other.chat_template_);
48
- other.model_ = nullptr;
49
- other.ctx_ = nullptr;
50
+ other.model_ = nullptr;
51
+ other.ctx_ = nullptr;
50
52
  other.sampler_ = nullptr;
51
53
  }
52
54
  return *this;
53
55
  }
54
56
 
55
- bool LlamaModel::load(const ModelParams& params) {
57
+ bool LlamaModel::load(const ModelParams & params) {
56
58
  if (model_) {
57
59
  unload();
58
60
  }
@@ -66,9 +68,9 @@ bool LlamaModel::load(const ModelParams& params) {
66
68
 
67
69
  // Set up model parameters
68
70
  llama_model_params model_params = llama_model_default_params();
69
- model_params.n_gpu_layers = params.n_gpu_layers;
70
- model_params.use_mmap = params.use_mmap;
71
- model_params.use_mlock = params.use_mlock;
71
+ model_params.n_gpu_layers = params.n_gpu_layers;
72
+ model_params.use_mmap = params.use_mmap;
73
+ model_params.use_mlock = params.use_mlock;
72
74
 
73
75
  // Load the model
74
76
  model_ = llama_model_load_from_file(params.model_path.c_str(), model_params);
@@ -76,7 +78,7 @@ bool LlamaModel::load(const ModelParams& params) {
76
78
  return false;
77
79
  }
78
80
 
79
- model_path_ = params.model_path;
81
+ model_path_ = params.model_path;
80
82
  chat_template_ = params.chat_template;
81
83
  return true;
82
84
  }
@@ -102,7 +104,7 @@ void LlamaModel::unload() {
102
104
  model_path_.clear();
103
105
  }
104
106
 
105
- bool LlamaModel::create_context(const ContextParams& params) {
107
+ bool LlamaModel::create_context(const ContextParams & params) {
106
108
  if (!model_) {
107
109
  return false;
108
110
  }
@@ -113,21 +115,33 @@ bool LlamaModel::create_context(const ContextParams& params) {
113
115
  }
114
116
 
115
117
  llama_context_params ctx_params = llama_context_default_params();
116
- ctx_params.n_ctx = params.n_ctx;
117
- ctx_params.n_batch = params.n_batch;
118
- ctx_params.n_threads = params.n_threads;
118
+
119
+ // Only override defaults if non-zero values are provided
120
+ if (params.n_ctx > 0) {
121
+ ctx_params.n_ctx = params.n_ctx;
122
+ }
123
+ if (params.n_batch > 0) {
124
+ ctx_params.n_batch = params.n_batch;
125
+ }
126
+
127
+ ctx_params.n_threads = params.n_threads;
119
128
  ctx_params.n_threads_batch = params.n_threads;
120
-
129
+
121
130
  if (params.embedding) {
122
- ctx_params.embeddings = true;
131
+ ctx_params.embeddings = true;
123
132
  ctx_params.pooling_type = LLAMA_POOLING_TYPE_MEAN;
133
+ // For embeddings, batch size must be at least as large as context size
134
+ // (see llama.cpp/examples/embedding/embedding.cpp)
135
+ if (ctx_params.n_batch < ctx_params.n_ctx) {
136
+ ctx_params.n_batch = ctx_params.n_ctx;
137
+ }
124
138
  }
125
139
 
126
140
  ctx_ = llama_init_from_model(model_, ctx_params);
127
141
  return ctx_ != nullptr;
128
142
  }
129
143
 
130
- void LlamaModel::normalize_embedding(float* embedding, int n_embd) {
144
+ void LlamaModel::normalize_embedding(float * embedding, int n_embd) {
131
145
  float sum = 0.0f;
132
146
  for (int i = 0; i < n_embd; i++) {
133
147
  sum += embedding[i] * embedding[i];
@@ -140,7 +154,58 @@ void LlamaModel::normalize_embedding(float* embedding, int n_embd) {
140
154
  }
141
155
  }
142
156
 
143
- EmbeddingResult LlamaModel::embed(const std::vector<std::string>& texts) {
157
+ std::vector<float> LlamaModel::embed_chunk(const std::vector<int32_t> & tokens,
158
+ int seq_id,
159
+ int n_embd,
160
+ int pooling_type) {
161
+ std::vector<float> embedding(n_embd, 0.0f);
162
+
163
+ if (tokens.empty()) {
164
+ return embedding;
165
+ }
166
+
167
+ // Clear the memory/KV cache
168
+ llama_memory_t mem = llama_get_memory(ctx_);
169
+ if (mem) {
170
+ llama_memory_clear(mem, true);
171
+ }
172
+
173
+ // Create batch with sequence ID
174
+ llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
175
+ for (size_t i = 0; i < tokens.size(); i++) {
176
+ batch.token[i] = tokens[i];
177
+ batch.pos[i] = i;
178
+ batch.n_seq_id[i] = 1;
179
+ batch.seq_id[i][0] = seq_id;
180
+ batch.logits[i] = true; // We want embeddings for all tokens
181
+ }
182
+ batch.n_tokens = tokens.size();
183
+
184
+ // Decode to get embeddings
185
+ if (llama_decode(ctx_, batch) != 0) {
186
+ llama_batch_free(batch);
187
+ return embedding; // Return zero embedding on failure
188
+ }
189
+
190
+ // Extract embedding based on pooling type
191
+ const float * embd = nullptr;
192
+ if (static_cast<enum llama_pooling_type>(pooling_type) == LLAMA_POOLING_TYPE_NONE) {
193
+ // Get embedding for last token
194
+ embd = llama_get_embeddings_ith(ctx_, tokens.size() - 1);
195
+ } else {
196
+ // Get pooled embedding for the sequence
197
+ embd = llama_get_embeddings_seq(ctx_, seq_id);
198
+ }
199
+
200
+ if (embd) {
201
+ std::copy(embd, embd + n_embd, embedding.begin());
202
+ }
203
+
204
+ llama_batch_free(batch);
205
+ return embedding;
206
+ }
207
+
208
+ EmbeddingResult LlamaModel::embed(const std::vector<std::string> & texts) {
144
209
  EmbeddingResult result;
145
210
  result.total_tokens = 0;
146
211
 
@@ -148,12 +213,17 @@ EmbeddingResult LlamaModel::embed(const std::vector<std::string>& texts) {
148
213
  return result;
149
214
  }
150
215
 
151
- const int n_embd = llama_model_n_embd(model_);
216
+ const int n_embd = llama_model_n_embd(model_);
152
217
  const enum llama_pooling_type pooling_type = llama_pooling_type(ctx_);
153
218
 
219
+ // Get context size for chunking
220
+ const int n_ctx = llama_n_ctx(ctx_);
221
+ const int overlap = n_ctx / 10; // 10% overlap between chunks
222
+ const int step = n_ctx - overlap;
223
+
154
224
  // Process each text
155
225
  for (size_t seq_id = 0; seq_id < texts.size(); seq_id++) {
156
- const std::string& text = texts[seq_id];
226
+ const std::string & text = texts[seq_id];
157
227
 
158
228
  // Tokenize the text
159
229
  std::vector<int32_t> tokens = tokenize(text, true);
@@ -165,63 +235,64 @@ EmbeddingResult LlamaModel::embed(const std::vector<std::string>& texts) {
165
235
  continue;
166
236
  }
167
237
 
168
- // Clear the memory/KV cache
169
- llama_memory_t mem = llama_get_memory(ctx_);
170
- if (mem) {
171
- llama_memory_clear(mem, true);
172
- }
238
+ // Check if text fits in context (no chunking needed)
239
+ if (static_cast<int>(tokens.size()) <= n_ctx) {
240
+ // Process single chunk
241
+ std::vector<float> embedding = embed_chunk(tokens, seq_id, n_embd, pooling_type);
242
+ normalize_embedding(embedding.data(), n_embd);
243
+ result.embeddings.push_back(std::move(embedding));
244
+ } else {
245
+ // Text exceeds context size - split into overlapping chunks
246
+ std::vector<std::vector<float>> chunk_embeddings;
173
247
 
174
- // Create batch with sequence ID
175
- llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
176
- for (size_t i = 0; i < tokens.size(); i++) {
177
- batch.token[i] = tokens[i];
178
- batch.pos[i] = i;
179
- batch.n_seq_id[i] = 1;
180
- batch.seq_id[i][0] = seq_id;
181
- batch.logits[i] = true; // We want embeddings for all tokens
182
- }
183
- batch.n_tokens = tokens.size();
248
+ for (size_t start = 0; start < tokens.size(); start += step) {
249
+ // Calculate chunk end position
250
+ size_t end = std::min(start + n_ctx, tokens.size());
184
251
 
185
- // Decode to get embeddings
186
- if (llama_decode(ctx_, batch) != 0) {
187
- llama_batch_free(batch);
188
- result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
189
- continue;
190
- }
252
+ // Extract chunk tokens
253
+ std::vector<int32_t> chunk_tokens(tokens.begin() + start, tokens.begin() + end);
191
254
 
192
- // Extract embedding based on pooling type
193
- const float* embd = nullptr;
194
- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
195
- // Get embedding for last token
196
- embd = llama_get_embeddings_ith(ctx_, tokens.size() - 1);
197
- } else {
198
- // Get pooled embedding for the sequence
199
- embd = llama_get_embeddings_seq(ctx_, seq_id);
200
- }
255
+ // Get embedding for this chunk
256
+ std::vector<float> chunk_emb = embed_chunk(chunk_tokens, seq_id, n_embd, pooling_type);
257
+ chunk_embeddings.push_back(std::move(chunk_emb));
201
258
 
202
- if (embd) {
203
- std::vector<float> embedding(n_embd);
204
- std::copy(embd, embd + n_embd, embedding.begin());
205
- // Normalize the embedding (L2 normalization)
206
- normalize_embedding(embedding.data(), n_embd);
207
- result.embeddings.push_back(std::move(embedding));
208
- } else {
209
- result.embeddings.push_back(std::vector<float>(n_embd, 0.0f));
210
- }
259
+ // If this chunk reached the end, we're done
260
+ if (end == tokens.size()) {
261
+ break;
262
+ }
263
+ }
211
264
 
212
- llama_batch_free(batch);
265
+ // Mean-pool all chunk embeddings
266
+ std::vector<float> final_embedding(n_embd, 0.0f);
267
+ if (!chunk_embeddings.empty()) {
268
+ for (const auto & chunk_emb : chunk_embeddings) {
269
+ for (int i = 0; i < n_embd; i++) {
270
+ final_embedding[i] += chunk_emb[i];
271
+ }
272
+ }
273
+ // Divide by number of chunks to get mean
274
+ float num_chunks = static_cast<float>(chunk_embeddings.size());
275
+ for (int i = 0; i < n_embd; i++) {
276
+ final_embedding[i] /= num_chunks;
277
+ }
278
+ }
279
+
280
+ // Normalize the final averaged embedding
281
+ normalize_embedding(final_embedding.data(), n_embd);
282
+ result.embeddings.push_back(std::move(final_embedding));
283
+ }
213
284
  }
214
285
 
215
286
  return result;
216
287
  }
217
288
 
218
- std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage>& messages) {
289
+ std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage> & messages) {
219
290
  if (!model_) {
220
291
  return "";
221
292
  }
222
293
 
223
294
  // Determine which template to use
224
- const char* tmpl = nullptr;
295
+ const char * tmpl = nullptr;
225
296
  if (chat_template_ == "auto") {
226
297
  // Use the template embedded in the model
227
298
  tmpl = llama_model_chat_template(model_, nullptr);
@@ -233,22 +304,17 @@ std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage>& mess
233
304
  // Convert messages to llama_chat_message format
234
305
  std::vector<llama_chat_message> chat_messages;
235
306
  chat_messages.reserve(messages.size());
236
- for (const auto& msg : messages) {
307
+ for (const auto & msg : messages) {
237
308
  llama_chat_message chat_msg;
238
- chat_msg.role = msg.role.c_str();
309
+ chat_msg.role = msg.role.c_str();
239
310
  chat_msg.content = msg.content.c_str();
240
311
  chat_messages.push_back(chat_msg);
241
312
  }
242
313
 
243
314
  // First call to get required buffer size
244
- int32_t result_size = llama_chat_apply_template(
245
- tmpl,
246
- chat_messages.data(),
247
- chat_messages.size(),
248
- true, // add_ass: add assistant prompt
249
- nullptr,
250
- 0
251
- );
315
+ int32_t result_size = llama_chat_apply_template(tmpl, chat_messages.data(), chat_messages.size(),
316
+ true, // add_ass: add assistant prompt
317
+ nullptr, 0);
252
318
 
253
319
  if (result_size < 0) {
254
320
  // Template not supported, return empty string
@@ -257,19 +323,12 @@ std::string LlamaModel::apply_chat_template(const std::vector<ChatMessage>& mess
257
323
 
258
324
  // Allocate buffer and apply template
259
325
  std::vector<char> buffer(result_size + 1);
260
- llama_chat_apply_template(
261
- tmpl,
262
- chat_messages.data(),
263
- chat_messages.size(),
264
- true,
265
- buffer.data(),
266
- buffer.size()
267
- );
326
+ llama_chat_apply_template(tmpl, chat_messages.data(), chat_messages.size(), true, buffer.data(), buffer.size());
268
327
 
269
328
  return std::string(buffer.data(), result_size);
270
329
  }
271
330
 
272
- void LlamaModel::create_sampler(const GenerationParams& params) {
331
+ void LlamaModel::create_sampler(const GenerationParams & params) {
273
332
  if (sampler_) {
274
333
  llama_sampler_free(sampler_);
275
334
  }
@@ -281,14 +340,15 @@ void LlamaModel::create_sampler(const GenerationParams& params) {
281
340
  llama_sampler_chain_add(sampler_, llama_sampler_init_top_k(params.top_k));
282
341
  llama_sampler_chain_add(sampler_, llama_sampler_init_top_p(params.top_p, 1));
283
342
  llama_sampler_chain_add(sampler_, llama_sampler_init_temp(params.temperature));
284
- llama_sampler_chain_add(sampler_, llama_sampler_init_dist(42)); // Random seed
343
+ llama_sampler_chain_add(sampler_, llama_sampler_init_dist(42)); // Random seed
285
344
  }
286
345
 
287
- std::vector<int32_t> LlamaModel::tokenize(const std::string& text, bool add_bos) {
288
- const llama_vocab* vocab = llama_model_get_vocab(model_);
346
+ std::vector<int32_t> LlamaModel::tokenize(const std::string & text, bool add_bos) {
347
+ const llama_vocab * vocab = llama_model_get_vocab(model_);
289
348
 
290
349
  // First, get the number of tokens needed
291
- // When passing 0 for n_tokens_max, llama_tokenize returns negative of required size
350
+ // When passing 0 for n_tokens_max, llama_tokenize returns negative of
351
+ // required size
292
352
  int n_tokens = llama_tokenize(vocab, text.c_str(), text.length(), nullptr, 0, add_bos, true);
293
353
 
294
354
  if (n_tokens < 0) {
@@ -318,10 +378,10 @@ std::vector<int32_t> LlamaModel::tokenize(const std::string& text, bool add_bos)
318
378
  }
319
379
 
320
380
  std::string LlamaModel::detokenize(int32_t token) {
321
- const llama_vocab* vocab = llama_model_get_vocab(model_);
381
+ const llama_vocab * vocab = llama_model_get_vocab(model_);
322
382
 
323
383
  char buf[256];
324
- int n = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true);
384
+ int n = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true);
325
385
  if (n < 0) {
326
386
  return "";
327
387
  }
@@ -329,11 +389,11 @@ std::string LlamaModel::detokenize(int32_t token) {
329
389
  }
330
390
 
331
391
  bool LlamaModel::is_eos_token(int32_t token) {
332
- const llama_vocab* vocab = llama_model_get_vocab(model_);
392
+ const llama_vocab * vocab = llama_model_get_vocab(model_);
333
393
  return llama_vocab_is_eog(vocab, token);
334
394
  }
335
395
 
336
- GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages, const GenerationParams& params) {
396
+ GenerationResult LlamaModel::generate(const std::vector<ChatMessage> & messages, const GenerationParams & params) {
337
397
  GenerationResult result;
338
398
  result.finish_reason = "error";
339
399
 
@@ -349,8 +409,8 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
349
409
 
350
410
  // Tokenize the prompt
351
411
  std::vector<int32_t> prompt_tokens = tokenize(prompt, true);
352
- result.prompt_tokens = prompt_tokens.size();
353
- result.completion_tokens = 0;
412
+ result.prompt_tokens = prompt_tokens.size();
413
+ result.completion_tokens = 0;
354
414
 
355
415
  // Clear the memory/KV cache
356
416
  llama_memory_t mem = llama_get_memory(ctx_);
@@ -370,7 +430,7 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
370
430
 
371
431
  // Generate tokens
372
432
  std::string generated_text;
373
- int n_cur = prompt_tokens.size();
433
+ int n_cur = prompt_tokens.size();
374
434
 
375
435
  for (int i = 0; i < params.max_tokens; i++) {
376
436
  // Sample the next token
@@ -389,18 +449,20 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
389
449
 
390
450
  // Check for stop sequences
391
451
  bool should_stop = false;
392
- for (const auto& stop_seq : params.stop_sequences) {
452
+ for (const auto & stop_seq : params.stop_sequences) {
393
453
  if (generated_text.length() >= stop_seq.length()) {
394
454
  if (generated_text.substr(generated_text.length() - stop_seq.length()) == stop_seq) {
395
455
  // Remove the stop sequence from output
396
- generated_text = generated_text.substr(0, generated_text.length() - stop_seq.length());
397
- should_stop = true;
456
+ generated_text = generated_text.substr(0, generated_text.length() - stop_seq.length());
457
+ should_stop = true;
398
458
  result.finish_reason = "stop";
399
459
  break;
400
460
  }
401
461
  }
402
462
  }
403
- if (should_stop) break;
463
+ if (should_stop) {
464
+ break;
465
+ }
404
466
 
405
467
  // Prepare for next iteration
406
468
  batch = llama_batch_get_one(&new_token, 1);
@@ -420,11 +482,9 @@ GenerationResult LlamaModel::generate(const std::vector<ChatMessage>& messages,
420
482
  return result;
421
483
  }
422
484
 
423
- GenerationResult LlamaModel::generate_streaming(
424
- const std::vector<ChatMessage>& messages,
425
- const GenerationParams& params,
426
- TokenCallback callback
427
- ) {
485
+ GenerationResult LlamaModel::generate_streaming(const std::vector<ChatMessage> & messages,
486
+ const GenerationParams & params,
487
+ TokenCallback callback) {
428
488
  GenerationResult result;
429
489
  result.finish_reason = "error";
430
490
 
@@ -440,8 +500,8 @@ GenerationResult LlamaModel::generate_streaming(
440
500
 
441
501
  // Tokenize the prompt
442
502
  std::vector<int32_t> prompt_tokens = tokenize(prompt, true);
443
- result.prompt_tokens = prompt_tokens.size();
444
- result.completion_tokens = 0;
503
+ result.prompt_tokens = prompt_tokens.size();
504
+ result.completion_tokens = 0;
445
505
 
446
506
  // Clear the memory/KV cache
447
507
  llama_memory_t mem = llama_get_memory(ctx_);
@@ -461,7 +521,7 @@ GenerationResult LlamaModel::generate_streaming(
461
521
 
462
522
  // Generate tokens
463
523
  std::string generated_text;
464
- int n_cur = prompt_tokens.size();
524
+ int n_cur = prompt_tokens.size();
465
525
 
466
526
  for (int i = 0; i < params.max_tokens; i++) {
467
527
  // Sample the next token
@@ -486,16 +546,18 @@ GenerationResult LlamaModel::generate_streaming(
486
546
 
487
547
  // Check for stop sequences
488
548
  bool should_stop = false;
489
- for (const auto& stop_seq : params.stop_sequences) {
549
+ for (const auto & stop_seq : params.stop_sequences) {
490
550
  if (generated_text.length() >= stop_seq.length()) {
491
551
  if (generated_text.substr(generated_text.length() - stop_seq.length()) == stop_seq) {
492
- should_stop = true;
552
+ should_stop = true;
493
553
  result.finish_reason = "stop";
494
554
  break;
495
555
  }
496
556
  }
497
557
  }
498
- if (should_stop) break;
558
+ if (should_stop) {
559
+ break;
560
+ }
499
561
 
500
562
  // Prepare for next iteration
501
563
  batch = llama_batch_get_one(&new_token, 1);
@@ -515,5 +577,4 @@ GenerationResult LlamaModel::generate_streaming(
515
577
  return result;
516
578
  }
517
579
 
518
- } // namespace llama_wrapper
519
-
580
+ } // namespace llama_wrapper