red-candle 1.0.0.pre.6 → 1.0.0.pre.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
2
  use std::cell::RefCell;
3
3
 
4
- use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma};
4
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
5
5
  use crate::ruby::{Result as RbResult, Device as RbDevice};
6
6
 
7
7
  // Use an enum to handle different model types instead of trait objects
@@ -10,6 +10,7 @@ enum ModelType {
10
10
  Mistral(RustMistral),
11
11
  Llama(RustLlama),
12
12
  Gemma(RustGemma),
13
+ QuantizedGGUF(RustQuantizedGGUF),
13
14
  }
14
15
 
15
16
  impl ModelType {
@@ -18,6 +19,7 @@ impl ModelType {
18
19
  ModelType::Mistral(m) => m.generate(prompt, config),
19
20
  ModelType::Llama(m) => m.generate(prompt, config),
20
21
  ModelType::Gemma(m) => m.generate(prompt, config),
22
+ ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
21
23
  }
22
24
  }
23
25
 
@@ -31,15 +33,7 @@ impl ModelType {
31
33
  ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
32
34
  ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
33
35
  ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
34
- }
35
- }
36
-
37
- #[allow(dead_code)]
38
- fn model_name(&self) -> &str {
39
- match self {
40
- ModelType::Mistral(m) => m.model_name(),
41
- ModelType::Llama(m) => m.model_name(),
42
- ModelType::Gemma(m) => m.model_name(),
36
+ ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
43
37
  }
44
38
  }
45
39
 
@@ -48,6 +42,7 @@ impl ModelType {
48
42
  ModelType::Mistral(m) => m.clear_cache(),
49
43
  ModelType::Llama(m) => m.clear_cache(),
50
44
  ModelType::Gemma(m) => m.clear_cache(),
45
+ ModelType::QuantizedGGUF(m) => m.clear_cache(),
51
46
  }
52
47
  }
53
48
 
@@ -72,6 +67,7 @@ impl ModelType {
72
67
  },
73
68
  ModelType::Llama(m) => m.apply_chat_template(messages),
74
69
  ModelType::Gemma(m) => m.apply_chat_template(messages),
70
+ ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
75
71
  }
76
72
  }
77
73
  }
@@ -144,6 +140,12 @@ impl GenerationConfig {
144
140
  }
145
141
  }
146
142
 
143
+ if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
144
+ if let Ok(v) = TryConvert::try_convert(value) {
145
+ config.debug_tokens = v;
146
+ }
147
+ }
148
+
147
149
  Ok(Self { inner: config })
148
150
  }
149
151
 
@@ -185,6 +187,10 @@ impl GenerationConfig {
185
187
  pub fn include_prompt(&self) -> bool {
186
188
  self.inner.include_prompt
187
189
  }
190
+
191
+ pub fn debug_tokens(&self) -> bool {
192
+ self.inner.debug_tokens
193
+ }
188
194
  }
189
195
 
190
196
  #[derive(Clone, Debug)]
@@ -206,31 +212,51 @@ impl LLM {
206
212
  let rt = tokio::runtime::Runtime::new()
207
213
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
208
214
 
209
- // Determine model type from ID and load appropriately
215
+ // Determine model type from ID and whether it's quantized
210
216
  let model_lower = model_id.to_lowercase();
211
- let model = if model_lower.contains("mistral") {
212
- let mistral = rt.block_on(async {
213
- RustMistral::from_pretrained(&model_id, candle_device).await
214
- })
215
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
216
- ModelType::Mistral(mistral)
217
- } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
218
- let llama = rt.block_on(async {
219
- RustLlama::from_pretrained(&model_id, candle_device).await
220
- })
221
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
222
- ModelType::Llama(llama)
223
- } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
224
- let gemma = rt.block_on(async {
225
- RustGemma::from_pretrained(&model_id, candle_device).await
217
+ let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
218
+
219
+ let model = if is_quantized {
220
+ // Extract tokenizer source if provided in model_id
221
+ let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
222
+ let (id, _tok) = model_id.split_at(pos);
223
+ (id.to_string(), Some(&model_id[pos+2..]))
224
+ } else {
225
+ (model_id.clone(), None)
226
+ };
227
+
228
+ // Use unified GGUF loader for all quantized models
229
+ let gguf_model = rt.block_on(async {
230
+ RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
226
231
  })
227
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
228
- ModelType::Gemma(gemma)
232
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
233
+ ModelType::QuantizedGGUF(gguf_model)
229
234
  } else {
230
- return Err(Error::new(
231
- magnus::exception::runtime_error(),
232
- format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
233
- ));
235
+ // Load non-quantized models
236
+ if model_lower.contains("mistral") {
237
+ let mistral = rt.block_on(async {
238
+ RustMistral::from_pretrained(&model_id, candle_device).await
239
+ })
240
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
241
+ ModelType::Mistral(mistral)
242
+ } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
243
+ let llama = rt.block_on(async {
244
+ RustLlama::from_pretrained(&model_id, candle_device).await
245
+ })
246
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
247
+ ModelType::Llama(llama)
248
+ } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
249
+ let gemma = rt.block_on(async {
250
+ RustGemma::from_pretrained(&model_id, candle_device).await
251
+ })
252
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
253
+ ModelType::Gemma(gemma)
254
+ } else {
255
+ return Err(Error::new(
256
+ magnus::exception::runtime_error(),
257
+ format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
258
+ ));
259
+ }
234
260
  };
235
261
 
236
262
  Ok(Self {
@@ -246,7 +272,10 @@ impl LLM {
246
272
  .map(|c| c.inner.clone())
247
273
  .unwrap_or_default();
248
274
 
249
- let model = self.model.lock().unwrap();
275
+ let model = match self.model.lock() {
276
+ Ok(guard) => guard,
277
+ Err(poisoned) => poisoned.into_inner(),
278
+ };
250
279
  let mut model_ref = model.borrow_mut();
251
280
 
252
281
  model_ref.generate(&prompt, &config)
@@ -266,7 +295,10 @@ impl LLM {
266
295
  }
267
296
  let block = block.unwrap();
268
297
 
269
- let model = self.model.lock().unwrap();
298
+ let model = match self.model.lock() {
299
+ Ok(guard) => guard,
300
+ Err(poisoned) => poisoned.into_inner(),
301
+ };
270
302
  let mut model_ref = model.borrow_mut();
271
303
 
272
304
  let result = model_ref.generate_stream(&prompt, &config, |token| {
@@ -289,7 +321,14 @@ impl LLM {
289
321
 
290
322
  /// Clear the model's cache (e.g., KV cache for transformers)
291
323
  pub fn clear_cache(&self) -> RbResult<()> {
292
- let model = self.model.lock().unwrap();
324
+ let model = match self.model.lock() {
325
+ Ok(guard) => guard,
326
+ Err(poisoned) => {
327
+ // If the mutex is poisoned, we can still recover the data
328
+ // This happens when another thread panicked while holding the lock
329
+ poisoned.into_inner()
330
+ }
331
+ };
293
332
  let mut model_ref = model.borrow_mut();
294
333
  model_ref.clear_cache();
295
334
  Ok(())
@@ -323,7 +362,10 @@ impl LLM {
323
362
  })
324
363
  .collect();
325
364
 
326
- let model = self.model.lock().unwrap();
365
+ let model = match self.model.lock() {
366
+ Ok(guard) => guard,
367
+ Err(poisoned) => poisoned.into_inner(),
368
+ };
327
369
  let model_ref = model.borrow();
328
370
 
329
371
  model_ref.apply_chat_template(&json_messages)
@@ -363,6 +405,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
363
405
  rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
364
406
  rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
365
407
  rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
408
+ rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
366
409
 
367
410
  let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
368
411
  rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
data/lib/candle/llm.rb CHANGED
@@ -1,5 +1,65 @@
1
1
  module Candle
2
2
  class LLM
3
+ # Tokenizer registry for automatic detection
4
+ TOKENIZER_REGISTRY = {
5
+ # Exact model matches
6
+ "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
7
+ "TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
8
+ "TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
9
+ "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
10
+
11
+ # Pattern-based fallbacks (evaluated in order)
12
+ :patterns => [
13
+ # Mistral models
14
+ [/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
15
+ [/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
16
+ [/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
17
+
18
+ # Llama models
19
+ [/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
20
+ [/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
21
+ [/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
22
+ [/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
23
+ [/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
24
+ [/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
25
+
26
+ # Gemma models
27
+ [/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
28
+ [/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
29
+ [/gemma.*?7b/i, "google/gemma-7b"],
30
+ [/gemma.*?2b/i, "google/gemma-2b"]
31
+ ]
32
+ }
33
+
34
+ # Allow users to register custom tokenizer mappings
35
+ def self.register_tokenizer(model_pattern, tokenizer_id)
36
+ if model_pattern.is_a?(String)
37
+ TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
38
+ elsif model_pattern.is_a?(Regexp)
39
+ TOKENIZER_REGISTRY[:patterns] ||= []
40
+ TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
41
+ else
42
+ raise ArgumentError, "model_pattern must be a String or Regexp"
43
+ end
44
+ end
45
+
46
+ # Guess the tokenizer for a model
47
+ def self.guess_tokenizer(model_id)
48
+ # Check exact matches first
49
+ return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
50
+
51
+ # Check patterns
52
+ if patterns = TOKENIZER_REGISTRY[:patterns]
53
+ patterns.each do |pattern, tokenizer|
54
+ return tokenizer if model_id.match?(pattern)
55
+ end
56
+ end
57
+
58
+ # Default: try removing common GGUF suffixes
59
+ base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
60
+ base_model
61
+ end
62
+
3
63
  # Simple chat interface for instruction models
4
64
  def chat(messages, **options)
5
65
  prompt = apply_chat_template(messages)
@@ -28,8 +88,37 @@ module Candle
28
88
  end
29
89
  end
30
90
 
31
- def self.from_pretrained(model_id, device: Candle::Device.cpu)
32
- _from_pretrained(model_id, device)
91
+ def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
92
+ model_str = if gguf_file
93
+ "#{model_id}@#{gguf_file}"
94
+ else
95
+ model_id
96
+ end
97
+
98
+ # Handle GGUF models that need tokenizer
99
+ if model_str.downcase.include?("gguf") && tokenizer.nil?
100
+ # Try to load without tokenizer first
101
+ begin
102
+ _from_pretrained(model_str, device)
103
+ rescue => e
104
+ if e.message.include?("No tokenizer found")
105
+ # Auto-detect tokenizer
106
+ detected_tokenizer = guess_tokenizer(model_id)
107
+ warn "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
108
+ model_str = "#{model_str}@@#{detected_tokenizer}"
109
+ _from_pretrained(model_str, device)
110
+ else
111
+ raise e
112
+ end
113
+ end
114
+ elsif tokenizer
115
+ # User specified tokenizer
116
+ model_str = "#{model_str}@@#{tokenizer}"
117
+ _from_pretrained(model_str, device)
118
+ else
119
+ # Non-GGUF model or GGUF with embedded tokenizer
120
+ _from_pretrained(model_str, device)
121
+ end
33
122
  end
34
123
 
35
124
  private
@@ -1,3 +1,3 @@
1
1
  module Candle
2
- VERSION = "1.0.0.pre.6"
2
+ VERSION = "1.0.0.pre.7"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.0.pre.6
4
+ version: 1.0.0.pre.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-07-10 00:00:00.000000000 Z
12
+ date: 2025-07-13 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -52,6 +52,7 @@ files:
52
52
  - ext/candle/src/llm/llama.rs
53
53
  - ext/candle/src/llm/mistral.rs
54
54
  - ext/candle/src/llm/mod.rs
55
+ - ext/candle/src/llm/quantized_gguf.rs
55
56
  - ext/candle/src/llm/text_generation.rs
56
57
  - ext/candle/src/reranker.rs
57
58
  - ext/candle/src/ruby/device.rs