red-candle 1.0.0.pre.6 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +481 -4
  4. data/Rakefile +1 -3
  5. data/ext/candle/src/lib.rs +6 -3
  6. data/ext/candle/src/llm/gemma.rs +21 -79
  7. data/ext/candle/src/llm/generation_config.rs +3 -0
  8. data/ext/candle/src/llm/llama.rs +21 -79
  9. data/ext/candle/src/llm/mistral.rs +21 -89
  10. data/ext/candle/src/llm/mod.rs +3 -33
  11. data/ext/candle/src/llm/quantized_gguf.rs +501 -0
  12. data/ext/candle/src/llm/text_generation.rs +0 -4
  13. data/ext/candle/src/ner.rs +423 -0
  14. data/ext/candle/src/reranker.rs +24 -21
  15. data/ext/candle/src/ruby/device.rs +6 -6
  16. data/ext/candle/src/ruby/dtype.rs +4 -4
  17. data/ext/candle/src/ruby/embedding_model.rs +36 -34
  18. data/ext/candle/src/ruby/llm.rs +110 -49
  19. data/ext/candle/src/ruby/mod.rs +1 -2
  20. data/ext/candle/src/ruby/tensor.rs +66 -66
  21. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  22. data/ext/candle/src/ruby/utils.rs +6 -24
  23. data/ext/candle/src/tokenizer/loader.rs +108 -0
  24. data/ext/candle/src/tokenizer/mod.rs +103 -0
  25. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  26. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  27. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  28. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  29. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  30. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  31. data/lib/candle/build_info.rb +2 -0
  32. data/lib/candle/device_utils.rb +2 -0
  33. data/lib/candle/llm.rb +91 -2
  34. data/lib/candle/ner.rb +345 -0
  35. data/lib/candle/reranker.rb +1 -1
  36. data/lib/candle/tensor.rb +2 -0
  37. data/lib/candle/tokenizer.rb +139 -0
  38. data/lib/candle/version.rb +4 -2
  39. data/lib/candle.rb +2 -0
  40. metadata +127 -3
  41. data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -21,6 +21,11 @@ impl Gemma {
21
21
  self.model.clear_kv_cache();
22
22
  }
23
23
 
24
+ /// Get the tokenizer
25
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
26
+ &self.tokenizer
27
+ }
28
+
24
29
  /// Load a Gemma model from HuggingFace Hub
25
30
  pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
31
  let api = Api::new()
@@ -169,77 +174,14 @@ impl Gemma {
169
174
 
170
175
  // Stream callback
171
176
  if let Some(ref mut cb) = callback {
172
- let token_text = self.tokenizer.token_to_piece(next_token)?;
173
- cb(&token_text);
174
- }
175
-
176
- // Check stop conditions
177
- if text_gen.should_stop(next_token, config.max_length) {
178
- break;
179
- }
180
-
181
- // Check stop sequences
182
- let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
183
- if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
184
- break;
185
- }
186
- }
187
-
188
- Ok(if config.include_prompt {
189
- all_tokens
190
- } else {
191
- all_tokens[start_gen..].to_vec()
192
- })
193
- }
194
-
195
- fn generate_tokens_decoded(
196
- &mut self,
197
- prompt_tokens: Vec<u32>,
198
- config: &GenerationConfig,
199
- mut callback: Option<impl FnMut(&str)>,
200
- ) -> CandleResult<Vec<u32>> {
201
- let mut text_gen = TextGeneration::from_config(config);
202
- text_gen.set_eos_token_id(self.eos_token_id);
203
- text_gen.set_tokens(prompt_tokens.clone());
204
-
205
- let mut all_tokens = prompt_tokens.clone();
206
- let start_gen = all_tokens.len();
207
- let mut previously_decoded = String::new();
208
-
209
- for index in 0..config.max_length {
210
- let context_size = if index > 0 { 1 } else { all_tokens.len() };
211
- let start_pos = all_tokens.len().saturating_sub(context_size);
212
- let ctxt = &all_tokens[start_pos..];
213
-
214
- let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
215
- let input = input.contiguous()?;
216
- let logits = self.model.forward(&input, start_pos)?;
217
-
218
- let logits = logits.squeeze(0)?;
219
- let logits = if logits.dims().len() == 2 {
220
- let seq_len = logits.dim(0)?;
221
- logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
222
- } else {
223
- logits
224
- };
225
-
226
- let logits = logits.to_dtype(DType::F32)?;
227
-
228
- let next_token = text_gen.sample_next_token(
229
- &logits,
230
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
231
- )?;
232
-
233
- all_tokens.push(next_token);
234
-
235
- // Stream callback with incremental decoding
236
- if let Some(ref mut cb) = callback {
237
- let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
238
-
239
- if current_decoded.len() > previously_decoded.len() {
240
- let new_text = &current_decoded[previously_decoded.len()..];
241
- cb(new_text);
242
- previously_decoded = current_decoded;
177
+ if config.debug_tokens {
178
+ // In debug mode, only show debug tokens
179
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
180
+ cb(&format!("[{}:{}]", next_token, token_piece));
181
+ } else {
182
+ // Normal mode: use incremental decoding for proper text
183
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
184
+ cb(&decoded_text);
243
185
  }
244
186
  }
245
187
 
@@ -249,12 +191,7 @@ impl Gemma {
249
191
  }
250
192
 
251
193
  // Check stop sequences
252
- let generated_text = if callback.is_some() {
253
- previously_decoded.clone()
254
- } else {
255
- self.tokenizer.decode(&all_tokens[start_gen..], true)?
256
- };
257
-
194
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
258
195
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
259
196
  break;
260
197
  }
@@ -312,7 +249,12 @@ impl TextGenerator for Gemma {
312
249
  ) -> CandleResult<String> {
313
250
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
314
251
  let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
315
- self.tokenizer.decode(&output_tokens, true)
252
+
253
+ if config.debug_tokens {
254
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
255
+ } else {
256
+ self.tokenizer.decode(&output_tokens, true)
257
+ }
316
258
  }
317
259
 
318
260
  fn generate_stream(
@@ -322,7 +264,7 @@ impl TextGenerator for Gemma {
322
264
  mut callback: impl FnMut(&str),
323
265
  ) -> CandleResult<String> {
324
266
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
325
- let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
267
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
326
268
  self.tokenizer.decode(&output_tokens, true)
327
269
  }
328
270
 
@@ -21,6 +21,8 @@ pub struct GenerationConfig {
21
21
  pub stop_sequences: Vec<String>,
22
22
  /// Whether to return the prompt in the output
23
23
  pub include_prompt: bool,
24
+ /// Whether to show raw tokens during generation (for debugging)
25
+ pub debug_tokens: bool,
24
26
  }
25
27
 
26
28
  /// Generate a random seed based on current time
@@ -43,6 +45,7 @@ impl Default for GenerationConfig {
43
45
  seed: random_seed(),
44
46
  stop_sequences: vec![],
45
47
  include_prompt: false,
48
+ debug_tokens: false,
46
49
  }
47
50
  }
48
51
  }
@@ -28,6 +28,11 @@ impl Llama {
28
28
  }
29
29
  }
30
30
 
31
+ /// Get the tokenizer
32
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
33
+ &self.tokenizer
34
+ }
35
+
31
36
  /// Load a Llama model from HuggingFace Hub
32
37
  pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
38
  let api = Api::new()
@@ -205,77 +210,14 @@ impl Llama {
205
210
 
206
211
  // Stream callback
207
212
  if let Some(ref mut cb) = callback {
208
- let token_text = self.tokenizer.token_to_piece(next_token)?;
209
- cb(&token_text);
210
- }
211
-
212
- // Check stop conditions
213
- if text_gen.should_stop(next_token, config.max_length) {
214
- break;
215
- }
216
-
217
- // Check stop sequences
218
- let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
219
- if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
220
- break;
221
- }
222
- }
223
-
224
- Ok(if config.include_prompt {
225
- all_tokens
226
- } else {
227
- all_tokens[start_gen..].to_vec()
228
- })
229
- }
230
-
231
- fn generate_tokens_decoded(
232
- &mut self,
233
- prompt_tokens: Vec<u32>,
234
- config: &GenerationConfig,
235
- mut callback: Option<impl FnMut(&str)>,
236
- ) -> CandleResult<Vec<u32>> {
237
- let mut text_gen = TextGeneration::from_config(config);
238
- text_gen.set_eos_token_id(self.eos_token_id);
239
- text_gen.set_tokens(prompt_tokens.clone());
240
-
241
- let mut all_tokens = prompt_tokens.clone();
242
- let start_gen = all_tokens.len();
243
- let mut previously_decoded = String::new();
244
-
245
- for index in 0..config.max_length {
246
- let context_size = if index > 0 { 1 } else { all_tokens.len() };
247
- let start_pos = all_tokens.len().saturating_sub(context_size);
248
- let ctxt = &all_tokens[start_pos..];
249
-
250
- let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
251
- let input = input.contiguous()?;
252
- let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
253
-
254
- let logits = logits.squeeze(0)?;
255
- let logits = if logits.dims().len() == 2 {
256
- let seq_len = logits.dim(0)?;
257
- logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
258
- } else {
259
- logits
260
- };
261
-
262
- let logits = logits.to_dtype(DType::F32)?;
263
-
264
- let next_token = text_gen.sample_next_token(
265
- &logits,
266
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
267
- )?;
268
-
269
- all_tokens.push(next_token);
270
-
271
- // Stream callback with incremental decoding
272
- if let Some(ref mut cb) = callback {
273
- let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
274
-
275
- if current_decoded.len() > previously_decoded.len() {
276
- let new_text = &current_decoded[previously_decoded.len()..];
277
- cb(new_text);
278
- previously_decoded = current_decoded;
213
+ if config.debug_tokens {
214
+ // In debug mode, only show debug tokens
215
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
216
+ cb(&format!("[{}:{}]", next_token, token_piece));
217
+ } else {
218
+ // Normal mode: use incremental decoding for proper text
219
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
220
+ cb(&decoded_text);
279
221
  }
280
222
  }
281
223
 
@@ -285,12 +227,7 @@ impl Llama {
285
227
  }
286
228
 
287
229
  // Check stop sequences
288
- let generated_text = if callback.is_some() {
289
- previously_decoded.clone()
290
- } else {
291
- self.tokenizer.decode(&all_tokens[start_gen..], true)?
292
- };
293
-
230
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
294
231
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
295
232
  break;
296
233
  }
@@ -374,7 +311,12 @@ impl TextGenerator for Llama {
374
311
  ) -> CandleResult<String> {
375
312
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
376
313
  let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
377
- self.tokenizer.decode(&output_tokens, true)
314
+
315
+ if config.debug_tokens {
316
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
317
+ } else {
318
+ self.tokenizer.decode(&output_tokens, true)
319
+ }
378
320
  }
379
321
 
380
322
  fn generate_stream(
@@ -384,7 +326,7 @@ impl TextGenerator for Llama {
384
326
  mut callback: impl FnMut(&str),
385
327
  ) -> CandleResult<String> {
386
328
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
387
- let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
329
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
388
330
  self.tokenizer.decode(&output_tokens, true)
389
331
  }
390
332
 
@@ -21,6 +21,11 @@ impl Mistral {
21
21
  self.model.clear_kv_cache();
22
22
  }
23
23
 
24
+ /// Get the tokenizer
25
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
26
+ &self.tokenizer
27
+ }
28
+
24
29
  /// Load a Mistral model from HuggingFace Hub
25
30
  pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
31
  let api = Api::new()
@@ -180,87 +185,14 @@ impl Mistral {
180
185
 
181
186
  // Stream callback
182
187
  if let Some(ref mut cb) = callback {
183
- let token_text = self.tokenizer.token_to_piece(next_token)?;
184
- cb(&token_text);
185
- }
186
-
187
- // Check stop conditions
188
- if text_gen.should_stop(next_token, config.max_length) {
189
- break;
190
- }
191
-
192
- // Check stop sequences
193
- let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
194
- if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
195
- break;
196
- }
197
- }
198
-
199
- Ok(if config.include_prompt {
200
- all_tokens
201
- } else {
202
- all_tokens[start_gen..].to_vec()
203
- })
204
- }
205
-
206
- fn generate_tokens_decoded(
207
- &mut self,
208
- prompt_tokens: Vec<u32>,
209
- config: &GenerationConfig,
210
- mut callback: Option<impl FnMut(&str)>,
211
- ) -> CandleResult<Vec<u32>> {
212
- let mut text_gen = TextGeneration::from_config(config);
213
- text_gen.set_eos_token_id(self.eos_token_id);
214
- text_gen.set_tokens(prompt_tokens.clone());
215
-
216
- let mut all_tokens = prompt_tokens.clone();
217
- let start_gen = all_tokens.len();
218
-
219
- // For incremental decoding
220
- let mut previously_decoded = String::new();
221
-
222
- for index in 0..config.max_length {
223
- let context_size = if index > 0 { 1 } else { all_tokens.len() };
224
- let start_pos = all_tokens.len().saturating_sub(context_size);
225
- let ctxt = &all_tokens[start_pos..];
226
-
227
- let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
228
- // Ensure input tensor is contiguous for Metal backend
229
- let input = input.contiguous()?;
230
- let logits = self.model.forward(&input, start_pos)?;
231
-
232
- // The model returns logits of shape [batch_size, seq_len, vocab_size]
233
- // We need to get the logits for the last token only
234
- let logits = logits.squeeze(0)?; // Remove batch dimension
235
- let logits = if logits.dims().len() == 2 {
236
- // If we still have [seq_len, vocab_size], take the last token
237
- let seq_len = logits.dim(0)?;
238
- logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
239
- } else {
240
- // Already [vocab_size]
241
- logits
242
- };
243
-
244
- // Convert to F32 for sampling if needed
245
- let logits = logits.to_dtype(DType::F32)?;
246
-
247
- let next_token = text_gen.sample_next_token(
248
- &logits,
249
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
250
- )?;
251
-
252
- all_tokens.push(next_token);
253
-
254
- // Stream callback with incremental decoding
255
- if let Some(ref mut cb) = callback {
256
- // Decode all generated tokens so far
257
- let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
258
-
259
- // Only emit the new text since last callback
260
- if current_decoded.len() > previously_decoded.len() {
261
- let new_text = &current_decoded[previously_decoded.len()..];
262
- cb(new_text);
263
- previously_decoded = current_decoded;
188
+ if config.debug_tokens {
189
+ // In debug mode, only show debug tokens
190
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
191
+ cb(&format!("[{}:{}]", next_token, token_piece));
192
+ } else {
193
+ // Normal mode: use incremental decoding for proper text
194
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
195
+ cb(&decoded_text);
264
196
  }
265
197
  }
266
198
 
@@ -270,12 +202,7 @@ impl Mistral {
270
202
  }
271
203
 
272
204
  // Check stop sequences
273
- let generated_text = if callback.is_some() {
274
- previously_decoded.clone()
275
- } else {
276
- self.tokenizer.decode(&all_tokens[start_gen..], true)?
277
- };
278
-
205
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
279
206
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
280
207
  break;
281
208
  }
@@ -297,7 +224,12 @@ impl TextGenerator for Mistral {
297
224
  ) -> CandleResult<String> {
298
225
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
299
226
  let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
300
- self.tokenizer.decode(&output_tokens, true)
227
+
228
+ if config.debug_tokens {
229
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
230
+ } else {
231
+ self.tokenizer.decode(&output_tokens, true)
232
+ }
301
233
  }
302
234
 
303
235
  fn generate_stream(
@@ -307,7 +239,7 @@ impl TextGenerator for Mistral {
307
239
  mut callback: impl FnMut(&str),
308
240
  ) -> CandleResult<String> {
309
241
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
310
- let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
242
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
311
243
  self.tokenizer.decode(&output_tokens, true)
312
244
  }
313
245
 
@@ -1,14 +1,16 @@
1
1
  use candle_core::{Device, Result as CandleResult};
2
- use tokenizers::Tokenizer;
3
2
 
4
3
  pub mod mistral;
5
4
  pub mod llama;
6
5
  pub mod gemma;
7
6
  pub mod generation_config;
8
7
  pub mod text_generation;
8
+ pub mod quantized_gguf;
9
9
 
10
10
  pub use generation_config::GenerationConfig;
11
11
  pub use text_generation::TextGeneration;
12
+ pub use quantized_gguf::QuantizedGGUF;
13
+ pub use crate::tokenizer::TokenizerWrapper;
12
14
 
13
15
  /// Trait for text generation models
14
16
  pub trait TextGenerator: Send + Sync {
@@ -35,36 +37,4 @@ pub trait TextGenerator: Send + Sync {
35
37
 
36
38
  /// Clear any cached state (like KV cache)
37
39
  fn clear_cache(&mut self);
38
- }
39
-
40
- /// Common structure for managing tokenizer
41
- #[derive(Debug)]
42
- pub struct TokenizerWrapper {
43
- tokenizer: Tokenizer,
44
- }
45
-
46
- impl TokenizerWrapper {
47
- pub fn new(tokenizer: Tokenizer) -> Self {
48
- Self { tokenizer }
49
- }
50
-
51
- pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
52
- let encoding = self.tokenizer
53
- .encode(text, add_special_tokens)
54
- .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
55
- Ok(encoding.get_ids().to_vec())
56
- }
57
-
58
- pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
59
- self.tokenizer
60
- .decode(tokens, skip_special_tokens)
61
- .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
62
- }
63
-
64
- pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
65
- self.tokenizer
66
- .id_to_token(token)
67
- .map(|s| s.to_string())
68
- .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
69
- }
70
40
  }