red-candle 1.0.0.pre.5 → 1.0.0.pre.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +164 -4
- data/Rakefile +1 -3
- data/ext/candle/src/llm/gemma.rs +277 -0
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +16 -79
- data/ext/candle/src/llm/mistral.rs +16 -89
- data/ext/candle/src/llm/mod.rs +59 -0
- data/ext/candle/src/llm/quantized_gguf.rs +496 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ruby/embedding_model.rs +0 -1
- data/ext/candle/src/ruby/llm.rs +84 -29
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/version.rb +1 -1
- metadata +4 -2
@@ -180,87 +180,14 @@ impl Mistral {
|
|
180
180
|
|
181
181
|
// Stream callback
|
182
182
|
if let Some(ref mut cb) = callback {
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
// Check stop sequences
|
193
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
194
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
195
|
-
break;
|
196
|
-
}
|
197
|
-
}
|
198
|
-
|
199
|
-
Ok(if config.include_prompt {
|
200
|
-
all_tokens
|
201
|
-
} else {
|
202
|
-
all_tokens[start_gen..].to_vec()
|
203
|
-
})
|
204
|
-
}
|
205
|
-
|
206
|
-
fn generate_tokens_decoded(
|
207
|
-
&mut self,
|
208
|
-
prompt_tokens: Vec<u32>,
|
209
|
-
config: &GenerationConfig,
|
210
|
-
mut callback: Option<impl FnMut(&str)>,
|
211
|
-
) -> CandleResult<Vec<u32>> {
|
212
|
-
let mut text_gen = TextGeneration::from_config(config);
|
213
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
214
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
215
|
-
|
216
|
-
let mut all_tokens = prompt_tokens.clone();
|
217
|
-
let start_gen = all_tokens.len();
|
218
|
-
|
219
|
-
// For incremental decoding
|
220
|
-
let mut previously_decoded = String::new();
|
221
|
-
|
222
|
-
for index in 0..config.max_length {
|
223
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
224
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
225
|
-
let ctxt = &all_tokens[start_pos..];
|
226
|
-
|
227
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
228
|
-
// Ensure input tensor is contiguous for Metal backend
|
229
|
-
let input = input.contiguous()?;
|
230
|
-
let logits = self.model.forward(&input, start_pos)?;
|
231
|
-
|
232
|
-
// The model returns logits of shape [batch_size, seq_len, vocab_size]
|
233
|
-
// We need to get the logits for the last token only
|
234
|
-
let logits = logits.squeeze(0)?; // Remove batch dimension
|
235
|
-
let logits = if logits.dims().len() == 2 {
|
236
|
-
// If we still have [seq_len, vocab_size], take the last token
|
237
|
-
let seq_len = logits.dim(0)?;
|
238
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
239
|
-
} else {
|
240
|
-
// Already [vocab_size]
|
241
|
-
logits
|
242
|
-
};
|
243
|
-
|
244
|
-
// Convert to F32 for sampling if needed
|
245
|
-
let logits = logits.to_dtype(DType::F32)?;
|
246
|
-
|
247
|
-
let next_token = text_gen.sample_next_token(
|
248
|
-
&logits,
|
249
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
250
|
-
)?;
|
251
|
-
|
252
|
-
all_tokens.push(next_token);
|
253
|
-
|
254
|
-
// Stream callback with incremental decoding
|
255
|
-
if let Some(ref mut cb) = callback {
|
256
|
-
// Decode all generated tokens so far
|
257
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
258
|
-
|
259
|
-
// Only emit the new text since last callback
|
260
|
-
if current_decoded.len() > previously_decoded.len() {
|
261
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
262
|
-
cb(new_text);
|
263
|
-
previously_decoded = current_decoded;
|
183
|
+
if config.debug_tokens {
|
184
|
+
// In debug mode, only show debug tokens
|
185
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
186
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
187
|
+
} else {
|
188
|
+
// Normal mode: use incremental decoding for proper text
|
189
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
190
|
+
cb(&decoded_text);
|
264
191
|
}
|
265
192
|
}
|
266
193
|
|
@@ -270,12 +197,7 @@ impl Mistral {
|
|
270
197
|
}
|
271
198
|
|
272
199
|
// Check stop sequences
|
273
|
-
let generated_text =
|
274
|
-
previously_decoded.clone()
|
275
|
-
} else {
|
276
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
277
|
-
};
|
278
|
-
|
200
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
279
201
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
280
202
|
break;
|
281
203
|
}
|
@@ -297,7 +219,12 @@ impl TextGenerator for Mistral {
|
|
297
219
|
) -> CandleResult<String> {
|
298
220
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
299
221
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
300
|
-
|
222
|
+
|
223
|
+
if config.debug_tokens {
|
224
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
225
|
+
} else {
|
226
|
+
self.tokenizer.decode(&output_tokens, true)
|
227
|
+
}
|
301
228
|
}
|
302
229
|
|
303
230
|
fn generate_stream(
|
@@ -307,7 +234,7 @@ impl TextGenerator for Mistral {
|
|
307
234
|
mut callback: impl FnMut(&str),
|
308
235
|
) -> CandleResult<String> {
|
309
236
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
310
|
-
let output_tokens = self.
|
237
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
311
238
|
self.tokenizer.decode(&output_tokens, true)
|
312
239
|
}
|
313
240
|
|
data/ext/candle/src/llm/mod.rs
CHANGED
@@ -3,11 +3,14 @@ use tokenizers::Tokenizer;
|
|
3
3
|
|
4
4
|
pub mod mistral;
|
5
5
|
pub mod llama;
|
6
|
+
pub mod gemma;
|
6
7
|
pub mod generation_config;
|
7
8
|
pub mod text_generation;
|
9
|
+
pub mod quantized_gguf;
|
8
10
|
|
9
11
|
pub use generation_config::GenerationConfig;
|
10
12
|
pub use text_generation::TextGeneration;
|
13
|
+
pub use quantized_gguf::QuantizedGGUF;
|
11
14
|
|
12
15
|
/// Trait for text generation models
|
13
16
|
pub trait TextGenerator: Send + Sync {
|
@@ -66,4 +69,60 @@ impl TokenizerWrapper {
|
|
66
69
|
.map(|s| s.to_string())
|
67
70
|
.ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
|
68
71
|
}
|
72
|
+
|
73
|
+
/// Decode a single token for streaming output
|
74
|
+
pub fn decode_token(&self, token: u32) -> CandleResult<String> {
|
75
|
+
// Decode the single token properly
|
76
|
+
self.decode(&[token], true)
|
77
|
+
}
|
78
|
+
|
79
|
+
/// Decode tokens incrementally for streaming
|
80
|
+
/// This is more efficient than decoding single tokens
|
81
|
+
pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
|
82
|
+
if new_tokens_start >= all_tokens.len() {
|
83
|
+
return Ok(String::new());
|
84
|
+
}
|
85
|
+
|
86
|
+
// Decode all tokens up to this point
|
87
|
+
let full_text = self.decode(all_tokens, true)?;
|
88
|
+
|
89
|
+
// If we're at the start, return everything
|
90
|
+
if new_tokens_start == 0 {
|
91
|
+
return Ok(full_text);
|
92
|
+
}
|
93
|
+
|
94
|
+
// Otherwise, decode up to the previous token and return the difference
|
95
|
+
let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
|
96
|
+
|
97
|
+
// Find the common prefix between the two strings to handle cases where
|
98
|
+
// the tokenizer might produce slightly different text when decoding
|
99
|
+
// different token sequences
|
100
|
+
let common_prefix_len = full_text
|
101
|
+
.char_indices()
|
102
|
+
.zip(previous_text.chars())
|
103
|
+
.take_while(|((_, c1), c2)| c1 == c2)
|
104
|
+
.count();
|
105
|
+
|
106
|
+
// Find the byte position of the character boundary
|
107
|
+
let byte_pos = full_text
|
108
|
+
.char_indices()
|
109
|
+
.nth(common_prefix_len)
|
110
|
+
.map(|(pos, _)| pos)
|
111
|
+
.unwrap_or(full_text.len());
|
112
|
+
|
113
|
+
// Return only the new portion
|
114
|
+
Ok(full_text[byte_pos..].to_string())
|
115
|
+
}
|
116
|
+
|
117
|
+
/// Format tokens with debug information
|
118
|
+
pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
|
119
|
+
let mut result = String::new();
|
120
|
+
|
121
|
+
for &token in tokens {
|
122
|
+
let token_piece = self.token_to_piece(token)?;
|
123
|
+
result.push_str(&format!("[{}:{}]", token, token_piece));
|
124
|
+
}
|
125
|
+
|
126
|
+
Ok(result)
|
127
|
+
}
|
69
128
|
}
|