red-candle 1.0.0.pre.6 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +481 -4
  4. data/Rakefile +1 -3
  5. data/ext/candle/src/lib.rs +6 -3
  6. data/ext/candle/src/llm/gemma.rs +21 -79
  7. data/ext/candle/src/llm/generation_config.rs +3 -0
  8. data/ext/candle/src/llm/llama.rs +21 -79
  9. data/ext/candle/src/llm/mistral.rs +21 -89
  10. data/ext/candle/src/llm/mod.rs +3 -33
  11. data/ext/candle/src/llm/quantized_gguf.rs +501 -0
  12. data/ext/candle/src/llm/text_generation.rs +0 -4
  13. data/ext/candle/src/ner.rs +423 -0
  14. data/ext/candle/src/reranker.rs +24 -21
  15. data/ext/candle/src/ruby/device.rs +6 -6
  16. data/ext/candle/src/ruby/dtype.rs +4 -4
  17. data/ext/candle/src/ruby/embedding_model.rs +36 -34
  18. data/ext/candle/src/ruby/llm.rs +110 -49
  19. data/ext/candle/src/ruby/mod.rs +1 -2
  20. data/ext/candle/src/ruby/tensor.rs +66 -66
  21. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  22. data/ext/candle/src/ruby/utils.rs +6 -24
  23. data/ext/candle/src/tokenizer/loader.rs +108 -0
  24. data/ext/candle/src/tokenizer/mod.rs +103 -0
  25. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  26. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  27. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  28. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  29. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  30. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  31. data/lib/candle/build_info.rb +2 -0
  32. data/lib/candle/device_utils.rb +2 -0
  33. data/lib/candle/llm.rb +91 -2
  34. data/lib/candle/ner.rb +345 -0
  35. data/lib/candle/reranker.rb +1 -1
  36. data/lib/candle/tensor.rb +2 -0
  37. data/lib/candle/tokenizer.rb +139 -0
  38. data/lib/candle/version.rb +4 -2
  39. data/lib/candle.rb +2 -0
  40. metadata +127 -3
  41. data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -0,0 +1,501 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_core::quantized::gguf_file;
3
+ use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
4
+ use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
5
+ use hf_hub::api::tokio::{Api, ApiRepo};
6
+ use tokenizers::Tokenizer;
7
+ use std::io::Seek;
8
+
9
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
10
+
11
+ /// Unified GGUF model that can load any GGUF file and detect the architecture
12
+ #[derive(Debug)]
13
+ pub struct QuantizedGGUF {
14
+ model: ModelType,
15
+ tokenizer: TokenizerWrapper,
16
+ device: Device,
17
+ model_id: String,
18
+ eos_token_id: u32,
19
+ architecture: String,
20
+ _chat_template: Option<String>,
21
+ }
22
+
23
+ #[derive(Debug)]
24
+ enum ModelType {
25
+ Llama(QuantizedLlamaModel),
26
+ Gemma(QuantizedGemmaModel),
27
+ // Mistral uses Llama loader due to tensor naming compatibility
28
+ }
29
+
30
+ impl QuantizedGGUF {
31
+ /// Get the tokenizer
32
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
33
+ &self.tokenizer
34
+ }
35
+
36
+ /// Load a quantized model from a GGUF file
37
+ pub async fn from_pretrained(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
38
+ // Check if user specified an exact GGUF filename
39
+ let (actual_model_id, gguf_file) = if let Some(pos) = model_id.find('@') {
40
+ let (id, filename) = model_id.split_at(pos);
41
+ (id, Some(&filename[1..]))
42
+ } else {
43
+ (model_id, None)
44
+ };
45
+
46
+ let api = Api::new()
47
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
48
+
49
+ let repo = api.model(actual_model_id.to_string());
50
+
51
+ // Download GGUF file
52
+ let gguf_filename = if let Some(filename) = gguf_file {
53
+ // User specified exact filename
54
+ repo.get(filename).await
55
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download GGUF file '{}': {}", filename, e)))?
56
+ .to_string_lossy().to_string()
57
+ } else {
58
+ // Let Ruby handle the search, for now just try a common name
59
+ return Err(candle_core::Error::Msg(
60
+ "Please specify a GGUF filename using gguf_file parameter".to_string()
61
+ ));
62
+ };
63
+
64
+ // Read GGUF metadata to determine architecture
65
+ let mut file = std::fs::File::open(&gguf_filename)?;
66
+ let content = gguf_file::Content::read(&mut file)?;
67
+
68
+ // Detect architecture from metadata
69
+ let architecture = Self::detect_architecture(&content, actual_model_id)?;
70
+
71
+ // For Gemma 3 models, we might need to adjust the architecture
72
+ let architecture = if actual_model_id.contains("gemma-3") || actual_model_id.contains("gemma3") {
73
+ "gemma3".to_string()
74
+ } else {
75
+ architecture
76
+ };
77
+
78
+ // Download tokenizer - either from specified source or with fallback
79
+ let tokenizer_filename = if let Some(source) = tokenizer_source {
80
+ Self::download_tokenizer_from_source(&api, source).await?
81
+ } else {
82
+ Self::download_tokenizer(&api, &repo, actual_model_id, &architecture).await?
83
+ };
84
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
85
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
86
+
87
+ // Determine EOS token based on architecture and model
88
+ let eos_token_id = Self::determine_eos_token(&tokenizer, &architecture, actual_model_id);
89
+
90
+ // Load the appropriate model based on architecture
91
+ file.seek(std::io::SeekFrom::Start(0))?;
92
+ let content = gguf_file::Content::read(&mut file)?;
93
+
94
+ let model = match architecture.as_str() {
95
+ "llama" | "mistral" => {
96
+ // Both use the same GGUF format with llama.cpp tensor names
97
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
98
+ ModelType::Llama(model)
99
+ }
100
+ "gemma" | "gemma2" | "gemma3" => {
101
+ // Try Gemma-specific loader first, fall back to Llama if it fails
102
+ match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
103
+ Ok(model) => ModelType::Gemma(model),
104
+ Err(e) if e.to_string().contains("gemma3.attention.head_count") => {
105
+ // This might be an older Gemma GGUF that uses llama format
106
+ // Note: Some Gemma GGUF files may not be compatible
107
+ file.seek(std::io::SeekFrom::Start(0))?;
108
+ let content = gguf_file::Content::read(&mut file)?;
109
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
110
+ ModelType::Llama(model)
111
+ }
112
+ Err(e) => return Err(e),
113
+ }
114
+ }
115
+ _ => {
116
+ return Err(candle_core::Error::Msg(format!(
117
+ "Unsupported architecture: {}. Supported: llama, mistral, gemma",
118
+ architecture
119
+ )));
120
+ }
121
+ };
122
+
123
+ // Detect chat template (for now, use defaults based on architecture)
124
+ let chat_template = Self::detect_chat_template(&tokenizer, &architecture, actual_model_id);
125
+
126
+ Ok(Self {
127
+ model,
128
+ tokenizer: TokenizerWrapper::new(tokenizer),
129
+ device,
130
+ model_id: actual_model_id.to_string(),
131
+ eos_token_id,
132
+ architecture: architecture.clone(),
133
+ _chat_template: chat_template,
134
+ })
135
+ }
136
+
137
+ /// Detect architecture from GGUF metadata or model name
138
+ fn detect_architecture(content: &gguf_file::Content, model_id: &str) -> CandleResult<String> {
139
+ // First try to get from metadata
140
+ if let Some(gguf_file::Value::String(arch)) = content.metadata.get("general.architecture") {
141
+ return Ok(arch.clone());
142
+ }
143
+
144
+ // Fallback to model name detection
145
+ let model_lower = model_id.to_lowercase();
146
+ if model_lower.contains("llama") || model_lower.contains("tinyllama") {
147
+ Ok("llama".to_string())
148
+ } else if model_lower.contains("mistral") {
149
+ Ok("mistral".to_string())
150
+ } else if model_lower.contains("gemma") {
151
+ Ok("gemma".to_string())
152
+ } else {
153
+ Err(candle_core::Error::Msg(
154
+ "Could not determine model architecture from metadata or name".to_string()
155
+ ))
156
+ }
157
+ }
158
+
159
+ /// Download tokenizer from a specific source
160
+ async fn download_tokenizer_from_source(
161
+ api: &Api,
162
+ source: &str
163
+ ) -> CandleResult<std::path::PathBuf> {
164
+ // Check if it's a local file path
165
+ if source.ends_with(".json") && std::path::Path::new(source).exists() {
166
+ return Ok(std::path::PathBuf::from(source));
167
+ }
168
+
169
+ // Otherwise treat it as a HuggingFace repo
170
+ let repo = api.model(source.to_string());
171
+
172
+ // Try tokenizer.json first
173
+ if let Ok(path) = repo.get("tokenizer.json").await {
174
+ return Ok(path);
175
+ }
176
+
177
+ // Try tokenizer.model (for models that use sentencepiece)
178
+ if let Ok(path) = repo.get("tokenizer.model").await {
179
+ return Ok(path);
180
+ }
181
+
182
+ Err(candle_core::Error::Msg(format!(
183
+ "Failed to find tokenizer in specified source: {}",
184
+ source
185
+ )))
186
+ }
187
+
188
+ /// Download tokenizer with architecture-specific fallbacks
189
+ async fn download_tokenizer(
190
+ _api: &Api,
191
+ repo: &ApiRepo,
192
+ model_id: &str,
193
+ _architecture: &str
194
+ ) -> CandleResult<std::path::PathBuf> {
195
+ // First try to get tokenizer.json from the GGUF repo
196
+ if let Ok(path) = repo.get("tokenizer.json").await {
197
+ return Ok(path);
198
+ }
199
+
200
+ // Try tokenizer.model (for models that use sentencepiece)
201
+ if let Ok(path) = repo.get("tokenizer.model").await {
202
+ return Ok(path);
203
+ }
204
+
205
+ // If no tokenizer found in GGUF repo, return error
206
+ // Ruby will handle the fallback logic
207
+ Err(candle_core::Error::Msg(format!(
208
+ "No tokenizer found in GGUF repository {}. Please specify a tokenizer source.",
209
+ model_id
210
+ )))
211
+ }
212
+
213
+ /// Determine EOS token based on architecture and model
214
+ fn determine_eos_token(tokenizer: &Tokenizer, architecture: &str, model_id: &str) -> u32 {
215
+ let vocab = tokenizer.get_vocab(true);
216
+
217
+ match architecture {
218
+ "llama" | "mistral" => {
219
+ // Check if it's Llama 3
220
+ if model_id.contains("Llama-3") || model_id.contains("llama-3") {
221
+ vocab.get("<|eot_id|>")
222
+ .or_else(|| vocab.get("<|end_of_text|>"))
223
+ .copied()
224
+ .unwrap_or(128009)
225
+ } else {
226
+ // Llama 2 and Mistral
227
+ vocab.get("</s>")
228
+ .copied()
229
+ .unwrap_or(2)
230
+ }
231
+ }
232
+ "gemma" => {
233
+ vocab.get("<eos>")
234
+ .or_else(|| vocab.get("<end_of_turn>"))
235
+ .copied()
236
+ .unwrap_or(1)
237
+ }
238
+ _ => 2, // Default
239
+ }
240
+ }
241
+
242
+ /// Detect chat template based on model
243
+ fn detect_chat_template(_tokenizer: &Tokenizer, _architecture: &str, _model_id: &str) -> Option<String> {
244
+ // For now, return None and handle templates in apply_chat_template
245
+ // In the future, this could read from tokenizer config
246
+ None
247
+ }
248
+
249
+ /// Apply chat template based on detected architecture
250
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
251
+ // Check model name since Mistral GGUF reports as llama architecture
252
+ let model_lower = self.model_id.to_lowercase();
253
+
254
+ if model_lower.contains("mistral") {
255
+ self.apply_mistral_template(messages)
256
+ } else if model_lower.contains("gemma") {
257
+ // Always use Gemma template for Gemma models, regardless of loader used
258
+ self.apply_gemma_template(messages)
259
+ } else {
260
+ match self.architecture.as_str() {
261
+ "llama" => {
262
+ if self.model_id.contains("Llama-3") || self.model_id.contains("llama-3") {
263
+ self.apply_llama3_template(messages)
264
+ } else {
265
+ self.apply_llama2_template(messages)
266
+ }
267
+ }
268
+ "gemma" => {
269
+ self.apply_gemma_template(messages)
270
+ }
271
+ _ => Ok(self.apply_generic_template(messages))
272
+ }
273
+ }
274
+ }
275
+
276
+ fn apply_llama2_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
277
+ let mut prompt = String::new();
278
+ let mut system_message = String::new();
279
+
280
+ for (i, message) in messages.iter().enumerate() {
281
+ let role = message["role"].as_str().unwrap_or("");
282
+ let content = message["content"].as_str().unwrap_or("");
283
+
284
+ match role {
285
+ "system" => {
286
+ system_message = content.to_string();
287
+ }
288
+ "user" => {
289
+ if i == 1 || (i == 0 && system_message.is_empty()) {
290
+ if !system_message.is_empty() {
291
+ prompt.push_str(&format!("<s>[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
292
+ } else {
293
+ prompt.push_str(&format!("<s>[INST] {} [/INST]", content));
294
+ }
295
+ } else {
296
+ prompt.push_str(&format!(" [INST] {} [/INST]", content));
297
+ }
298
+ }
299
+ "assistant" => {
300
+ prompt.push_str(&format!(" {} </s>", content));
301
+ }
302
+ _ => {}
303
+ }
304
+ }
305
+
306
+ Ok(prompt)
307
+ }
308
+
309
+ fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
310
+ let mut prompt = String::new();
311
+ prompt.push_str("<|begin_of_text|>");
312
+
313
+ for message in messages {
314
+ let role = message["role"].as_str().unwrap_or("");
315
+ let content = message["content"].as_str().unwrap_or("");
316
+ prompt.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, content));
317
+ }
318
+
319
+ prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
320
+ Ok(prompt)
321
+ }
322
+
323
+ fn apply_mistral_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
324
+ let mut prompt = String::new();
325
+
326
+ for message in messages {
327
+ let role = message["role"].as_str().unwrap_or("");
328
+ let content = message["content"].as_str().unwrap_or("");
329
+
330
+ match role {
331
+ "user" => prompt.push_str(&format!("[INST] {} [/INST]", content)),
332
+ "assistant" => prompt.push_str(&format!(" {}</s>", content)),
333
+ "system" => prompt.push_str(&format!("[INST] {} [/INST]\n", content)),
334
+ _ => {}
335
+ }
336
+ }
337
+
338
+ Ok(prompt)
339
+ }
340
+
341
+ fn apply_gemma_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
342
+ let mut prompt = String::new();
343
+
344
+ for message in messages {
345
+ let role = message["role"].as_str().unwrap_or("");
346
+ let content = message["content"].as_str().unwrap_or("");
347
+
348
+ match role {
349
+ "system" => {
350
+ prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
351
+ }
352
+ "user" => {
353
+ if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
354
+ prompt.push_str("<start_of_turn>user\n");
355
+ }
356
+ prompt.push_str(&format!("{}<end_of_turn>\n", content));
357
+ }
358
+ "assistant" | "model" => {
359
+ prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
360
+ }
361
+ _ => {}
362
+ }
363
+ }
364
+
365
+ prompt.push_str("<start_of_turn>model\n");
366
+ Ok(prompt)
367
+ }
368
+
369
+ fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
370
+ let mut prompt = String::new();
371
+
372
+ for message in messages {
373
+ let role = message["role"].as_str().unwrap_or("");
374
+ let content = message["content"].as_str().unwrap_or("");
375
+ prompt.push_str(&format!("{}: {}\n", role, content));
376
+ }
377
+
378
+ prompt.push_str("assistant: ");
379
+ prompt
380
+ }
381
+
382
+ /// Clear the KV cache between generations
383
+ pub fn clear_kv_cache(&mut self) {
384
+ // Quantized models manage cache internally
385
+ }
386
+
387
+ fn generate_tokens(
388
+ &mut self,
389
+ prompt_tokens: Vec<u32>,
390
+ config: &GenerationConfig,
391
+ mut callback: Option<impl FnMut(&str)>,
392
+ ) -> CandleResult<Vec<u32>> {
393
+ let mut text_gen = TextGeneration::from_config(config);
394
+ text_gen.set_eos_token_id(self.eos_token_id);
395
+ text_gen.set_tokens(prompt_tokens.clone());
396
+
397
+ let mut all_tokens = prompt_tokens.clone();
398
+ let start_gen = all_tokens.len();
399
+
400
+ for index in 0..config.max_length {
401
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
402
+ let start_pos = all_tokens.len().saturating_sub(context_size);
403
+ let ctxt = &all_tokens[start_pos..];
404
+
405
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
406
+ let input = input.contiguous()?;
407
+
408
+ let logits = match &mut self.model {
409
+ ModelType::Llama(model) => model.forward(&input, start_pos)?,
410
+ ModelType::Gemma(model) => model.forward(&input, start_pos)?,
411
+ };
412
+
413
+ let logits = logits.squeeze(0)?;
414
+ let logits = if logits.dims().len() == 2 {
415
+ let seq_len = logits.dim(0)?;
416
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
417
+ } else {
418
+ logits
419
+ };
420
+
421
+ let logits = logits.to_dtype(DType::F32)?;
422
+
423
+ let next_token = text_gen.sample_next_token(
424
+ &logits,
425
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
426
+ )?;
427
+
428
+ all_tokens.push(next_token);
429
+
430
+ // Stream callback
431
+ if let Some(ref mut cb) = callback {
432
+ if config.debug_tokens {
433
+ // In debug mode, only show debug tokens
434
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
435
+ cb(&format!("[{}:{}]", next_token, token_piece));
436
+ } else {
437
+ // Normal mode: use incremental decoding for proper text
438
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
439
+ cb(&decoded_text);
440
+ }
441
+ }
442
+
443
+ // Check stop conditions
444
+ if text_gen.should_stop(next_token, config.max_length) {
445
+ break;
446
+ }
447
+
448
+ // Check stop sequences
449
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
450
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
451
+ break;
452
+ }
453
+ }
454
+
455
+ Ok(if config.include_prompt {
456
+ all_tokens
457
+ } else {
458
+ all_tokens[start_gen..].to_vec()
459
+ })
460
+ }
461
+ }
462
+
463
+ impl TextGenerator for QuantizedGGUF {
464
+ fn generate(
465
+ &mut self,
466
+ prompt: &str,
467
+ config: &GenerationConfig,
468
+ ) -> CandleResult<String> {
469
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
470
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
471
+
472
+ if config.debug_tokens {
473
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
474
+ } else {
475
+ self.tokenizer.decode(&output_tokens, true)
476
+ }
477
+ }
478
+
479
+ fn generate_stream(
480
+ &mut self,
481
+ prompt: &str,
482
+ config: &GenerationConfig,
483
+ mut callback: impl FnMut(&str),
484
+ ) -> CandleResult<String> {
485
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
486
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
487
+ self.tokenizer.decode(&output_tokens, true)
488
+ }
489
+
490
+ fn model_name(&self) -> &str {
491
+ &self.model_id
492
+ }
493
+
494
+ fn device(&self) -> &Device {
495
+ &self.device
496
+ }
497
+
498
+ fn clear_cache(&mut self) {
499
+ // Quantized models manage cache internally
500
+ }
501
+ }
@@ -1,13 +1,10 @@
1
1
  use candle_core::{Result as CandleResult, Tensor};
2
2
  use candle_transformers::generation::LogitsProcessor;
3
- use rand::{rngs::StdRng, SeedableRng};
4
3
 
5
4
  use super::GenerationConfig;
6
5
 
7
6
  /// Helper struct for text generation process
8
7
  pub struct TextGeneration {
9
- #[allow(dead_code)]
10
- rng: StdRng,
11
8
  logits_processor: LogitsProcessor,
12
9
  tokens: Vec<u32>,
13
10
  eos_token_id: Option<u32>,
@@ -25,7 +22,6 @@ impl TextGeneration {
25
22
  let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
26
23
 
27
24
  Self {
28
- rng: StdRng::seed_from_u64(seed),
29
25
  logits_processor,
30
26
  tokens: Vec::new(),
31
27
  eos_token_id: None,