red-candle 1.0.0.pre.4 → 1.0.0.pre.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 53283961925b76f6fdc39634e01d3d08f58805584488053f7e82fbd484b96b21
4
- data.tar.gz: 91d3dae0c3c1f6686980708fc898132178454712fe9adf2a79d59f75038c21f1
3
+ metadata.gz: 91a4c43a1a12d6d8960f1a1d190c9bfe8ea60db75f687233012d09b8c90b5020
4
+ data.tar.gz: 3f6ce143cd38856365231baebe25a188e7d5824d930ecdae79c1660b3ad6c787
5
5
  SHA512:
6
- metadata.gz: 809ab126883fc6a0b706f46b34922ccbbb5eb18b6b153b703404d17bdd2626a0bafe5cebbdd9168389c44f7506b4f0dabd2b9ed7a6ad6348b7406a899e3e5484
7
- data.tar.gz: 1d908b031b207bf43e65d81c3f8ac34cd1451cdd38d32476a30e83c23b34c6148a2aa7e235b804ad47041b0d781b1ca36585c2c83b0fa7cd1e9ba26d7cbfc377
6
+ metadata.gz: 273a01c438b085509a433602097b5ad4bcdb3420fc19ebe84c4fd37bf43ae5e6e1040701ecaad9e2864c0cda31d6128d4b2df6c395290994c17f135620282be6
7
+ data.tar.gz: 9193d01d8bfc704b982c839f9aebac23217ae5d56b5afae6228daae73b60e219b049fe8fc8d5ad13657ceb110476489c68bbb9aa6397688bf70a976a8d62d41e
data/README.md CHANGED
@@ -45,6 +45,11 @@ results = reranker.rerank("query", ["doc1", "doc2", "doc3"])
45
45
 
46
46
  Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
47
47
 
48
+ ### Supported Models
49
+
50
+ - **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
51
+ - **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
52
+
48
53
  > ### ⚠️ Huggingface login warning
49
54
  >
50
55
  > Many models, including the one below, require you to agree to the terms. You'll need to:
@@ -62,7 +67,9 @@ device = Candle::Device.cpu # CPU (default)
62
67
  device = Candle::Device.metal # Apple GPU (Metal)
63
68
  device = Candle::Device.cuda # NVIDIA GPU (CUDA)
64
69
 
65
- # Load a model
70
+ # Load a Llama model
71
+ llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
72
+ # Or a Mistral model
66
73
  llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
67
74
 
68
75
  # Generate text
@@ -86,7 +93,7 @@ response = llm.chat(messages)
86
93
  ```ruby
87
94
  # CPU works for all models
88
95
  device = Candle::Device.cpu
89
- llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
96
+ llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
90
97
 
91
98
  # Metal
92
99
  device = Candle::Device.metal
data/Rakefile CHANGED
@@ -4,8 +4,6 @@ require "bundler/gem_tasks"
4
4
  require "rake/testtask"
5
5
  require "rake/extensiontask"
6
6
 
7
- ENV['CANDLE_TEST_SKIP_LLM'] = 'true'
8
-
9
7
  task default: :test
10
8
  Rake::TestTask.new do |t|
11
9
  t.deps << :compile
@@ -0,0 +1,402 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::llama::{Config, LlamaConfig, Llama as LlamaModel, Cache};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Llama {
11
+ model: LlamaModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ cache: Cache,
17
+ config: Config,
18
+ }
19
+
20
+ impl Llama {
21
+ /// Clear the KV cache between generations
22
+ pub fn clear_kv_cache(&mut self) {
23
+ // Since Cache doesn't expose a reset method and kvs is private,
24
+ // we'll recreate the cache to clear it
25
+ // This is a workaround until candle provides a proper reset method
26
+ if let Ok(new_cache) = Cache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device) {
27
+ self.cache = new_cache;
28
+ }
29
+ }
30
+
31
+ /// Load a Llama model from HuggingFace Hub
32
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ let api = Api::new()
34
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
35
+
36
+ let repo = api.repo(Repo::model(model_id.to_string()));
37
+
38
+ // Download model files
39
+ let config_filename = repo
40
+ .get("config.json")
41
+ .await
42
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
43
+
44
+ let tokenizer_filename = repo
45
+ .get("tokenizer.json")
46
+ .await
47
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
48
+
49
+ // Try different file patterns for model weights
50
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
51
+ vec![single_file]
52
+ } else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
53
+ vec![consolidated_file]
54
+ } else {
55
+ // Try to find sharded model files
56
+ let mut sharded_files = Vec::new();
57
+ let mut index = 1;
58
+ loop {
59
+ // Try common shard counts for Llama models
60
+ let mut found = false;
61
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
62
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
63
+ if let Ok(file) = repo.get(&filename).await {
64
+ sharded_files.push(file);
65
+ found = true;
66
+ break;
67
+ }
68
+ }
69
+ if !found {
70
+ break;
71
+ }
72
+ index += 1;
73
+ }
74
+
75
+ if sharded_files.is_empty() {
76
+ return Err(candle_core::Error::Msg(
77
+ "Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
78
+ ));
79
+ }
80
+ sharded_files
81
+ };
82
+
83
+ // Load config
84
+ let llama_config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config_filename)?)
85
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
86
+ let config = llama_config.into_config(false); // Don't use flash attention for now
87
+
88
+ // Load tokenizer
89
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
90
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
91
+
92
+ // Determine EOS token ID based on model type
93
+ let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
94
+ // Llama 3 uses different special tokens
95
+ {
96
+ let vocab = tokenizer.get_vocab(true);
97
+ vocab.get("<|eot_id|>")
98
+ .or_else(|| vocab.get("<|end_of_text|>"))
99
+ .copied()
100
+ .unwrap_or(128009) // Default Llama 3 EOS
101
+ }
102
+ } else {
103
+ // Llama 2 and earlier
104
+ tokenizer
105
+ .get_vocab(true)
106
+ .get("</s>")
107
+ .copied()
108
+ .unwrap_or(2)
109
+ };
110
+
111
+ // Load model weights
112
+ let vb = unsafe {
113
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
114
+ };
115
+
116
+ let model = LlamaModel::load(vb, &config)?;
117
+ let cache = Cache::new(true, DType::F32, &config, &device)?;
118
+
119
+ Ok(Self {
120
+ model,
121
+ tokenizer: TokenizerWrapper::new(tokenizer),
122
+ device,
123
+ model_id: model_id.to_string(),
124
+ eos_token_id,
125
+ cache,
126
+ config,
127
+ })
128
+ }
129
+
130
+ /// Create from existing components (useful for testing)
131
+ pub fn new(
132
+ model: LlamaModel,
133
+ tokenizer: Tokenizer,
134
+ device: Device,
135
+ model_id: String,
136
+ config: &Config,
137
+ ) -> CandleResult<Self> {
138
+ let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
139
+ {
140
+ let vocab = tokenizer.get_vocab(true);
141
+ vocab.get("<|eot_id|>")
142
+ .or_else(|| vocab.get("<|end_of_text|>"))
143
+ .copied()
144
+ .unwrap_or(128009)
145
+ }
146
+ } else {
147
+ tokenizer
148
+ .get_vocab(true)
149
+ .get("</s>")
150
+ .copied()
151
+ .unwrap_or(2)
152
+ };
153
+
154
+ let cache = Cache::new(true, DType::F32, config, &device)?;
155
+
156
+ Ok(Self {
157
+ model,
158
+ tokenizer: TokenizerWrapper::new(tokenizer),
159
+ device,
160
+ model_id,
161
+ eos_token_id,
162
+ cache,
163
+ config: config.clone(),
164
+ })
165
+ }
166
+
167
+ fn generate_tokens(
168
+ &mut self,
169
+ prompt_tokens: Vec<u32>,
170
+ config: &GenerationConfig,
171
+ mut callback: Option<impl FnMut(&str)>,
172
+ ) -> CandleResult<Vec<u32>> {
173
+ let mut text_gen = TextGeneration::from_config(config);
174
+ text_gen.set_eos_token_id(self.eos_token_id);
175
+ text_gen.set_tokens(prompt_tokens.clone());
176
+
177
+ let mut all_tokens = prompt_tokens.clone();
178
+ let start_gen = all_tokens.len();
179
+
180
+ for index in 0..config.max_length {
181
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
182
+ let start_pos = all_tokens.len().saturating_sub(context_size);
183
+ let ctxt = &all_tokens[start_pos..];
184
+
185
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
186
+ let input = input.contiguous()?;
187
+ let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
188
+
189
+ let logits = logits.squeeze(0)?;
190
+ let logits = if logits.dims().len() == 2 {
191
+ let seq_len = logits.dim(0)?;
192
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
193
+ } else {
194
+ logits
195
+ };
196
+
197
+ let logits = logits.to_dtype(DType::F32)?;
198
+
199
+ let next_token = text_gen.sample_next_token(
200
+ &logits,
201
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
202
+ )?;
203
+
204
+ all_tokens.push(next_token);
205
+
206
+ // Stream callback
207
+ if let Some(ref mut cb) = callback {
208
+ let token_text = self.tokenizer.token_to_piece(next_token)?;
209
+ cb(&token_text);
210
+ }
211
+
212
+ // Check stop conditions
213
+ if text_gen.should_stop(next_token, config.max_length) {
214
+ break;
215
+ }
216
+
217
+ // Check stop sequences
218
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
219
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
220
+ break;
221
+ }
222
+ }
223
+
224
+ Ok(if config.include_prompt {
225
+ all_tokens
226
+ } else {
227
+ all_tokens[start_gen..].to_vec()
228
+ })
229
+ }
230
+
231
+ fn generate_tokens_decoded(
232
+ &mut self,
233
+ prompt_tokens: Vec<u32>,
234
+ config: &GenerationConfig,
235
+ mut callback: Option<impl FnMut(&str)>,
236
+ ) -> CandleResult<Vec<u32>> {
237
+ let mut text_gen = TextGeneration::from_config(config);
238
+ text_gen.set_eos_token_id(self.eos_token_id);
239
+ text_gen.set_tokens(prompt_tokens.clone());
240
+
241
+ let mut all_tokens = prompt_tokens.clone();
242
+ let start_gen = all_tokens.len();
243
+ let mut previously_decoded = String::new();
244
+
245
+ for index in 0..config.max_length {
246
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
247
+ let start_pos = all_tokens.len().saturating_sub(context_size);
248
+ let ctxt = &all_tokens[start_pos..];
249
+
250
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
251
+ let input = input.contiguous()?;
252
+ let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
253
+
254
+ let logits = logits.squeeze(0)?;
255
+ let logits = if logits.dims().len() == 2 {
256
+ let seq_len = logits.dim(0)?;
257
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
258
+ } else {
259
+ logits
260
+ };
261
+
262
+ let logits = logits.to_dtype(DType::F32)?;
263
+
264
+ let next_token = text_gen.sample_next_token(
265
+ &logits,
266
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
267
+ )?;
268
+
269
+ all_tokens.push(next_token);
270
+
271
+ // Stream callback with incremental decoding
272
+ if let Some(ref mut cb) = callback {
273
+ let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
274
+
275
+ if current_decoded.len() > previously_decoded.len() {
276
+ let new_text = &current_decoded[previously_decoded.len()..];
277
+ cb(new_text);
278
+ previously_decoded = current_decoded;
279
+ }
280
+ }
281
+
282
+ // Check stop conditions
283
+ if text_gen.should_stop(next_token, config.max_length) {
284
+ break;
285
+ }
286
+
287
+ // Check stop sequences
288
+ let generated_text = if callback.is_some() {
289
+ previously_decoded.clone()
290
+ } else {
291
+ self.tokenizer.decode(&all_tokens[start_gen..], true)?
292
+ };
293
+
294
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
295
+ break;
296
+ }
297
+ }
298
+
299
+ Ok(if config.include_prompt {
300
+ all_tokens
301
+ } else {
302
+ all_tokens[start_gen..].to_vec()
303
+ })
304
+ }
305
+
306
+ /// Apply chat template based on Llama version
307
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
308
+ let is_llama3 = self.model_id.contains("Llama-3") || self.model_id.contains("llama-3");
309
+
310
+ if is_llama3 {
311
+ self.apply_llama3_template(messages)
312
+ } else {
313
+ self.apply_llama2_template(messages)
314
+ }
315
+ }
316
+
317
+ fn apply_llama2_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
318
+ let mut prompt = String::new();
319
+ let mut system_message = String::new();
320
+
321
+ for (i, message) in messages.iter().enumerate() {
322
+ let role = message["role"].as_str().unwrap_or("");
323
+ let content = message["content"].as_str().unwrap_or("");
324
+
325
+ match role {
326
+ "system" => {
327
+ system_message = content.to_string();
328
+ }
329
+ "user" => {
330
+ if i == 1 || (i == 0 && system_message.is_empty()) {
331
+ // First user message
332
+ if !system_message.is_empty() {
333
+ prompt.push_str(&format!("<s>[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
334
+ } else {
335
+ prompt.push_str(&format!("<s>[INST] {} [/INST]", content));
336
+ }
337
+ } else {
338
+ prompt.push_str(&format!(" [INST] {} [/INST]", content));
339
+ }
340
+ }
341
+ "assistant" => {
342
+ prompt.push_str(&format!(" {} </s>", content));
343
+ }
344
+ _ => {}
345
+ }
346
+ }
347
+
348
+ Ok(prompt)
349
+ }
350
+
351
+ fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
352
+ let mut prompt = String::new();
353
+
354
+ prompt.push_str("<|begin_of_text|>");
355
+
356
+ for message in messages {
357
+ let role = message["role"].as_str().unwrap_or("");
358
+ let content = message["content"].as_str().unwrap_or("");
359
+
360
+ prompt.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, content));
361
+ }
362
+
363
+ prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
364
+
365
+ Ok(prompt)
366
+ }
367
+ }
368
+
369
+ impl TextGenerator for Llama {
370
+ fn generate(
371
+ &mut self,
372
+ prompt: &str,
373
+ config: &GenerationConfig,
374
+ ) -> CandleResult<String> {
375
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
376
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
377
+ self.tokenizer.decode(&output_tokens, true)
378
+ }
379
+
380
+ fn generate_stream(
381
+ &mut self,
382
+ prompt: &str,
383
+ config: &GenerationConfig,
384
+ mut callback: impl FnMut(&str),
385
+ ) -> CandleResult<String> {
386
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
387
+ let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
388
+ self.tokenizer.decode(&output_tokens, true)
389
+ }
390
+
391
+ fn model_name(&self) -> &str {
392
+ &self.model_id
393
+ }
394
+
395
+ fn device(&self) -> &Device {
396
+ &self.device
397
+ }
398
+
399
+ fn clear_cache(&mut self) {
400
+ self.clear_kv_cache();
401
+ }
402
+ }
@@ -2,6 +2,7 @@ use candle_core::{Device, Result as CandleResult};
2
2
  use tokenizers::Tokenizer;
3
3
 
4
4
  pub mod mistral;
5
+ pub mod llama;
5
6
  pub mod generation_config;
6
7
  pub mod text_generation;
7
8
 
@@ -1,19 +1,21 @@
1
1
  use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
2
  use std::cell::RefCell;
3
3
 
4
- use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral};
4
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama};
5
5
  use crate::ruby::{Result as RbResult, Device as RbDevice};
6
6
 
7
7
  // Use an enum to handle different model types instead of trait objects
8
8
  #[derive(Debug)]
9
9
  enum ModelType {
10
10
  Mistral(RustMistral),
11
+ Llama(RustLlama),
11
12
  }
12
13
 
13
14
  impl ModelType {
14
15
  fn generate(&mut self, prompt: &str, config: &RustGenerationConfig) -> candle_core::Result<String> {
15
16
  match self {
16
17
  ModelType::Mistral(m) => m.generate(prompt, config),
18
+ ModelType::Llama(m) => m.generate(prompt, config),
17
19
  }
18
20
  }
19
21
 
@@ -25,6 +27,7 @@ impl ModelType {
25
27
  ) -> candle_core::Result<String> {
26
28
  match self {
27
29
  ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
30
+ ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
28
31
  }
29
32
  }
30
33
 
@@ -32,12 +35,37 @@ impl ModelType {
32
35
  fn model_name(&self) -> &str {
33
36
  match self {
34
37
  ModelType::Mistral(m) => m.model_name(),
38
+ ModelType::Llama(m) => m.model_name(),
35
39
  }
36
40
  }
37
41
 
38
42
  fn clear_cache(&mut self) {
39
43
  match self {
40
44
  ModelType::Mistral(m) => m.clear_cache(),
45
+ ModelType::Llama(m) => m.clear_cache(),
46
+ }
47
+ }
48
+
49
+ fn apply_chat_template(&self, messages: &[serde_json::Value]) -> candle_core::Result<String> {
50
+ match self {
51
+ ModelType::Mistral(_) => {
52
+ // For now, use a simple template for Mistral
53
+ // In the future, we could implement proper Mistral chat templating
54
+ let mut prompt = String::new();
55
+ for message in messages {
56
+ let role = message["role"].as_str().unwrap_or("");
57
+ let content = message["content"].as_str().unwrap_or("");
58
+ match role {
59
+ "system" => prompt.push_str(&format!("System: {}\n\n", content)),
60
+ "user" => prompt.push_str(&format!("User: {}\n\n", content)),
61
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
62
+ _ => {}
63
+ }
64
+ }
65
+ prompt.push_str("Assistant: ");
66
+ Ok(prompt)
67
+ },
68
+ ModelType::Llama(m) => m.apply_chat_template(messages),
41
69
  }
42
70
  }
43
71
  }
@@ -180,10 +208,16 @@ impl LLM {
180
208
  })
181
209
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
182
210
  ModelType::Mistral(mistral)
211
+ } else if model_lower.contains("llama") || model_lower.contains("meta-llama") {
212
+ let llama = rt.block_on(async {
213
+ RustLlama::from_pretrained(&model_id, candle_device).await
214
+ })
215
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
216
+ ModelType::Llama(llama)
183
217
  } else {
184
218
  return Err(Error::new(
185
219
  magnus::exception::runtime_error(),
186
- format!("Unsupported model type: {}. Currently only Mistral models are supported.", model_id),
220
+ format!("Unsupported model type: {}. Currently only Mistral and Llama models are supported.", model_id),
187
221
  ));
188
222
  };
189
223
 
@@ -248,6 +282,41 @@ impl LLM {
248
282
  model_ref.clear_cache();
249
283
  Ok(())
250
284
  }
285
+
286
+ /// Apply chat template to messages
287
+ pub fn apply_chat_template(&self, messages: RArray) -> RbResult<String> {
288
+ // Convert Ruby array to JSON values
289
+ let json_messages: Vec<serde_json::Value> = messages
290
+ .into_iter()
291
+ .filter_map(|msg| {
292
+ if let Ok(hash) = <RHash as TryConvert>::try_convert(msg) {
293
+ let mut json_msg = serde_json::Map::new();
294
+
295
+ if let Some(role) = hash.get(magnus::Symbol::new("role")) {
296
+ if let Ok(role_str) = <String as TryConvert>::try_convert(role) {
297
+ json_msg.insert("role".to_string(), serde_json::Value::String(role_str));
298
+ }
299
+ }
300
+
301
+ if let Some(content) = hash.get(magnus::Symbol::new("content")) {
302
+ if let Ok(content_str) = <String as TryConvert>::try_convert(content) {
303
+ json_msg.insert("content".to_string(), serde_json::Value::String(content_str));
304
+ }
305
+ }
306
+
307
+ Some(serde_json::Value::Object(json_msg))
308
+ } else {
309
+ None
310
+ }
311
+ })
312
+ .collect();
313
+
314
+ let model = self.model.lock().unwrap();
315
+ let model_ref = model.borrow();
316
+
317
+ model_ref.apply_chat_template(&json_messages)
318
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to apply chat template: {}", e)))
319
+ }
251
320
  }
252
321
 
253
322
  // Define a standalone function for from_pretrained that handles variable arguments
@@ -290,6 +359,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
290
359
  rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
291
360
  rb_llm.define_method("device", method!(LLM::device, 0))?;
292
361
  rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
362
+ rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
293
363
 
294
364
  Ok(())
295
365
  }
data/lib/candle/llm.rb CHANGED
@@ -2,13 +2,13 @@ module Candle
2
2
  class LLM
3
3
  # Simple chat interface for instruction models
4
4
  def chat(messages, **options)
5
- prompt = format_messages(messages)
5
+ prompt = apply_chat_template(messages)
6
6
  generate(prompt, **options)
7
7
  end
8
8
 
9
9
  # Streaming chat interface
10
10
  def chat_stream(messages, **options, &block)
11
- prompt = format_messages(messages)
11
+ prompt = apply_chat_template(messages)
12
12
  generate_stream(prompt, **options, &block)
13
13
  end
14
14
 
@@ -34,8 +34,8 @@ module Candle
34
34
 
35
35
  private
36
36
 
37
- # Format messages into a prompt string
38
- # This is a simple implementation - model-specific formatting should be added
37
+ # Legacy format messages method - kept for backward compatibility
38
+ # Use apply_chat_template for proper model-specific formatting
39
39
  def format_messages(messages)
40
40
  formatted = messages.map do |msg|
41
41
  case msg[:role]
@@ -1,3 +1,3 @@
1
1
  module Candle
2
- VERSION = "1.0.0.pre.4"
2
+ VERSION = "1.0.0.pre.5"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.0.pre.4
4
+ version: 1.0.0.pre.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -48,6 +48,7 @@ files:
48
48
  - ext/candle/rustfmt.toml
49
49
  - ext/candle/src/lib.rs
50
50
  - ext/candle/src/llm/generation_config.rs
51
+ - ext/candle/src/llm/llama.rs
51
52
  - ext/candle/src/llm/mistral.rs
52
53
  - ext/candle/src/llm/mod.rs
53
54
  - ext/candle/src/llm/text_generation.rs