red-candle 1.0.0.pre.5 → 1.0.0.pre.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -180,87 +180,14 @@ impl Mistral {
180
180
 
181
181
  // Stream callback
182
182
  if let Some(ref mut cb) = callback {
183
- let token_text = self.tokenizer.token_to_piece(next_token)?;
184
- cb(&token_text);
185
- }
186
-
187
- // Check stop conditions
188
- if text_gen.should_stop(next_token, config.max_length) {
189
- break;
190
- }
191
-
192
- // Check stop sequences
193
- let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
194
- if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
195
- break;
196
- }
197
- }
198
-
199
- Ok(if config.include_prompt {
200
- all_tokens
201
- } else {
202
- all_tokens[start_gen..].to_vec()
203
- })
204
- }
205
-
206
- fn generate_tokens_decoded(
207
- &mut self,
208
- prompt_tokens: Vec<u32>,
209
- config: &GenerationConfig,
210
- mut callback: Option<impl FnMut(&str)>,
211
- ) -> CandleResult<Vec<u32>> {
212
- let mut text_gen = TextGeneration::from_config(config);
213
- text_gen.set_eos_token_id(self.eos_token_id);
214
- text_gen.set_tokens(prompt_tokens.clone());
215
-
216
- let mut all_tokens = prompt_tokens.clone();
217
- let start_gen = all_tokens.len();
218
-
219
- // For incremental decoding
220
- let mut previously_decoded = String::new();
221
-
222
- for index in 0..config.max_length {
223
- let context_size = if index > 0 { 1 } else { all_tokens.len() };
224
- let start_pos = all_tokens.len().saturating_sub(context_size);
225
- let ctxt = &all_tokens[start_pos..];
226
-
227
- let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
228
- // Ensure input tensor is contiguous for Metal backend
229
- let input = input.contiguous()?;
230
- let logits = self.model.forward(&input, start_pos)?;
231
-
232
- // The model returns logits of shape [batch_size, seq_len, vocab_size]
233
- // We need to get the logits for the last token only
234
- let logits = logits.squeeze(0)?; // Remove batch dimension
235
- let logits = if logits.dims().len() == 2 {
236
- // If we still have [seq_len, vocab_size], take the last token
237
- let seq_len = logits.dim(0)?;
238
- logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
239
- } else {
240
- // Already [vocab_size]
241
- logits
242
- };
243
-
244
- // Convert to F32 for sampling if needed
245
- let logits = logits.to_dtype(DType::F32)?;
246
-
247
- let next_token = text_gen.sample_next_token(
248
- &logits,
249
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
250
- )?;
251
-
252
- all_tokens.push(next_token);
253
-
254
- // Stream callback with incremental decoding
255
- if let Some(ref mut cb) = callback {
256
- // Decode all generated tokens so far
257
- let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
258
-
259
- // Only emit the new text since last callback
260
- if current_decoded.len() > previously_decoded.len() {
261
- let new_text = &current_decoded[previously_decoded.len()..];
262
- cb(new_text);
263
- previously_decoded = current_decoded;
183
+ if config.debug_tokens {
184
+ // In debug mode, only show debug tokens
185
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
186
+ cb(&format!("[{}:{}]", next_token, token_piece));
187
+ } else {
188
+ // Normal mode: use incremental decoding for proper text
189
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
190
+ cb(&decoded_text);
264
191
  }
265
192
  }
266
193
 
@@ -270,12 +197,7 @@ impl Mistral {
270
197
  }
271
198
 
272
199
  // Check stop sequences
273
- let generated_text = if callback.is_some() {
274
- previously_decoded.clone()
275
- } else {
276
- self.tokenizer.decode(&all_tokens[start_gen..], true)?
277
- };
278
-
200
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
279
201
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
280
202
  break;
281
203
  }
@@ -297,7 +219,12 @@ impl TextGenerator for Mistral {
297
219
  ) -> CandleResult<String> {
298
220
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
299
221
  let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
300
- self.tokenizer.decode(&output_tokens, true)
222
+
223
+ if config.debug_tokens {
224
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
225
+ } else {
226
+ self.tokenizer.decode(&output_tokens, true)
227
+ }
301
228
  }
302
229
 
303
230
  fn generate_stream(
@@ -307,7 +234,7 @@ impl TextGenerator for Mistral {
307
234
  mut callback: impl FnMut(&str),
308
235
  ) -> CandleResult<String> {
309
236
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
310
- let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
237
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
311
238
  self.tokenizer.decode(&output_tokens, true)
312
239
  }
313
240
 
@@ -3,11 +3,14 @@ use tokenizers::Tokenizer;
3
3
 
4
4
  pub mod mistral;
5
5
  pub mod llama;
6
+ pub mod gemma;
6
7
  pub mod generation_config;
7
8
  pub mod text_generation;
9
+ pub mod quantized_gguf;
8
10
 
9
11
  pub use generation_config::GenerationConfig;
10
12
  pub use text_generation::TextGeneration;
13
+ pub use quantized_gguf::QuantizedGGUF;
11
14
 
12
15
  /// Trait for text generation models
13
16
  pub trait TextGenerator: Send + Sync {
@@ -66,4 +69,60 @@ impl TokenizerWrapper {
66
69
  .map(|s| s.to_string())
67
70
  .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
68
71
  }
72
+
73
+ /// Decode a single token for streaming output
74
+ pub fn decode_token(&self, token: u32) -> CandleResult<String> {
75
+ // Decode the single token properly
76
+ self.decode(&[token], true)
77
+ }
78
+
79
+ /// Decode tokens incrementally for streaming
80
+ /// This is more efficient than decoding single tokens
81
+ pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
82
+ if new_tokens_start >= all_tokens.len() {
83
+ return Ok(String::new());
84
+ }
85
+
86
+ // Decode all tokens up to this point
87
+ let full_text = self.decode(all_tokens, true)?;
88
+
89
+ // If we're at the start, return everything
90
+ if new_tokens_start == 0 {
91
+ return Ok(full_text);
92
+ }
93
+
94
+ // Otherwise, decode up to the previous token and return the difference
95
+ let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
96
+
97
+ // Find the common prefix between the two strings to handle cases where
98
+ // the tokenizer might produce slightly different text when decoding
99
+ // different token sequences
100
+ let common_prefix_len = full_text
101
+ .char_indices()
102
+ .zip(previous_text.chars())
103
+ .take_while(|((_, c1), c2)| c1 == c2)
104
+ .count();
105
+
106
+ // Find the byte position of the character boundary
107
+ let byte_pos = full_text
108
+ .char_indices()
109
+ .nth(common_prefix_len)
110
+ .map(|(pos, _)| pos)
111
+ .unwrap_or(full_text.len());
112
+
113
+ // Return only the new portion
114
+ Ok(full_text[byte_pos..].to_string())
115
+ }
116
+
117
+ /// Format tokens with debug information
118
+ pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
119
+ let mut result = String::new();
120
+
121
+ for &token in tokens {
122
+ let token_piece = self.token_to_piece(token)?;
123
+ result.push_str(&format!("[{}:{}]", token, token_piece));
124
+ }
125
+
126
+ Ok(result)
127
+ }
69
128
  }