red-candle 1.8.0-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, RModule, Ruby};
|
|
2
|
+
use candle_transformers::models::llava::{
|
|
3
|
+
config::{LLaVAConfig, HFLLaVAConfig, HFGenerationConfig, HFPreProcessorConfig},
|
|
4
|
+
LLaVA,
|
|
5
|
+
};
|
|
6
|
+
use candle_transformers::models::llama::Cache;
|
|
7
|
+
use candle_core::{Device as CoreDevice, Tensor, DType};
|
|
8
|
+
use candle_nn::VarBuilder;
|
|
9
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
10
|
+
use tokenizers::Tokenizer;
|
|
11
|
+
use crate::ruby::{Device, Result};
|
|
12
|
+
use crate::tokenizer::TokenizerWrapper;
|
|
13
|
+
|
|
14
|
+
const CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
|
|
15
|
+
const CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
|
|
16
|
+
|
|
17
|
+
/// Vision-Language Model wrapping LLaVA for image understanding.
|
|
18
|
+
/// Uses CLIP vision encoder + MM projector + Llama LLM.
|
|
19
|
+
///
|
|
20
|
+
/// Note: LLaVA contains trait objects (dyn Module) that are !Send,
|
|
21
|
+
/// so we wrap it in an UnsafeCell. This is safe because Ruby's GVL
|
|
22
|
+
/// ensures single-threaded access to the model.
|
|
23
|
+
struct UnsafeSendSync<T>(T);
|
|
24
|
+
unsafe impl<T> Send for UnsafeSendSync<T> {}
|
|
25
|
+
unsafe impl<T> Sync for UnsafeSendSync<T> {}
|
|
26
|
+
|
|
27
|
+
#[magnus::wrap(class = "Candle::VLM", free_immediately, size)]
|
|
28
|
+
pub struct VLM {
|
|
29
|
+
model: std::cell::RefCell<UnsafeSendSync<LLaVA>>,
|
|
30
|
+
tokenizer: TokenizerWrapper,
|
|
31
|
+
cache: std::cell::RefCell<UnsafeSendSync<Cache>>,
|
|
32
|
+
config: LLaVAConfig,
|
|
33
|
+
device: CoreDevice,
|
|
34
|
+
model_id: String,
|
|
35
|
+
image_size: usize,
|
|
36
|
+
eos_token_id: u32,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
impl VLM {
|
|
40
|
+
pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
|
|
41
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
|
42
|
+
Self::load_model(model_id, device)
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
fn load_model(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
|
46
|
+
let ruby = Ruby::get().unwrap();
|
|
47
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
48
|
+
|
|
49
|
+
let result = (|| -> std::result::Result<_, Box<dyn std::error::Error + Send + Sync>> {
|
|
50
|
+
let api = Api::new()?;
|
|
51
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
|
52
|
+
|
|
53
|
+
// Download config files
|
|
54
|
+
let config_filename = repo.get("config.json")?;
|
|
55
|
+
let gen_config_filename = repo.get("generation_config.json")?;
|
|
56
|
+
let preproc_config_filename = repo.get("preprocessor_config.json")?;
|
|
57
|
+
let tokenizer_filename = repo.get("tokenizer.json")?;
|
|
58
|
+
|
|
59
|
+
// Read configs
|
|
60
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
|
61
|
+
let gen_config_str = std::fs::read_to_string(&gen_config_filename)?;
|
|
62
|
+
let preproc_config_str = std::fs::read_to_string(&preproc_config_filename)?;
|
|
63
|
+
|
|
64
|
+
// Patch config: some models have null pad_token_id in text_config
|
|
65
|
+
// but candle's HFLLaVATextConfig requires usize. Fix by defaulting to 0.
|
|
66
|
+
let mut config_json: serde_json::Value = serde_json::from_str(&config_str)?;
|
|
67
|
+
let top_pad_id = config_json.get("pad_token_id")
|
|
68
|
+
.and_then(|v| v.as_u64())
|
|
69
|
+
.unwrap_or(0);
|
|
70
|
+
// Patch missing image_grid_pinpoints for LLaVA 1.5
|
|
71
|
+
if config_json.get("image_grid_pinpoints").map_or(true, |v| v.is_null()) {
|
|
72
|
+
config_json["image_grid_pinpoints"] = serde_json::json!([[336, 672], [672, 336], [672, 672]]);
|
|
73
|
+
}
|
|
74
|
+
if let Some(text_config) = config_json.get_mut("text_config") {
|
|
75
|
+
if text_config.get("pad_token_id").map_or(true, |v| v.is_null()) {
|
|
76
|
+
text_config["pad_token_id"] = serde_json::Value::Number(top_pad_id.into());
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
let patched_config_str = serde_json::to_string(&config_json)?;
|
|
80
|
+
let hf_config: HFLLaVAConfig = serde_json::from_str(&patched_config_str)?;
|
|
81
|
+
let gen_config: HFGenerationConfig = serde_json::from_str(&gen_config_str)?;
|
|
82
|
+
let preproc_config: HFPreProcessorConfig = serde_json::from_str(&preproc_config_str)?;
|
|
83
|
+
|
|
84
|
+
let image_size = hf_config.vision_config.image_size;
|
|
85
|
+
let eos_token_id = gen_config.eos_token_id as u32;
|
|
86
|
+
|
|
87
|
+
let clip_vision_config = hf_config.to_clip_vision_config();
|
|
88
|
+
let config = hf_config.to_llava_config(&gen_config, &preproc_config);
|
|
89
|
+
|
|
90
|
+
// Load tokenizer
|
|
91
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
|
92
|
+
|
|
93
|
+
// Download weight files (sharded)
|
|
94
|
+
let weight_files = Self::download_weights(&repo)?;
|
|
95
|
+
|
|
96
|
+
// Load model weights
|
|
97
|
+
let vb = unsafe {
|
|
98
|
+
VarBuilder::from_mmaped_safetensors(&weight_files, DType::F32, &device)?
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
// Load LLaVA model with CLIP vision config
|
|
102
|
+
let model = LLaVA::load(vb, &config, Some(clip_vision_config))?;
|
|
103
|
+
|
|
104
|
+
// Create KV cache for the Llama LLM
|
|
105
|
+
let llama_config = config.to_llama_config();
|
|
106
|
+
let cache = Cache::new(true, DType::F32, &llama_config, &device)?;
|
|
107
|
+
|
|
108
|
+
Ok((model, TokenizerWrapper::new(tokenizer), cache, config, image_size, eos_token_id))
|
|
109
|
+
})();
|
|
110
|
+
|
|
111
|
+
match result {
|
|
112
|
+
Ok((model, tokenizer, cache, config, image_size, eos_token_id)) => {
|
|
113
|
+
Ok(Self {
|
|
114
|
+
model: std::cell::RefCell::new(UnsafeSendSync(model)),
|
|
115
|
+
tokenizer,
|
|
116
|
+
cache: std::cell::RefCell::new(UnsafeSendSync(cache)),
|
|
117
|
+
config,
|
|
118
|
+
device,
|
|
119
|
+
model_id,
|
|
120
|
+
image_size,
|
|
121
|
+
eos_token_id,
|
|
122
|
+
})
|
|
123
|
+
}
|
|
124
|
+
Err(e) => Err(Error::new(runtime_error, format!("Failed to load VLM: {}", e))),
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
fn download_weights(
|
|
129
|
+
repo: &hf_hub::api::sync::ApiRepo,
|
|
130
|
+
) -> std::result::Result<Vec<std::path::PathBuf>, Box<dyn std::error::Error + Send + Sync>> {
|
|
131
|
+
// Try single file first
|
|
132
|
+
if let Ok(path) = repo.get("model.safetensors") {
|
|
133
|
+
return Ok(vec![path]);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// Try to get the index file for sharded weights
|
|
137
|
+
let index_path = repo.get("model.safetensors.index.json")?;
|
|
138
|
+
let index_str = std::fs::read_to_string(&index_path)?;
|
|
139
|
+
let index: serde_json::Value = serde_json::from_str(&index_str)?;
|
|
140
|
+
|
|
141
|
+
let weight_map = index["weight_map"].as_object()
|
|
142
|
+
.ok_or("Missing weight_map in index")?;
|
|
143
|
+
|
|
144
|
+
let mut filenames: Vec<String> = weight_map.values()
|
|
145
|
+
.filter_map(|v| v.as_str().map(String::from))
|
|
146
|
+
.collect();
|
|
147
|
+
filenames.sort();
|
|
148
|
+
filenames.dedup();
|
|
149
|
+
|
|
150
|
+
let mut paths = Vec::new();
|
|
151
|
+
for filename in &filenames {
|
|
152
|
+
let path = repo.get(filename)?;
|
|
153
|
+
paths.push(path);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
Ok(paths)
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
/// Load and preprocess an image from a file path into a CLIP-ready tensor
|
|
160
|
+
fn load_image(&self, image_path: &str) -> std::result::Result<Tensor, Error> {
|
|
161
|
+
let ruby = Ruby::get().unwrap();
|
|
162
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
163
|
+
|
|
164
|
+
let img = image::open(image_path)
|
|
165
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to open image: {}", e)))?;
|
|
166
|
+
|
|
167
|
+
// Resize to expected size
|
|
168
|
+
let img = img.resize_exact(
|
|
169
|
+
self.image_size as u32,
|
|
170
|
+
self.image_size as u32,
|
|
171
|
+
image::imageops::FilterType::Triangle,
|
|
172
|
+
);
|
|
173
|
+
|
|
174
|
+
let img = img.to_rgb8();
|
|
175
|
+
let (width, height) = img.dimensions();
|
|
176
|
+
let h = height as usize;
|
|
177
|
+
let w = width as usize;
|
|
178
|
+
|
|
179
|
+
// Convert to CHW format with CLIP normalization
|
|
180
|
+
let mut chw = vec![0f32; 3 * h * w];
|
|
181
|
+
for y in 0..h {
|
|
182
|
+
for x in 0..w {
|
|
183
|
+
let p = img.get_pixel(x as u32, y as u32);
|
|
184
|
+
chw[0 * h * w + y * w + x] = (p[0] as f32 / 255.0 - CLIP_MEAN[0]) / CLIP_STD[0];
|
|
185
|
+
chw[1 * h * w + y * w + x] = (p[1] as f32 / 255.0 - CLIP_MEAN[1]) / CLIP_STD[1];
|
|
186
|
+
chw[2 * h * w + y * w + x] = (p[2] as f32 / 255.0 - CLIP_MEAN[2]) / CLIP_STD[2];
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
Tensor::from_vec(chw, (1, 3, h, w), &self.device)
|
|
191
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create image tensor: {}", e)))
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/// Describe an image
|
|
195
|
+
pub fn describe(&self, image_path: String, max_length: Option<usize>) -> std::result::Result<String, Error> {
|
|
196
|
+
self.ask(image_path, "Describe this image in detail.".to_string(), max_length)
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
/// Ask a question about an image
|
|
200
|
+
pub fn ask(&self, image_path: String, question: String, max_length: Option<usize>) -> std::result::Result<String, Error> {
|
|
201
|
+
let ruby = Ruby::get().unwrap();
|
|
202
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
203
|
+
let max_length = max_length.unwrap_or(256);
|
|
204
|
+
|
|
205
|
+
// Load and preprocess image
|
|
206
|
+
let image_tensor = self.load_image(&image_path)?;
|
|
207
|
+
|
|
208
|
+
// Build prompt with image token placeholder
|
|
209
|
+
// LLaVA 1.5 HF format: USER: <image>\n{question}\nASSISTANT:
|
|
210
|
+
let prompt = format!("USER: <image>\n{}\nASSISTANT:", question);
|
|
211
|
+
|
|
212
|
+
// Tokenize
|
|
213
|
+
let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
|
|
214
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
215
|
+
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
216
|
+
|
|
217
|
+
// LLaVA expects I64 input IDs
|
|
218
|
+
let input_ids_i64: Vec<i64> = input_ids.iter().map(|&id| id as i64).collect();
|
|
219
|
+
let input_tensor = Tensor::new(&input_ids_i64[..], &self.device)
|
|
220
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create input tensor: {}", e)))?
|
|
221
|
+
.unsqueeze(0)
|
|
222
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
|
|
223
|
+
|
|
224
|
+
let mut model_ref = self.model.borrow_mut();
|
|
225
|
+
let model = &mut model_ref.0;
|
|
226
|
+
let mut cache_ref = self.cache.borrow_mut();
|
|
227
|
+
let cache = &mut cache_ref.0;
|
|
228
|
+
|
|
229
|
+
// Prepare multimodal input: merge image features with text embeddings
|
|
230
|
+
let image_size = (self.image_size as u32, self.image_size as u32);
|
|
231
|
+
let input_embeds = model.prepare_inputs_labels_for_multimodal(
|
|
232
|
+
&input_tensor,
|
|
233
|
+
&[image_tensor],
|
|
234
|
+
&[image_size],
|
|
235
|
+
).map_err(|e| Error::new(runtime_error, format!("Failed to prepare multimodal input: {}", e)))?;
|
|
236
|
+
|
|
237
|
+
// Reset KV cache for fresh generation
|
|
238
|
+
let llama_config = self.config.to_llama_config();
|
|
239
|
+
*cache = Cache::new(true, DType::F32, &llama_config, &self.device)
|
|
240
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to reset cache: {}", e)))?;
|
|
241
|
+
|
|
242
|
+
// Generate tokens autoregressively
|
|
243
|
+
let mut generated_tokens: Vec<u32> = Vec::new();
|
|
244
|
+
let mut current_embeds = input_embeds;
|
|
245
|
+
let mut pos = 0usize;
|
|
246
|
+
|
|
247
|
+
for _i in 0..max_length {
|
|
248
|
+
let logits = model.forward(¤t_embeds, pos, cache)
|
|
249
|
+
.map_err(|e| Error::new(runtime_error, format!("Forward pass failed at pos {}: {}", pos, e)))?;
|
|
250
|
+
|
|
251
|
+
// Advance position by the number of tokens we just processed
|
|
252
|
+
let step_len = current_embeds.dim(1)
|
|
253
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get step len: {}", e)))?;
|
|
254
|
+
pos += step_len;
|
|
255
|
+
|
|
256
|
+
// Get logits for last position
|
|
257
|
+
let logits = logits.flatten_all()
|
|
258
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?;
|
|
259
|
+
|
|
260
|
+
// Handle multi-dim logits (take last token if needed)
|
|
261
|
+
let vocab_size = self.config.vocab_size;
|
|
262
|
+
let logits = if logits.elem_count() > vocab_size {
|
|
263
|
+
let n_tokens = logits.elem_count() / vocab_size;
|
|
264
|
+
logits.reshape((n_tokens, vocab_size))
|
|
265
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to reshape logits: {}", e)))?
|
|
266
|
+
.narrow(0, n_tokens - 1, 1)
|
|
267
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to narrow logits: {}", e)))?
|
|
268
|
+
.squeeze(0)
|
|
269
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze logits: {}", e)))?
|
|
270
|
+
} else {
|
|
271
|
+
logits
|
|
272
|
+
};
|
|
273
|
+
|
|
274
|
+
let logits = logits.to_dtype(DType::F32)
|
|
275
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
|
|
276
|
+
|
|
277
|
+
// Greedy decoding
|
|
278
|
+
let next_token = logits.argmax(0)
|
|
279
|
+
.map_err(|e| Error::new(runtime_error, format!("Argmax failed: {}", e)))?
|
|
280
|
+
.to_scalar::<u32>()
|
|
281
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get token: {}", e)))?;
|
|
282
|
+
|
|
283
|
+
// Check for EOS
|
|
284
|
+
if next_token == self.eos_token_id {
|
|
285
|
+
break;
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
generated_tokens.push(next_token);
|
|
289
|
+
|
|
290
|
+
// For subsequent tokens, embed directly through Llama's embedding layer
|
|
291
|
+
let next_input = Tensor::new(&[next_token as i64], &self.device)
|
|
292
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create next input: {}", e)))?
|
|
293
|
+
.unsqueeze(0)
|
|
294
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze next: {}", e)))?;
|
|
295
|
+
|
|
296
|
+
current_embeds = model.llama.embed(&next_input)
|
|
297
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to embed next token: {}", e)))?;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
// Decode generated tokens
|
|
301
|
+
let text = self.tokenizer.inner().decode(&generated_tokens, true)
|
|
302
|
+
.map_err(|e| Error::new(runtime_error, format!("Decoding failed: {}", e)))?;
|
|
303
|
+
|
|
304
|
+
Ok(text.trim().to_string())
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
pub fn model_id(&self) -> String {
|
|
308
|
+
self.model_id.clone()
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
pub fn device(&self) -> Device {
|
|
312
|
+
Device::from_device(&self.device)
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
|
|
316
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
|
321
|
+
let ruby = Ruby::get().unwrap();
|
|
322
|
+
let c_vlm = rb_candle.define_class("VLM", ruby.class_object())?;
|
|
323
|
+
c_vlm.define_singleton_method("_create", function!(VLM::new, 2))?;
|
|
324
|
+
c_vlm.define_method("_describe", method!(VLM::describe, 2))?;
|
|
325
|
+
c_vlm.define_method("_ask", method!(VLM::ask, 3))?;
|
|
326
|
+
c_vlm.define_method("model_id", method!(VLM::model_id, 0))?;
|
|
327
|
+
c_vlm.define_method("device", method!(VLM::device, 0))?;
|
|
328
|
+
c_vlm.define_method("tokenizer", method!(VLM::tokenizer, 0))?;
|
|
329
|
+
Ok(())
|
|
330
|
+
}
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#[cfg(test)]
|
|
2
|
+
mod integration_tests {
|
|
3
|
+
use super::super::*;
|
|
4
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
5
|
+
use std::sync::Arc;
|
|
6
|
+
|
|
7
|
+
#[tokio::test]
|
|
8
|
+
async fn test_schema_processor_with_vocabulary() {
|
|
9
|
+
// This test requires a tokenizer to create a vocabulary
|
|
10
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
|
11
|
+
|
|
12
|
+
if let Ok(tokenizer) = tokenizer_result {
|
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
|
14
|
+
|
|
15
|
+
// Create vocabulary from tokenizer
|
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
|
17
|
+
.expect("Should create vocabulary");
|
|
18
|
+
|
|
19
|
+
// Create schema processor
|
|
20
|
+
let processor = SchemaProcessor::new();
|
|
21
|
+
|
|
22
|
+
// Test with a simple JSON schema
|
|
23
|
+
let schema = r#"{
|
|
24
|
+
"type": "object",
|
|
25
|
+
"properties": {
|
|
26
|
+
"name": {"type": "string"},
|
|
27
|
+
"age": {"type": "integer"}
|
|
28
|
+
},
|
|
29
|
+
"required": ["name", "age"]
|
|
30
|
+
}"#;
|
|
31
|
+
|
|
32
|
+
// Process schema into Index
|
|
33
|
+
let index_result = processor.process_schema(schema, &vocabulary);
|
|
34
|
+
assert!(index_result.is_ok(), "Should process schema successfully");
|
|
35
|
+
|
|
36
|
+
// Test caching - second call should use cache
|
|
37
|
+
let index2_result = processor.process_schema(schema, &vocabulary);
|
|
38
|
+
assert!(index2_result.is_ok(), "Should retrieve from cache");
|
|
39
|
+
|
|
40
|
+
// Both should be the same Arc
|
|
41
|
+
let index1 = index_result.unwrap();
|
|
42
|
+
let index2 = index2_result.unwrap();
|
|
43
|
+
assert!(Arc::ptr_eq(&index1, &index2), "Should return cached Index");
|
|
44
|
+
|
|
45
|
+
// Check cache stats
|
|
46
|
+
let (size, _) = processor.cache_stats();
|
|
47
|
+
assert_eq!(size, 1, "Cache should have one entry");
|
|
48
|
+
} else {
|
|
49
|
+
eprintln!("Skipping integration test - couldn't load tokenizer");
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
#[tokio::test]
|
|
54
|
+
async fn test_regex_processing() {
|
|
55
|
+
let tokenizer_result = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await;
|
|
56
|
+
|
|
57
|
+
if let Ok(tokenizer) = tokenizer_result {
|
|
58
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
|
59
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
|
60
|
+
.expect("Should create vocabulary");
|
|
61
|
+
|
|
62
|
+
let processor = SchemaProcessor::new();
|
|
63
|
+
|
|
64
|
+
// Test with a simple regex pattern
|
|
65
|
+
let email_regex = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}";
|
|
66
|
+
|
|
67
|
+
let index_result = processor.process_regex(email_regex, &vocabulary);
|
|
68
|
+
assert!(index_result.is_ok(), "Should process regex successfully");
|
|
69
|
+
|
|
70
|
+
// Test different regex
|
|
71
|
+
let phone_regex = r"\d{3}-\d{3}-\d{4}";
|
|
72
|
+
let phone_index_result = processor.process_regex(phone_regex, &vocabulary);
|
|
73
|
+
assert!(phone_index_result.is_ok(), "Should process phone regex");
|
|
74
|
+
|
|
75
|
+
// Cache should have both
|
|
76
|
+
let (size, _) = processor.cache_stats();
|
|
77
|
+
assert_eq!(size, 2, "Cache should have two entries");
|
|
78
|
+
|
|
79
|
+
// Clear cache
|
|
80
|
+
processor.clear_cache();
|
|
81
|
+
let (size, _) = processor.cache_stats();
|
|
82
|
+
assert_eq!(size, 0, "Cache should be empty after clear");
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
#[test]
|
|
87
|
+
fn test_various_json_schemas() {
|
|
88
|
+
let _processor = SchemaProcessor::new();
|
|
89
|
+
|
|
90
|
+
// Array schema
|
|
91
|
+
let array_schema = serde_json::json!({
|
|
92
|
+
"type": "array",
|
|
93
|
+
"items": {"type": "string"}
|
|
94
|
+
});
|
|
95
|
+
|
|
96
|
+
// Process as a full schema instead of testing private method
|
|
97
|
+
// This would need a mock vocabulary in a real test
|
|
98
|
+
// For now, just verify the schema is valid JSON
|
|
99
|
+
let json_str = serde_json::to_string(&array_schema).unwrap();
|
|
100
|
+
assert!(!json_str.is_empty(), "Should serialize array schema");
|
|
101
|
+
|
|
102
|
+
// Nested object schema
|
|
103
|
+
let nested_schema = serde_json::json!({
|
|
104
|
+
"type": "object",
|
|
105
|
+
"properties": {
|
|
106
|
+
"user": {
|
|
107
|
+
"type": "object",
|
|
108
|
+
"properties": {
|
|
109
|
+
"id": {"type": "integer"},
|
|
110
|
+
"email": {"type": "string", "format": "email"}
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
// Verify nested schema is valid
|
|
117
|
+
let json_str = serde_json::to_string(&nested_schema).unwrap();
|
|
118
|
+
assert!(json_str.contains("properties"), "Should have nested properties");
|
|
119
|
+
|
|
120
|
+
// Schema with enum
|
|
121
|
+
let enum_schema = serde_json::json!({
|
|
122
|
+
"type": "string",
|
|
123
|
+
"enum": ["red", "green", "blue"]
|
|
124
|
+
});
|
|
125
|
+
|
|
126
|
+
// Verify enum schema is valid
|
|
127
|
+
let json_str = serde_json::to_string(&enum_schema).unwrap();
|
|
128
|
+
assert!(json_str.contains("enum"), "Should have enum values");
|
|
129
|
+
}
|
|
130
|
+
}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
/// Structured generation support using Outlines
|
|
2
|
+
///
|
|
3
|
+
/// This module provides functionality to constrain language model generation
|
|
4
|
+
/// to follow specific patterns, such as JSON schemas or regular expressions.
|
|
5
|
+
|
|
6
|
+
pub mod vocabulary_adapter;
|
|
7
|
+
pub mod schema_processor;
|
|
8
|
+
|
|
9
|
+
pub use vocabulary_adapter::VocabularyAdapter;
|
|
10
|
+
pub use schema_processor::SchemaProcessor;
|
|
11
|
+
|
|
12
|
+
// Re-export commonly used types from outlines-core
|
|
13
|
+
pub use outlines_core::prelude::Index;
|
|
14
|
+
pub use outlines_core::vocabulary::Vocabulary;
|
|
15
|
+
|
|
16
|
+
#[cfg(test)]
|
|
17
|
+
mod vocabulary_adapter_simple_test;
|
|
18
|
+
|
|
19
|
+
#[cfg(test)]
|
|
20
|
+
mod integration_test;
|
|
21
|
+
|
|
22
|
+
#[cfg(test)]
|
|
23
|
+
mod tests {
|
|
24
|
+
use super::*;
|
|
25
|
+
|
|
26
|
+
#[test]
|
|
27
|
+
fn test_module_imports() {
|
|
28
|
+
// Ensure all exports are available
|
|
29
|
+
let _ = VocabularyAdapter;
|
|
30
|
+
}
|
|
31
|
+
}
|