red-candle 1.0.0.pre.5 → 1.0.0.pre.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
2
  use std::cell::RefCell;
3
3
 
4
- use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama};
4
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
5
5
  use crate::ruby::{Result as RbResult, Device as RbDevice};
6
6
 
7
7
  // Use an enum to handle different model types instead of trait objects
@@ -9,6 +9,8 @@ use crate::ruby::{Result as RbResult, Device as RbDevice};
9
9
  enum ModelType {
10
10
  Mistral(RustMistral),
11
11
  Llama(RustLlama),
12
+ Gemma(RustGemma),
13
+ QuantizedGGUF(RustQuantizedGGUF),
12
14
  }
13
15
 
14
16
  impl ModelType {
@@ -16,6 +18,8 @@ impl ModelType {
16
18
  match self {
17
19
  ModelType::Mistral(m) => m.generate(prompt, config),
18
20
  ModelType::Llama(m) => m.generate(prompt, config),
21
+ ModelType::Gemma(m) => m.generate(prompt, config),
22
+ ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
19
23
  }
20
24
  }
21
25
 
@@ -28,14 +32,8 @@ impl ModelType {
28
32
  match self {
29
33
  ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
30
34
  ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
31
- }
32
- }
33
-
34
- #[allow(dead_code)]
35
- fn model_name(&self) -> &str {
36
- match self {
37
- ModelType::Mistral(m) => m.model_name(),
38
- ModelType::Llama(m) => m.model_name(),
35
+ ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
36
+ ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
39
37
  }
40
38
  }
41
39
 
@@ -43,6 +41,8 @@ impl ModelType {
43
41
  match self {
44
42
  ModelType::Mistral(m) => m.clear_cache(),
45
43
  ModelType::Llama(m) => m.clear_cache(),
44
+ ModelType::Gemma(m) => m.clear_cache(),
45
+ ModelType::QuantizedGGUF(m) => m.clear_cache(),
46
46
  }
47
47
  }
48
48
 
@@ -66,6 +66,8 @@ impl ModelType {
66
66
  Ok(prompt)
67
67
  },
68
68
  ModelType::Llama(m) => m.apply_chat_template(messages),
69
+ ModelType::Gemma(m) => m.apply_chat_template(messages),
70
+ ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
69
71
  }
70
72
  }
71
73
  }
@@ -138,6 +140,12 @@ impl GenerationConfig {
138
140
  }
139
141
  }
140
142
 
143
+ if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
144
+ if let Ok(v) = TryConvert::try_convert(value) {
145
+ config.debug_tokens = v;
146
+ }
147
+ }
148
+
141
149
  Ok(Self { inner: config })
142
150
  }
143
151
 
@@ -179,6 +187,10 @@ impl GenerationConfig {
179
187
  pub fn include_prompt(&self) -> bool {
180
188
  self.inner.include_prompt
181
189
  }
190
+
191
+ pub fn debug_tokens(&self) -> bool {
192
+ self.inner.debug_tokens
193
+ }
182
194
  }
183
195
 
184
196
  #[derive(Clone, Debug)]
@@ -200,25 +212,51 @@ impl LLM {
200
212
  let rt = tokio::runtime::Runtime::new()
201
213
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
202
214
 
203
- // Determine model type from ID and load appropriately
215
+ // Determine model type from ID and whether it's quantized
204
216
  let model_lower = model_id.to_lowercase();
205
- let model = if model_lower.contains("mistral") {
206
- let mistral = rt.block_on(async {
207
- RustMistral::from_pretrained(&model_id, candle_device).await
208
- })
209
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
210
- ModelType::Mistral(mistral)
211
- } else if model_lower.contains("llama") || model_lower.contains("meta-llama") {
212
- let llama = rt.block_on(async {
213
- RustLlama::from_pretrained(&model_id, candle_device).await
217
+ let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
218
+
219
+ let model = if is_quantized {
220
+ // Extract tokenizer source if provided in model_id
221
+ let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
222
+ let (id, _tok) = model_id.split_at(pos);
223
+ (id.to_string(), Some(&model_id[pos+2..]))
224
+ } else {
225
+ (model_id.clone(), None)
226
+ };
227
+
228
+ // Use unified GGUF loader for all quantized models
229
+ let gguf_model = rt.block_on(async {
230
+ RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
214
231
  })
215
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
216
- ModelType::Llama(llama)
232
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
233
+ ModelType::QuantizedGGUF(gguf_model)
217
234
  } else {
218
- return Err(Error::new(
219
- magnus::exception::runtime_error(),
220
- format!("Unsupported model type: {}. Currently only Mistral and Llama models are supported.", model_id),
221
- ));
235
+ // Load non-quantized models
236
+ if model_lower.contains("mistral") {
237
+ let mistral = rt.block_on(async {
238
+ RustMistral::from_pretrained(&model_id, candle_device).await
239
+ })
240
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
241
+ ModelType::Mistral(mistral)
242
+ } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
243
+ let llama = rt.block_on(async {
244
+ RustLlama::from_pretrained(&model_id, candle_device).await
245
+ })
246
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
247
+ ModelType::Llama(llama)
248
+ } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
249
+ let gemma = rt.block_on(async {
250
+ RustGemma::from_pretrained(&model_id, candle_device).await
251
+ })
252
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
253
+ ModelType::Gemma(gemma)
254
+ } else {
255
+ return Err(Error::new(
256
+ magnus::exception::runtime_error(),
257
+ format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
258
+ ));
259
+ }
222
260
  };
223
261
 
224
262
  Ok(Self {
@@ -234,7 +272,10 @@ impl LLM {
234
272
  .map(|c| c.inner.clone())
235
273
  .unwrap_or_default();
236
274
 
237
- let model = self.model.lock().unwrap();
275
+ let model = match self.model.lock() {
276
+ Ok(guard) => guard,
277
+ Err(poisoned) => poisoned.into_inner(),
278
+ };
238
279
  let mut model_ref = model.borrow_mut();
239
280
 
240
281
  model_ref.generate(&prompt, &config)
@@ -254,7 +295,10 @@ impl LLM {
254
295
  }
255
296
  let block = block.unwrap();
256
297
 
257
- let model = self.model.lock().unwrap();
298
+ let model = match self.model.lock() {
299
+ Ok(guard) => guard,
300
+ Err(poisoned) => poisoned.into_inner(),
301
+ };
258
302
  let mut model_ref = model.borrow_mut();
259
303
 
260
304
  let result = model_ref.generate_stream(&prompt, &config, |token| {
@@ -277,7 +321,14 @@ impl LLM {
277
321
 
278
322
  /// Clear the model's cache (e.g., KV cache for transformers)
279
323
  pub fn clear_cache(&self) -> RbResult<()> {
280
- let model = self.model.lock().unwrap();
324
+ let model = match self.model.lock() {
325
+ Ok(guard) => guard,
326
+ Err(poisoned) => {
327
+ // If the mutex is poisoned, we can still recover the data
328
+ // This happens when another thread panicked while holding the lock
329
+ poisoned.into_inner()
330
+ }
331
+ };
281
332
  let mut model_ref = model.borrow_mut();
282
333
  model_ref.clear_cache();
283
334
  Ok(())
@@ -311,7 +362,10 @@ impl LLM {
311
362
  })
312
363
  .collect();
313
364
 
314
- let model = self.model.lock().unwrap();
365
+ let model = match self.model.lock() {
366
+ Ok(guard) => guard,
367
+ Err(poisoned) => poisoned.into_inner(),
368
+ };
315
369
  let model_ref = model.borrow();
316
370
 
317
371
  model_ref.apply_chat_template(&json_messages)
@@ -351,6 +405,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
351
405
  rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
352
406
  rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
353
407
  rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
408
+ rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
354
409
 
355
410
  let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
356
411
  rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
data/lib/candle/llm.rb CHANGED
@@ -1,5 +1,65 @@
1
1
  module Candle
2
2
  class LLM
3
+ # Tokenizer registry for automatic detection
4
+ TOKENIZER_REGISTRY = {
5
+ # Exact model matches
6
+ "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
7
+ "TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
8
+ "TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
9
+ "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
10
+
11
+ # Pattern-based fallbacks (evaluated in order)
12
+ :patterns => [
13
+ # Mistral models
14
+ [/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
15
+ [/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
16
+ [/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
17
+
18
+ # Llama models
19
+ [/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
20
+ [/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
21
+ [/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
22
+ [/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
23
+ [/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
24
+ [/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
25
+
26
+ # Gemma models
27
+ [/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
28
+ [/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
29
+ [/gemma.*?7b/i, "google/gemma-7b"],
30
+ [/gemma.*?2b/i, "google/gemma-2b"]
31
+ ]
32
+ }
33
+
34
+ # Allow users to register custom tokenizer mappings
35
+ def self.register_tokenizer(model_pattern, tokenizer_id)
36
+ if model_pattern.is_a?(String)
37
+ TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
38
+ elsif model_pattern.is_a?(Regexp)
39
+ TOKENIZER_REGISTRY[:patterns] ||= []
40
+ TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
41
+ else
42
+ raise ArgumentError, "model_pattern must be a String or Regexp"
43
+ end
44
+ end
45
+
46
+ # Guess the tokenizer for a model
47
+ def self.guess_tokenizer(model_id)
48
+ # Check exact matches first
49
+ return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
50
+
51
+ # Check patterns
52
+ if patterns = TOKENIZER_REGISTRY[:patterns]
53
+ patterns.each do |pattern, tokenizer|
54
+ return tokenizer if model_id.match?(pattern)
55
+ end
56
+ end
57
+
58
+ # Default: try removing common GGUF suffixes
59
+ base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
60
+ base_model
61
+ end
62
+
3
63
  # Simple chat interface for instruction models
4
64
  def chat(messages, **options)
5
65
  prompt = apply_chat_template(messages)
@@ -28,8 +88,37 @@ module Candle
28
88
  end
29
89
  end
30
90
 
31
- def self.from_pretrained(model_id, device: Candle::Device.cpu)
32
- _from_pretrained(model_id, device)
91
+ def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
92
+ model_str = if gguf_file
93
+ "#{model_id}@#{gguf_file}"
94
+ else
95
+ model_id
96
+ end
97
+
98
+ # Handle GGUF models that need tokenizer
99
+ if model_str.downcase.include?("gguf") && tokenizer.nil?
100
+ # Try to load without tokenizer first
101
+ begin
102
+ _from_pretrained(model_str, device)
103
+ rescue => e
104
+ if e.message.include?("No tokenizer found")
105
+ # Auto-detect tokenizer
106
+ detected_tokenizer = guess_tokenizer(model_id)
107
+ warn "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
108
+ model_str = "#{model_str}@@#{detected_tokenizer}"
109
+ _from_pretrained(model_str, device)
110
+ else
111
+ raise e
112
+ end
113
+ end
114
+ elsif tokenizer
115
+ # User specified tokenizer
116
+ model_str = "#{model_str}@@#{tokenizer}"
117
+ _from_pretrained(model_str, device)
118
+ else
119
+ # Non-GGUF model or GGUF with embedded tokenizer
120
+ _from_pretrained(model_str, device)
121
+ end
33
122
  end
34
123
 
35
124
  private
@@ -1,3 +1,3 @@
1
1
  module Candle
2
- VERSION = "1.0.0.pre.5"
2
+ VERSION = "1.0.0.pre.7"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.0.pre.5
4
+ version: 1.0.0.pre.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-07-09 00:00:00.000000000 Z
12
+ date: 2025-07-13 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -47,10 +47,12 @@ files:
47
47
  - ext/candle/extconf.rb
48
48
  - ext/candle/rustfmt.toml
49
49
  - ext/candle/src/lib.rs
50
+ - ext/candle/src/llm/gemma.rs
50
51
  - ext/candle/src/llm/generation_config.rs
51
52
  - ext/candle/src/llm/llama.rs
52
53
  - ext/candle/src/llm/mistral.rs
53
54
  - ext/candle/src/llm/mod.rs
55
+ - ext/candle/src/llm/quantized_gguf.rs
54
56
  - ext/candle/src/llm/text_generation.rs
55
57
  - ext/candle/src/reranker.rs
56
58
  - ext/candle/src/ruby/device.rs