red-candle 1.8.0-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,330 @@
1
+ use magnus::{function, method, prelude::*, Error, RModule, Ruby};
2
+ use candle_transformers::models::llava::{
3
+ config::{LLaVAConfig, HFLLaVAConfig, HFGenerationConfig, HFPreProcessorConfig},
4
+ LLaVA,
5
+ };
6
+ use candle_transformers::models::llama::Cache;
7
+ use candle_core::{Device as CoreDevice, Tensor, DType};
8
+ use candle_nn::VarBuilder;
9
+ use hf_hub::{api::sync::Api, Repo, RepoType};
10
+ use tokenizers::Tokenizer;
11
+ use crate::ruby::{Device, Result};
12
+ use crate::tokenizer::TokenizerWrapper;
13
+
14
+ const CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
15
+ const CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
16
+
17
+ /// Vision-Language Model wrapping LLaVA for image understanding.
18
+ /// Uses CLIP vision encoder + MM projector + Llama LLM.
19
+ ///
20
+ /// Note: LLaVA contains trait objects (dyn Module) that are !Send,
21
+ /// so we wrap it in an UnsafeCell. This is safe because Ruby's GVL
22
+ /// ensures single-threaded access to the model.
23
+ struct UnsafeSendSync<T>(T);
24
+ unsafe impl<T> Send for UnsafeSendSync<T> {}
25
+ unsafe impl<T> Sync for UnsafeSendSync<T> {}
26
+
27
+ #[magnus::wrap(class = "Candle::VLM", free_immediately, size)]
28
+ pub struct VLM {
29
+ model: std::cell::RefCell<UnsafeSendSync<LLaVA>>,
30
+ tokenizer: TokenizerWrapper,
31
+ cache: std::cell::RefCell<UnsafeSendSync<Cache>>,
32
+ config: LLaVAConfig,
33
+ device: CoreDevice,
34
+ model_id: String,
35
+ image_size: usize,
36
+ eos_token_id: u32,
37
+ }
38
+
39
+ impl VLM {
40
+ pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
41
+ let device = device.unwrap_or(Device::best()).as_device()?;
42
+ Self::load_model(model_id, device)
43
+ }
44
+
45
+ fn load_model(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
46
+ let ruby = Ruby::get().unwrap();
47
+ let runtime_error = ruby.exception_runtime_error();
48
+
49
+ let result = (|| -> std::result::Result<_, Box<dyn std::error::Error + Send + Sync>> {
50
+ let api = Api::new()?;
51
+ let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
52
+
53
+ // Download config files
54
+ let config_filename = repo.get("config.json")?;
55
+ let gen_config_filename = repo.get("generation_config.json")?;
56
+ let preproc_config_filename = repo.get("preprocessor_config.json")?;
57
+ let tokenizer_filename = repo.get("tokenizer.json")?;
58
+
59
+ // Read configs
60
+ let config_str = std::fs::read_to_string(&config_filename)?;
61
+ let gen_config_str = std::fs::read_to_string(&gen_config_filename)?;
62
+ let preproc_config_str = std::fs::read_to_string(&preproc_config_filename)?;
63
+
64
+ // Patch config: some models have null pad_token_id in text_config
65
+ // but candle's HFLLaVATextConfig requires usize. Fix by defaulting to 0.
66
+ let mut config_json: serde_json::Value = serde_json::from_str(&config_str)?;
67
+ let top_pad_id = config_json.get("pad_token_id")
68
+ .and_then(|v| v.as_u64())
69
+ .unwrap_or(0);
70
+ // Patch missing image_grid_pinpoints for LLaVA 1.5
71
+ if config_json.get("image_grid_pinpoints").map_or(true, |v| v.is_null()) {
72
+ config_json["image_grid_pinpoints"] = serde_json::json!([[336, 672], [672, 336], [672, 672]]);
73
+ }
74
+ if let Some(text_config) = config_json.get_mut("text_config") {
75
+ if text_config.get("pad_token_id").map_or(true, |v| v.is_null()) {
76
+ text_config["pad_token_id"] = serde_json::Value::Number(top_pad_id.into());
77
+ }
78
+ }
79
+ let patched_config_str = serde_json::to_string(&config_json)?;
80
+ let hf_config: HFLLaVAConfig = serde_json::from_str(&patched_config_str)?;
81
+ let gen_config: HFGenerationConfig = serde_json::from_str(&gen_config_str)?;
82
+ let preproc_config: HFPreProcessorConfig = serde_json::from_str(&preproc_config_str)?;
83
+
84
+ let image_size = hf_config.vision_config.image_size;
85
+ let eos_token_id = gen_config.eos_token_id as u32;
86
+
87
+ let clip_vision_config = hf_config.to_clip_vision_config();
88
+ let config = hf_config.to_llava_config(&gen_config, &preproc_config);
89
+
90
+ // Load tokenizer
91
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
92
+
93
+ // Download weight files (sharded)
94
+ let weight_files = Self::download_weights(&repo)?;
95
+
96
+ // Load model weights
97
+ let vb = unsafe {
98
+ VarBuilder::from_mmaped_safetensors(&weight_files, DType::F32, &device)?
99
+ };
100
+
101
+ // Load LLaVA model with CLIP vision config
102
+ let model = LLaVA::load(vb, &config, Some(clip_vision_config))?;
103
+
104
+ // Create KV cache for the Llama LLM
105
+ let llama_config = config.to_llama_config();
106
+ let cache = Cache::new(true, DType::F32, &llama_config, &device)?;
107
+
108
+ Ok((model, TokenizerWrapper::new(tokenizer), cache, config, image_size, eos_token_id))
109
+ })();
110
+
111
+ match result {
112
+ Ok((model, tokenizer, cache, config, image_size, eos_token_id)) => {
113
+ Ok(Self {
114
+ model: std::cell::RefCell::new(UnsafeSendSync(model)),
115
+ tokenizer,
116
+ cache: std::cell::RefCell::new(UnsafeSendSync(cache)),
117
+ config,
118
+ device,
119
+ model_id,
120
+ image_size,
121
+ eos_token_id,
122
+ })
123
+ }
124
+ Err(e) => Err(Error::new(runtime_error, format!("Failed to load VLM: {}", e))),
125
+ }
126
+ }
127
+
128
+ fn download_weights(
129
+ repo: &hf_hub::api::sync::ApiRepo,
130
+ ) -> std::result::Result<Vec<std::path::PathBuf>, Box<dyn std::error::Error + Send + Sync>> {
131
+ // Try single file first
132
+ if let Ok(path) = repo.get("model.safetensors") {
133
+ return Ok(vec![path]);
134
+ }
135
+
136
+ // Try to get the index file for sharded weights
137
+ let index_path = repo.get("model.safetensors.index.json")?;
138
+ let index_str = std::fs::read_to_string(&index_path)?;
139
+ let index: serde_json::Value = serde_json::from_str(&index_str)?;
140
+
141
+ let weight_map = index["weight_map"].as_object()
142
+ .ok_or("Missing weight_map in index")?;
143
+
144
+ let mut filenames: Vec<String> = weight_map.values()
145
+ .filter_map(|v| v.as_str().map(String::from))
146
+ .collect();
147
+ filenames.sort();
148
+ filenames.dedup();
149
+
150
+ let mut paths = Vec::new();
151
+ for filename in &filenames {
152
+ let path = repo.get(filename)?;
153
+ paths.push(path);
154
+ }
155
+
156
+ Ok(paths)
157
+ }
158
+
159
+ /// Load and preprocess an image from a file path into a CLIP-ready tensor
160
+ fn load_image(&self, image_path: &str) -> std::result::Result<Tensor, Error> {
161
+ let ruby = Ruby::get().unwrap();
162
+ let runtime_error = ruby.exception_runtime_error();
163
+
164
+ let img = image::open(image_path)
165
+ .map_err(|e| Error::new(runtime_error, format!("Failed to open image: {}", e)))?;
166
+
167
+ // Resize to expected size
168
+ let img = img.resize_exact(
169
+ self.image_size as u32,
170
+ self.image_size as u32,
171
+ image::imageops::FilterType::Triangle,
172
+ );
173
+
174
+ let img = img.to_rgb8();
175
+ let (width, height) = img.dimensions();
176
+ let h = height as usize;
177
+ let w = width as usize;
178
+
179
+ // Convert to CHW format with CLIP normalization
180
+ let mut chw = vec![0f32; 3 * h * w];
181
+ for y in 0..h {
182
+ for x in 0..w {
183
+ let p = img.get_pixel(x as u32, y as u32);
184
+ chw[0 * h * w + y * w + x] = (p[0] as f32 / 255.0 - CLIP_MEAN[0]) / CLIP_STD[0];
185
+ chw[1 * h * w + y * w + x] = (p[1] as f32 / 255.0 - CLIP_MEAN[1]) / CLIP_STD[1];
186
+ chw[2 * h * w + y * w + x] = (p[2] as f32 / 255.0 - CLIP_MEAN[2]) / CLIP_STD[2];
187
+ }
188
+ }
189
+
190
+ Tensor::from_vec(chw, (1, 3, h, w), &self.device)
191
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create image tensor: {}", e)))
192
+ }
193
+
194
+ /// Describe an image
195
+ pub fn describe(&self, image_path: String, max_length: Option<usize>) -> std::result::Result<String, Error> {
196
+ self.ask(image_path, "Describe this image in detail.".to_string(), max_length)
197
+ }
198
+
199
+ /// Ask a question about an image
200
+ pub fn ask(&self, image_path: String, question: String, max_length: Option<usize>) -> std::result::Result<String, Error> {
201
+ let ruby = Ruby::get().unwrap();
202
+ let runtime_error = ruby.exception_runtime_error();
203
+ let max_length = max_length.unwrap_or(256);
204
+
205
+ // Load and preprocess image
206
+ let image_tensor = self.load_image(&image_path)?;
207
+
208
+ // Build prompt with image token placeholder
209
+ // LLaVA 1.5 HF format: USER: <image>\n{question}\nASSISTANT:
210
+ let prompt = format!("USER: <image>\n{}\nASSISTANT:", question);
211
+
212
+ // Tokenize
213
+ let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
214
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
215
+ let input_ids: Vec<u32> = encoding.get_ids().to_vec();
216
+
217
+ // LLaVA expects I64 input IDs
218
+ let input_ids_i64: Vec<i64> = input_ids.iter().map(|&id| id as i64).collect();
219
+ let input_tensor = Tensor::new(&input_ids_i64[..], &self.device)
220
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create input tensor: {}", e)))?
221
+ .unsqueeze(0)
222
+ .map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
223
+
224
+ let mut model_ref = self.model.borrow_mut();
225
+ let model = &mut model_ref.0;
226
+ let mut cache_ref = self.cache.borrow_mut();
227
+ let cache = &mut cache_ref.0;
228
+
229
+ // Prepare multimodal input: merge image features with text embeddings
230
+ let image_size = (self.image_size as u32, self.image_size as u32);
231
+ let input_embeds = model.prepare_inputs_labels_for_multimodal(
232
+ &input_tensor,
233
+ &[image_tensor],
234
+ &[image_size],
235
+ ).map_err(|e| Error::new(runtime_error, format!("Failed to prepare multimodal input: {}", e)))?;
236
+
237
+ // Reset KV cache for fresh generation
238
+ let llama_config = self.config.to_llama_config();
239
+ *cache = Cache::new(true, DType::F32, &llama_config, &self.device)
240
+ .map_err(|e| Error::new(runtime_error, format!("Failed to reset cache: {}", e)))?;
241
+
242
+ // Generate tokens autoregressively
243
+ let mut generated_tokens: Vec<u32> = Vec::new();
244
+ let mut current_embeds = input_embeds;
245
+ let mut pos = 0usize;
246
+
247
+ for _i in 0..max_length {
248
+ let logits = model.forward(&current_embeds, pos, cache)
249
+ .map_err(|e| Error::new(runtime_error, format!("Forward pass failed at pos {}: {}", pos, e)))?;
250
+
251
+ // Advance position by the number of tokens we just processed
252
+ let step_len = current_embeds.dim(1)
253
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get step len: {}", e)))?;
254
+ pos += step_len;
255
+
256
+ // Get logits for last position
257
+ let logits = logits.flatten_all()
258
+ .map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?;
259
+
260
+ // Handle multi-dim logits (take last token if needed)
261
+ let vocab_size = self.config.vocab_size;
262
+ let logits = if logits.elem_count() > vocab_size {
263
+ let n_tokens = logits.elem_count() / vocab_size;
264
+ logits.reshape((n_tokens, vocab_size))
265
+ .map_err(|e| Error::new(runtime_error, format!("Failed to reshape logits: {}", e)))?
266
+ .narrow(0, n_tokens - 1, 1)
267
+ .map_err(|e| Error::new(runtime_error, format!("Failed to narrow logits: {}", e)))?
268
+ .squeeze(0)
269
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze logits: {}", e)))?
270
+ } else {
271
+ logits
272
+ };
273
+
274
+ let logits = logits.to_dtype(DType::F32)
275
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
276
+
277
+ // Greedy decoding
278
+ let next_token = logits.argmax(0)
279
+ .map_err(|e| Error::new(runtime_error, format!("Argmax failed: {}", e)))?
280
+ .to_scalar::<u32>()
281
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get token: {}", e)))?;
282
+
283
+ // Check for EOS
284
+ if next_token == self.eos_token_id {
285
+ break;
286
+ }
287
+
288
+ generated_tokens.push(next_token);
289
+
290
+ // For subsequent tokens, embed directly through Llama's embedding layer
291
+ let next_input = Tensor::new(&[next_token as i64], &self.device)
292
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create next input: {}", e)))?
293
+ .unsqueeze(0)
294
+ .map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze next: {}", e)))?;
295
+
296
+ current_embeds = model.llama.embed(&next_input)
297
+ .map_err(|e| Error::new(runtime_error, format!("Failed to embed next token: {}", e)))?;
298
+ }
299
+
300
+ // Decode generated tokens
301
+ let text = self.tokenizer.inner().decode(&generated_tokens, true)
302
+ .map_err(|e| Error::new(runtime_error, format!("Decoding failed: {}", e)))?;
303
+
304
+ Ok(text.trim().to_string())
305
+ }
306
+
307
+ pub fn model_id(&self) -> String {
308
+ self.model_id.clone()
309
+ }
310
+
311
+ pub fn device(&self) -> Device {
312
+ Device::from_device(&self.device)
313
+ }
314
+
315
+ pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
316
+ Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
317
+ }
318
+ }
319
+
320
+ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
321
+ let ruby = Ruby::get().unwrap();
322
+ let c_vlm = rb_candle.define_class("VLM", ruby.class_object())?;
323
+ c_vlm.define_singleton_method("_create", function!(VLM::new, 2))?;
324
+ c_vlm.define_method("_describe", method!(VLM::describe, 2))?;
325
+ c_vlm.define_method("_ask", method!(VLM::ask, 3))?;
326
+ c_vlm.define_method("model_id", method!(VLM::model_id, 0))?;
327
+ c_vlm.define_method("device", method!(VLM::device, 0))?;
328
+ c_vlm.define_method("tokenizer", method!(VLM::tokenizer, 0))?;
329
+ Ok(())
330
+ }
@@ -0,0 +1,130 @@
1
+ #[cfg(test)]
2
+ mod integration_tests {
3
+ use super::super::*;
4
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
5
+ use std::sync::Arc;
6
+
7
+ #[tokio::test]
8
+ async fn test_schema_processor_with_vocabulary() {
9
+ // This test requires a tokenizer to create a vocabulary
10
+ let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
11
+
12
+ if let Ok(tokenizer) = tokenizer_result {
13
+ let wrapper = TokenizerWrapper::new(tokenizer);
14
+
15
+ // Create vocabulary from tokenizer
16
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
17
+ .expect("Should create vocabulary");
18
+
19
+ // Create schema processor
20
+ let processor = SchemaProcessor::new();
21
+
22
+ // Test with a simple JSON schema
23
+ let schema = r#"{
24
+ "type": "object",
25
+ "properties": {
26
+ "name": {"type": "string"},
27
+ "age": {"type": "integer"}
28
+ },
29
+ "required": ["name", "age"]
30
+ }"#;
31
+
32
+ // Process schema into Index
33
+ let index_result = processor.process_schema(schema, &vocabulary);
34
+ assert!(index_result.is_ok(), "Should process schema successfully");
35
+
36
+ // Test caching - second call should use cache
37
+ let index2_result = processor.process_schema(schema, &vocabulary);
38
+ assert!(index2_result.is_ok(), "Should retrieve from cache");
39
+
40
+ // Both should be the same Arc
41
+ let index1 = index_result.unwrap();
42
+ let index2 = index2_result.unwrap();
43
+ assert!(Arc::ptr_eq(&index1, &index2), "Should return cached Index");
44
+
45
+ // Check cache stats
46
+ let (size, _) = processor.cache_stats();
47
+ assert_eq!(size, 1, "Cache should have one entry");
48
+ } else {
49
+ eprintln!("Skipping integration test - couldn't load tokenizer");
50
+ }
51
+ }
52
+
53
+ #[tokio::test]
54
+ async fn test_regex_processing() {
55
+ let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
56
+
57
+ if let Ok(tokenizer) = tokenizer_result {
58
+ let wrapper = TokenizerWrapper::new(tokenizer);
59
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
60
+ .expect("Should create vocabulary");
61
+
62
+ let processor = SchemaProcessor::new();
63
+
64
+ // Test with a simple regex pattern
65
+ let email_regex = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}";
66
+
67
+ let index_result = processor.process_regex(email_regex, &vocabulary);
68
+ assert!(index_result.is_ok(), "Should process regex successfully");
69
+
70
+ // Test different regex
71
+ let phone_regex = r"\d{3}-\d{3}-\d{4}";
72
+ let phone_index_result = processor.process_regex(phone_regex, &vocabulary);
73
+ assert!(phone_index_result.is_ok(), "Should process phone regex");
74
+
75
+ // Cache should have both
76
+ let (size, _) = processor.cache_stats();
77
+ assert_eq!(size, 2, "Cache should have two entries");
78
+
79
+ // Clear cache
80
+ processor.clear_cache();
81
+ let (size, _) = processor.cache_stats();
82
+ assert_eq!(size, 0, "Cache should be empty after clear");
83
+ }
84
+ }
85
+
86
+ #[test]
87
+ fn test_various_json_schemas() {
88
+ let _processor = SchemaProcessor::new();
89
+
90
+ // Array schema
91
+ let array_schema = serde_json::json!({
92
+ "type": "array",
93
+ "items": {"type": "string"}
94
+ });
95
+
96
+ // Process as a full schema instead of testing private method
97
+ // This would need a mock vocabulary in a real test
98
+ // For now, just verify the schema is valid JSON
99
+ let json_str = serde_json::to_string(&array_schema).unwrap();
100
+ assert!(!json_str.is_empty(), "Should serialize array schema");
101
+
102
+ // Nested object schema
103
+ let nested_schema = serde_json::json!({
104
+ "type": "object",
105
+ "properties": {
106
+ "user": {
107
+ "type": "object",
108
+ "properties": {
109
+ "id": {"type": "integer"},
110
+ "email": {"type": "string", "format": "email"}
111
+ }
112
+ }
113
+ }
114
+ });
115
+
116
+ // Verify nested schema is valid
117
+ let json_str = serde_json::to_string(&nested_schema).unwrap();
118
+ assert!(json_str.contains("properties"), "Should have nested properties");
119
+
120
+ // Schema with enum
121
+ let enum_schema = serde_json::json!({
122
+ "type": "string",
123
+ "enum": ["red", "green", "blue"]
124
+ });
125
+
126
+ // Verify enum schema is valid
127
+ let json_str = serde_json::to_string(&enum_schema).unwrap();
128
+ assert!(json_str.contains("enum"), "Should have enum values");
129
+ }
130
+ }
@@ -0,0 +1,31 @@
1
+ /// Structured generation support using Outlines
2
+ ///
3
+ /// This module provides functionality to constrain language model generation
4
+ /// to follow specific patterns, such as JSON schemas or regular expressions.
5
+
6
+ pub mod vocabulary_adapter;
7
+ pub mod schema_processor;
8
+
9
+ pub use vocabulary_adapter::VocabularyAdapter;
10
+ pub use schema_processor::SchemaProcessor;
11
+
12
+ // Re-export commonly used types from outlines-core
13
+ pub use outlines_core::prelude::Index;
14
+ pub use outlines_core::vocabulary::Vocabulary;
15
+
16
+ #[cfg(test)]
17
+ mod vocabulary_adapter_simple_test;
18
+
19
+ #[cfg(test)]
20
+ mod integration_test;
21
+
22
+ #[cfg(test)]
23
+ mod tests {
24
+ use super::*;
25
+
26
+ #[test]
27
+ fn test_module_imports() {
28
+ // Ensure all exports are available
29
+ let _ = VocabularyAdapter;
30
+ }
31
+ }