red-candle 1.0.0.pre.6 → 1.0.0.pre.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +159 -0
- data/Rakefile +1 -3
- data/ext/candle/src/llm/gemma.rs +16 -79
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +16 -79
- data/ext/candle/src/llm/mistral.rs +16 -89
- data/ext/candle/src/llm/mod.rs +58 -0
- data/ext/candle/src/llm/quantized_gguf.rs +496 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ruby/embedding_model.rs +0 -1
- data/ext/candle/src/ruby/llm.rs +79 -36
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/version.rb +1 -1
- metadata +3 -2
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
3
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma};
|
4
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
5
|
use crate::ruby::{Result as RbResult, Device as RbDevice};
|
6
6
|
|
7
7
|
// Use an enum to handle different model types instead of trait objects
|
@@ -10,6 +10,7 @@ enum ModelType {
|
|
10
10
|
Mistral(RustMistral),
|
11
11
|
Llama(RustLlama),
|
12
12
|
Gemma(RustGemma),
|
13
|
+
QuantizedGGUF(RustQuantizedGGUF),
|
13
14
|
}
|
14
15
|
|
15
16
|
impl ModelType {
|
@@ -18,6 +19,7 @@ impl ModelType {
|
|
18
19
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
19
20
|
ModelType::Llama(m) => m.generate(prompt, config),
|
20
21
|
ModelType::Gemma(m) => m.generate(prompt, config),
|
22
|
+
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
21
23
|
}
|
22
24
|
}
|
23
25
|
|
@@ -31,15 +33,7 @@ impl ModelType {
|
|
31
33
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
32
34
|
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
33
35
|
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
34
|
-
|
35
|
-
}
|
36
|
-
|
37
|
-
#[allow(dead_code)]
|
38
|
-
fn model_name(&self) -> &str {
|
39
|
-
match self {
|
40
|
-
ModelType::Mistral(m) => m.model_name(),
|
41
|
-
ModelType::Llama(m) => m.model_name(),
|
42
|
-
ModelType::Gemma(m) => m.model_name(),
|
36
|
+
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
43
37
|
}
|
44
38
|
}
|
45
39
|
|
@@ -48,6 +42,7 @@ impl ModelType {
|
|
48
42
|
ModelType::Mistral(m) => m.clear_cache(),
|
49
43
|
ModelType::Llama(m) => m.clear_cache(),
|
50
44
|
ModelType::Gemma(m) => m.clear_cache(),
|
45
|
+
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
51
46
|
}
|
52
47
|
}
|
53
48
|
|
@@ -72,6 +67,7 @@ impl ModelType {
|
|
72
67
|
},
|
73
68
|
ModelType::Llama(m) => m.apply_chat_template(messages),
|
74
69
|
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
70
|
+
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
75
71
|
}
|
76
72
|
}
|
77
73
|
}
|
@@ -144,6 +140,12 @@ impl GenerationConfig {
|
|
144
140
|
}
|
145
141
|
}
|
146
142
|
|
143
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
|
144
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
145
|
+
config.debug_tokens = v;
|
146
|
+
}
|
147
|
+
}
|
148
|
+
|
147
149
|
Ok(Self { inner: config })
|
148
150
|
}
|
149
151
|
|
@@ -185,6 +187,10 @@ impl GenerationConfig {
|
|
185
187
|
pub fn include_prompt(&self) -> bool {
|
186
188
|
self.inner.include_prompt
|
187
189
|
}
|
190
|
+
|
191
|
+
pub fn debug_tokens(&self) -> bool {
|
192
|
+
self.inner.debug_tokens
|
193
|
+
}
|
188
194
|
}
|
189
195
|
|
190
196
|
#[derive(Clone, Debug)]
|
@@ -206,31 +212,51 @@ impl LLM {
|
|
206
212
|
let rt = tokio::runtime::Runtime::new()
|
207
213
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
208
214
|
|
209
|
-
// Determine model type from ID and
|
215
|
+
// Determine model type from ID and whether it's quantized
|
210
216
|
let model_lower = model_id.to_lowercase();
|
211
|
-
let
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
}
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
RustGemma::from_pretrained(&model_id, candle_device).await
|
217
|
+
let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
|
218
|
+
|
219
|
+
let model = if is_quantized {
|
220
|
+
// Extract tokenizer source if provided in model_id
|
221
|
+
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
222
|
+
let (id, _tok) = model_id.split_at(pos);
|
223
|
+
(id.to_string(), Some(&model_id[pos+2..]))
|
224
|
+
} else {
|
225
|
+
(model_id.clone(), None)
|
226
|
+
};
|
227
|
+
|
228
|
+
// Use unified GGUF loader for all quantized models
|
229
|
+
let gguf_model = rt.block_on(async {
|
230
|
+
RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
|
226
231
|
})
|
227
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
228
|
-
ModelType::
|
232
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
|
233
|
+
ModelType::QuantizedGGUF(gguf_model)
|
229
234
|
} else {
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
235
|
+
// Load non-quantized models
|
236
|
+
if model_lower.contains("mistral") {
|
237
|
+
let mistral = rt.block_on(async {
|
238
|
+
RustMistral::from_pretrained(&model_id, candle_device).await
|
239
|
+
})
|
240
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
241
|
+
ModelType::Mistral(mistral)
|
242
|
+
} else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
|
243
|
+
let llama = rt.block_on(async {
|
244
|
+
RustLlama::from_pretrained(&model_id, candle_device).await
|
245
|
+
})
|
246
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
247
|
+
ModelType::Llama(llama)
|
248
|
+
} else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
|
249
|
+
let gemma = rt.block_on(async {
|
250
|
+
RustGemma::from_pretrained(&model_id, candle_device).await
|
251
|
+
})
|
252
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
253
|
+
ModelType::Gemma(gemma)
|
254
|
+
} else {
|
255
|
+
return Err(Error::new(
|
256
|
+
magnus::exception::runtime_error(),
|
257
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
|
258
|
+
));
|
259
|
+
}
|
234
260
|
};
|
235
261
|
|
236
262
|
Ok(Self {
|
@@ -246,7 +272,10 @@ impl LLM {
|
|
246
272
|
.map(|c| c.inner.clone())
|
247
273
|
.unwrap_or_default();
|
248
274
|
|
249
|
-
let model = self.model.lock()
|
275
|
+
let model = match self.model.lock() {
|
276
|
+
Ok(guard) => guard,
|
277
|
+
Err(poisoned) => poisoned.into_inner(),
|
278
|
+
};
|
250
279
|
let mut model_ref = model.borrow_mut();
|
251
280
|
|
252
281
|
model_ref.generate(&prompt, &config)
|
@@ -266,7 +295,10 @@ impl LLM {
|
|
266
295
|
}
|
267
296
|
let block = block.unwrap();
|
268
297
|
|
269
|
-
let model = self.model.lock()
|
298
|
+
let model = match self.model.lock() {
|
299
|
+
Ok(guard) => guard,
|
300
|
+
Err(poisoned) => poisoned.into_inner(),
|
301
|
+
};
|
270
302
|
let mut model_ref = model.borrow_mut();
|
271
303
|
|
272
304
|
let result = model_ref.generate_stream(&prompt, &config, |token| {
|
@@ -289,7 +321,14 @@ impl LLM {
|
|
289
321
|
|
290
322
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
291
323
|
pub fn clear_cache(&self) -> RbResult<()> {
|
292
|
-
let model = self.model.lock()
|
324
|
+
let model = match self.model.lock() {
|
325
|
+
Ok(guard) => guard,
|
326
|
+
Err(poisoned) => {
|
327
|
+
// If the mutex is poisoned, we can still recover the data
|
328
|
+
// This happens when another thread panicked while holding the lock
|
329
|
+
poisoned.into_inner()
|
330
|
+
}
|
331
|
+
};
|
293
332
|
let mut model_ref = model.borrow_mut();
|
294
333
|
model_ref.clear_cache();
|
295
334
|
Ok(())
|
@@ -323,7 +362,10 @@ impl LLM {
|
|
323
362
|
})
|
324
363
|
.collect();
|
325
364
|
|
326
|
-
let model = self.model.lock()
|
365
|
+
let model = match self.model.lock() {
|
366
|
+
Ok(guard) => guard,
|
367
|
+
Err(poisoned) => poisoned.into_inner(),
|
368
|
+
};
|
327
369
|
let model_ref = model.borrow();
|
328
370
|
|
329
371
|
model_ref.apply_chat_template(&json_messages)
|
@@ -363,6 +405,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
|
363
405
|
rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
|
364
406
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
365
407
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
408
|
+
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
366
409
|
|
367
410
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
368
411
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
data/lib/candle/llm.rb
CHANGED
@@ -1,5 +1,65 @@
|
|
1
1
|
module Candle
|
2
2
|
class LLM
|
3
|
+
# Tokenizer registry for automatic detection
|
4
|
+
TOKENIZER_REGISTRY = {
|
5
|
+
# Exact model matches
|
6
|
+
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
|
7
|
+
"TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
|
8
|
+
"TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
|
9
|
+
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
10
|
+
|
11
|
+
# Pattern-based fallbacks (evaluated in order)
|
12
|
+
:patterns => [
|
13
|
+
# Mistral models
|
14
|
+
[/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
|
15
|
+
[/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
|
16
|
+
[/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
|
17
|
+
|
18
|
+
# Llama models
|
19
|
+
[/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
|
20
|
+
[/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
|
21
|
+
[/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
|
22
|
+
[/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
|
23
|
+
[/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
|
24
|
+
[/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
|
25
|
+
|
26
|
+
# Gemma models
|
27
|
+
[/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
|
28
|
+
[/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
|
29
|
+
[/gemma.*?7b/i, "google/gemma-7b"],
|
30
|
+
[/gemma.*?2b/i, "google/gemma-2b"]
|
31
|
+
]
|
32
|
+
}
|
33
|
+
|
34
|
+
# Allow users to register custom tokenizer mappings
|
35
|
+
def self.register_tokenizer(model_pattern, tokenizer_id)
|
36
|
+
if model_pattern.is_a?(String)
|
37
|
+
TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
|
38
|
+
elsif model_pattern.is_a?(Regexp)
|
39
|
+
TOKENIZER_REGISTRY[:patterns] ||= []
|
40
|
+
TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
|
41
|
+
else
|
42
|
+
raise ArgumentError, "model_pattern must be a String or Regexp"
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
# Guess the tokenizer for a model
|
47
|
+
def self.guess_tokenizer(model_id)
|
48
|
+
# Check exact matches first
|
49
|
+
return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
|
50
|
+
|
51
|
+
# Check patterns
|
52
|
+
if patterns = TOKENIZER_REGISTRY[:patterns]
|
53
|
+
patterns.each do |pattern, tokenizer|
|
54
|
+
return tokenizer if model_id.match?(pattern)
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
# Default: try removing common GGUF suffixes
|
59
|
+
base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
|
60
|
+
base_model
|
61
|
+
end
|
62
|
+
|
3
63
|
# Simple chat interface for instruction models
|
4
64
|
def chat(messages, **options)
|
5
65
|
prompt = apply_chat_template(messages)
|
@@ -28,8 +88,37 @@ module Candle
|
|
28
88
|
end
|
29
89
|
end
|
30
90
|
|
31
|
-
def self.from_pretrained(model_id, device: Candle::Device.cpu)
|
32
|
-
|
91
|
+
def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
|
92
|
+
model_str = if gguf_file
|
93
|
+
"#{model_id}@#{gguf_file}"
|
94
|
+
else
|
95
|
+
model_id
|
96
|
+
end
|
97
|
+
|
98
|
+
# Handle GGUF models that need tokenizer
|
99
|
+
if model_str.downcase.include?("gguf") && tokenizer.nil?
|
100
|
+
# Try to load without tokenizer first
|
101
|
+
begin
|
102
|
+
_from_pretrained(model_str, device)
|
103
|
+
rescue => e
|
104
|
+
if e.message.include?("No tokenizer found")
|
105
|
+
# Auto-detect tokenizer
|
106
|
+
detected_tokenizer = guess_tokenizer(model_id)
|
107
|
+
warn "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
|
108
|
+
model_str = "#{model_str}@@#{detected_tokenizer}"
|
109
|
+
_from_pretrained(model_str, device)
|
110
|
+
else
|
111
|
+
raise e
|
112
|
+
end
|
113
|
+
end
|
114
|
+
elsif tokenizer
|
115
|
+
# User specified tokenizer
|
116
|
+
model_str = "#{model_str}@@#{tokenizer}"
|
117
|
+
_from_pretrained(model_str, device)
|
118
|
+
else
|
119
|
+
# Non-GGUF model or GGUF with embedded tokenizer
|
120
|
+
_from_pretrained(model_str, device)
|
121
|
+
end
|
33
122
|
end
|
34
123
|
|
35
124
|
private
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.0.0.pre.
|
4
|
+
version: 1.0.0.pre.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -9,7 +9,7 @@ authors:
|
|
9
9
|
autorequire:
|
10
10
|
bindir: bin
|
11
11
|
cert_chain: []
|
12
|
-
date: 2025-07-
|
12
|
+
date: 2025-07-13 00:00:00.000000000 Z
|
13
13
|
dependencies:
|
14
14
|
- !ruby/object:Gem::Dependency
|
15
15
|
name: rb_sys
|
@@ -52,6 +52,7 @@ files:
|
|
52
52
|
- ext/candle/src/llm/llama.rs
|
53
53
|
- ext/candle/src/llm/mistral.rs
|
54
54
|
- ext/candle/src/llm/mod.rs
|
55
|
+
- ext/candle/src/llm/quantized_gguf.rs
|
55
56
|
- ext/candle/src/llm/text_generation.rs
|
56
57
|
- ext/candle/src/reranker.rs
|
57
58
|
- ext/candle/src/ruby/device.rs
|