red-candle 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,215 @@
1
+ use std::collections::HashMap;
2
+ use std::sync::{Arc, Mutex};
3
+ use candle_core::Result as CandleResult;
4
+ use candle_core::Error as CandleError;
5
+ use outlines_core::prelude::Index;
6
+ use outlines_core::vocabulary::Vocabulary;
7
+ use serde_json::Value as JsonValue;
8
+ use outlines_core::json_schema;
9
+
10
+ /// Processes JSON schemas into compiled Index objects for structured generation
11
+ pub struct SchemaProcessor {
12
+ /// Cache of compiled Index objects keyed by schema hash
13
+ cache: Arc<Mutex<HashMap<u64, Arc<Index>>>>,
14
+ }
15
+
16
+ impl SchemaProcessor {
17
+ /// Create a new schema processor with an empty cache
18
+ pub fn new() -> Self {
19
+ Self {
20
+ cache: Arc::new(Mutex::new(HashMap::new())),
21
+ }
22
+ }
23
+
24
+ /// Process a JSON schema into a compiled Index
25
+ ///
26
+ /// # Arguments
27
+ /// * `schema` - JSON schema as a string
28
+ /// * `vocabulary` - The tokenizer's vocabulary
29
+ ///
30
+ /// # Returns
31
+ /// A compiled Index ready for constrained generation
32
+ pub fn process_schema(
33
+ &self,
34
+ schema: &str,
35
+ vocabulary: &Vocabulary,
36
+ ) -> CandleResult<Arc<Index>> {
37
+ // Calculate hash of the schema for caching
38
+ let schema_hash = self.calculate_hash(schema);
39
+
40
+ // Check cache first
41
+ if let Ok(cache) = self.cache.lock() {
42
+ if let Some(cached_index) = cache.get(&schema_hash) {
43
+ return Ok(Arc::clone(cached_index));
44
+ }
45
+ }
46
+
47
+ // Parse the JSON schema
48
+ let schema_value: JsonValue = serde_json::from_str(schema)
49
+ .map_err(|e| CandleError::Msg(format!("Invalid JSON schema: {}", e)))?;
50
+
51
+ // Convert schema to regex using Outlines
52
+ let regex = self.schema_to_regex(&schema_value)?;
53
+
54
+ // Compile regex into Index
55
+ let index = self.compile_regex(&regex, vocabulary)?;
56
+ let index_arc = Arc::new(index);
57
+
58
+ // Cache the compiled Index
59
+ if let Ok(mut cache) = self.cache.lock() {
60
+ cache.insert(schema_hash, Arc::clone(&index_arc));
61
+ }
62
+
63
+ Ok(index_arc)
64
+ }
65
+
66
+ /// Process a regex pattern directly into an Index
67
+ ///
68
+ /// # Arguments
69
+ /// * `regex` - Regular expression pattern
70
+ /// * `vocabulary` - The tokenizer's vocabulary
71
+ ///
72
+ /// # Returns
73
+ /// A compiled Index for the regex pattern
74
+ pub fn process_regex(
75
+ &self,
76
+ regex: &str,
77
+ vocabulary: &Vocabulary,
78
+ ) -> CandleResult<Arc<Index>> {
79
+ // Calculate hash for caching
80
+ let regex_hash = self.calculate_hash(regex);
81
+
82
+ // Check cache
83
+ if let Ok(cache) = self.cache.lock() {
84
+ if let Some(cached_index) = cache.get(&regex_hash) {
85
+ return Ok(Arc::clone(cached_index));
86
+ }
87
+ }
88
+
89
+ // Compile the regex
90
+ let index = self.compile_regex(regex, vocabulary)?;
91
+ let index_arc = Arc::new(index);
92
+
93
+ // Cache it
94
+ if let Ok(mut cache) = self.cache.lock() {
95
+ cache.insert(regex_hash, Arc::clone(&index_arc));
96
+ }
97
+
98
+ Ok(index_arc)
99
+ }
100
+
101
+ /// Convert a JSON schema to a regex pattern
102
+ fn schema_to_regex(&self, schema: &JsonValue) -> CandleResult<String> {
103
+ // Use Outlines' built-in JSON schema to regex conversion
104
+ json_schema::regex_from_value(schema, None)
105
+ .map_err(|e| CandleError::Msg(format!("Failed to convert schema to regex: {:?}", e)))
106
+ }
107
+
108
+ /// Compile a regex pattern into an Index
109
+ fn compile_regex(&self, regex: &str, vocabulary: &Vocabulary) -> CandleResult<Index> {
110
+ // Use Outlines to build the Index from regex
111
+ Index::new(regex, vocabulary)
112
+ .map_err(|e| CandleError::Msg(format!("Failed to build index from regex: {:?}", e)))
113
+ }
114
+
115
+ /// Calculate a hash for caching
116
+ fn calculate_hash(&self, input: &str) -> u64 {
117
+ use std::collections::hash_map::DefaultHasher;
118
+ use std::hash::{Hash, Hasher};
119
+
120
+ let mut hasher = DefaultHasher::new();
121
+ input.hash(&mut hasher);
122
+ hasher.finish()
123
+ }
124
+
125
+ /// Clear the cache
126
+ pub fn clear_cache(&self) {
127
+ if let Ok(mut cache) = self.cache.lock() {
128
+ cache.clear();
129
+ }
130
+ }
131
+
132
+ /// Get cache statistics
133
+ pub fn cache_stats(&self) -> (usize, usize) {
134
+ if let Ok(cache) = self.cache.lock() {
135
+ let size = cache.len();
136
+ let capacity = cache.capacity();
137
+ (size, capacity)
138
+ } else {
139
+ (0, 0)
140
+ }
141
+ }
142
+ }
143
+
144
+ impl Default for SchemaProcessor {
145
+ fn default() -> Self {
146
+ Self::new()
147
+ }
148
+ }
149
+
150
+ #[cfg(test)]
151
+ mod tests {
152
+ use super::*;
153
+
154
+ #[test]
155
+ fn test_schema_processor_creation() {
156
+ let processor = SchemaProcessor::new();
157
+ let (size, _) = processor.cache_stats();
158
+ assert_eq!(size, 0, "Cache should start empty");
159
+ }
160
+
161
+ #[test]
162
+ fn test_cache_operations() {
163
+ let processor = SchemaProcessor::new();
164
+
165
+ // Initially empty
166
+ let (size, _) = processor.cache_stats();
167
+ assert_eq!(size, 0);
168
+
169
+ // After clear (should still be empty)
170
+ processor.clear_cache();
171
+ let (size, _) = processor.cache_stats();
172
+ assert_eq!(size, 0);
173
+ }
174
+
175
+ #[test]
176
+ fn test_schema_to_regex_basic_types() {
177
+ let processor = SchemaProcessor::new();
178
+
179
+ // Test string type
180
+ let string_schema = serde_json::json!({
181
+ "type": "string"
182
+ });
183
+ let regex = processor.schema_to_regex(&string_schema).unwrap();
184
+ // Just verify it produces a regex, exact format depends on Outlines
185
+ assert!(!regex.is_empty(), "String schema should produce a regex");
186
+
187
+ // Test number type
188
+ let number_schema = serde_json::json!({
189
+ "type": "number"
190
+ });
191
+ let regex = processor.schema_to_regex(&number_schema).unwrap();
192
+ assert!(!regex.is_empty(), "Number schema should produce a regex");
193
+
194
+ // Test boolean type
195
+ let bool_schema = serde_json::json!({
196
+ "type": "boolean"
197
+ });
198
+ let regex = processor.schema_to_regex(&bool_schema).unwrap();
199
+ assert!(regex.contains("true") && regex.contains("false"), "Boolean regex should contain true/false");
200
+ }
201
+
202
+ #[test]
203
+ fn test_schema_with_pattern() {
204
+ let processor = SchemaProcessor::new();
205
+
206
+ let schema = serde_json::json!({
207
+ "type": "string",
208
+ "pattern": r"^\d{3}-\d{3}-\d{4}$"
209
+ });
210
+
211
+ let regex = processor.schema_to_regex(&schema).unwrap();
212
+ // Pattern should be included in the generated regex
213
+ assert!(regex.contains("\\d{3}"), "Should contain digit pattern");
214
+ }
215
+ }
@@ -0,0 +1,152 @@
1
+ use crate::tokenizer::TokenizerWrapper;
2
+ use candle_core::Result as CandleResult;
3
+ use outlines_core::vocabulary::Vocabulary;
4
+ use std::collections::HashMap;
5
+
6
+ /// Adapter to convert between red-candle's TokenizerWrapper and Outlines' Vocabulary
7
+ pub struct VocabularyAdapter;
8
+
9
+ impl VocabularyAdapter {
10
+ /// Convert a TokenizerWrapper's vocabulary to an Outlines Vocabulary
11
+ ///
12
+ /// # Arguments
13
+ /// * `tokenizer` - The tokenizer to extract vocabulary from
14
+ ///
15
+ /// # Returns
16
+ /// An Outlines Vocabulary suitable for use with Index construction
17
+ pub fn from_tokenizer(tokenizer: &TokenizerWrapper) -> CandleResult<Vocabulary> {
18
+ // Get the vocabulary mapping from the tokenizer
19
+ let vocab_map: HashMap<String, u32> = tokenizer.inner().get_vocab(true);
20
+
21
+ // Try to find EOS token in vocabulary
22
+ let eos_token_id = vocab_map.get("</s>")
23
+ .or_else(|| vocab_map.get("<|endoftext|>"))
24
+ .or_else(|| vocab_map.get("<eos>"))
25
+ .or_else(|| vocab_map.get("[SEP]"))
26
+ .copied();
27
+
28
+ // Create a sorted list of (token_id, token_string) pairs
29
+ let mut token_pairs: Vec<(u32, String)> = vocab_map
30
+ .into_iter()
31
+ .map(|(token, id)| (id, token))
32
+ .collect();
33
+
34
+ // Sort by token ID to ensure correct indexing
35
+ token_pairs.sort_by_key(|(id, _)| *id);
36
+
37
+ // Find the maximum token ID to determine vocabulary size
38
+ let max_token_id = token_pairs
39
+ .last()
40
+ .map(|(id, _)| *id)
41
+ .unwrap_or(0);
42
+
43
+ // Create vocabulary items in the format expected by Outlines
44
+ // We need to handle potential gaps in token IDs
45
+ let mut vocab_items: Vec<(String, Vec<u8>)> = Vec::new();
46
+ let mut current_id = 0;
47
+
48
+ for (token_id, token_string) in token_pairs {
49
+ // Fill gaps with placeholder tokens
50
+ while current_id < token_id {
51
+ vocab_items.push((
52
+ format!("<unused_{}>", current_id),
53
+ format!("<unused_{}>", current_id).into_bytes(),
54
+ ));
55
+ current_id += 1;
56
+ }
57
+
58
+ // Add the actual token
59
+ // Convert token string to bytes for Outlines
60
+ vocab_items.push((
61
+ token_string.clone(),
62
+ token_string.into_bytes(),
63
+ ));
64
+ current_id += 1;
65
+ }
66
+
67
+ // Fill any remaining gaps up to a reasonable vocabulary size
68
+ // This ensures we don't have issues with token IDs beyond our vocabulary
69
+ while current_id <= max_token_id {
70
+ vocab_items.push((
71
+ format!("<unused_{}>", current_id),
72
+ format!("<unused_{}>", current_id).into_bytes(),
73
+ ));
74
+ current_id += 1;
75
+ }
76
+
77
+ // Create the Outlines vocabulary
78
+ // The Vocabulary API expects us to build it token by token
79
+ let mut vocabulary = Vocabulary::new(
80
+ eos_token_id.unwrap_or(0) // Use EOS token ID or 0 as default
81
+ );
82
+
83
+ // Insert all tokens into the vocabulary
84
+ for (idx, (token, bytes)) in vocab_items.into_iter().enumerate() {
85
+ // Skip inserting the EOS token as it's already set in the vocabulary
86
+ if Some(idx as u32) == eos_token_id {
87
+ continue;
88
+ }
89
+
90
+ vocabulary.try_insert(bytes, idx as u32)
91
+ .map_err(|e| candle_core::Error::Msg(
92
+ format!("Failed to insert token '{}': {:?}", token, e)
93
+ ))?;
94
+ }
95
+
96
+ Ok(vocabulary)
97
+ }
98
+
99
+ /// Get vocabulary size from a tokenizer
100
+ pub fn vocab_size(tokenizer: &TokenizerWrapper) -> usize {
101
+ tokenizer.inner().get_vocab_size(true)
102
+ }
103
+
104
+ /// Extract and validate special tokens
105
+ pub fn get_special_tokens(tokenizer: &TokenizerWrapper) -> HashMap<String, u32> {
106
+ let tokenizer_inner = tokenizer.inner();
107
+ let mut special_tokens = HashMap::new();
108
+
109
+ // Get common special tokens if they exist
110
+ if let Some(_token) = tokenizer_inner.id_to_token(0) {
111
+ special_tokens.insert("pad_token".to_string(), 0);
112
+ }
113
+
114
+ // Try to find EOS token
115
+ let vocab = tokenizer_inner.get_vocab(true);
116
+ if let Some(&eos_id) = vocab.get("</s>")
117
+ .or_else(|| vocab.get("<|endoftext|>"))
118
+ .or_else(|| vocab.get("<eos>"))
119
+ .or_else(|| vocab.get("[SEP]")) {
120
+ special_tokens.insert("eos_token".to_string(), eos_id);
121
+ }
122
+
123
+ // Try to get BOS token if it exists
124
+ if let Some(bos_token) = tokenizer_inner.token_to_id("<s>") {
125
+ special_tokens.insert("bos_token".to_string(), bos_token);
126
+ } else if let Some(bos_token) = tokenizer_inner.token_to_id("<|startoftext|>") {
127
+ special_tokens.insert("bos_token".to_string(), bos_token);
128
+ }
129
+
130
+ special_tokens
131
+ }
132
+ }
133
+
134
+ #[cfg(test)]
135
+ mod tests {
136
+
137
+ #[test]
138
+ fn test_vocabulary_adapter_creation() {
139
+ // This test will be implemented once we have a way to create test tokenizers
140
+ // For now, it serves as a placeholder for the test structure
141
+ }
142
+
143
+ #[test]
144
+ fn test_special_tokens_extraction() {
145
+ // Test special token extraction logic
146
+ }
147
+
148
+ #[test]
149
+ fn test_vocab_size() {
150
+ // Test vocabulary size calculation
151
+ }
152
+ }
@@ -0,0 +1,66 @@
1
+ #[cfg(test)]
2
+ mod real_tests {
3
+ use super::super::*;
4
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
5
+
6
+ #[tokio::test]
7
+ async fn test_vocabulary_conversion_with_real_outlines() {
8
+ // This test requires network access to download a tokenizer
9
+ // It verifies that our adapter works with the real outlines-core crate
10
+
11
+ // Load a simple tokenizer
12
+ let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
13
+
14
+ if let Ok(tokenizer) = tokenizer_result {
15
+ let wrapper = TokenizerWrapper::new(tokenizer);
16
+
17
+ // Convert to Outlines vocabulary
18
+ let vocab_result = VocabularyAdapter::from_tokenizer(&wrapper);
19
+ assert!(vocab_result.is_ok(), "Vocabulary conversion should succeed");
20
+
21
+ let vocabulary = vocab_result.unwrap();
22
+
23
+ // Verify the vocabulary was created
24
+ // The real Vocabulary doesn't expose a size method directly,
25
+ // but we can verify it exists and has the correct EOS token
26
+ assert_eq!(vocabulary.eos_token_id(), 102); // BERT's [SEP] token
27
+
28
+ println!("✓ Successfully created Outlines Vocabulary from BERT tokenizer");
29
+ } else {
30
+ println!("⚠️ Skipping test - couldn't download tokenizer (likely offline)");
31
+ }
32
+ }
33
+
34
+ #[test]
35
+ fn test_vocabulary_adapter_with_mock_data() {
36
+ // This test doesn't require network access
37
+ // It uses a mock tokenizer to verify the conversion logic
38
+
39
+ use tokenizers::models::wordpiece::WordPiece;
40
+ use tokenizers::Tokenizer;
41
+ use std::collections::HashMap;
42
+
43
+ // Create a minimal vocabulary
44
+ let mut vocab = HashMap::new();
45
+ vocab.insert("[PAD]".to_string(), 0);
46
+ vocab.insert("[UNK]".to_string(), 1);
47
+ vocab.insert("[SEP]".to_string(), 2);
48
+ vocab.insert("hello".to_string(), 3);
49
+ vocab.insert("world".to_string(), 4);
50
+
51
+ let model = WordPiece::from_vocab(vocab);
52
+ let tokenizer = Tokenizer::new(model);
53
+ let wrapper = TokenizerWrapper::new(tokenizer);
54
+
55
+ // Convert to Outlines vocabulary
56
+ let vocab_result = VocabularyAdapter::from_tokenizer(&wrapper);
57
+ assert!(vocab_result.is_ok(), "Vocabulary conversion should succeed");
58
+
59
+ let vocabulary = vocab_result.unwrap();
60
+
61
+ // Verify EOS token was found
62
+ assert_eq!(vocabulary.eos_token_id(), 2); // [SEP] token
63
+
64
+ println!("✓ Mock vocabulary conversion successful");
65
+ }
66
+ }
@@ -0,0 +1,70 @@
1
+ #[cfg(test)]
2
+ mod simple_tests {
3
+ use super::super::*;
4
+
5
+ #[test]
6
+ fn test_vocabulary_adapter_basic() {
7
+ // Create a simple mock tokenizer to test the adapter
8
+ // This validates that the VocabularyAdapter compiles and can be called
9
+
10
+ // Note: Creating a full tokenizer in tests is complex due to the tokenizers crate API
11
+ // For now, we verify compilation and will rely on integration tests
12
+
13
+ // The important thing is that this code compiles, proving our integration works
14
+ let _adapter = VocabularyAdapter;
15
+
16
+ // Test the static methods compile
17
+ // These would be tested with a real tokenizer in integration tests
18
+
19
+ // Test passes if this compiles - no output needed
20
+ }
21
+
22
+ #[test]
23
+ fn test_outlines_vocabulary_api() {
24
+ use outlines_core::vocabulary::Vocabulary;
25
+
26
+ // Test that we can create a Vocabulary object
27
+ // Use token ID 2 as EOS (like BERT's [SEP] token)
28
+ let mut vocab = Vocabulary::new(2);
29
+
30
+ // Test inserting tokens
31
+ let test_tokens = vec![
32
+ ("<pad>".to_string(), "<pad>".as_bytes().to_vec()),
33
+ ("<unk>".to_string(), "<unk>".as_bytes().to_vec()),
34
+ ("<sep>".to_string(), "<sep>".as_bytes().to_vec()), // EOS token at ID 2
35
+ ("hello".to_string(), "hello".as_bytes().to_vec()),
36
+ ("world".to_string(), "world".as_bytes().to_vec()),
37
+ ];
38
+
39
+ for (idx, (_token, bytes)) in test_tokens.into_iter().enumerate() {
40
+ match vocab.try_insert(bytes, idx as u32) {
41
+ Ok(_) => {},
42
+ Err(e) => {
43
+ // It's ok if we can't insert the EOS token
44
+ if idx != 2 {
45
+ panic!("Failed to insert token at index {}: {:?}", idx, e);
46
+ }
47
+ }
48
+ }
49
+ }
50
+
51
+ // Test passes - vocabulary API works correctly
52
+ }
53
+
54
+ #[test]
55
+ fn test_special_token_patterns() {
56
+
57
+ // Test that our special token patterns are correct
58
+ let test_cases = vec![
59
+ ("</s>", "EOS token for many models"),
60
+ ("<|endoftext|>", "GPT-style EOS token"),
61
+ ("<eos>", "Alternative EOS token"),
62
+ ("[SEP]", "BERT-style separator"),
63
+ ("<s>", "BOS token"),
64
+ ("<|startoftext|>", "GPT-style BOS token"),
65
+ ];
66
+
67
+ // Just verify the patterns exist - no output needed
68
+ assert_eq!(test_cases.len(), 6, "Should have 6 special token patterns");
69
+ }
70
+ }