red-candle 1.0.0.pre.5 → 1.0.0.pre.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +164 -4
- data/Rakefile +1 -3
- data/ext/candle/src/llm/gemma.rs +277 -0
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +16 -79
- data/ext/candle/src/llm/mistral.rs +16 -89
- data/ext/candle/src/llm/mod.rs +59 -0
- data/ext/candle/src/llm/quantized_gguf.rs +496 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ruby/embedding_model.rs +0 -1
- data/ext/candle/src/ruby/llm.rs +84 -29
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/version.rb +1 -1
- metadata +4 -2
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
3
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama};
|
4
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
5
|
use crate::ruby::{Result as RbResult, Device as RbDevice};
|
6
6
|
|
7
7
|
// Use an enum to handle different model types instead of trait objects
|
@@ -9,6 +9,8 @@ use crate::ruby::{Result as RbResult, Device as RbDevice};
|
|
9
9
|
enum ModelType {
|
10
10
|
Mistral(RustMistral),
|
11
11
|
Llama(RustLlama),
|
12
|
+
Gemma(RustGemma),
|
13
|
+
QuantizedGGUF(RustQuantizedGGUF),
|
12
14
|
}
|
13
15
|
|
14
16
|
impl ModelType {
|
@@ -16,6 +18,8 @@ impl ModelType {
|
|
16
18
|
match self {
|
17
19
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
18
20
|
ModelType::Llama(m) => m.generate(prompt, config),
|
21
|
+
ModelType::Gemma(m) => m.generate(prompt, config),
|
22
|
+
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
19
23
|
}
|
20
24
|
}
|
21
25
|
|
@@ -28,14 +32,8 @@ impl ModelType {
|
|
28
32
|
match self {
|
29
33
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
30
34
|
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
#[allow(dead_code)]
|
35
|
-
fn model_name(&self) -> &str {
|
36
|
-
match self {
|
37
|
-
ModelType::Mistral(m) => m.model_name(),
|
38
|
-
ModelType::Llama(m) => m.model_name(),
|
35
|
+
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
36
|
+
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
39
37
|
}
|
40
38
|
}
|
41
39
|
|
@@ -43,6 +41,8 @@ impl ModelType {
|
|
43
41
|
match self {
|
44
42
|
ModelType::Mistral(m) => m.clear_cache(),
|
45
43
|
ModelType::Llama(m) => m.clear_cache(),
|
44
|
+
ModelType::Gemma(m) => m.clear_cache(),
|
45
|
+
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
46
46
|
}
|
47
47
|
}
|
48
48
|
|
@@ -66,6 +66,8 @@ impl ModelType {
|
|
66
66
|
Ok(prompt)
|
67
67
|
},
|
68
68
|
ModelType::Llama(m) => m.apply_chat_template(messages),
|
69
|
+
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
70
|
+
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
69
71
|
}
|
70
72
|
}
|
71
73
|
}
|
@@ -138,6 +140,12 @@ impl GenerationConfig {
|
|
138
140
|
}
|
139
141
|
}
|
140
142
|
|
143
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
|
144
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
145
|
+
config.debug_tokens = v;
|
146
|
+
}
|
147
|
+
}
|
148
|
+
|
141
149
|
Ok(Self { inner: config })
|
142
150
|
}
|
143
151
|
|
@@ -179,6 +187,10 @@ impl GenerationConfig {
|
|
179
187
|
pub fn include_prompt(&self) -> bool {
|
180
188
|
self.inner.include_prompt
|
181
189
|
}
|
190
|
+
|
191
|
+
pub fn debug_tokens(&self) -> bool {
|
192
|
+
self.inner.debug_tokens
|
193
|
+
}
|
182
194
|
}
|
183
195
|
|
184
196
|
#[derive(Clone, Debug)]
|
@@ -200,25 +212,51 @@ impl LLM {
|
|
200
212
|
let rt = tokio::runtime::Runtime::new()
|
201
213
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
202
214
|
|
203
|
-
// Determine model type from ID and
|
215
|
+
// Determine model type from ID and whether it's quantized
|
204
216
|
let model_lower = model_id.to_lowercase();
|
205
|
-
let
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
217
|
+
let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
|
218
|
+
|
219
|
+
let model = if is_quantized {
|
220
|
+
// Extract tokenizer source if provided in model_id
|
221
|
+
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
222
|
+
let (id, _tok) = model_id.split_at(pos);
|
223
|
+
(id.to_string(), Some(&model_id[pos+2..]))
|
224
|
+
} else {
|
225
|
+
(model_id.clone(), None)
|
226
|
+
};
|
227
|
+
|
228
|
+
// Use unified GGUF loader for all quantized models
|
229
|
+
let gguf_model = rt.block_on(async {
|
230
|
+
RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
|
214
231
|
})
|
215
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
216
|
-
ModelType::
|
232
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
|
233
|
+
ModelType::QuantizedGGUF(gguf_model)
|
217
234
|
} else {
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
235
|
+
// Load non-quantized models
|
236
|
+
if model_lower.contains("mistral") {
|
237
|
+
let mistral = rt.block_on(async {
|
238
|
+
RustMistral::from_pretrained(&model_id, candle_device).await
|
239
|
+
})
|
240
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
241
|
+
ModelType::Mistral(mistral)
|
242
|
+
} else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
|
243
|
+
let llama = rt.block_on(async {
|
244
|
+
RustLlama::from_pretrained(&model_id, candle_device).await
|
245
|
+
})
|
246
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
247
|
+
ModelType::Llama(llama)
|
248
|
+
} else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
|
249
|
+
let gemma = rt.block_on(async {
|
250
|
+
RustGemma::from_pretrained(&model_id, candle_device).await
|
251
|
+
})
|
252
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
253
|
+
ModelType::Gemma(gemma)
|
254
|
+
} else {
|
255
|
+
return Err(Error::new(
|
256
|
+
magnus::exception::runtime_error(),
|
257
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
|
258
|
+
));
|
259
|
+
}
|
222
260
|
};
|
223
261
|
|
224
262
|
Ok(Self {
|
@@ -234,7 +272,10 @@ impl LLM {
|
|
234
272
|
.map(|c| c.inner.clone())
|
235
273
|
.unwrap_or_default();
|
236
274
|
|
237
|
-
let model = self.model.lock()
|
275
|
+
let model = match self.model.lock() {
|
276
|
+
Ok(guard) => guard,
|
277
|
+
Err(poisoned) => poisoned.into_inner(),
|
278
|
+
};
|
238
279
|
let mut model_ref = model.borrow_mut();
|
239
280
|
|
240
281
|
model_ref.generate(&prompt, &config)
|
@@ -254,7 +295,10 @@ impl LLM {
|
|
254
295
|
}
|
255
296
|
let block = block.unwrap();
|
256
297
|
|
257
|
-
let model = self.model.lock()
|
298
|
+
let model = match self.model.lock() {
|
299
|
+
Ok(guard) => guard,
|
300
|
+
Err(poisoned) => poisoned.into_inner(),
|
301
|
+
};
|
258
302
|
let mut model_ref = model.borrow_mut();
|
259
303
|
|
260
304
|
let result = model_ref.generate_stream(&prompt, &config, |token| {
|
@@ -277,7 +321,14 @@ impl LLM {
|
|
277
321
|
|
278
322
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
279
323
|
pub fn clear_cache(&self) -> RbResult<()> {
|
280
|
-
let model = self.model.lock()
|
324
|
+
let model = match self.model.lock() {
|
325
|
+
Ok(guard) => guard,
|
326
|
+
Err(poisoned) => {
|
327
|
+
// If the mutex is poisoned, we can still recover the data
|
328
|
+
// This happens when another thread panicked while holding the lock
|
329
|
+
poisoned.into_inner()
|
330
|
+
}
|
331
|
+
};
|
281
332
|
let mut model_ref = model.borrow_mut();
|
282
333
|
model_ref.clear_cache();
|
283
334
|
Ok(())
|
@@ -311,7 +362,10 @@ impl LLM {
|
|
311
362
|
})
|
312
363
|
.collect();
|
313
364
|
|
314
|
-
let model = self.model.lock()
|
365
|
+
let model = match self.model.lock() {
|
366
|
+
Ok(guard) => guard,
|
367
|
+
Err(poisoned) => poisoned.into_inner(),
|
368
|
+
};
|
315
369
|
let model_ref = model.borrow();
|
316
370
|
|
317
371
|
model_ref.apply_chat_template(&json_messages)
|
@@ -351,6 +405,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
|
351
405
|
rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
|
352
406
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
353
407
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
408
|
+
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
354
409
|
|
355
410
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
356
411
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
data/lib/candle/llm.rb
CHANGED
@@ -1,5 +1,65 @@
|
|
1
1
|
module Candle
|
2
2
|
class LLM
|
3
|
+
# Tokenizer registry for automatic detection
|
4
|
+
TOKENIZER_REGISTRY = {
|
5
|
+
# Exact model matches
|
6
|
+
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
|
7
|
+
"TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
|
8
|
+
"TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
|
9
|
+
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
10
|
+
|
11
|
+
# Pattern-based fallbacks (evaluated in order)
|
12
|
+
:patterns => [
|
13
|
+
# Mistral models
|
14
|
+
[/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
|
15
|
+
[/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
|
16
|
+
[/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
|
17
|
+
|
18
|
+
# Llama models
|
19
|
+
[/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
|
20
|
+
[/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
|
21
|
+
[/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
|
22
|
+
[/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
|
23
|
+
[/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
|
24
|
+
[/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
|
25
|
+
|
26
|
+
# Gemma models
|
27
|
+
[/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
|
28
|
+
[/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
|
29
|
+
[/gemma.*?7b/i, "google/gemma-7b"],
|
30
|
+
[/gemma.*?2b/i, "google/gemma-2b"]
|
31
|
+
]
|
32
|
+
}
|
33
|
+
|
34
|
+
# Allow users to register custom tokenizer mappings
|
35
|
+
def self.register_tokenizer(model_pattern, tokenizer_id)
|
36
|
+
if model_pattern.is_a?(String)
|
37
|
+
TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
|
38
|
+
elsif model_pattern.is_a?(Regexp)
|
39
|
+
TOKENIZER_REGISTRY[:patterns] ||= []
|
40
|
+
TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
|
41
|
+
else
|
42
|
+
raise ArgumentError, "model_pattern must be a String or Regexp"
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
# Guess the tokenizer for a model
|
47
|
+
def self.guess_tokenizer(model_id)
|
48
|
+
# Check exact matches first
|
49
|
+
return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
|
50
|
+
|
51
|
+
# Check patterns
|
52
|
+
if patterns = TOKENIZER_REGISTRY[:patterns]
|
53
|
+
patterns.each do |pattern, tokenizer|
|
54
|
+
return tokenizer if model_id.match?(pattern)
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
# Default: try removing common GGUF suffixes
|
59
|
+
base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
|
60
|
+
base_model
|
61
|
+
end
|
62
|
+
|
3
63
|
# Simple chat interface for instruction models
|
4
64
|
def chat(messages, **options)
|
5
65
|
prompt = apply_chat_template(messages)
|
@@ -28,8 +88,37 @@ module Candle
|
|
28
88
|
end
|
29
89
|
end
|
30
90
|
|
31
|
-
def self.from_pretrained(model_id, device: Candle::Device.cpu)
|
32
|
-
|
91
|
+
def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
|
92
|
+
model_str = if gguf_file
|
93
|
+
"#{model_id}@#{gguf_file}"
|
94
|
+
else
|
95
|
+
model_id
|
96
|
+
end
|
97
|
+
|
98
|
+
# Handle GGUF models that need tokenizer
|
99
|
+
if model_str.downcase.include?("gguf") && tokenizer.nil?
|
100
|
+
# Try to load without tokenizer first
|
101
|
+
begin
|
102
|
+
_from_pretrained(model_str, device)
|
103
|
+
rescue => e
|
104
|
+
if e.message.include?("No tokenizer found")
|
105
|
+
# Auto-detect tokenizer
|
106
|
+
detected_tokenizer = guess_tokenizer(model_id)
|
107
|
+
warn "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
|
108
|
+
model_str = "#{model_str}@@#{detected_tokenizer}"
|
109
|
+
_from_pretrained(model_str, device)
|
110
|
+
else
|
111
|
+
raise e
|
112
|
+
end
|
113
|
+
end
|
114
|
+
elsif tokenizer
|
115
|
+
# User specified tokenizer
|
116
|
+
model_str = "#{model_str}@@#{tokenizer}"
|
117
|
+
_from_pretrained(model_str, device)
|
118
|
+
else
|
119
|
+
# Non-GGUF model or GGUF with embedded tokenizer
|
120
|
+
_from_pretrained(model_str, device)
|
121
|
+
end
|
33
122
|
end
|
34
123
|
|
35
124
|
private
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.0.0.pre.
|
4
|
+
version: 1.0.0.pre.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -9,7 +9,7 @@ authors:
|
|
9
9
|
autorequire:
|
10
10
|
bindir: bin
|
11
11
|
cert_chain: []
|
12
|
-
date: 2025-07-
|
12
|
+
date: 2025-07-13 00:00:00.000000000 Z
|
13
13
|
dependencies:
|
14
14
|
- !ruby/object:Gem::Dependency
|
15
15
|
name: rb_sys
|
@@ -47,10 +47,12 @@ files:
|
|
47
47
|
- ext/candle/extconf.rb
|
48
48
|
- ext/candle/rustfmt.toml
|
49
49
|
- ext/candle/src/lib.rs
|
50
|
+
- ext/candle/src/llm/gemma.rs
|
50
51
|
- ext/candle/src/llm/generation_config.rs
|
51
52
|
- ext/candle/src/llm/llama.rs
|
52
53
|
- ext/candle/src/llm/mistral.rs
|
53
54
|
- ext/candle/src/llm/mod.rs
|
55
|
+
- ext/candle/src/llm/quantized_gguf.rs
|
54
56
|
- ext/candle/src/llm/text_generation.rs
|
55
57
|
- ext/candle/src/reranker.rs
|
56
58
|
- ext/candle/src/ruby/device.rs
|