red-candle 1.0.2 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +36 -2
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +123 -0
- data/ext/candle/src/llm/generation_config.rs +5 -0
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +285 -0
- data/ext/candle/src/llm/quantized_gguf.rs +155 -4
- data/ext/candle/src/llm/qwen.rs +229 -0
- data/ext/candle/src/llm/text_generation.rs +66 -2
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +42 -4
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/llm.rb +109 -3
- data/lib/candle/version.rb +1 -1
- metadata +14 -4
@@ -0,0 +1,229 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_transformers::models::qwen2::{Config, Model as QwenModel};
|
3
|
+
use hf_hub::api::tokio::Api;
|
4
|
+
use tokenizers::Tokenizer;
|
5
|
+
|
6
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
7
|
+
|
8
|
+
/// Qwen model wrapper for text generation
|
9
|
+
#[derive(Debug)]
|
10
|
+
pub struct Qwen {
|
11
|
+
model: QwenModel,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
}
|
17
|
+
|
18
|
+
impl Qwen {
|
19
|
+
/// Get the tokenizer
|
20
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
21
|
+
&self.tokenizer
|
22
|
+
}
|
23
|
+
|
24
|
+
/// Clear the KV cache between generations
|
25
|
+
pub fn clear_kv_cache(&mut self) {
|
26
|
+
self.model.clear_kv_cache();
|
27
|
+
}
|
28
|
+
|
29
|
+
/// Load a Qwen model from HuggingFace
|
30
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
31
|
+
let api = Api::new()
|
32
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
33
|
+
|
34
|
+
let repo = api.model(model_id.to_string());
|
35
|
+
|
36
|
+
// Download configuration
|
37
|
+
let config_filename = repo.get("config.json").await
|
38
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
39
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
40
|
+
let config: Config = serde_json::from_str(&config_str)
|
41
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
42
|
+
|
43
|
+
// Download tokenizer
|
44
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
46
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
47
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
48
|
+
|
49
|
+
// Determine EOS token
|
50
|
+
let vocab = tokenizer.get_vocab(true);
|
51
|
+
let eos_token_id = vocab.get("<|endoftext|>")
|
52
|
+
.or_else(|| vocab.get("<|im_end|>"))
|
53
|
+
.or_else(|| vocab.get("</s>"))
|
54
|
+
.copied()
|
55
|
+
.unwrap_or(151643); // Default Qwen3 EOS token
|
56
|
+
|
57
|
+
// Download model weights
|
58
|
+
let mut filenames = vec![];
|
59
|
+
let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
|
60
|
+
else if model_id.contains("14b") || model_id.contains("14B") { 3 }
|
61
|
+
else { 1 };
|
62
|
+
|
63
|
+
if num_shards == 1 {
|
64
|
+
// Single file model
|
65
|
+
let filename = repo.get("model.safetensors").await
|
66
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
|
67
|
+
filenames.push(filename);
|
68
|
+
} else {
|
69
|
+
// Sharded model
|
70
|
+
for shard_idx in 1..=num_shards {
|
71
|
+
let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
|
72
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
|
73
|
+
filenames.push(filename);
|
74
|
+
}
|
75
|
+
}
|
76
|
+
|
77
|
+
// Load the model
|
78
|
+
let vb = unsafe {
|
79
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
|
80
|
+
};
|
81
|
+
|
82
|
+
let model = QwenModel::new(&config, vb)?;
|
83
|
+
|
84
|
+
Ok(Self {
|
85
|
+
model,
|
86
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
87
|
+
device,
|
88
|
+
model_id: model_id.to_string(),
|
89
|
+
eos_token_id,
|
90
|
+
})
|
91
|
+
}
|
92
|
+
|
93
|
+
/// Apply Qwen chat template to messages
|
94
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
95
|
+
let mut prompt = String::new();
|
96
|
+
|
97
|
+
for message in messages {
|
98
|
+
let role = message["role"].as_str().unwrap_or("");
|
99
|
+
let content = message["content"].as_str().unwrap_or("");
|
100
|
+
|
101
|
+
match role {
|
102
|
+
"system" => {
|
103
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
104
|
+
}
|
105
|
+
"user" => {
|
106
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
107
|
+
}
|
108
|
+
"assistant" => {
|
109
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
110
|
+
}
|
111
|
+
_ => {}
|
112
|
+
}
|
113
|
+
}
|
114
|
+
|
115
|
+
// Add generation prompt
|
116
|
+
prompt.push_str("<|im_start|>assistant\n");
|
117
|
+
|
118
|
+
Ok(prompt)
|
119
|
+
}
|
120
|
+
|
121
|
+
fn generate_tokens(
|
122
|
+
&mut self,
|
123
|
+
prompt_tokens: Vec<u32>,
|
124
|
+
config: &GenerationConfig,
|
125
|
+
mut callback: Option<impl FnMut(&str)>,
|
126
|
+
) -> CandleResult<Vec<u32>> {
|
127
|
+
let mut text_gen = TextGeneration::from_config(config);
|
128
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
129
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
130
|
+
|
131
|
+
let mut all_tokens = prompt_tokens.clone();
|
132
|
+
let start_gen = all_tokens.len();
|
133
|
+
|
134
|
+
for index in 0..config.max_length {
|
135
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
136
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
137
|
+
let ctxt = &all_tokens[start_pos..];
|
138
|
+
|
139
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
140
|
+
let logits = self.model.forward(&input, start_pos, None)?;
|
141
|
+
let logits = logits.squeeze(0)?;
|
142
|
+
|
143
|
+
// Handle different output shapes
|
144
|
+
let logits = if logits.dims().len() == 2 {
|
145
|
+
let seq_len = logits.dim(0)?;
|
146
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
147
|
+
} else {
|
148
|
+
logits
|
149
|
+
};
|
150
|
+
|
151
|
+
let logits = logits.to_dtype(DType::F32)?;
|
152
|
+
|
153
|
+
let next_token = text_gen.sample_next_token(
|
154
|
+
&logits,
|
155
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
156
|
+
)?;
|
157
|
+
|
158
|
+
all_tokens.push(next_token);
|
159
|
+
|
160
|
+
// Stream callback
|
161
|
+
if let Some(ref mut cb) = callback {
|
162
|
+
if config.debug_tokens {
|
163
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
164
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
165
|
+
} else {
|
166
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
167
|
+
cb(&decoded_text);
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
// Check stop conditions
|
172
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
173
|
+
break;
|
174
|
+
}
|
175
|
+
|
176
|
+
// Check stop sequences
|
177
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
178
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
179
|
+
break;
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
183
|
+
Ok(if config.include_prompt {
|
184
|
+
all_tokens
|
185
|
+
} else {
|
186
|
+
all_tokens[start_gen..].to_vec()
|
187
|
+
})
|
188
|
+
}
|
189
|
+
}
|
190
|
+
|
191
|
+
impl TextGenerator for Qwen {
|
192
|
+
fn generate(
|
193
|
+
&mut self,
|
194
|
+
prompt: &str,
|
195
|
+
config: &GenerationConfig,
|
196
|
+
) -> CandleResult<String> {
|
197
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
198
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
199
|
+
|
200
|
+
if config.debug_tokens {
|
201
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
202
|
+
} else {
|
203
|
+
self.tokenizer.decode(&output_tokens, true)
|
204
|
+
}
|
205
|
+
}
|
206
|
+
|
207
|
+
fn generate_stream(
|
208
|
+
&mut self,
|
209
|
+
prompt: &str,
|
210
|
+
config: &GenerationConfig,
|
211
|
+
mut callback: impl FnMut(&str),
|
212
|
+
) -> CandleResult<String> {
|
213
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
214
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
215
|
+
self.tokenizer.decode(&output_tokens, true)
|
216
|
+
}
|
217
|
+
|
218
|
+
fn model_name(&self) -> &str {
|
219
|
+
&self.model_id
|
220
|
+
}
|
221
|
+
|
222
|
+
fn device(&self) -> &Device {
|
223
|
+
&self.device
|
224
|
+
}
|
225
|
+
|
226
|
+
fn clear_cache(&mut self) {
|
227
|
+
self.clear_kv_cache();
|
228
|
+
}
|
229
|
+
}
|
@@ -1,13 +1,17 @@
|
|
1
1
|
use candle_core::{Result as CandleResult, Tensor};
|
2
2
|
use candle_transformers::generation::LogitsProcessor;
|
3
|
+
use std::sync::Arc;
|
3
4
|
|
4
5
|
use super::GenerationConfig;
|
6
|
+
use crate::structured::Index;
|
5
7
|
|
6
8
|
/// Helper struct for text generation process
|
7
9
|
pub struct TextGeneration {
|
8
10
|
logits_processor: LogitsProcessor,
|
9
11
|
tokens: Vec<u32>,
|
10
12
|
eos_token_id: Option<u32>,
|
13
|
+
constraint: Option<Arc<Index>>,
|
14
|
+
constraint_state: Option<u32>,
|
11
15
|
}
|
12
16
|
|
13
17
|
impl TextGeneration {
|
@@ -25,18 +29,27 @@ impl TextGeneration {
|
|
25
29
|
logits_processor,
|
26
30
|
tokens: Vec::new(),
|
27
31
|
eos_token_id: None,
|
32
|
+
constraint: None,
|
33
|
+
constraint_state: None,
|
28
34
|
}
|
29
35
|
}
|
30
36
|
|
31
37
|
pub fn from_config(config: &GenerationConfig) -> Self {
|
32
|
-
Self::new(
|
38
|
+
let mut text_gen = Self::new(
|
33
39
|
config.seed,
|
34
40
|
Some(config.temperature),
|
35
41
|
config.top_p,
|
36
42
|
config.top_k,
|
37
43
|
config.repetition_penalty,
|
38
44
|
config.repetition_penalty_last_n,
|
39
|
-
)
|
45
|
+
);
|
46
|
+
|
47
|
+
// Set constraint if provided
|
48
|
+
if let Some(ref constraint) = config.constraint {
|
49
|
+
text_gen.set_constraint(Arc::clone(constraint));
|
50
|
+
}
|
51
|
+
|
52
|
+
text_gen
|
40
53
|
}
|
41
54
|
|
42
55
|
pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
|
@@ -55,6 +68,36 @@ impl TextGeneration {
|
|
55
68
|
self.tokens.push(token);
|
56
69
|
}
|
57
70
|
|
71
|
+
pub fn set_constraint(&mut self, constraint: Arc<Index>) {
|
72
|
+
// Initialize with the first state
|
73
|
+
self.constraint_state = Some(constraint.initial_state());
|
74
|
+
self.constraint = Some(constraint);
|
75
|
+
}
|
76
|
+
|
77
|
+
/// Apply constraints to logits by masking disallowed tokens
|
78
|
+
fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
|
79
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
80
|
+
let device = logits.device();
|
81
|
+
let vocab_size = logits.dims1()?;
|
82
|
+
|
83
|
+
// Get allowed tokens from the constraint index for current state
|
84
|
+
if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
|
85
|
+
// Create a mask where allowed tokens have value 0 and others have -inf
|
86
|
+
let mut mask = vec![f32::NEG_INFINITY; vocab_size];
|
87
|
+
for &token_id in &allowed_tokens {
|
88
|
+
if (token_id as usize) < vocab_size {
|
89
|
+
mask[token_id as usize] = 0.0;
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
// Apply mask to logits
|
94
|
+
let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
|
95
|
+
*logits = logits.add(&mask_tensor)?;
|
96
|
+
}
|
97
|
+
}
|
98
|
+
Ok(())
|
99
|
+
}
|
100
|
+
|
58
101
|
/// Apply repetition penalty to logits
|
59
102
|
pub fn apply_repetition_penalty(
|
60
103
|
&self,
|
@@ -103,10 +146,18 @@ impl TextGeneration {
|
|
103
146
|
self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
|
104
147
|
}
|
105
148
|
|
149
|
+
// Apply constraints if active
|
150
|
+
self.apply_constraints(&mut logits)?;
|
151
|
+
|
106
152
|
// Sample token
|
107
153
|
let next_token = self.logits_processor.sample(&logits)?;
|
108
154
|
self.tokens.push(next_token);
|
109
155
|
|
156
|
+
// Update constraint state if active
|
157
|
+
if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
|
158
|
+
self.constraint_state = constraint_index.next_state(¤t_state, &next_token);
|
159
|
+
}
|
160
|
+
|
110
161
|
Ok(next_token)
|
111
162
|
}
|
112
163
|
|
@@ -122,6 +173,19 @@ impl TextGeneration {
|
|
122
173
|
}
|
123
174
|
}
|
124
175
|
|
176
|
+
// Check if we've reached a final state in constraint
|
177
|
+
// A state is considered final if it has no allowed tokens
|
178
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
179
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
180
|
+
if allowed.is_empty() {
|
181
|
+
return true;
|
182
|
+
}
|
183
|
+
} else {
|
184
|
+
// None means no tokens allowed - we're done
|
185
|
+
return true;
|
186
|
+
}
|
187
|
+
}
|
188
|
+
|
125
189
|
false
|
126
190
|
}
|
127
191
|
|
@@ -162,6 +162,10 @@ impl Device {
|
|
162
162
|
pub fn __str__(&self) -> String {
|
163
163
|
self.__repr__()
|
164
164
|
}
|
165
|
+
|
166
|
+
pub fn __eq__(&self, other: &Device) -> bool {
|
167
|
+
self == other
|
168
|
+
}
|
165
169
|
}
|
166
170
|
|
167
171
|
impl magnus::TryConvert for Device {
|
@@ -193,5 +197,6 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
193
197
|
rb_device.define_singleton_method("default", function!(default_device, 0))?;
|
194
198
|
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
195
199
|
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
200
|
+
rb_device.define_method("==", method!(Device::__eq__, 1))?;
|
196
201
|
Ok(())
|
197
202
|
}
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,15 +1,18 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
|
+
use std::sync::Arc;
|
3
4
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, phi::Phi as RustPhi, QuantizedGGUF as RustQuantizedGGUF};
|
5
6
|
use crate::ruby::{Result, Device};
|
7
|
+
use crate::ruby::structured::StructuredConstraint;
|
6
8
|
|
7
9
|
// Use an enum to handle different model types instead of trait objects
|
8
|
-
#[derive(Debug)]
|
9
10
|
enum ModelType {
|
10
11
|
Mistral(RustMistral),
|
11
12
|
Llama(RustLlama),
|
12
13
|
Gemma(RustGemma),
|
14
|
+
Qwen(RustQwen),
|
15
|
+
Phi(RustPhi),
|
13
16
|
QuantizedGGUF(RustQuantizedGGUF),
|
14
17
|
}
|
15
18
|
|
@@ -19,6 +22,8 @@ impl ModelType {
|
|
19
22
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
20
23
|
ModelType::Llama(m) => m.generate(prompt, config),
|
21
24
|
ModelType::Gemma(m) => m.generate(prompt, config),
|
25
|
+
ModelType::Qwen(m) => m.generate(prompt, config),
|
26
|
+
ModelType::Phi(m) => m.generate(prompt, config),
|
22
27
|
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
23
28
|
}
|
24
29
|
}
|
@@ -33,6 +38,8 @@ impl ModelType {
|
|
33
38
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
34
39
|
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
35
40
|
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
41
|
+
ModelType::Qwen(m) => m.generate_stream(prompt, config, callback),
|
42
|
+
ModelType::Phi(m) => m.generate_stream(prompt, config, callback),
|
36
43
|
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
37
44
|
}
|
38
45
|
}
|
@@ -42,6 +49,8 @@ impl ModelType {
|
|
42
49
|
ModelType::Mistral(m) => m.clear_cache(),
|
43
50
|
ModelType::Llama(m) => m.clear_cache(),
|
44
51
|
ModelType::Gemma(m) => m.clear_cache(),
|
52
|
+
ModelType::Qwen(m) => m.clear_cache(),
|
53
|
+
ModelType::Phi(m) => m.clear_cache(),
|
45
54
|
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
46
55
|
}
|
47
56
|
}
|
@@ -67,6 +76,8 @@ impl ModelType {
|
|
67
76
|
},
|
68
77
|
ModelType::Llama(m) => m.apply_chat_template(messages),
|
69
78
|
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
79
|
+
ModelType::Qwen(m) => m.apply_chat_template(messages),
|
80
|
+
ModelType::Phi(m) => m.apply_chat_template(messages),
|
70
81
|
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
71
82
|
}
|
72
83
|
}
|
@@ -146,6 +157,13 @@ impl GenerationConfig {
|
|
146
157
|
}
|
147
158
|
}
|
148
159
|
|
160
|
+
// Handle constraint parameter
|
161
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
|
162
|
+
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
163
|
+
config.constraint = Some(Arc::clone(&constraint.index));
|
164
|
+
}
|
165
|
+
}
|
166
|
+
|
149
167
|
Ok(Self { inner: config })
|
150
168
|
}
|
151
169
|
|
@@ -191,9 +209,14 @@ impl GenerationConfig {
|
|
191
209
|
pub fn debug_tokens(&self) -> bool {
|
192
210
|
self.inner.debug_tokens
|
193
211
|
}
|
212
|
+
pub fn constraint(&self) -> Option<StructuredConstraint> {
|
213
|
+
self.inner.constraint.as_ref().map(|c| StructuredConstraint {
|
214
|
+
index: Arc::clone(c),
|
215
|
+
})
|
216
|
+
}
|
194
217
|
}
|
195
218
|
|
196
|
-
#[derive(Clone
|
219
|
+
#[derive(Clone)]
|
197
220
|
#[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
|
198
221
|
pub struct LLM {
|
199
222
|
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
@@ -251,10 +274,22 @@ impl LLM {
|
|
251
274
|
})
|
252
275
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
253
276
|
ModelType::Gemma(gemma)
|
277
|
+
} else if model_lower.contains("qwen") {
|
278
|
+
let qwen = rt.block_on(async {
|
279
|
+
RustQwen::from_pretrained(&model_id, candle_device).await
|
280
|
+
})
|
281
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
282
|
+
ModelType::Qwen(qwen)
|
283
|
+
} else if model_lower.contains("phi") {
|
284
|
+
let phi = rt.block_on(async {
|
285
|
+
RustPhi::from_pretrained(&model_id, candle_device).await
|
286
|
+
})
|
287
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
288
|
+
ModelType::Phi(phi)
|
254
289
|
} else {
|
255
290
|
return Err(Error::new(
|
256
291
|
magnus::exception::runtime_error(),
|
257
|
-
format!("Unsupported model type: {}. Currently Mistral, Llama, and
|
292
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id),
|
258
293
|
));
|
259
294
|
}
|
260
295
|
};
|
@@ -332,6 +367,8 @@ impl LLM {
|
|
332
367
|
ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
333
368
|
ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
334
369
|
ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
370
|
+
ModelType::Qwen(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
371
|
+
ModelType::Phi(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
335
372
|
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
336
373
|
}
|
337
374
|
}
|
@@ -423,6 +460,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
423
460
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
424
461
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
425
462
|
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
463
|
+
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
426
464
|
|
427
465
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
428
466
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
data/ext/candle/src/ruby/mod.rs
CHANGED
@@ -0,0 +1,47 @@
|
|
1
|
+
use magnus::{Error, Module, RModule, function, Object};
|
2
|
+
use std::sync::Arc;
|
3
|
+
|
4
|
+
use crate::structured::{SchemaProcessor, VocabularyAdapter, Index};
|
5
|
+
use crate::ruby::{Result, tokenizer::Tokenizer};
|
6
|
+
|
7
|
+
/// Ruby wrapper for structured generation constraints
|
8
|
+
#[derive(Clone, Debug)]
|
9
|
+
#[magnus::wrap(class = "Candle::StructuredConstraint", mark, free_immediately)]
|
10
|
+
pub struct StructuredConstraint {
|
11
|
+
pub(crate) index: Arc<Index>,
|
12
|
+
}
|
13
|
+
|
14
|
+
impl StructuredConstraint {
|
15
|
+
/// Create a constraint from a JSON schema
|
16
|
+
pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
|
17
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
18
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
19
|
+
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
let index = processor.process_schema(&schema, &vocabulary)
|
22
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process schema: {}", e)))?;
|
23
|
+
|
24
|
+
Ok(Self { index })
|
25
|
+
}
|
26
|
+
|
27
|
+
/// Create a constraint from a regex pattern
|
28
|
+
pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
|
29
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
30
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
31
|
+
|
32
|
+
let processor = SchemaProcessor::new();
|
33
|
+
let index = processor.process_regex(&pattern, &vocabulary)
|
34
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process regex: {}", e)))?;
|
35
|
+
|
36
|
+
Ok(Self { index })
|
37
|
+
}
|
38
|
+
}
|
39
|
+
|
40
|
+
pub fn init_structured(rb_candle: RModule) -> Result<()> {
|
41
|
+
let class = rb_candle.define_class("StructuredConstraint", magnus::class::object())?;
|
42
|
+
|
43
|
+
class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
|
44
|
+
class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
|
45
|
+
|
46
|
+
Ok(())
|
47
|
+
}
|
@@ -0,0 +1,130 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod integration_tests {
|
3
|
+
use super::super::*;
|
4
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
5
|
+
use std::sync::Arc;
|
6
|
+
|
7
|
+
#[tokio::test]
|
8
|
+
async fn test_schema_processor_with_vocabulary() {
|
9
|
+
// This test requires a tokenizer to create a vocabulary
|
10
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
11
|
+
|
12
|
+
if let Ok(tokenizer) = tokenizer_result {
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
14
|
+
|
15
|
+
// Create vocabulary from tokenizer
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
17
|
+
.expect("Should create vocabulary");
|
18
|
+
|
19
|
+
// Create schema processor
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
|
22
|
+
// Test with a simple JSON schema
|
23
|
+
let schema = r#"{
|
24
|
+
"type": "object",
|
25
|
+
"properties": {
|
26
|
+
"name": {"type": "string"},
|
27
|
+
"age": {"type": "integer"}
|
28
|
+
},
|
29
|
+
"required": ["name", "age"]
|
30
|
+
}"#;
|
31
|
+
|
32
|
+
// Process schema into Index
|
33
|
+
let index_result = processor.process_schema(schema, &vocabulary);
|
34
|
+
assert!(index_result.is_ok(), "Should process schema successfully");
|
35
|
+
|
36
|
+
// Test caching - second call should use cache
|
37
|
+
let index2_result = processor.process_schema(schema, &vocabulary);
|
38
|
+
assert!(index2_result.is_ok(), "Should retrieve from cache");
|
39
|
+
|
40
|
+
// Both should be the same Arc
|
41
|
+
let index1 = index_result.unwrap();
|
42
|
+
let index2 = index2_result.unwrap();
|
43
|
+
assert!(Arc::ptr_eq(&index1, &index2), "Should return cached Index");
|
44
|
+
|
45
|
+
// Check cache stats
|
46
|
+
let (size, _) = processor.cache_stats();
|
47
|
+
assert_eq!(size, 1, "Cache should have one entry");
|
48
|
+
} else {
|
49
|
+
eprintln!("Skipping integration test - couldn't load tokenizer");
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
#[tokio::test]
|
54
|
+
async fn test_regex_processing() {
|
55
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
56
|
+
|
57
|
+
if let Ok(tokenizer) = tokenizer_result {
|
58
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
59
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
60
|
+
.expect("Should create vocabulary");
|
61
|
+
|
62
|
+
let processor = SchemaProcessor::new();
|
63
|
+
|
64
|
+
// Test with a simple regex pattern
|
65
|
+
let email_regex = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}";
|
66
|
+
|
67
|
+
let index_result = processor.process_regex(email_regex, &vocabulary);
|
68
|
+
assert!(index_result.is_ok(), "Should process regex successfully");
|
69
|
+
|
70
|
+
// Test different regex
|
71
|
+
let phone_regex = r"\d{3}-\d{3}-\d{4}";
|
72
|
+
let phone_index_result = processor.process_regex(phone_regex, &vocabulary);
|
73
|
+
assert!(phone_index_result.is_ok(), "Should process phone regex");
|
74
|
+
|
75
|
+
// Cache should have both
|
76
|
+
let (size, _) = processor.cache_stats();
|
77
|
+
assert_eq!(size, 2, "Cache should have two entries");
|
78
|
+
|
79
|
+
// Clear cache
|
80
|
+
processor.clear_cache();
|
81
|
+
let (size, _) = processor.cache_stats();
|
82
|
+
assert_eq!(size, 0, "Cache should be empty after clear");
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
#[test]
|
87
|
+
fn test_various_json_schemas() {
|
88
|
+
let _processor = SchemaProcessor::new();
|
89
|
+
|
90
|
+
// Array schema
|
91
|
+
let array_schema = serde_json::json!({
|
92
|
+
"type": "array",
|
93
|
+
"items": {"type": "string"}
|
94
|
+
});
|
95
|
+
|
96
|
+
// Process as a full schema instead of testing private method
|
97
|
+
// This would need a mock vocabulary in a real test
|
98
|
+
// For now, just verify the schema is valid JSON
|
99
|
+
let json_str = serde_json::to_string(&array_schema).unwrap();
|
100
|
+
assert!(!json_str.is_empty(), "Should serialize array schema");
|
101
|
+
|
102
|
+
// Nested object schema
|
103
|
+
let nested_schema = serde_json::json!({
|
104
|
+
"type": "object",
|
105
|
+
"properties": {
|
106
|
+
"user": {
|
107
|
+
"type": "object",
|
108
|
+
"properties": {
|
109
|
+
"id": {"type": "integer"},
|
110
|
+
"email": {"type": "string", "format": "email"}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
}
|
114
|
+
});
|
115
|
+
|
116
|
+
// Verify nested schema is valid
|
117
|
+
let json_str = serde_json::to_string(&nested_schema).unwrap();
|
118
|
+
assert!(json_str.contains("properties"), "Should have nested properties");
|
119
|
+
|
120
|
+
// Schema with enum
|
121
|
+
let enum_schema = serde_json::json!({
|
122
|
+
"type": "string",
|
123
|
+
"enum": ["red", "green", "blue"]
|
124
|
+
});
|
125
|
+
|
126
|
+
// Verify enum schema is valid
|
127
|
+
let json_str = serde_json::to_string(&enum_schema).unwrap();
|
128
|
+
assert!(json_str.contains("enum"), "Should have enum values");
|
129
|
+
}
|
130
|
+
}
|