red-candle 1.0.2 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +36 -2
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +123 -0
- data/ext/candle/src/llm/generation_config.rs +5 -0
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +285 -0
- data/ext/candle/src/llm/quantized_gguf.rs +155 -4
- data/ext/candle/src/llm/qwen.rs +229 -0
- data/ext/candle/src/llm/text_generation.rs +66 -2
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +42 -4
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/llm.rb +109 -3
- data/lib/candle/version.rb +1 -1
- metadata +14 -4
@@ -0,0 +1,31 @@
|
|
1
|
+
/// Structured generation support using Outlines
|
2
|
+
///
|
3
|
+
/// This module provides functionality to constrain language model generation
|
4
|
+
/// to follow specific patterns, such as JSON schemas or regular expressions.
|
5
|
+
|
6
|
+
pub mod vocabulary_adapter;
|
7
|
+
pub mod schema_processor;
|
8
|
+
|
9
|
+
pub use vocabulary_adapter::VocabularyAdapter;
|
10
|
+
pub use schema_processor::SchemaProcessor;
|
11
|
+
|
12
|
+
// Re-export commonly used types from outlines-core
|
13
|
+
pub use outlines_core::prelude::Index;
|
14
|
+
pub use outlines_core::vocabulary::Vocabulary;
|
15
|
+
|
16
|
+
#[cfg(test)]
|
17
|
+
mod vocabulary_adapter_simple_test;
|
18
|
+
|
19
|
+
#[cfg(test)]
|
20
|
+
mod integration_test;
|
21
|
+
|
22
|
+
#[cfg(test)]
|
23
|
+
mod tests {
|
24
|
+
use super::*;
|
25
|
+
|
26
|
+
#[test]
|
27
|
+
fn test_module_imports() {
|
28
|
+
// Ensure all exports are available
|
29
|
+
let _ = VocabularyAdapter;
|
30
|
+
}
|
31
|
+
}
|
@@ -0,0 +1,215 @@
|
|
1
|
+
use std::collections::HashMap;
|
2
|
+
use std::sync::{Arc, Mutex};
|
3
|
+
use candle_core::Result as CandleResult;
|
4
|
+
use candle_core::Error as CandleError;
|
5
|
+
use outlines_core::prelude::Index;
|
6
|
+
use outlines_core::vocabulary::Vocabulary;
|
7
|
+
use serde_json::Value as JsonValue;
|
8
|
+
use outlines_core::json_schema;
|
9
|
+
|
10
|
+
/// Processes JSON schemas into compiled Index objects for structured generation
|
11
|
+
pub struct SchemaProcessor {
|
12
|
+
/// Cache of compiled Index objects keyed by schema hash
|
13
|
+
cache: Arc<Mutex<HashMap<u64, Arc<Index>>>>,
|
14
|
+
}
|
15
|
+
|
16
|
+
impl SchemaProcessor {
|
17
|
+
/// Create a new schema processor with an empty cache
|
18
|
+
pub fn new() -> Self {
|
19
|
+
Self {
|
20
|
+
cache: Arc::new(Mutex::new(HashMap::new())),
|
21
|
+
}
|
22
|
+
}
|
23
|
+
|
24
|
+
/// Process a JSON schema into a compiled Index
|
25
|
+
///
|
26
|
+
/// # Arguments
|
27
|
+
/// * `schema` - JSON schema as a string
|
28
|
+
/// * `vocabulary` - The tokenizer's vocabulary
|
29
|
+
///
|
30
|
+
/// # Returns
|
31
|
+
/// A compiled Index ready for constrained generation
|
32
|
+
pub fn process_schema(
|
33
|
+
&self,
|
34
|
+
schema: &str,
|
35
|
+
vocabulary: &Vocabulary,
|
36
|
+
) -> CandleResult<Arc<Index>> {
|
37
|
+
// Calculate hash of the schema for caching
|
38
|
+
let schema_hash = self.calculate_hash(schema);
|
39
|
+
|
40
|
+
// Check cache first
|
41
|
+
if let Ok(cache) = self.cache.lock() {
|
42
|
+
if let Some(cached_index) = cache.get(&schema_hash) {
|
43
|
+
return Ok(Arc::clone(cached_index));
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
47
|
+
// Parse the JSON schema
|
48
|
+
let schema_value: JsonValue = serde_json::from_str(schema)
|
49
|
+
.map_err(|e| CandleError::Msg(format!("Invalid JSON schema: {}", e)))?;
|
50
|
+
|
51
|
+
// Convert schema to regex using Outlines
|
52
|
+
let regex = self.schema_to_regex(&schema_value)?;
|
53
|
+
|
54
|
+
// Compile regex into Index
|
55
|
+
let index = self.compile_regex(®ex, vocabulary)?;
|
56
|
+
let index_arc = Arc::new(index);
|
57
|
+
|
58
|
+
// Cache the compiled Index
|
59
|
+
if let Ok(mut cache) = self.cache.lock() {
|
60
|
+
cache.insert(schema_hash, Arc::clone(&index_arc));
|
61
|
+
}
|
62
|
+
|
63
|
+
Ok(index_arc)
|
64
|
+
}
|
65
|
+
|
66
|
+
/// Process a regex pattern directly into an Index
|
67
|
+
///
|
68
|
+
/// # Arguments
|
69
|
+
/// * `regex` - Regular expression pattern
|
70
|
+
/// * `vocabulary` - The tokenizer's vocabulary
|
71
|
+
///
|
72
|
+
/// # Returns
|
73
|
+
/// A compiled Index for the regex pattern
|
74
|
+
pub fn process_regex(
|
75
|
+
&self,
|
76
|
+
regex: &str,
|
77
|
+
vocabulary: &Vocabulary,
|
78
|
+
) -> CandleResult<Arc<Index>> {
|
79
|
+
// Calculate hash for caching
|
80
|
+
let regex_hash = self.calculate_hash(regex);
|
81
|
+
|
82
|
+
// Check cache
|
83
|
+
if let Ok(cache) = self.cache.lock() {
|
84
|
+
if let Some(cached_index) = cache.get(®ex_hash) {
|
85
|
+
return Ok(Arc::clone(cached_index));
|
86
|
+
}
|
87
|
+
}
|
88
|
+
|
89
|
+
// Compile the regex
|
90
|
+
let index = self.compile_regex(regex, vocabulary)?;
|
91
|
+
let index_arc = Arc::new(index);
|
92
|
+
|
93
|
+
// Cache it
|
94
|
+
if let Ok(mut cache) = self.cache.lock() {
|
95
|
+
cache.insert(regex_hash, Arc::clone(&index_arc));
|
96
|
+
}
|
97
|
+
|
98
|
+
Ok(index_arc)
|
99
|
+
}
|
100
|
+
|
101
|
+
/// Convert a JSON schema to a regex pattern
|
102
|
+
fn schema_to_regex(&self, schema: &JsonValue) -> CandleResult<String> {
|
103
|
+
// Use Outlines' built-in JSON schema to regex conversion
|
104
|
+
json_schema::regex_from_value(schema, None)
|
105
|
+
.map_err(|e| CandleError::Msg(format!("Failed to convert schema to regex: {:?}", e)))
|
106
|
+
}
|
107
|
+
|
108
|
+
/// Compile a regex pattern into an Index
|
109
|
+
fn compile_regex(&self, regex: &str, vocabulary: &Vocabulary) -> CandleResult<Index> {
|
110
|
+
// Use Outlines to build the Index from regex
|
111
|
+
Index::new(regex, vocabulary)
|
112
|
+
.map_err(|e| CandleError::Msg(format!("Failed to build index from regex: {:?}", e)))
|
113
|
+
}
|
114
|
+
|
115
|
+
/// Calculate a hash for caching
|
116
|
+
fn calculate_hash(&self, input: &str) -> u64 {
|
117
|
+
use std::collections::hash_map::DefaultHasher;
|
118
|
+
use std::hash::{Hash, Hasher};
|
119
|
+
|
120
|
+
let mut hasher = DefaultHasher::new();
|
121
|
+
input.hash(&mut hasher);
|
122
|
+
hasher.finish()
|
123
|
+
}
|
124
|
+
|
125
|
+
/// Clear the cache
|
126
|
+
pub fn clear_cache(&self) {
|
127
|
+
if let Ok(mut cache) = self.cache.lock() {
|
128
|
+
cache.clear();
|
129
|
+
}
|
130
|
+
}
|
131
|
+
|
132
|
+
/// Get cache statistics
|
133
|
+
pub fn cache_stats(&self) -> (usize, usize) {
|
134
|
+
if let Ok(cache) = self.cache.lock() {
|
135
|
+
let size = cache.len();
|
136
|
+
let capacity = cache.capacity();
|
137
|
+
(size, capacity)
|
138
|
+
} else {
|
139
|
+
(0, 0)
|
140
|
+
}
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
impl Default for SchemaProcessor {
|
145
|
+
fn default() -> Self {
|
146
|
+
Self::new()
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
#[cfg(test)]
|
151
|
+
mod tests {
|
152
|
+
use super::*;
|
153
|
+
|
154
|
+
#[test]
|
155
|
+
fn test_schema_processor_creation() {
|
156
|
+
let processor = SchemaProcessor::new();
|
157
|
+
let (size, _) = processor.cache_stats();
|
158
|
+
assert_eq!(size, 0, "Cache should start empty");
|
159
|
+
}
|
160
|
+
|
161
|
+
#[test]
|
162
|
+
fn test_cache_operations() {
|
163
|
+
let processor = SchemaProcessor::new();
|
164
|
+
|
165
|
+
// Initially empty
|
166
|
+
let (size, _) = processor.cache_stats();
|
167
|
+
assert_eq!(size, 0);
|
168
|
+
|
169
|
+
// After clear (should still be empty)
|
170
|
+
processor.clear_cache();
|
171
|
+
let (size, _) = processor.cache_stats();
|
172
|
+
assert_eq!(size, 0);
|
173
|
+
}
|
174
|
+
|
175
|
+
#[test]
|
176
|
+
fn test_schema_to_regex_basic_types() {
|
177
|
+
let processor = SchemaProcessor::new();
|
178
|
+
|
179
|
+
// Test string type
|
180
|
+
let string_schema = serde_json::json!({
|
181
|
+
"type": "string"
|
182
|
+
});
|
183
|
+
let regex = processor.schema_to_regex(&string_schema).unwrap();
|
184
|
+
// Just verify it produces a regex, exact format depends on Outlines
|
185
|
+
assert!(!regex.is_empty(), "String schema should produce a regex");
|
186
|
+
|
187
|
+
// Test number type
|
188
|
+
let number_schema = serde_json::json!({
|
189
|
+
"type": "number"
|
190
|
+
});
|
191
|
+
let regex = processor.schema_to_regex(&number_schema).unwrap();
|
192
|
+
assert!(!regex.is_empty(), "Number schema should produce a regex");
|
193
|
+
|
194
|
+
// Test boolean type
|
195
|
+
let bool_schema = serde_json::json!({
|
196
|
+
"type": "boolean"
|
197
|
+
});
|
198
|
+
let regex = processor.schema_to_regex(&bool_schema).unwrap();
|
199
|
+
assert!(regex.contains("true") && regex.contains("false"), "Boolean regex should contain true/false");
|
200
|
+
}
|
201
|
+
|
202
|
+
#[test]
|
203
|
+
fn test_schema_with_pattern() {
|
204
|
+
let processor = SchemaProcessor::new();
|
205
|
+
|
206
|
+
let schema = serde_json::json!({
|
207
|
+
"type": "string",
|
208
|
+
"pattern": r"^\d{3}-\d{3}-\d{4}$"
|
209
|
+
});
|
210
|
+
|
211
|
+
let regex = processor.schema_to_regex(&schema).unwrap();
|
212
|
+
// Pattern should be included in the generated regex
|
213
|
+
assert!(regex.contains("\\d{3}"), "Should contain digit pattern");
|
214
|
+
}
|
215
|
+
}
|
@@ -0,0 +1,152 @@
|
|
1
|
+
use crate::tokenizer::TokenizerWrapper;
|
2
|
+
use candle_core::Result as CandleResult;
|
3
|
+
use outlines_core::vocabulary::Vocabulary;
|
4
|
+
use std::collections::HashMap;
|
5
|
+
|
6
|
+
/// Adapter to convert between red-candle's TokenizerWrapper and Outlines' Vocabulary
|
7
|
+
pub struct VocabularyAdapter;
|
8
|
+
|
9
|
+
impl VocabularyAdapter {
|
10
|
+
/// Convert a TokenizerWrapper's vocabulary to an Outlines Vocabulary
|
11
|
+
///
|
12
|
+
/// # Arguments
|
13
|
+
/// * `tokenizer` - The tokenizer to extract vocabulary from
|
14
|
+
///
|
15
|
+
/// # Returns
|
16
|
+
/// An Outlines Vocabulary suitable for use with Index construction
|
17
|
+
pub fn from_tokenizer(tokenizer: &TokenizerWrapper) -> CandleResult<Vocabulary> {
|
18
|
+
// Get the vocabulary mapping from the tokenizer
|
19
|
+
let vocab_map: HashMap<String, u32> = tokenizer.inner().get_vocab(true);
|
20
|
+
|
21
|
+
// Try to find EOS token in vocabulary
|
22
|
+
let eos_token_id = vocab_map.get("</s>")
|
23
|
+
.or_else(|| vocab_map.get("<|endoftext|>"))
|
24
|
+
.or_else(|| vocab_map.get("<eos>"))
|
25
|
+
.or_else(|| vocab_map.get("[SEP]"))
|
26
|
+
.copied();
|
27
|
+
|
28
|
+
// Create a sorted list of (token_id, token_string) pairs
|
29
|
+
let mut token_pairs: Vec<(u32, String)> = vocab_map
|
30
|
+
.into_iter()
|
31
|
+
.map(|(token, id)| (id, token))
|
32
|
+
.collect();
|
33
|
+
|
34
|
+
// Sort by token ID to ensure correct indexing
|
35
|
+
token_pairs.sort_by_key(|(id, _)| *id);
|
36
|
+
|
37
|
+
// Find the maximum token ID to determine vocabulary size
|
38
|
+
let max_token_id = token_pairs
|
39
|
+
.last()
|
40
|
+
.map(|(id, _)| *id)
|
41
|
+
.unwrap_or(0);
|
42
|
+
|
43
|
+
// Create vocabulary items in the format expected by Outlines
|
44
|
+
// We need to handle potential gaps in token IDs
|
45
|
+
let mut vocab_items: Vec<(String, Vec<u8>)> = Vec::new();
|
46
|
+
let mut current_id = 0;
|
47
|
+
|
48
|
+
for (token_id, token_string) in token_pairs {
|
49
|
+
// Fill gaps with placeholder tokens
|
50
|
+
while current_id < token_id {
|
51
|
+
vocab_items.push((
|
52
|
+
format!("<unused_{}>", current_id),
|
53
|
+
format!("<unused_{}>", current_id).into_bytes(),
|
54
|
+
));
|
55
|
+
current_id += 1;
|
56
|
+
}
|
57
|
+
|
58
|
+
// Add the actual token
|
59
|
+
// Convert token string to bytes for Outlines
|
60
|
+
vocab_items.push((
|
61
|
+
token_string.clone(),
|
62
|
+
token_string.into_bytes(),
|
63
|
+
));
|
64
|
+
current_id += 1;
|
65
|
+
}
|
66
|
+
|
67
|
+
// Fill any remaining gaps up to a reasonable vocabulary size
|
68
|
+
// This ensures we don't have issues with token IDs beyond our vocabulary
|
69
|
+
while current_id <= max_token_id {
|
70
|
+
vocab_items.push((
|
71
|
+
format!("<unused_{}>", current_id),
|
72
|
+
format!("<unused_{}>", current_id).into_bytes(),
|
73
|
+
));
|
74
|
+
current_id += 1;
|
75
|
+
}
|
76
|
+
|
77
|
+
// Create the Outlines vocabulary
|
78
|
+
// The Vocabulary API expects us to build it token by token
|
79
|
+
let mut vocabulary = Vocabulary::new(
|
80
|
+
eos_token_id.unwrap_or(0) // Use EOS token ID or 0 as default
|
81
|
+
);
|
82
|
+
|
83
|
+
// Insert all tokens into the vocabulary
|
84
|
+
for (idx, (token, bytes)) in vocab_items.into_iter().enumerate() {
|
85
|
+
// Skip inserting the EOS token as it's already set in the vocabulary
|
86
|
+
if Some(idx as u32) == eos_token_id {
|
87
|
+
continue;
|
88
|
+
}
|
89
|
+
|
90
|
+
vocabulary.try_insert(bytes, idx as u32)
|
91
|
+
.map_err(|e| candle_core::Error::Msg(
|
92
|
+
format!("Failed to insert token '{}': {:?}", token, e)
|
93
|
+
))?;
|
94
|
+
}
|
95
|
+
|
96
|
+
Ok(vocabulary)
|
97
|
+
}
|
98
|
+
|
99
|
+
/// Get vocabulary size from a tokenizer
|
100
|
+
pub fn vocab_size(tokenizer: &TokenizerWrapper) -> usize {
|
101
|
+
tokenizer.inner().get_vocab_size(true)
|
102
|
+
}
|
103
|
+
|
104
|
+
/// Extract and validate special tokens
|
105
|
+
pub fn get_special_tokens(tokenizer: &TokenizerWrapper) -> HashMap<String, u32> {
|
106
|
+
let tokenizer_inner = tokenizer.inner();
|
107
|
+
let mut special_tokens = HashMap::new();
|
108
|
+
|
109
|
+
// Get common special tokens if they exist
|
110
|
+
if let Some(_token) = tokenizer_inner.id_to_token(0) {
|
111
|
+
special_tokens.insert("pad_token".to_string(), 0);
|
112
|
+
}
|
113
|
+
|
114
|
+
// Try to find EOS token
|
115
|
+
let vocab = tokenizer_inner.get_vocab(true);
|
116
|
+
if let Some(&eos_id) = vocab.get("</s>")
|
117
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
118
|
+
.or_else(|| vocab.get("<eos>"))
|
119
|
+
.or_else(|| vocab.get("[SEP]")) {
|
120
|
+
special_tokens.insert("eos_token".to_string(), eos_id);
|
121
|
+
}
|
122
|
+
|
123
|
+
// Try to get BOS token if it exists
|
124
|
+
if let Some(bos_token) = tokenizer_inner.token_to_id("<s>") {
|
125
|
+
special_tokens.insert("bos_token".to_string(), bos_token);
|
126
|
+
} else if let Some(bos_token) = tokenizer_inner.token_to_id("<|startoftext|>") {
|
127
|
+
special_tokens.insert("bos_token".to_string(), bos_token);
|
128
|
+
}
|
129
|
+
|
130
|
+
special_tokens
|
131
|
+
}
|
132
|
+
}
|
133
|
+
|
134
|
+
#[cfg(test)]
|
135
|
+
mod tests {
|
136
|
+
|
137
|
+
#[test]
|
138
|
+
fn test_vocabulary_adapter_creation() {
|
139
|
+
// This test will be implemented once we have a way to create test tokenizers
|
140
|
+
// For now, it serves as a placeholder for the test structure
|
141
|
+
}
|
142
|
+
|
143
|
+
#[test]
|
144
|
+
fn test_special_tokens_extraction() {
|
145
|
+
// Test special token extraction logic
|
146
|
+
}
|
147
|
+
|
148
|
+
#[test]
|
149
|
+
fn test_vocab_size() {
|
150
|
+
// Test vocabulary size calculation
|
151
|
+
}
|
152
|
+
}
|
@@ -0,0 +1,66 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod real_tests {
|
3
|
+
use super::super::*;
|
4
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
5
|
+
|
6
|
+
#[tokio::test]
|
7
|
+
async fn test_vocabulary_conversion_with_real_outlines() {
|
8
|
+
// This test requires network access to download a tokenizer
|
9
|
+
// It verifies that our adapter works with the real outlines-core crate
|
10
|
+
|
11
|
+
// Load a simple tokenizer
|
12
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
13
|
+
|
14
|
+
if let Ok(tokenizer) = tokenizer_result {
|
15
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
16
|
+
|
17
|
+
// Convert to Outlines vocabulary
|
18
|
+
let vocab_result = VocabularyAdapter::from_tokenizer(&wrapper);
|
19
|
+
assert!(vocab_result.is_ok(), "Vocabulary conversion should succeed");
|
20
|
+
|
21
|
+
let vocabulary = vocab_result.unwrap();
|
22
|
+
|
23
|
+
// Verify the vocabulary was created
|
24
|
+
// The real Vocabulary doesn't expose a size method directly,
|
25
|
+
// but we can verify it exists and has the correct EOS token
|
26
|
+
assert_eq!(vocabulary.eos_token_id(), 102); // BERT's [SEP] token
|
27
|
+
|
28
|
+
println!("✓ Successfully created Outlines Vocabulary from BERT tokenizer");
|
29
|
+
} else {
|
30
|
+
println!("⚠️ Skipping test - couldn't download tokenizer (likely offline)");
|
31
|
+
}
|
32
|
+
}
|
33
|
+
|
34
|
+
#[test]
|
35
|
+
fn test_vocabulary_adapter_with_mock_data() {
|
36
|
+
// This test doesn't require network access
|
37
|
+
// It uses a mock tokenizer to verify the conversion logic
|
38
|
+
|
39
|
+
use tokenizers::models::wordpiece::WordPiece;
|
40
|
+
use tokenizers::Tokenizer;
|
41
|
+
use std::collections::HashMap;
|
42
|
+
|
43
|
+
// Create a minimal vocabulary
|
44
|
+
let mut vocab = HashMap::new();
|
45
|
+
vocab.insert("[PAD]".to_string(), 0);
|
46
|
+
vocab.insert("[UNK]".to_string(), 1);
|
47
|
+
vocab.insert("[SEP]".to_string(), 2);
|
48
|
+
vocab.insert("hello".to_string(), 3);
|
49
|
+
vocab.insert("world".to_string(), 4);
|
50
|
+
|
51
|
+
let model = WordPiece::from_vocab(vocab);
|
52
|
+
let tokenizer = Tokenizer::new(model);
|
53
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
54
|
+
|
55
|
+
// Convert to Outlines vocabulary
|
56
|
+
let vocab_result = VocabularyAdapter::from_tokenizer(&wrapper);
|
57
|
+
assert!(vocab_result.is_ok(), "Vocabulary conversion should succeed");
|
58
|
+
|
59
|
+
let vocabulary = vocab_result.unwrap();
|
60
|
+
|
61
|
+
// Verify EOS token was found
|
62
|
+
assert_eq!(vocabulary.eos_token_id(), 2); // [SEP] token
|
63
|
+
|
64
|
+
println!("✓ Mock vocabulary conversion successful");
|
65
|
+
}
|
66
|
+
}
|
@@ -0,0 +1,70 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod simple_tests {
|
3
|
+
use super::super::*;
|
4
|
+
|
5
|
+
#[test]
|
6
|
+
fn test_vocabulary_adapter_basic() {
|
7
|
+
// Create a simple mock tokenizer to test the adapter
|
8
|
+
// This validates that the VocabularyAdapter compiles and can be called
|
9
|
+
|
10
|
+
// Note: Creating a full tokenizer in tests is complex due to the tokenizers crate API
|
11
|
+
// For now, we verify compilation and will rely on integration tests
|
12
|
+
|
13
|
+
// The important thing is that this code compiles, proving our integration works
|
14
|
+
let _adapter = VocabularyAdapter;
|
15
|
+
|
16
|
+
// Test the static methods compile
|
17
|
+
// These would be tested with a real tokenizer in integration tests
|
18
|
+
|
19
|
+
// Test passes if this compiles - no output needed
|
20
|
+
}
|
21
|
+
|
22
|
+
#[test]
|
23
|
+
fn test_outlines_vocabulary_api() {
|
24
|
+
use outlines_core::vocabulary::Vocabulary;
|
25
|
+
|
26
|
+
// Test that we can create a Vocabulary object
|
27
|
+
// Use token ID 2 as EOS (like BERT's [SEP] token)
|
28
|
+
let mut vocab = Vocabulary::new(2);
|
29
|
+
|
30
|
+
// Test inserting tokens
|
31
|
+
let test_tokens = vec![
|
32
|
+
("<pad>".to_string(), "<pad>".as_bytes().to_vec()),
|
33
|
+
("<unk>".to_string(), "<unk>".as_bytes().to_vec()),
|
34
|
+
("<sep>".to_string(), "<sep>".as_bytes().to_vec()), // EOS token at ID 2
|
35
|
+
("hello".to_string(), "hello".as_bytes().to_vec()),
|
36
|
+
("world".to_string(), "world".as_bytes().to_vec()),
|
37
|
+
];
|
38
|
+
|
39
|
+
for (idx, (_token, bytes)) in test_tokens.into_iter().enumerate() {
|
40
|
+
match vocab.try_insert(bytes, idx as u32) {
|
41
|
+
Ok(_) => {},
|
42
|
+
Err(e) => {
|
43
|
+
// It's ok if we can't insert the EOS token
|
44
|
+
if idx != 2 {
|
45
|
+
panic!("Failed to insert token at index {}: {:?}", idx, e);
|
46
|
+
}
|
47
|
+
}
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
// Test passes - vocabulary API works correctly
|
52
|
+
}
|
53
|
+
|
54
|
+
#[test]
|
55
|
+
fn test_special_token_patterns() {
|
56
|
+
|
57
|
+
// Test that our special token patterns are correct
|
58
|
+
let test_cases = vec![
|
59
|
+
("</s>", "EOS token for many models"),
|
60
|
+
("<|endoftext|>", "GPT-style EOS token"),
|
61
|
+
("<eos>", "Alternative EOS token"),
|
62
|
+
("[SEP]", "BERT-style separator"),
|
63
|
+
("<s>", "BOS token"),
|
64
|
+
("<|startoftext|>", "GPT-style BOS token"),
|
65
|
+
];
|
66
|
+
|
67
|
+
// Just verify the patterns exist - no output needed
|
68
|
+
assert_eq!(test_cases.len(), 6, "Should have 6 special token patterns");
|
69
|
+
}
|
70
|
+
}
|
data/lib/candle/llm.rb
CHANGED
@@ -1,5 +1,67 @@
|
|
1
|
+
require 'json'
|
2
|
+
|
1
3
|
module Candle
|
2
4
|
class LLM
|
5
|
+
# Create a structured constraint from a JSON schema
|
6
|
+
def constraint_from_schema(schema)
|
7
|
+
schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
|
8
|
+
StructuredConstraint.from_schema(schema_str, tokenizer)
|
9
|
+
end
|
10
|
+
|
11
|
+
# Create a structured constraint from a regex pattern
|
12
|
+
def constraint_from_regex(pattern)
|
13
|
+
pattern_str = pattern.is_a?(Regexp) ? pattern.source : pattern.to_s
|
14
|
+
StructuredConstraint.from_regex(pattern_str, tokenizer)
|
15
|
+
end
|
16
|
+
|
17
|
+
# Generate with regex constraint
|
18
|
+
def generate_regex(prompt, pattern:, **options)
|
19
|
+
constraint = constraint_from_regex(pattern)
|
20
|
+
|
21
|
+
# Add common EOS tokens as stop sequences for regex generation
|
22
|
+
stop_sequences = options[:stop_sequences] || []
|
23
|
+
stop_sequences += ["</s>", "<|endoftext|>", "<|im_end|>", "<end>", "\n"] unless options[:no_auto_stop]
|
24
|
+
|
25
|
+
config_opts = options.merge(constraint: constraint, stop_sequences: stop_sequences)
|
26
|
+
config = options[:config] || GenerationConfig.balanced(**config_opts)
|
27
|
+
|
28
|
+
result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
|
29
|
+
|
30
|
+
# Clean up any trailing EOS tokens
|
31
|
+
result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
|
32
|
+
end
|
33
|
+
|
34
|
+
# Generate and parse structured output from a JSON schema
|
35
|
+
def generate_structured(prompt, schema:, **options)
|
36
|
+
constraint = constraint_from_schema(schema)
|
37
|
+
config_opts = options.merge(constraint: constraint)
|
38
|
+
config = options[:config] || GenerationConfig.balanced(**config_opts)
|
39
|
+
|
40
|
+
result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
|
41
|
+
|
42
|
+
# Clean up the result - remove common end-of-sequence tokens
|
43
|
+
# that might appear after valid JSON
|
44
|
+
cleaned_result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '')
|
45
|
+
|
46
|
+
# Try to parse as JSON
|
47
|
+
begin
|
48
|
+
JSON.parse(cleaned_result)
|
49
|
+
rescue JSON::ParserError => e
|
50
|
+
# If cleaning didn't help, try to extract JSON from the result
|
51
|
+
# Look for the first complete JSON object/array
|
52
|
+
if match = cleaned_result.match(/(\{[^{}]*\}|\[[^\[\]]*\])/m)
|
53
|
+
begin
|
54
|
+
return JSON.parse(match[1])
|
55
|
+
rescue JSON::ParserError
|
56
|
+
# Fall through to warning
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
# Return the raw string if parsing fails
|
61
|
+
warn "Warning: Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
|
62
|
+
result
|
63
|
+
end
|
64
|
+
end
|
3
65
|
# Tokenizer registry for automatic detection
|
4
66
|
TOKENIZER_REGISTRY = {
|
5
67
|
# Exact model matches
|
@@ -8,6 +70,18 @@ module Candle
|
|
8
70
|
"TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
|
9
71
|
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
10
72
|
|
73
|
+
# Qwen official GGUF models
|
74
|
+
"Qwen/Qwen3-8B-GGUF" => "Qwen/Qwen3-8B",
|
75
|
+
"Qwen/Qwen3-4B-GGUF" => "Qwen/Qwen3-4B",
|
76
|
+
"Qwen/Qwen3-14B-GGUF" => "Qwen/Qwen3-14B",
|
77
|
+
"Qwen/Qwen3-32B-GGUF" => "Qwen/Qwen3-32B",
|
78
|
+
"Qwen/Qwen3-72B-GGUF" => "Qwen/Qwen3-72B",
|
79
|
+
|
80
|
+
# Phi GGUF models
|
81
|
+
"TheBloke/phi-2-GGUF" => "microsoft/phi-2",
|
82
|
+
"microsoft/phi-4-gguf" => "microsoft/phi-4",
|
83
|
+
"bartowski/Phi-3.5-mini-instruct-GGUF" => "microsoft/Phi-3.5-mini-instruct",
|
84
|
+
|
11
85
|
# Pattern-based fallbacks (evaluated in order)
|
12
86
|
:patterns => [
|
13
87
|
# Mistral models
|
@@ -27,7 +101,31 @@ module Candle
|
|
27
101
|
[/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
|
28
102
|
[/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
|
29
103
|
[/gemma.*?7b/i, "google/gemma-7b"],
|
30
|
-
[/gemma.*?2b/i, "google/gemma-2b"]
|
104
|
+
[/gemma.*?2b/i, "google/gemma-2b"],
|
105
|
+
|
106
|
+
# Qwen models
|
107
|
+
[/qwen.*?3.*?72b/i, "Qwen/Qwen3-72B"],
|
108
|
+
[/qwen.*?3.*?32b/i, "Qwen/Qwen3-32B"],
|
109
|
+
[/qwen.*?3.*?14b/i, "Qwen/Qwen3-14B"],
|
110
|
+
[/qwen.*?3.*?8b/i, "Qwen/Qwen3-8B"],
|
111
|
+
[/qwen.*?3.*?4b/i, "Qwen/Qwen3-4B"],
|
112
|
+
[/qwen.*?3.*?1\.8b/i, "Qwen/Qwen3-1.8B"],
|
113
|
+
[/qwen.*?3.*?0\.5b/i, "Qwen/Qwen3-0.5B"],
|
114
|
+
[/qwen.*?2\.5/i, "Qwen/Qwen2.5-0.5B"],
|
115
|
+
[/qwen.*?2/i, "Qwen/Qwen2-1.5B"],
|
116
|
+
[/qwen/i, "Qwen/Qwen-1_8B"],
|
117
|
+
|
118
|
+
# Phi models (order matters - more specific patterns first)
|
119
|
+
[/phi.*?3\.5.*?mini/i, "microsoft/Phi-3.5-mini-instruct"],
|
120
|
+
[/phi.*?3.*?mini.*?4k/i, "microsoft/Phi-3-mini-4k-instruct"],
|
121
|
+
[/phi.*?3.*?medium/i, "microsoft/Phi-3-medium-4k-instruct"],
|
122
|
+
[/phi.*?3.*?small/i, "microsoft/Phi-3-small-8k-instruct"],
|
123
|
+
[/phi.*?3.*?mini/i, "microsoft/Phi-3-mini-4k-instruct"],
|
124
|
+
[/phi.*?3/i, "microsoft/Phi-3-mini-4k-instruct"],
|
125
|
+
[/phi-4/i, "microsoft/phi-4"],
|
126
|
+
[/phi.*?2/i, "microsoft/phi-2"],
|
127
|
+
[/phi.*?1\.5/i, "microsoft/phi-1_5"],
|
128
|
+
[/phi/i, "microsoft/phi-2"]
|
31
129
|
]
|
32
130
|
}
|
33
131
|
|
@@ -74,7 +172,14 @@ module Candle
|
|
74
172
|
|
75
173
|
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
76
174
|
begin
|
77
|
-
_generate(prompt, config)
|
175
|
+
result = _generate(prompt, config)
|
176
|
+
|
177
|
+
# If there's a constraint, clean up common EOS tokens that appear after the constrained content
|
178
|
+
if config.constraint
|
179
|
+
result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
|
180
|
+
end
|
181
|
+
|
182
|
+
result
|
78
183
|
ensure
|
79
184
|
clear_cache if reset_cache
|
80
185
|
end
|
@@ -155,7 +260,8 @@ module Candle
|
|
155
260
|
repetition_penalty: repetition_penalty,
|
156
261
|
seed: seed,
|
157
262
|
stop_sequences: stop_sequences,
|
158
|
-
include_prompt: include_prompt
|
263
|
+
include_prompt: include_prompt,
|
264
|
+
constraint: defined?(@constraint) ? @constraint : nil
|
159
265
|
}.compact
|
160
266
|
|
161
267
|
self.class.new(current_config.merge(overrides))
|
data/lib/candle/version.rb
CHANGED