red-candle 1.0.1 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +57 -4
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/build.rs +6 -5
- data/ext/candle/extconf.rb +5 -6
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +123 -0
- data/ext/candle/src/llm/generation_config.rs +5 -0
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +285 -0
- data/ext/candle/src/llm/quantized_gguf.rs +155 -4
- data/ext/candle/src/llm/qwen.rs +229 -0
- data/ext/candle/src/llm/text_generation.rs +66 -2
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +42 -4
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/build_info.rb +2 -2
- data/lib/candle/llm.rb +109 -3
- data/lib/candle/version.rb +1 -1
- data/lib/red-candle.rb +1 -0
- metadata +15 -4
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,15 +1,18 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
|
+
use std::sync::Arc;
|
3
4
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, phi::Phi as RustPhi, QuantizedGGUF as RustQuantizedGGUF};
|
5
6
|
use crate::ruby::{Result, Device};
|
7
|
+
use crate::ruby::structured::StructuredConstraint;
|
6
8
|
|
7
9
|
// Use an enum to handle different model types instead of trait objects
|
8
|
-
#[derive(Debug)]
|
9
10
|
enum ModelType {
|
10
11
|
Mistral(RustMistral),
|
11
12
|
Llama(RustLlama),
|
12
13
|
Gemma(RustGemma),
|
14
|
+
Qwen(RustQwen),
|
15
|
+
Phi(RustPhi),
|
13
16
|
QuantizedGGUF(RustQuantizedGGUF),
|
14
17
|
}
|
15
18
|
|
@@ -19,6 +22,8 @@ impl ModelType {
|
|
19
22
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
20
23
|
ModelType::Llama(m) => m.generate(prompt, config),
|
21
24
|
ModelType::Gemma(m) => m.generate(prompt, config),
|
25
|
+
ModelType::Qwen(m) => m.generate(prompt, config),
|
26
|
+
ModelType::Phi(m) => m.generate(prompt, config),
|
22
27
|
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
23
28
|
}
|
24
29
|
}
|
@@ -33,6 +38,8 @@ impl ModelType {
|
|
33
38
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
34
39
|
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
35
40
|
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
41
|
+
ModelType::Qwen(m) => m.generate_stream(prompt, config, callback),
|
42
|
+
ModelType::Phi(m) => m.generate_stream(prompt, config, callback),
|
36
43
|
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
37
44
|
}
|
38
45
|
}
|
@@ -42,6 +49,8 @@ impl ModelType {
|
|
42
49
|
ModelType::Mistral(m) => m.clear_cache(),
|
43
50
|
ModelType::Llama(m) => m.clear_cache(),
|
44
51
|
ModelType::Gemma(m) => m.clear_cache(),
|
52
|
+
ModelType::Qwen(m) => m.clear_cache(),
|
53
|
+
ModelType::Phi(m) => m.clear_cache(),
|
45
54
|
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
46
55
|
}
|
47
56
|
}
|
@@ -67,6 +76,8 @@ impl ModelType {
|
|
67
76
|
},
|
68
77
|
ModelType::Llama(m) => m.apply_chat_template(messages),
|
69
78
|
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
79
|
+
ModelType::Qwen(m) => m.apply_chat_template(messages),
|
80
|
+
ModelType::Phi(m) => m.apply_chat_template(messages),
|
70
81
|
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
71
82
|
}
|
72
83
|
}
|
@@ -146,6 +157,13 @@ impl GenerationConfig {
|
|
146
157
|
}
|
147
158
|
}
|
148
159
|
|
160
|
+
// Handle constraint parameter
|
161
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
|
162
|
+
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
163
|
+
config.constraint = Some(Arc::clone(&constraint.index));
|
164
|
+
}
|
165
|
+
}
|
166
|
+
|
149
167
|
Ok(Self { inner: config })
|
150
168
|
}
|
151
169
|
|
@@ -191,9 +209,14 @@ impl GenerationConfig {
|
|
191
209
|
pub fn debug_tokens(&self) -> bool {
|
192
210
|
self.inner.debug_tokens
|
193
211
|
}
|
212
|
+
pub fn constraint(&self) -> Option<StructuredConstraint> {
|
213
|
+
self.inner.constraint.as_ref().map(|c| StructuredConstraint {
|
214
|
+
index: Arc::clone(c),
|
215
|
+
})
|
216
|
+
}
|
194
217
|
}
|
195
218
|
|
196
|
-
#[derive(Clone
|
219
|
+
#[derive(Clone)]
|
197
220
|
#[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
|
198
221
|
pub struct LLM {
|
199
222
|
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
@@ -251,10 +274,22 @@ impl LLM {
|
|
251
274
|
})
|
252
275
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
253
276
|
ModelType::Gemma(gemma)
|
277
|
+
} else if model_lower.contains("qwen") {
|
278
|
+
let qwen = rt.block_on(async {
|
279
|
+
RustQwen::from_pretrained(&model_id, candle_device).await
|
280
|
+
})
|
281
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
282
|
+
ModelType::Qwen(qwen)
|
283
|
+
} else if model_lower.contains("phi") {
|
284
|
+
let phi = rt.block_on(async {
|
285
|
+
RustPhi::from_pretrained(&model_id, candle_device).await
|
286
|
+
})
|
287
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
288
|
+
ModelType::Phi(phi)
|
254
289
|
} else {
|
255
290
|
return Err(Error::new(
|
256
291
|
magnus::exception::runtime_error(),
|
257
|
-
format!("Unsupported model type: {}. Currently Mistral, Llama, and
|
292
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id),
|
258
293
|
));
|
259
294
|
}
|
260
295
|
};
|
@@ -332,6 +367,8 @@ impl LLM {
|
|
332
367
|
ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
333
368
|
ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
334
369
|
ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
370
|
+
ModelType::Qwen(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
371
|
+
ModelType::Phi(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
335
372
|
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
336
373
|
}
|
337
374
|
}
|
@@ -423,6 +460,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
423
460
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
424
461
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
425
462
|
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
463
|
+
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
426
464
|
|
427
465
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
428
466
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
data/ext/candle/src/ruby/mod.rs
CHANGED
@@ -0,0 +1,47 @@
|
|
1
|
+
use magnus::{Error, Module, RModule, function, Object};
|
2
|
+
use std::sync::Arc;
|
3
|
+
|
4
|
+
use crate::structured::{SchemaProcessor, VocabularyAdapter, Index};
|
5
|
+
use crate::ruby::{Result, tokenizer::Tokenizer};
|
6
|
+
|
7
|
+
/// Ruby wrapper for structured generation constraints
|
8
|
+
#[derive(Clone, Debug)]
|
9
|
+
#[magnus::wrap(class = "Candle::StructuredConstraint", mark, free_immediately)]
|
10
|
+
pub struct StructuredConstraint {
|
11
|
+
pub(crate) index: Arc<Index>,
|
12
|
+
}
|
13
|
+
|
14
|
+
impl StructuredConstraint {
|
15
|
+
/// Create a constraint from a JSON schema
|
16
|
+
pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
|
17
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
18
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
19
|
+
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
let index = processor.process_schema(&schema, &vocabulary)
|
22
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process schema: {}", e)))?;
|
23
|
+
|
24
|
+
Ok(Self { index })
|
25
|
+
}
|
26
|
+
|
27
|
+
/// Create a constraint from a regex pattern
|
28
|
+
pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
|
29
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
30
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
31
|
+
|
32
|
+
let processor = SchemaProcessor::new();
|
33
|
+
let index = processor.process_regex(&pattern, &vocabulary)
|
34
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to process regex: {}", e)))?;
|
35
|
+
|
36
|
+
Ok(Self { index })
|
37
|
+
}
|
38
|
+
}
|
39
|
+
|
40
|
+
pub fn init_structured(rb_candle: RModule) -> Result<()> {
|
41
|
+
let class = rb_candle.define_class("StructuredConstraint", magnus::class::object())?;
|
42
|
+
|
43
|
+
class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
|
44
|
+
class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
|
45
|
+
|
46
|
+
Ok(())
|
47
|
+
}
|
@@ -0,0 +1,130 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod integration_tests {
|
3
|
+
use super::super::*;
|
4
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
5
|
+
use std::sync::Arc;
|
6
|
+
|
7
|
+
#[tokio::test]
|
8
|
+
async fn test_schema_processor_with_vocabulary() {
|
9
|
+
// This test requires a tokenizer to create a vocabulary
|
10
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
11
|
+
|
12
|
+
if let Ok(tokenizer) = tokenizer_result {
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
14
|
+
|
15
|
+
// Create vocabulary from tokenizer
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
17
|
+
.expect("Should create vocabulary");
|
18
|
+
|
19
|
+
// Create schema processor
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
|
22
|
+
// Test with a simple JSON schema
|
23
|
+
let schema = r#"{
|
24
|
+
"type": "object",
|
25
|
+
"properties": {
|
26
|
+
"name": {"type": "string"},
|
27
|
+
"age": {"type": "integer"}
|
28
|
+
},
|
29
|
+
"required": ["name", "age"]
|
30
|
+
}"#;
|
31
|
+
|
32
|
+
// Process schema into Index
|
33
|
+
let index_result = processor.process_schema(schema, &vocabulary);
|
34
|
+
assert!(index_result.is_ok(), "Should process schema successfully");
|
35
|
+
|
36
|
+
// Test caching - second call should use cache
|
37
|
+
let index2_result = processor.process_schema(schema, &vocabulary);
|
38
|
+
assert!(index2_result.is_ok(), "Should retrieve from cache");
|
39
|
+
|
40
|
+
// Both should be the same Arc
|
41
|
+
let index1 = index_result.unwrap();
|
42
|
+
let index2 = index2_result.unwrap();
|
43
|
+
assert!(Arc::ptr_eq(&index1, &index2), "Should return cached Index");
|
44
|
+
|
45
|
+
// Check cache stats
|
46
|
+
let (size, _) = processor.cache_stats();
|
47
|
+
assert_eq!(size, 1, "Cache should have one entry");
|
48
|
+
} else {
|
49
|
+
eprintln!("Skipping integration test - couldn't load tokenizer");
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
#[tokio::test]
|
54
|
+
async fn test_regex_processing() {
|
55
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
56
|
+
|
57
|
+
if let Ok(tokenizer) = tokenizer_result {
|
58
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
59
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
60
|
+
.expect("Should create vocabulary");
|
61
|
+
|
62
|
+
let processor = SchemaProcessor::new();
|
63
|
+
|
64
|
+
// Test with a simple regex pattern
|
65
|
+
let email_regex = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}";
|
66
|
+
|
67
|
+
let index_result = processor.process_regex(email_regex, &vocabulary);
|
68
|
+
assert!(index_result.is_ok(), "Should process regex successfully");
|
69
|
+
|
70
|
+
// Test different regex
|
71
|
+
let phone_regex = r"\d{3}-\d{3}-\d{4}";
|
72
|
+
let phone_index_result = processor.process_regex(phone_regex, &vocabulary);
|
73
|
+
assert!(phone_index_result.is_ok(), "Should process phone regex");
|
74
|
+
|
75
|
+
// Cache should have both
|
76
|
+
let (size, _) = processor.cache_stats();
|
77
|
+
assert_eq!(size, 2, "Cache should have two entries");
|
78
|
+
|
79
|
+
// Clear cache
|
80
|
+
processor.clear_cache();
|
81
|
+
let (size, _) = processor.cache_stats();
|
82
|
+
assert_eq!(size, 0, "Cache should be empty after clear");
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
#[test]
|
87
|
+
fn test_various_json_schemas() {
|
88
|
+
let _processor = SchemaProcessor::new();
|
89
|
+
|
90
|
+
// Array schema
|
91
|
+
let array_schema = serde_json::json!({
|
92
|
+
"type": "array",
|
93
|
+
"items": {"type": "string"}
|
94
|
+
});
|
95
|
+
|
96
|
+
// Process as a full schema instead of testing private method
|
97
|
+
// This would need a mock vocabulary in a real test
|
98
|
+
// For now, just verify the schema is valid JSON
|
99
|
+
let json_str = serde_json::to_string(&array_schema).unwrap();
|
100
|
+
assert!(!json_str.is_empty(), "Should serialize array schema");
|
101
|
+
|
102
|
+
// Nested object schema
|
103
|
+
let nested_schema = serde_json::json!({
|
104
|
+
"type": "object",
|
105
|
+
"properties": {
|
106
|
+
"user": {
|
107
|
+
"type": "object",
|
108
|
+
"properties": {
|
109
|
+
"id": {"type": "integer"},
|
110
|
+
"email": {"type": "string", "format": "email"}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
}
|
114
|
+
});
|
115
|
+
|
116
|
+
// Verify nested schema is valid
|
117
|
+
let json_str = serde_json::to_string(&nested_schema).unwrap();
|
118
|
+
assert!(json_str.contains("properties"), "Should have nested properties");
|
119
|
+
|
120
|
+
// Schema with enum
|
121
|
+
let enum_schema = serde_json::json!({
|
122
|
+
"type": "string",
|
123
|
+
"enum": ["red", "green", "blue"]
|
124
|
+
});
|
125
|
+
|
126
|
+
// Verify enum schema is valid
|
127
|
+
let json_str = serde_json::to_string(&enum_schema).unwrap();
|
128
|
+
assert!(json_str.contains("enum"), "Should have enum values");
|
129
|
+
}
|
130
|
+
}
|
@@ -0,0 +1,31 @@
|
|
1
|
+
/// Structured generation support using Outlines
|
2
|
+
///
|
3
|
+
/// This module provides functionality to constrain language model generation
|
4
|
+
/// to follow specific patterns, such as JSON schemas or regular expressions.
|
5
|
+
|
6
|
+
pub mod vocabulary_adapter;
|
7
|
+
pub mod schema_processor;
|
8
|
+
|
9
|
+
pub use vocabulary_adapter::VocabularyAdapter;
|
10
|
+
pub use schema_processor::SchemaProcessor;
|
11
|
+
|
12
|
+
// Re-export commonly used types from outlines-core
|
13
|
+
pub use outlines_core::prelude::Index;
|
14
|
+
pub use outlines_core::vocabulary::Vocabulary;
|
15
|
+
|
16
|
+
#[cfg(test)]
|
17
|
+
mod vocabulary_adapter_simple_test;
|
18
|
+
|
19
|
+
#[cfg(test)]
|
20
|
+
mod integration_test;
|
21
|
+
|
22
|
+
#[cfg(test)]
|
23
|
+
mod tests {
|
24
|
+
use super::*;
|
25
|
+
|
26
|
+
#[test]
|
27
|
+
fn test_module_imports() {
|
28
|
+
// Ensure all exports are available
|
29
|
+
let _ = VocabularyAdapter;
|
30
|
+
}
|
31
|
+
}
|
@@ -0,0 +1,215 @@
|
|
1
|
+
use std::collections::HashMap;
|
2
|
+
use std::sync::{Arc, Mutex};
|
3
|
+
use candle_core::Result as CandleResult;
|
4
|
+
use candle_core::Error as CandleError;
|
5
|
+
use outlines_core::prelude::Index;
|
6
|
+
use outlines_core::vocabulary::Vocabulary;
|
7
|
+
use serde_json::Value as JsonValue;
|
8
|
+
use outlines_core::json_schema;
|
9
|
+
|
10
|
+
/// Processes JSON schemas into compiled Index objects for structured generation
|
11
|
+
pub struct SchemaProcessor {
|
12
|
+
/// Cache of compiled Index objects keyed by schema hash
|
13
|
+
cache: Arc<Mutex<HashMap<u64, Arc<Index>>>>,
|
14
|
+
}
|
15
|
+
|
16
|
+
impl SchemaProcessor {
|
17
|
+
/// Create a new schema processor with an empty cache
|
18
|
+
pub fn new() -> Self {
|
19
|
+
Self {
|
20
|
+
cache: Arc::new(Mutex::new(HashMap::new())),
|
21
|
+
}
|
22
|
+
}
|
23
|
+
|
24
|
+
/// Process a JSON schema into a compiled Index
|
25
|
+
///
|
26
|
+
/// # Arguments
|
27
|
+
/// * `schema` - JSON schema as a string
|
28
|
+
/// * `vocabulary` - The tokenizer's vocabulary
|
29
|
+
///
|
30
|
+
/// # Returns
|
31
|
+
/// A compiled Index ready for constrained generation
|
32
|
+
pub fn process_schema(
|
33
|
+
&self,
|
34
|
+
schema: &str,
|
35
|
+
vocabulary: &Vocabulary,
|
36
|
+
) -> CandleResult<Arc<Index>> {
|
37
|
+
// Calculate hash of the schema for caching
|
38
|
+
let schema_hash = self.calculate_hash(schema);
|
39
|
+
|
40
|
+
// Check cache first
|
41
|
+
if let Ok(cache) = self.cache.lock() {
|
42
|
+
if let Some(cached_index) = cache.get(&schema_hash) {
|
43
|
+
return Ok(Arc::clone(cached_index));
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
47
|
+
// Parse the JSON schema
|
48
|
+
let schema_value: JsonValue = serde_json::from_str(schema)
|
49
|
+
.map_err(|e| CandleError::Msg(format!("Invalid JSON schema: {}", e)))?;
|
50
|
+
|
51
|
+
// Convert schema to regex using Outlines
|
52
|
+
let regex = self.schema_to_regex(&schema_value)?;
|
53
|
+
|
54
|
+
// Compile regex into Index
|
55
|
+
let index = self.compile_regex(®ex, vocabulary)?;
|
56
|
+
let index_arc = Arc::new(index);
|
57
|
+
|
58
|
+
// Cache the compiled Index
|
59
|
+
if let Ok(mut cache) = self.cache.lock() {
|
60
|
+
cache.insert(schema_hash, Arc::clone(&index_arc));
|
61
|
+
}
|
62
|
+
|
63
|
+
Ok(index_arc)
|
64
|
+
}
|
65
|
+
|
66
|
+
/// Process a regex pattern directly into an Index
|
67
|
+
///
|
68
|
+
/// # Arguments
|
69
|
+
/// * `regex` - Regular expression pattern
|
70
|
+
/// * `vocabulary` - The tokenizer's vocabulary
|
71
|
+
///
|
72
|
+
/// # Returns
|
73
|
+
/// A compiled Index for the regex pattern
|
74
|
+
pub fn process_regex(
|
75
|
+
&self,
|
76
|
+
regex: &str,
|
77
|
+
vocabulary: &Vocabulary,
|
78
|
+
) -> CandleResult<Arc<Index>> {
|
79
|
+
// Calculate hash for caching
|
80
|
+
let regex_hash = self.calculate_hash(regex);
|
81
|
+
|
82
|
+
// Check cache
|
83
|
+
if let Ok(cache) = self.cache.lock() {
|
84
|
+
if let Some(cached_index) = cache.get(®ex_hash) {
|
85
|
+
return Ok(Arc::clone(cached_index));
|
86
|
+
}
|
87
|
+
}
|
88
|
+
|
89
|
+
// Compile the regex
|
90
|
+
let index = self.compile_regex(regex, vocabulary)?;
|
91
|
+
let index_arc = Arc::new(index);
|
92
|
+
|
93
|
+
// Cache it
|
94
|
+
if let Ok(mut cache) = self.cache.lock() {
|
95
|
+
cache.insert(regex_hash, Arc::clone(&index_arc));
|
96
|
+
}
|
97
|
+
|
98
|
+
Ok(index_arc)
|
99
|
+
}
|
100
|
+
|
101
|
+
/// Convert a JSON schema to a regex pattern
|
102
|
+
fn schema_to_regex(&self, schema: &JsonValue) -> CandleResult<String> {
|
103
|
+
// Use Outlines' built-in JSON schema to regex conversion
|
104
|
+
json_schema::regex_from_value(schema, None)
|
105
|
+
.map_err(|e| CandleError::Msg(format!("Failed to convert schema to regex: {:?}", e)))
|
106
|
+
}
|
107
|
+
|
108
|
+
/// Compile a regex pattern into an Index
|
109
|
+
fn compile_regex(&self, regex: &str, vocabulary: &Vocabulary) -> CandleResult<Index> {
|
110
|
+
// Use Outlines to build the Index from regex
|
111
|
+
Index::new(regex, vocabulary)
|
112
|
+
.map_err(|e| CandleError::Msg(format!("Failed to build index from regex: {:?}", e)))
|
113
|
+
}
|
114
|
+
|
115
|
+
/// Calculate a hash for caching
|
116
|
+
fn calculate_hash(&self, input: &str) -> u64 {
|
117
|
+
use std::collections::hash_map::DefaultHasher;
|
118
|
+
use std::hash::{Hash, Hasher};
|
119
|
+
|
120
|
+
let mut hasher = DefaultHasher::new();
|
121
|
+
input.hash(&mut hasher);
|
122
|
+
hasher.finish()
|
123
|
+
}
|
124
|
+
|
125
|
+
/// Clear the cache
|
126
|
+
pub fn clear_cache(&self) {
|
127
|
+
if let Ok(mut cache) = self.cache.lock() {
|
128
|
+
cache.clear();
|
129
|
+
}
|
130
|
+
}
|
131
|
+
|
132
|
+
/// Get cache statistics
|
133
|
+
pub fn cache_stats(&self) -> (usize, usize) {
|
134
|
+
if let Ok(cache) = self.cache.lock() {
|
135
|
+
let size = cache.len();
|
136
|
+
let capacity = cache.capacity();
|
137
|
+
(size, capacity)
|
138
|
+
} else {
|
139
|
+
(0, 0)
|
140
|
+
}
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
impl Default for SchemaProcessor {
|
145
|
+
fn default() -> Self {
|
146
|
+
Self::new()
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
#[cfg(test)]
|
151
|
+
mod tests {
|
152
|
+
use super::*;
|
153
|
+
|
154
|
+
#[test]
|
155
|
+
fn test_schema_processor_creation() {
|
156
|
+
let processor = SchemaProcessor::new();
|
157
|
+
let (size, _) = processor.cache_stats();
|
158
|
+
assert_eq!(size, 0, "Cache should start empty");
|
159
|
+
}
|
160
|
+
|
161
|
+
#[test]
|
162
|
+
fn test_cache_operations() {
|
163
|
+
let processor = SchemaProcessor::new();
|
164
|
+
|
165
|
+
// Initially empty
|
166
|
+
let (size, _) = processor.cache_stats();
|
167
|
+
assert_eq!(size, 0);
|
168
|
+
|
169
|
+
// After clear (should still be empty)
|
170
|
+
processor.clear_cache();
|
171
|
+
let (size, _) = processor.cache_stats();
|
172
|
+
assert_eq!(size, 0);
|
173
|
+
}
|
174
|
+
|
175
|
+
#[test]
|
176
|
+
fn test_schema_to_regex_basic_types() {
|
177
|
+
let processor = SchemaProcessor::new();
|
178
|
+
|
179
|
+
// Test string type
|
180
|
+
let string_schema = serde_json::json!({
|
181
|
+
"type": "string"
|
182
|
+
});
|
183
|
+
let regex = processor.schema_to_regex(&string_schema).unwrap();
|
184
|
+
// Just verify it produces a regex, exact format depends on Outlines
|
185
|
+
assert!(!regex.is_empty(), "String schema should produce a regex");
|
186
|
+
|
187
|
+
// Test number type
|
188
|
+
let number_schema = serde_json::json!({
|
189
|
+
"type": "number"
|
190
|
+
});
|
191
|
+
let regex = processor.schema_to_regex(&number_schema).unwrap();
|
192
|
+
assert!(!regex.is_empty(), "Number schema should produce a regex");
|
193
|
+
|
194
|
+
// Test boolean type
|
195
|
+
let bool_schema = serde_json::json!({
|
196
|
+
"type": "boolean"
|
197
|
+
});
|
198
|
+
let regex = processor.schema_to_regex(&bool_schema).unwrap();
|
199
|
+
assert!(regex.contains("true") && regex.contains("false"), "Boolean regex should contain true/false");
|
200
|
+
}
|
201
|
+
|
202
|
+
#[test]
|
203
|
+
fn test_schema_with_pattern() {
|
204
|
+
let processor = SchemaProcessor::new();
|
205
|
+
|
206
|
+
let schema = serde_json::json!({
|
207
|
+
"type": "string",
|
208
|
+
"pattern": r"^\d{3}-\d{3}-\d{4}$"
|
209
|
+
});
|
210
|
+
|
211
|
+
let regex = processor.schema_to_regex(&schema).unwrap();
|
212
|
+
// Pattern should be included in the generated regex
|
213
|
+
assert!(regex.contains("\\d{3}"), "Should contain digit pattern");
|
214
|
+
}
|
215
|
+
}
|