red-candle 1.0.2 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +38 -3
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +316 -0
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +11 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +301 -0
- data/ext/candle/src/llm/quantized_gguf.rs +173 -9
- data/ext/candle/src/llm/qwen.rs +245 -0
- data/ext/candle/src/llm/text_generation.rs +183 -26
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +119 -55
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/llm.rb +203 -2
- data/lib/candle/version.rb +1 -1
- metadata +14 -4
data/ext/candle/src/reranker.rs
CHANGED
@@ -4,7 +4,6 @@ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
4
4
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
5
5
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
6
|
use tokenizers::{EncodeInput, Tokenizer};
|
7
|
-
use std::thread;
|
8
7
|
use crate::ruby::{Device, Result};
|
9
8
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
10
9
|
|
@@ -24,8 +23,7 @@ impl Reranker {
|
|
24
23
|
}
|
25
24
|
|
26
25
|
fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
27
|
-
let
|
28
|
-
let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
26
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
29
27
|
let api = Api::new()?;
|
30
28
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
31
29
|
|
@@ -44,7 +42,7 @@ impl Reranker {
|
|
44
42
|
|
45
43
|
// Load model weights
|
46
44
|
let vb = unsafe {
|
47
|
-
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &
|
45
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
48
46
|
};
|
49
47
|
|
50
48
|
// Load BERT model
|
@@ -57,17 +55,49 @@ impl Reranker {
|
|
57
55
|
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
58
56
|
|
59
57
|
Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
|
60
|
-
});
|
58
|
+
})();
|
61
59
|
|
62
|
-
match
|
63
|
-
Ok(
|
60
|
+
match result {
|
61
|
+
Ok((model, tokenizer, pooler, classifier)) => {
|
64
62
|
Ok(Self { model, tokenizer, pooler, classifier, device })
|
65
63
|
}
|
66
|
-
|
67
|
-
Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
|
64
|
+
Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
|
68
65
|
}
|
69
66
|
}
|
70
67
|
|
68
|
+
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
69
|
+
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
|
70
|
+
let cls_embeddings = if self.device.is_metal() {
|
71
|
+
// Metal has issues with tensor indexing, use a different approach
|
72
|
+
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
73
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
74
|
+
|
75
|
+
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
76
|
+
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
77
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
78
|
+
|
79
|
+
// Extract CLS tokens (first token of each sequence)
|
80
|
+
let mut cls_vecs = Vec::new();
|
81
|
+
for i in 0..batch_size {
|
82
|
+
let start_idx = i * seq_len;
|
83
|
+
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
84
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
85
|
+
cls_vecs.push(cls_vec);
|
86
|
+
}
|
87
|
+
|
88
|
+
// Stack the CLS vectors
|
89
|
+
Tensor::cat(&cls_vecs, 0)
|
90
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
91
|
+
} else {
|
92
|
+
embeddings.i((.., 0))
|
93
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
94
|
+
};
|
95
|
+
|
96
|
+
// Ensure tensor is contiguous for downstream operations
|
97
|
+
cls_embeddings.contiguous()
|
98
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
|
99
|
+
}
|
100
|
+
|
71
101
|
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
|
72
102
|
// Create query-document pair for cross-encoder
|
73
103
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
@@ -131,37 +161,7 @@ impl Reranker {
|
|
131
161
|
let pooled_embeddings = match pooling_method.as_str() {
|
132
162
|
"pooler" => {
|
133
163
|
// Extract [CLS] token and apply pooler (dense + tanh)
|
134
|
-
|
135
|
-
let cls_embeddings = if self.device.is_metal() {
|
136
|
-
// Metal has issues with tensor indexing, use a different approach
|
137
|
-
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
138
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
139
|
-
|
140
|
-
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
141
|
-
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
142
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
143
|
-
|
144
|
-
// Extract CLS tokens (first token of each sequence)
|
145
|
-
let mut cls_vecs = Vec::new();
|
146
|
-
for i in 0..batch_size {
|
147
|
-
let start_idx = i * _seq_len;
|
148
|
-
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
149
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
150
|
-
cls_vecs.push(cls_vec);
|
151
|
-
}
|
152
|
-
|
153
|
-
// Stack the CLS vectors
|
154
|
-
Tensor::cat(&cls_vecs, 0)
|
155
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
156
|
-
.contiguous()
|
157
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
158
|
-
} else {
|
159
|
-
embeddings.i((.., 0))
|
160
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
161
|
-
};
|
162
|
-
// Ensure tensor is contiguous before linear layer
|
163
|
-
let cls_embeddings = cls_embeddings.contiguous()
|
164
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
|
164
|
+
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
165
165
|
let pooled = self.pooler.forward(&cls_embeddings)
|
166
166
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
|
167
167
|
pooled.tanh()
|
@@ -169,34 +169,7 @@ impl Reranker {
|
|
169
169
|
},
|
170
170
|
"cls" => {
|
171
171
|
// Just use the [CLS] token embeddings directly (no pooler layer)
|
172
|
-
|
173
|
-
let cls_embeddings = if self.device.is_metal() {
|
174
|
-
// Use same approach as pooler method
|
175
|
-
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
176
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
177
|
-
|
178
|
-
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
179
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
180
|
-
|
181
|
-
let mut cls_vecs = Vec::new();
|
182
|
-
for i in 0..batch_size {
|
183
|
-
let start_idx = i * _seq_len;
|
184
|
-
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
185
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
186
|
-
cls_vecs.push(cls_vec);
|
187
|
-
}
|
188
|
-
|
189
|
-
Tensor::cat(&cls_vecs, 0)
|
190
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
191
|
-
.contiguous()
|
192
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
193
|
-
} else {
|
194
|
-
embeddings.i((.., 0))
|
195
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
196
|
-
};
|
197
|
-
// Ensure contiguous for classifier
|
198
|
-
cls_embeddings.contiguous()
|
199
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
|
172
|
+
self.extract_cls_embeddings(&embeddings)?
|
200
173
|
},
|
201
174
|
"mean" => {
|
202
175
|
// Mean pooling across all tokens
|
@@ -162,6 +162,10 @@ impl Device {
|
|
162
162
|
pub fn __str__(&self) -> String {
|
163
163
|
self.__repr__()
|
164
164
|
}
|
165
|
+
|
166
|
+
pub fn __eq__(&self, other: &Device) -> bool {
|
167
|
+
self == other
|
168
|
+
}
|
165
169
|
}
|
166
170
|
|
167
171
|
impl magnus::TryConvert for Device {
|
@@ -193,5 +197,6 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
193
197
|
rb_device.define_singleton_method("default", function!(default_device, 0))?;
|
194
198
|
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
195
199
|
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
200
|
+
rb_device.define_method("==", method!(Device::__eq__, 1))?;
|
196
201
|
Ok(())
|
197
202
|
}
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,15 +1,18 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
|
+
use std::sync::Arc;
|
3
4
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, phi::Phi as RustPhi, QuantizedGGUF as RustQuantizedGGUF};
|
5
6
|
use crate::ruby::{Result, Device};
|
7
|
+
use crate::ruby::structured::StructuredConstraint;
|
6
8
|
|
7
9
|
// Use an enum to handle different model types instead of trait objects
|
8
|
-
#[derive(Debug)]
|
9
10
|
enum ModelType {
|
10
11
|
Mistral(RustMistral),
|
11
12
|
Llama(RustLlama),
|
12
13
|
Gemma(RustGemma),
|
14
|
+
Qwen(RustQwen),
|
15
|
+
Phi(RustPhi),
|
13
16
|
QuantizedGGUF(RustQuantizedGGUF),
|
14
17
|
}
|
15
18
|
|
@@ -19,6 +22,8 @@ impl ModelType {
|
|
19
22
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
20
23
|
ModelType::Llama(m) => m.generate(prompt, config),
|
21
24
|
ModelType::Gemma(m) => m.generate(prompt, config),
|
25
|
+
ModelType::Qwen(m) => m.generate(prompt, config),
|
26
|
+
ModelType::Phi(m) => m.generate(prompt, config),
|
22
27
|
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
23
28
|
}
|
24
29
|
}
|
@@ -33,6 +38,8 @@ impl ModelType {
|
|
33
38
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
34
39
|
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
35
40
|
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
41
|
+
ModelType::Qwen(m) => m.generate_stream(prompt, config, callback),
|
42
|
+
ModelType::Phi(m) => m.generate_stream(prompt, config, callback),
|
36
43
|
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
37
44
|
}
|
38
45
|
}
|
@@ -42,6 +49,8 @@ impl ModelType {
|
|
42
49
|
ModelType::Mistral(m) => m.clear_cache(),
|
43
50
|
ModelType::Llama(m) => m.clear_cache(),
|
44
51
|
ModelType::Gemma(m) => m.clear_cache(),
|
52
|
+
ModelType::Qwen(m) => m.clear_cache(),
|
53
|
+
ModelType::Phi(m) => m.clear_cache(),
|
45
54
|
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
46
55
|
}
|
47
56
|
}
|
@@ -67,11 +76,33 @@ impl ModelType {
|
|
67
76
|
},
|
68
77
|
ModelType::Llama(m) => m.apply_chat_template(messages),
|
69
78
|
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
79
|
+
ModelType::Qwen(m) => m.apply_chat_template(messages),
|
80
|
+
ModelType::Phi(m) => m.apply_chat_template(messages),
|
70
81
|
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
71
82
|
}
|
72
83
|
}
|
73
84
|
}
|
74
85
|
|
86
|
+
// Macro to extract parameters from Ruby hash to reduce boilerplate
|
87
|
+
macro_rules! extract_param {
|
88
|
+
// Basic parameter extraction
|
89
|
+
($kwargs:expr, $config:expr, $param:ident) => {
|
90
|
+
if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
|
91
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
92
|
+
$config.$param = v;
|
93
|
+
}
|
94
|
+
}
|
95
|
+
};
|
96
|
+
// Optional parameter extraction (wraps in Some)
|
97
|
+
($kwargs:expr, $config:expr, $param:ident, optional) => {
|
98
|
+
if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
|
99
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
100
|
+
$config.$param = Some(v);
|
101
|
+
}
|
102
|
+
}
|
103
|
+
};
|
104
|
+
}
|
105
|
+
|
75
106
|
#[derive(Clone, Debug)]
|
76
107
|
#[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
|
77
108
|
pub struct GenerationConfig {
|
@@ -82,55 +113,20 @@ impl GenerationConfig {
|
|
82
113
|
pub fn new(kwargs: RHash) -> Result<Self> {
|
83
114
|
let mut config = RustGenerationConfig::default();
|
84
115
|
|
85
|
-
// Extract
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
|
99
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
100
|
-
config.top_p = Some(v);
|
101
|
-
}
|
102
|
-
}
|
103
|
-
|
104
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
|
105
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
106
|
-
config.top_k = Some(v);
|
107
|
-
}
|
108
|
-
}
|
109
|
-
|
110
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
|
111
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
112
|
-
config.repetition_penalty = v;
|
113
|
-
}
|
114
|
-
}
|
115
|
-
|
116
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
|
117
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
118
|
-
config.repetition_penalty_last_n = v;
|
119
|
-
}
|
120
|
-
}
|
121
|
-
|
122
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
|
123
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
124
|
-
config.seed = v;
|
125
|
-
}
|
126
|
-
}
|
127
|
-
|
128
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
|
129
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
130
|
-
config.include_prompt = v;
|
131
|
-
}
|
132
|
-
}
|
116
|
+
// Extract basic parameters using macro
|
117
|
+
extract_param!(kwargs, config, max_length);
|
118
|
+
extract_param!(kwargs, config, temperature);
|
119
|
+
extract_param!(kwargs, config, top_p, optional);
|
120
|
+
extract_param!(kwargs, config, top_k, optional);
|
121
|
+
extract_param!(kwargs, config, repetition_penalty);
|
122
|
+
extract_param!(kwargs, config, repetition_penalty_last_n);
|
123
|
+
extract_param!(kwargs, config, seed);
|
124
|
+
extract_param!(kwargs, config, include_prompt);
|
125
|
+
extract_param!(kwargs, config, debug_tokens);
|
126
|
+
extract_param!(kwargs, config, stop_on_constraint_satisfaction);
|
127
|
+
extract_param!(kwargs, config, stop_on_match);
|
133
128
|
|
129
|
+
// Handle special cases that need custom logic
|
134
130
|
if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
|
135
131
|
if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
|
136
132
|
config.stop_sequences = arr
|
@@ -140,9 +136,9 @@ impl GenerationConfig {
|
|
140
136
|
}
|
141
137
|
}
|
142
138
|
|
143
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("
|
144
|
-
if let Ok(
|
145
|
-
config.
|
139
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
|
140
|
+
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
141
|
+
config.constraint = Some(Arc::clone(&constraint.index));
|
146
142
|
}
|
147
143
|
}
|
148
144
|
|
@@ -191,9 +187,23 @@ impl GenerationConfig {
|
|
191
187
|
pub fn debug_tokens(&self) -> bool {
|
192
188
|
self.inner.debug_tokens
|
193
189
|
}
|
190
|
+
|
191
|
+
pub fn stop_on_constraint_satisfaction(&self) -> bool {
|
192
|
+
self.inner.stop_on_constraint_satisfaction
|
193
|
+
}
|
194
|
+
|
195
|
+
pub fn stop_on_match(&self) -> bool {
|
196
|
+
self.inner.stop_on_match
|
197
|
+
}
|
198
|
+
|
199
|
+
pub fn constraint(&self) -> Option<StructuredConstraint> {
|
200
|
+
self.inner.constraint.as_ref().map(|c| StructuredConstraint {
|
201
|
+
index: Arc::clone(c),
|
202
|
+
})
|
203
|
+
}
|
194
204
|
}
|
195
205
|
|
196
|
-
#[derive(Clone
|
206
|
+
#[derive(Clone)]
|
197
207
|
#[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
|
198
208
|
pub struct LLM {
|
199
209
|
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
@@ -251,10 +261,22 @@ impl LLM {
|
|
251
261
|
})
|
252
262
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
253
263
|
ModelType::Gemma(gemma)
|
264
|
+
} else if model_lower.contains("qwen") {
|
265
|
+
let qwen = rt.block_on(async {
|
266
|
+
RustQwen::from_pretrained(&model_id, candle_device).await
|
267
|
+
})
|
268
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
269
|
+
ModelType::Qwen(qwen)
|
270
|
+
} else if model_lower.contains("phi") {
|
271
|
+
let phi = rt.block_on(async {
|
272
|
+
RustPhi::from_pretrained(&model_id, candle_device).await
|
273
|
+
})
|
274
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
275
|
+
ModelType::Phi(phi)
|
254
276
|
} else {
|
255
277
|
return Err(Error::new(
|
256
278
|
magnus::exception::runtime_error(),
|
257
|
-
format!("Unsupported model type: {}. Currently Mistral, Llama, and
|
279
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id),
|
258
280
|
));
|
259
281
|
}
|
260
282
|
};
|
@@ -332,9 +354,47 @@ impl LLM {
|
|
332
354
|
ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
333
355
|
ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
334
356
|
ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
357
|
+
ModelType::Qwen(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
358
|
+
ModelType::Phi(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
335
359
|
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
336
360
|
}
|
337
361
|
}
|
362
|
+
|
363
|
+
/// Get the EOS token string for this model
|
364
|
+
pub fn eos_token(&self) -> Result<String> {
|
365
|
+
let (eos_token_id, tokenizer_clone) = {
|
366
|
+
let model = match self.model.lock() {
|
367
|
+
Ok(guard) => guard,
|
368
|
+
Err(poisoned) => poisoned.into_inner(),
|
369
|
+
};
|
370
|
+
let model_ref = model.borrow();
|
371
|
+
|
372
|
+
// Get both EOS token ID and tokenizer clone in one lock scope
|
373
|
+
let eos_id = match &*model_ref {
|
374
|
+
ModelType::Mistral(m) => m.eos_token_id(),
|
375
|
+
ModelType::Llama(m) => m.eos_token_id(),
|
376
|
+
ModelType::Gemma(m) => m.eos_token_id(),
|
377
|
+
ModelType::Qwen(m) => m.eos_token_id(),
|
378
|
+
ModelType::Phi(m) => m.eos_token_id(),
|
379
|
+
ModelType::QuantizedGGUF(m) => m.eos_token_id(),
|
380
|
+
};
|
381
|
+
|
382
|
+
let tokenizer = match &*model_ref {
|
383
|
+
ModelType::Mistral(m) => m.tokenizer().clone(),
|
384
|
+
ModelType::Llama(m) => m.tokenizer().clone(),
|
385
|
+
ModelType::Gemma(m) => m.tokenizer().clone(),
|
386
|
+
ModelType::Qwen(m) => m.tokenizer().clone(),
|
387
|
+
ModelType::Phi(m) => m.tokenizer().clone(),
|
388
|
+
ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
|
389
|
+
};
|
390
|
+
|
391
|
+
(eos_id, tokenizer)
|
392
|
+
}; // Lock is released here
|
393
|
+
|
394
|
+
// Convert ID to string using the tokenizer
|
395
|
+
let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
|
396
|
+
tokenizer_wrapper.id_to_token(eos_token_id as i64)
|
397
|
+
}
|
338
398
|
|
339
399
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
340
400
|
pub fn clear_cache(&self) -> Result<()> {
|
@@ -423,6 +483,9 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
423
483
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
424
484
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
425
485
|
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
486
|
+
rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
|
487
|
+
rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
|
488
|
+
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
426
489
|
|
427
490
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
428
491
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
@@ -431,6 +494,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
431
494
|
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
432
495
|
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
433
496
|
rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
|
497
|
+
rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
|
434
498
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
435
499
|
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
436
500
|
|
data/ext/candle/src/ruby/mod.rs
CHANGED
@@ -0,0 +1,47 @@
|
|
1
|
+
use magnus::{Error, Module, RModule, function, Object};
|
2
|
+
use std::sync::Arc;
|
3
|
+
|
4
|
+
use crate::structured::{SchemaProcessor, VocabularyAdapter, Index};
|
5
|
+
use crate::ruby::{Result, tokenizer::Tokenizer};
|
6
|
+
|
7
|
+
/// Ruby wrapper for structured generation constraints
|
8
|
+
#[derive(Clone, Debug)]
|
9
|
+
#[magnus::wrap(class = "Candle::StructuredConstraint", mark, free_immediately)]
|
10
|
+
pub struct StructuredConstraint {
|
11
|
+
pub(crate) index: Arc<Index>,
|
12
|
+
}
|
13
|
+
|
14
|
+
impl StructuredConstraint {
|
15
|
+
/// Create a constraint from a JSON schema
|
16
|
+
pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
|
17
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
18
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
19
|
+
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
let index = processor.process_schema(&schema, &vocabulary)
|
22
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process schema: {}", e)))?;
|
23
|
+
|
24
|
+
Ok(Self { index })
|
25
|
+
}
|
26
|
+
|
27
|
+
/// Create a constraint from a regex pattern
|
28
|
+
pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
|
29
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
30
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
31
|
+
|
32
|
+
let processor = SchemaProcessor::new();
|
33
|
+
let index = processor.process_regex(&pattern, &vocabulary)
|
34
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process regex: {}", e)))?;
|
35
|
+
|
36
|
+
Ok(Self { index })
|
37
|
+
}
|
38
|
+
}
|
39
|
+
|
40
|
+
pub fn init_structured(rb_candle: RModule) -> Result<()> {
|
41
|
+
let class = rb_candle.define_class("StructuredConstraint", magnus::class::object())?;
|
42
|
+
|
43
|
+
class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
|
44
|
+
class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
|
45
|
+
|
46
|
+
Ok(())
|
47
|
+
}
|
@@ -0,0 +1,130 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod integration_tests {
|
3
|
+
use super::super::*;
|
4
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
5
|
+
use std::sync::Arc;
|
6
|
+
|
7
|
+
#[tokio::test]
|
8
|
+
async fn test_schema_processor_with_vocabulary() {
|
9
|
+
// This test requires a tokenizer to create a vocabulary
|
10
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
11
|
+
|
12
|
+
if let Ok(tokenizer) = tokenizer_result {
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
14
|
+
|
15
|
+
// Create vocabulary from tokenizer
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
17
|
+
.expect("Should create vocabulary");
|
18
|
+
|
19
|
+
// Create schema processor
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
|
22
|
+
// Test with a simple JSON schema
|
23
|
+
let schema = r#"{
|
24
|
+
"type": "object",
|
25
|
+
"properties": {
|
26
|
+
"name": {"type": "string"},
|
27
|
+
"age": {"type": "integer"}
|
28
|
+
},
|
29
|
+
"required": ["name", "age"]
|
30
|
+
}"#;
|
31
|
+
|
32
|
+
// Process schema into Index
|
33
|
+
let index_result = processor.process_schema(schema, &vocabulary);
|
34
|
+
assert!(index_result.is_ok(), "Should process schema successfully");
|
35
|
+
|
36
|
+
// Test caching - second call should use cache
|
37
|
+
let index2_result = processor.process_schema(schema, &vocabulary);
|
38
|
+
assert!(index2_result.is_ok(), "Should retrieve from cache");
|
39
|
+
|
40
|
+
// Both should be the same Arc
|
41
|
+
let index1 = index_result.unwrap();
|
42
|
+
let index2 = index2_result.unwrap();
|
43
|
+
assert!(Arc::ptr_eq(&index1, &index2), "Should return cached Index");
|
44
|
+
|
45
|
+
// Check cache stats
|
46
|
+
let (size, _) = processor.cache_stats();
|
47
|
+
assert_eq!(size, 1, "Cache should have one entry");
|
48
|
+
} else {
|
49
|
+
eprintln!("Skipping integration test - couldn't load tokenizer");
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
#[tokio::test]
|
54
|
+
async fn test_regex_processing() {
|
55
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
56
|
+
|
57
|
+
if let Ok(tokenizer) = tokenizer_result {
|
58
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
59
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
60
|
+
.expect("Should create vocabulary");
|
61
|
+
|
62
|
+
let processor = SchemaProcessor::new();
|
63
|
+
|
64
|
+
// Test with a simple regex pattern
|
65
|
+
let email_regex = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}";
|
66
|
+
|
67
|
+
let index_result = processor.process_regex(email_regex, &vocabulary);
|
68
|
+
assert!(index_result.is_ok(), "Should process regex successfully");
|
69
|
+
|
70
|
+
// Test different regex
|
71
|
+
let phone_regex = r"\d{3}-\d{3}-\d{4}";
|
72
|
+
let phone_index_result = processor.process_regex(phone_regex, &vocabulary);
|
73
|
+
assert!(phone_index_result.is_ok(), "Should process phone regex");
|
74
|
+
|
75
|
+
// Cache should have both
|
76
|
+
let (size, _) = processor.cache_stats();
|
77
|
+
assert_eq!(size, 2, "Cache should have two entries");
|
78
|
+
|
79
|
+
// Clear cache
|
80
|
+
processor.clear_cache();
|
81
|
+
let (size, _) = processor.cache_stats();
|
82
|
+
assert_eq!(size, 0, "Cache should be empty after clear");
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
#[test]
|
87
|
+
fn test_various_json_schemas() {
|
88
|
+
let _processor = SchemaProcessor::new();
|
89
|
+
|
90
|
+
// Array schema
|
91
|
+
let array_schema = serde_json::json!({
|
92
|
+
"type": "array",
|
93
|
+
"items": {"type": "string"}
|
94
|
+
});
|
95
|
+
|
96
|
+
// Process as a full schema instead of testing private method
|
97
|
+
// This would need a mock vocabulary in a real test
|
98
|
+
// For now, just verify the schema is valid JSON
|
99
|
+
let json_str = serde_json::to_string(&array_schema).unwrap();
|
100
|
+
assert!(!json_str.is_empty(), "Should serialize array schema");
|
101
|
+
|
102
|
+
// Nested object schema
|
103
|
+
let nested_schema = serde_json::json!({
|
104
|
+
"type": "object",
|
105
|
+
"properties": {
|
106
|
+
"user": {
|
107
|
+
"type": "object",
|
108
|
+
"properties": {
|
109
|
+
"id": {"type": "integer"},
|
110
|
+
"email": {"type": "string", "format": "email"}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
}
|
114
|
+
});
|
115
|
+
|
116
|
+
// Verify nested schema is valid
|
117
|
+
let json_str = serde_json::to_string(&nested_schema).unwrap();
|
118
|
+
assert!(json_str.contains("properties"), "Should have nested properties");
|
119
|
+
|
120
|
+
// Schema with enum
|
121
|
+
let enum_schema = serde_json::json!({
|
122
|
+
"type": "string",
|
123
|
+
"enum": ["red", "green", "blue"]
|
124
|
+
});
|
125
|
+
|
126
|
+
// Verify enum schema is valid
|
127
|
+
let json_str = serde_json::to_string(&enum_schema).unwrap();
|
128
|
+
assert!(json_str.contains("enum"), "Should have enum values");
|
129
|
+
}
|
130
|
+
}
|
@@ -0,0 +1,31 @@
|
|
1
|
+
/// Structured generation support using Outlines
|
2
|
+
///
|
3
|
+
/// This module provides functionality to constrain language model generation
|
4
|
+
/// to follow specific patterns, such as JSON schemas or regular expressions.
|
5
|
+
|
6
|
+
pub mod vocabulary_adapter;
|
7
|
+
pub mod schema_processor;
|
8
|
+
|
9
|
+
pub use vocabulary_adapter::VocabularyAdapter;
|
10
|
+
pub use schema_processor::SchemaProcessor;
|
11
|
+
|
12
|
+
// Re-export commonly used types from outlines-core
|
13
|
+
pub use outlines_core::prelude::Index;
|
14
|
+
pub use outlines_core::vocabulary::Vocabulary;
|
15
|
+
|
16
|
+
#[cfg(test)]
|
17
|
+
mod vocabulary_adapter_simple_test;
|
18
|
+
|
19
|
+
#[cfg(test)]
|
20
|
+
mod integration_test;
|
21
|
+
|
22
|
+
#[cfg(test)]
|
23
|
+
mod tests {
|
24
|
+
use super::*;
|
25
|
+
|
26
|
+
#[test]
|
27
|
+
fn test_module_imports() {
|
28
|
+
// Ensure all exports are available
|
29
|
+
let _ = VocabularyAdapter;
|
30
|
+
}
|
31
|
+
}
|