red-candle 1.0.0.pre.4 → 1.0.0.pre.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +9 -2
- data/Rakefile +0 -2
- data/ext/candle/src/llm/llama.rs +402 -0
- data/ext/candle/src/llm/mod.rs +1 -0
- data/ext/candle/src/ruby/llm.rs +72 -2
- data/lib/candle/llm.rb +4 -4
- data/lib/candle/version.rb +1 -1
- metadata +2 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 91a4c43a1a12d6d8960f1a1d190c9bfe8ea60db75f687233012d09b8c90b5020
|
4
|
+
data.tar.gz: 3f6ce143cd38856365231baebe25a188e7d5824d930ecdae79c1660b3ad6c787
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 273a01c438b085509a433602097b5ad4bcdb3420fc19ebe84c4fd37bf43ae5e6e1040701ecaad9e2864c0cda31d6128d4b2df6c395290994c17f135620282be6
|
7
|
+
data.tar.gz: 9193d01d8bfc704b982c839f9aebac23217ae5d56b5afae6228daae73b60e219b049fe8fc8d5ad13657ceb110476489c68bbb9aa6397688bf70a976a8d62d41e
|
data/README.md
CHANGED
@@ -45,6 +45,11 @@ results = reranker.rerank("query", ["doc1", "doc2", "doc3"])
|
|
45
45
|
|
46
46
|
Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
|
47
47
|
|
48
|
+
### Supported Models
|
49
|
+
|
50
|
+
- **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
|
51
|
+
- **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
|
52
|
+
|
48
53
|
> ### ⚠️ Huggingface login warning
|
49
54
|
>
|
50
55
|
> Many models, including the one below, require you to agree to the terms. You'll need to:
|
@@ -62,7 +67,9 @@ device = Candle::Device.cpu # CPU (default)
|
|
62
67
|
device = Candle::Device.metal # Apple GPU (Metal)
|
63
68
|
device = Candle::Device.cuda # NVIDIA GPU (CUDA)
|
64
69
|
|
65
|
-
# Load a model
|
70
|
+
# Load a Llama model
|
71
|
+
llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
|
72
|
+
# Or a Mistral model
|
66
73
|
llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
|
67
74
|
|
68
75
|
# Generate text
|
@@ -86,7 +93,7 @@ response = llm.chat(messages)
|
|
86
93
|
```ruby
|
87
94
|
# CPU works for all models
|
88
95
|
device = Candle::Device.cpu
|
89
|
-
llm = Candle::LLM.from_pretrained("
|
96
|
+
llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
|
90
97
|
|
91
98
|
# Metal
|
92
99
|
device = Candle::Device.metal
|
data/Rakefile
CHANGED
@@ -0,0 +1,402 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_nn::VarBuilder;
|
3
|
+
use candle_transformers::models::llama::{Config, LlamaConfig, Llama as LlamaModel, Cache};
|
4
|
+
use hf_hub::{api::tokio::Api, Repo};
|
5
|
+
use tokenizers::Tokenizer;
|
6
|
+
|
7
|
+
use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
8
|
+
|
9
|
+
#[derive(Debug)]
|
10
|
+
pub struct Llama {
|
11
|
+
model: LlamaModel,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
cache: Cache,
|
17
|
+
config: Config,
|
18
|
+
}
|
19
|
+
|
20
|
+
impl Llama {
|
21
|
+
/// Clear the KV cache between generations
|
22
|
+
pub fn clear_kv_cache(&mut self) {
|
23
|
+
// Since Cache doesn't expose a reset method and kvs is private,
|
24
|
+
// we'll recreate the cache to clear it
|
25
|
+
// This is a workaround until candle provides a proper reset method
|
26
|
+
if let Ok(new_cache) = Cache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device) {
|
27
|
+
self.cache = new_cache;
|
28
|
+
}
|
29
|
+
}
|
30
|
+
|
31
|
+
/// Load a Llama model from HuggingFace Hub
|
32
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
33
|
+
let api = Api::new()
|
34
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
35
|
+
|
36
|
+
let repo = api.repo(Repo::model(model_id.to_string()));
|
37
|
+
|
38
|
+
// Download model files
|
39
|
+
let config_filename = repo
|
40
|
+
.get("config.json")
|
41
|
+
.await
|
42
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
43
|
+
|
44
|
+
let tokenizer_filename = repo
|
45
|
+
.get("tokenizer.json")
|
46
|
+
.await
|
47
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
48
|
+
|
49
|
+
// Try different file patterns for model weights
|
50
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
51
|
+
vec![single_file]
|
52
|
+
} else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
|
53
|
+
vec![consolidated_file]
|
54
|
+
} else {
|
55
|
+
// Try to find sharded model files
|
56
|
+
let mut sharded_files = Vec::new();
|
57
|
+
let mut index = 1;
|
58
|
+
loop {
|
59
|
+
// Try common shard counts for Llama models
|
60
|
+
let mut found = false;
|
61
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
|
62
|
+
let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
|
63
|
+
if let Ok(file) = repo.get(&filename).await {
|
64
|
+
sharded_files.push(file);
|
65
|
+
found = true;
|
66
|
+
break;
|
67
|
+
}
|
68
|
+
}
|
69
|
+
if !found {
|
70
|
+
break;
|
71
|
+
}
|
72
|
+
index += 1;
|
73
|
+
}
|
74
|
+
|
75
|
+
if sharded_files.is_empty() {
|
76
|
+
return Err(candle_core::Error::Msg(
|
77
|
+
"Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
|
78
|
+
));
|
79
|
+
}
|
80
|
+
sharded_files
|
81
|
+
};
|
82
|
+
|
83
|
+
// Load config
|
84
|
+
let llama_config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
85
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
86
|
+
let config = llama_config.into_config(false); // Don't use flash attention for now
|
87
|
+
|
88
|
+
// Load tokenizer
|
89
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
90
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
91
|
+
|
92
|
+
// Determine EOS token ID based on model type
|
93
|
+
let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
|
94
|
+
// Llama 3 uses different special tokens
|
95
|
+
{
|
96
|
+
let vocab = tokenizer.get_vocab(true);
|
97
|
+
vocab.get("<|eot_id|>")
|
98
|
+
.or_else(|| vocab.get("<|end_of_text|>"))
|
99
|
+
.copied()
|
100
|
+
.unwrap_or(128009) // Default Llama 3 EOS
|
101
|
+
}
|
102
|
+
} else {
|
103
|
+
// Llama 2 and earlier
|
104
|
+
tokenizer
|
105
|
+
.get_vocab(true)
|
106
|
+
.get("</s>")
|
107
|
+
.copied()
|
108
|
+
.unwrap_or(2)
|
109
|
+
};
|
110
|
+
|
111
|
+
// Load model weights
|
112
|
+
let vb = unsafe {
|
113
|
+
VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
114
|
+
};
|
115
|
+
|
116
|
+
let model = LlamaModel::load(vb, &config)?;
|
117
|
+
let cache = Cache::new(true, DType::F32, &config, &device)?;
|
118
|
+
|
119
|
+
Ok(Self {
|
120
|
+
model,
|
121
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
122
|
+
device,
|
123
|
+
model_id: model_id.to_string(),
|
124
|
+
eos_token_id,
|
125
|
+
cache,
|
126
|
+
config,
|
127
|
+
})
|
128
|
+
}
|
129
|
+
|
130
|
+
/// Create from existing components (useful for testing)
|
131
|
+
pub fn new(
|
132
|
+
model: LlamaModel,
|
133
|
+
tokenizer: Tokenizer,
|
134
|
+
device: Device,
|
135
|
+
model_id: String,
|
136
|
+
config: &Config,
|
137
|
+
) -> CandleResult<Self> {
|
138
|
+
let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
|
139
|
+
{
|
140
|
+
let vocab = tokenizer.get_vocab(true);
|
141
|
+
vocab.get("<|eot_id|>")
|
142
|
+
.or_else(|| vocab.get("<|end_of_text|>"))
|
143
|
+
.copied()
|
144
|
+
.unwrap_or(128009)
|
145
|
+
}
|
146
|
+
} else {
|
147
|
+
tokenizer
|
148
|
+
.get_vocab(true)
|
149
|
+
.get("</s>")
|
150
|
+
.copied()
|
151
|
+
.unwrap_or(2)
|
152
|
+
};
|
153
|
+
|
154
|
+
let cache = Cache::new(true, DType::F32, config, &device)?;
|
155
|
+
|
156
|
+
Ok(Self {
|
157
|
+
model,
|
158
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
159
|
+
device,
|
160
|
+
model_id,
|
161
|
+
eos_token_id,
|
162
|
+
cache,
|
163
|
+
config: config.clone(),
|
164
|
+
})
|
165
|
+
}
|
166
|
+
|
167
|
+
fn generate_tokens(
|
168
|
+
&mut self,
|
169
|
+
prompt_tokens: Vec<u32>,
|
170
|
+
config: &GenerationConfig,
|
171
|
+
mut callback: Option<impl FnMut(&str)>,
|
172
|
+
) -> CandleResult<Vec<u32>> {
|
173
|
+
let mut text_gen = TextGeneration::from_config(config);
|
174
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
175
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
176
|
+
|
177
|
+
let mut all_tokens = prompt_tokens.clone();
|
178
|
+
let start_gen = all_tokens.len();
|
179
|
+
|
180
|
+
for index in 0..config.max_length {
|
181
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
182
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
183
|
+
let ctxt = &all_tokens[start_pos..];
|
184
|
+
|
185
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
186
|
+
let input = input.contiguous()?;
|
187
|
+
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
188
|
+
|
189
|
+
let logits = logits.squeeze(0)?;
|
190
|
+
let logits = if logits.dims().len() == 2 {
|
191
|
+
let seq_len = logits.dim(0)?;
|
192
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
193
|
+
} else {
|
194
|
+
logits
|
195
|
+
};
|
196
|
+
|
197
|
+
let logits = logits.to_dtype(DType::F32)?;
|
198
|
+
|
199
|
+
let next_token = text_gen.sample_next_token(
|
200
|
+
&logits,
|
201
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
202
|
+
)?;
|
203
|
+
|
204
|
+
all_tokens.push(next_token);
|
205
|
+
|
206
|
+
// Stream callback
|
207
|
+
if let Some(ref mut cb) = callback {
|
208
|
+
let token_text = self.tokenizer.token_to_piece(next_token)?;
|
209
|
+
cb(&token_text);
|
210
|
+
}
|
211
|
+
|
212
|
+
// Check stop conditions
|
213
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
214
|
+
break;
|
215
|
+
}
|
216
|
+
|
217
|
+
// Check stop sequences
|
218
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
219
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
220
|
+
break;
|
221
|
+
}
|
222
|
+
}
|
223
|
+
|
224
|
+
Ok(if config.include_prompt {
|
225
|
+
all_tokens
|
226
|
+
} else {
|
227
|
+
all_tokens[start_gen..].to_vec()
|
228
|
+
})
|
229
|
+
}
|
230
|
+
|
231
|
+
fn generate_tokens_decoded(
|
232
|
+
&mut self,
|
233
|
+
prompt_tokens: Vec<u32>,
|
234
|
+
config: &GenerationConfig,
|
235
|
+
mut callback: Option<impl FnMut(&str)>,
|
236
|
+
) -> CandleResult<Vec<u32>> {
|
237
|
+
let mut text_gen = TextGeneration::from_config(config);
|
238
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
239
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
240
|
+
|
241
|
+
let mut all_tokens = prompt_tokens.clone();
|
242
|
+
let start_gen = all_tokens.len();
|
243
|
+
let mut previously_decoded = String::new();
|
244
|
+
|
245
|
+
for index in 0..config.max_length {
|
246
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
247
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
248
|
+
let ctxt = &all_tokens[start_pos..];
|
249
|
+
|
250
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
251
|
+
let input = input.contiguous()?;
|
252
|
+
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
253
|
+
|
254
|
+
let logits = logits.squeeze(0)?;
|
255
|
+
let logits = if logits.dims().len() == 2 {
|
256
|
+
let seq_len = logits.dim(0)?;
|
257
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
258
|
+
} else {
|
259
|
+
logits
|
260
|
+
};
|
261
|
+
|
262
|
+
let logits = logits.to_dtype(DType::F32)?;
|
263
|
+
|
264
|
+
let next_token = text_gen.sample_next_token(
|
265
|
+
&logits,
|
266
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
267
|
+
)?;
|
268
|
+
|
269
|
+
all_tokens.push(next_token);
|
270
|
+
|
271
|
+
// Stream callback with incremental decoding
|
272
|
+
if let Some(ref mut cb) = callback {
|
273
|
+
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
274
|
+
|
275
|
+
if current_decoded.len() > previously_decoded.len() {
|
276
|
+
let new_text = ¤t_decoded[previously_decoded.len()..];
|
277
|
+
cb(new_text);
|
278
|
+
previously_decoded = current_decoded;
|
279
|
+
}
|
280
|
+
}
|
281
|
+
|
282
|
+
// Check stop conditions
|
283
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
284
|
+
break;
|
285
|
+
}
|
286
|
+
|
287
|
+
// Check stop sequences
|
288
|
+
let generated_text = if callback.is_some() {
|
289
|
+
previously_decoded.clone()
|
290
|
+
} else {
|
291
|
+
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
292
|
+
};
|
293
|
+
|
294
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
295
|
+
break;
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
Ok(if config.include_prompt {
|
300
|
+
all_tokens
|
301
|
+
} else {
|
302
|
+
all_tokens[start_gen..].to_vec()
|
303
|
+
})
|
304
|
+
}
|
305
|
+
|
306
|
+
/// Apply chat template based on Llama version
|
307
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
308
|
+
let is_llama3 = self.model_id.contains("Llama-3") || self.model_id.contains("llama-3");
|
309
|
+
|
310
|
+
if is_llama3 {
|
311
|
+
self.apply_llama3_template(messages)
|
312
|
+
} else {
|
313
|
+
self.apply_llama2_template(messages)
|
314
|
+
}
|
315
|
+
}
|
316
|
+
|
317
|
+
fn apply_llama2_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
318
|
+
let mut prompt = String::new();
|
319
|
+
let mut system_message = String::new();
|
320
|
+
|
321
|
+
for (i, message) in messages.iter().enumerate() {
|
322
|
+
let role = message["role"].as_str().unwrap_or("");
|
323
|
+
let content = message["content"].as_str().unwrap_or("");
|
324
|
+
|
325
|
+
match role {
|
326
|
+
"system" => {
|
327
|
+
system_message = content.to_string();
|
328
|
+
}
|
329
|
+
"user" => {
|
330
|
+
if i == 1 || (i == 0 && system_message.is_empty()) {
|
331
|
+
// First user message
|
332
|
+
if !system_message.is_empty() {
|
333
|
+
prompt.push_str(&format!("<s>[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
|
334
|
+
} else {
|
335
|
+
prompt.push_str(&format!("<s>[INST] {} [/INST]", content));
|
336
|
+
}
|
337
|
+
} else {
|
338
|
+
prompt.push_str(&format!(" [INST] {} [/INST]", content));
|
339
|
+
}
|
340
|
+
}
|
341
|
+
"assistant" => {
|
342
|
+
prompt.push_str(&format!(" {} </s>", content));
|
343
|
+
}
|
344
|
+
_ => {}
|
345
|
+
}
|
346
|
+
}
|
347
|
+
|
348
|
+
Ok(prompt)
|
349
|
+
}
|
350
|
+
|
351
|
+
fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
352
|
+
let mut prompt = String::new();
|
353
|
+
|
354
|
+
prompt.push_str("<|begin_of_text|>");
|
355
|
+
|
356
|
+
for message in messages {
|
357
|
+
let role = message["role"].as_str().unwrap_or("");
|
358
|
+
let content = message["content"].as_str().unwrap_or("");
|
359
|
+
|
360
|
+
prompt.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, content));
|
361
|
+
}
|
362
|
+
|
363
|
+
prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
|
364
|
+
|
365
|
+
Ok(prompt)
|
366
|
+
}
|
367
|
+
}
|
368
|
+
|
369
|
+
impl TextGenerator for Llama {
|
370
|
+
fn generate(
|
371
|
+
&mut self,
|
372
|
+
prompt: &str,
|
373
|
+
config: &GenerationConfig,
|
374
|
+
) -> CandleResult<String> {
|
375
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
376
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
377
|
+
self.tokenizer.decode(&output_tokens, true)
|
378
|
+
}
|
379
|
+
|
380
|
+
fn generate_stream(
|
381
|
+
&mut self,
|
382
|
+
prompt: &str,
|
383
|
+
config: &GenerationConfig,
|
384
|
+
mut callback: impl FnMut(&str),
|
385
|
+
) -> CandleResult<String> {
|
386
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
387
|
+
let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
|
388
|
+
self.tokenizer.decode(&output_tokens, true)
|
389
|
+
}
|
390
|
+
|
391
|
+
fn model_name(&self) -> &str {
|
392
|
+
&self.model_id
|
393
|
+
}
|
394
|
+
|
395
|
+
fn device(&self) -> &Device {
|
396
|
+
&self.device
|
397
|
+
}
|
398
|
+
|
399
|
+
fn clear_cache(&mut self) {
|
400
|
+
self.clear_kv_cache();
|
401
|
+
}
|
402
|
+
}
|
data/ext/candle/src/llm/mod.rs
CHANGED
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,19 +1,21 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
3
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral};
|
4
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama};
|
5
5
|
use crate::ruby::{Result as RbResult, Device as RbDevice};
|
6
6
|
|
7
7
|
// Use an enum to handle different model types instead of trait objects
|
8
8
|
#[derive(Debug)]
|
9
9
|
enum ModelType {
|
10
10
|
Mistral(RustMistral),
|
11
|
+
Llama(RustLlama),
|
11
12
|
}
|
12
13
|
|
13
14
|
impl ModelType {
|
14
15
|
fn generate(&mut self, prompt: &str, config: &RustGenerationConfig) -> candle_core::Result<String> {
|
15
16
|
match self {
|
16
17
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
18
|
+
ModelType::Llama(m) => m.generate(prompt, config),
|
17
19
|
}
|
18
20
|
}
|
19
21
|
|
@@ -25,6 +27,7 @@ impl ModelType {
|
|
25
27
|
) -> candle_core::Result<String> {
|
26
28
|
match self {
|
27
29
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
30
|
+
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
28
31
|
}
|
29
32
|
}
|
30
33
|
|
@@ -32,12 +35,37 @@ impl ModelType {
|
|
32
35
|
fn model_name(&self) -> &str {
|
33
36
|
match self {
|
34
37
|
ModelType::Mistral(m) => m.model_name(),
|
38
|
+
ModelType::Llama(m) => m.model_name(),
|
35
39
|
}
|
36
40
|
}
|
37
41
|
|
38
42
|
fn clear_cache(&mut self) {
|
39
43
|
match self {
|
40
44
|
ModelType::Mistral(m) => m.clear_cache(),
|
45
|
+
ModelType::Llama(m) => m.clear_cache(),
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
49
|
+
fn apply_chat_template(&self, messages: &[serde_json::Value]) -> candle_core::Result<String> {
|
50
|
+
match self {
|
51
|
+
ModelType::Mistral(_) => {
|
52
|
+
// For now, use a simple template for Mistral
|
53
|
+
// In the future, we could implement proper Mistral chat templating
|
54
|
+
let mut prompt = String::new();
|
55
|
+
for message in messages {
|
56
|
+
let role = message["role"].as_str().unwrap_or("");
|
57
|
+
let content = message["content"].as_str().unwrap_or("");
|
58
|
+
match role {
|
59
|
+
"system" => prompt.push_str(&format!("System: {}\n\n", content)),
|
60
|
+
"user" => prompt.push_str(&format!("User: {}\n\n", content)),
|
61
|
+
"assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
|
62
|
+
_ => {}
|
63
|
+
}
|
64
|
+
}
|
65
|
+
prompt.push_str("Assistant: ");
|
66
|
+
Ok(prompt)
|
67
|
+
},
|
68
|
+
ModelType::Llama(m) => m.apply_chat_template(messages),
|
41
69
|
}
|
42
70
|
}
|
43
71
|
}
|
@@ -180,10 +208,16 @@ impl LLM {
|
|
180
208
|
})
|
181
209
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
182
210
|
ModelType::Mistral(mistral)
|
211
|
+
} else if model_lower.contains("llama") || model_lower.contains("meta-llama") {
|
212
|
+
let llama = rt.block_on(async {
|
213
|
+
RustLlama::from_pretrained(&model_id, candle_device).await
|
214
|
+
})
|
215
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
216
|
+
ModelType::Llama(llama)
|
183
217
|
} else {
|
184
218
|
return Err(Error::new(
|
185
219
|
magnus::exception::runtime_error(),
|
186
|
-
format!("Unsupported model type: {}. Currently only Mistral models are supported.", model_id),
|
220
|
+
format!("Unsupported model type: {}. Currently only Mistral and Llama models are supported.", model_id),
|
187
221
|
));
|
188
222
|
};
|
189
223
|
|
@@ -248,6 +282,41 @@ impl LLM {
|
|
248
282
|
model_ref.clear_cache();
|
249
283
|
Ok(())
|
250
284
|
}
|
285
|
+
|
286
|
+
/// Apply chat template to messages
|
287
|
+
pub fn apply_chat_template(&self, messages: RArray) -> RbResult<String> {
|
288
|
+
// Convert Ruby array to JSON values
|
289
|
+
let json_messages: Vec<serde_json::Value> = messages
|
290
|
+
.into_iter()
|
291
|
+
.filter_map(|msg| {
|
292
|
+
if let Ok(hash) = <RHash as TryConvert>::try_convert(msg) {
|
293
|
+
let mut json_msg = serde_json::Map::new();
|
294
|
+
|
295
|
+
if let Some(role) = hash.get(magnus::Symbol::new("role")) {
|
296
|
+
if let Ok(role_str) = <String as TryConvert>::try_convert(role) {
|
297
|
+
json_msg.insert("role".to_string(), serde_json::Value::String(role_str));
|
298
|
+
}
|
299
|
+
}
|
300
|
+
|
301
|
+
if let Some(content) = hash.get(magnus::Symbol::new("content")) {
|
302
|
+
if let Ok(content_str) = <String as TryConvert>::try_convert(content) {
|
303
|
+
json_msg.insert("content".to_string(), serde_json::Value::String(content_str));
|
304
|
+
}
|
305
|
+
}
|
306
|
+
|
307
|
+
Some(serde_json::Value::Object(json_msg))
|
308
|
+
} else {
|
309
|
+
None
|
310
|
+
}
|
311
|
+
})
|
312
|
+
.collect();
|
313
|
+
|
314
|
+
let model = self.model.lock().unwrap();
|
315
|
+
let model_ref = model.borrow();
|
316
|
+
|
317
|
+
model_ref.apply_chat_template(&json_messages)
|
318
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to apply chat template: {}", e)))
|
319
|
+
}
|
251
320
|
}
|
252
321
|
|
253
322
|
// Define a standalone function for from_pretrained that handles variable arguments
|
@@ -290,6 +359,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
|
290
359
|
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
291
360
|
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
292
361
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
362
|
+
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
293
363
|
|
294
364
|
Ok(())
|
295
365
|
}
|
data/lib/candle/llm.rb
CHANGED
@@ -2,13 +2,13 @@ module Candle
|
|
2
2
|
class LLM
|
3
3
|
# Simple chat interface for instruction models
|
4
4
|
def chat(messages, **options)
|
5
|
-
prompt =
|
5
|
+
prompt = apply_chat_template(messages)
|
6
6
|
generate(prompt, **options)
|
7
7
|
end
|
8
8
|
|
9
9
|
# Streaming chat interface
|
10
10
|
def chat_stream(messages, **options, &block)
|
11
|
-
prompt =
|
11
|
+
prompt = apply_chat_template(messages)
|
12
12
|
generate_stream(prompt, **options, &block)
|
13
13
|
end
|
14
14
|
|
@@ -34,8 +34,8 @@ module Candle
|
|
34
34
|
|
35
35
|
private
|
36
36
|
|
37
|
-
#
|
38
|
-
#
|
37
|
+
# Legacy format messages method - kept for backward compatibility
|
38
|
+
# Use apply_chat_template for proper model-specific formatting
|
39
39
|
def format_messages(messages)
|
40
40
|
formatted = messages.map do |msg|
|
41
41
|
case msg[:role]
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.0.0.pre.
|
4
|
+
version: 1.0.0.pre.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -48,6 +48,7 @@ files:
|
|
48
48
|
- ext/candle/rustfmt.toml
|
49
49
|
- ext/candle/src/lib.rs
|
50
50
|
- ext/candle/src/llm/generation_config.rs
|
51
|
+
- ext/candle/src/llm/llama.rs
|
51
52
|
- ext/candle/src/llm/mistral.rs
|
52
53
|
- ext/candle/src/llm/mod.rs
|
53
54
|
- ext/candle/src/llm/text_generation.rs
|