red-candle 1.8.0.pre2-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5193 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +33 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,488 @@
1
+ use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
2
+ use candle_transformers::models::bert::{BertModel, Config as BertConfig};
3
+ use candle_transformers::models::xlm_roberta::{
4
+ XLMRobertaForSequenceClassification, Config as XLMRobertaConfig,
5
+ };
6
+ use candle_transformers::models::debertav2::{
7
+ DebertaV2Model, DebertaV2ContextPooler, Config as DebertaV2Config,
8
+ };
9
+ use candle_transformers::models::modernbert::{
10
+ ModernBert, Config as ModernBertConfig,
11
+ };
12
+ use candle_transformers::models::qwen3::{
13
+ ModelForCausalLM as Qwen3Model, Config as Qwen3Config,
14
+ };
15
+ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
16
+ use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
17
+ use hf_hub::{api::sync::Api, Repo, RepoType};
18
+ use tokenizers::{EncodeInput, Tokenizer};
19
+ use std::cell::RefCell;
20
+ use crate::ruby::{Device, Result};
21
+ use crate::gvl;
22
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
23
+
24
+ enum RerankerModel {
25
+ Bert {
26
+ model: BertModel,
27
+ pooler: Linear,
28
+ classifier: Linear,
29
+ },
30
+ XLMRoberta {
31
+ model: XLMRobertaForSequenceClassification,
32
+ pad_token_id: u32,
33
+ },
34
+ DeBERTa {
35
+ model: DebertaV2Model,
36
+ pooler: DebertaV2ContextPooler,
37
+ classifier: Linear,
38
+ pad_token_id: u32,
39
+ },
40
+ ModernBert {
41
+ model: ModernBert,
42
+ head_dense: Linear,
43
+ head_norm: candle_nn::LayerNorm,
44
+ classifier: Linear,
45
+ pad_token_id: u32,
46
+ },
47
+ Qwen3 {
48
+ model: RefCell<Qwen3Model>,
49
+ yes_token_id: u32,
50
+ no_token_id: u32,
51
+ },
52
+ }
53
+
54
+ #[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
55
+ pub struct Reranker {
56
+ model: RerankerModel,
57
+ tokenizer: TokenizerWrapper,
58
+ device: CoreDevice,
59
+ model_id: String,
60
+ }
61
+
62
+ impl Reranker {
63
+ pub fn new(model_id: String, device: Option<Device>, max_length: Option<usize>) -> Result<Self> {
64
+ let device = device.unwrap_or(Device::best()).as_device()?;
65
+ let max_length = max_length.unwrap_or(512); // Default to 512
66
+ Self::new_with_core_device(model_id, device, max_length)
67
+ }
68
+
69
+ fn new_with_core_device(model_id: String, device: CoreDevice, max_length: usize) -> std::result::Result<Self, Error> {
70
+ let ruby = Ruby::get().unwrap();
71
+ let runtime_error = ruby.exception_runtime_error();
72
+
73
+ let result = (|| -> std::result::Result<(RerankerModel, TokenizerWrapper), Box<dyn std::error::Error + Send + Sync>> {
74
+ let api = Api::new()?;
75
+ let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
76
+
77
+ // Download model files
78
+ let config_filename = repo.get("config.json")?;
79
+ let tokenizer_filename = repo.get("tokenizer.json")?;
80
+ let weights_filename = repo.get("model.safetensors")?;
81
+
82
+ // Read raw config to detect model type
83
+ let config_str = std::fs::read_to_string(&config_filename)?;
84
+ let raw_config: serde_json::Value = serde_json::from_str(&config_str)?;
85
+ let model_type = raw_config["model_type"].as_str().unwrap_or("bert");
86
+
87
+ // Setup tokenizer with padding AND truncation
88
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
89
+ let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
90
+ let tokenizer = TokenizerLoader::with_truncation(tokenizer, max_length);
91
+
92
+ // Load model weights
93
+ let vb = unsafe {
94
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
95
+ };
96
+
97
+ let model = match model_type {
98
+ "xlm-roberta" => {
99
+ let config: XLMRobertaConfig = serde_json::from_str(&config_str)?;
100
+ let pad_token_id = config.pad_token_id;
101
+ let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
102
+ RerankerModel::XLMRoberta { model, pad_token_id }
103
+ }
104
+ "deberta-v2" => {
105
+ let config: DebertaV2Config = serde_json::from_str(&config_str)?;
106
+ let pad_token_id = config.pad_token_id.unwrap_or(0) as u32;
107
+ let model = DebertaV2Model::load(vb.pp("deberta"), &config)?;
108
+ let pooler = DebertaV2ContextPooler::load(vb.clone(), &config)?;
109
+ let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
110
+ let num_labels = config.id2label.as_ref().map_or(1, |m| m.len());
111
+ let classifier = candle_nn::linear(pooler_hidden_size, num_labels, vb.pp("classifier"))?;
112
+ RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id }
113
+ }
114
+ "qwen3" => {
115
+ let config: Qwen3Config = serde_json::from_str(&config_str)?;
116
+ let model = Qwen3Model::new(&config, vb)?;
117
+
118
+ // Look up "yes" and "no" token IDs from the tokenizer
119
+ let yes_token_id: u32 = tokenizer
120
+ .encode("yes", false)
121
+ .ok()
122
+ .and_then(|enc| enc.get_ids().first().copied())
123
+ .unwrap_or(9693);
124
+ let no_token_id: u32 = tokenizer
125
+ .encode("no", false)
126
+ .ok()
127
+ .and_then(|enc| enc.get_ids().first().copied())
128
+ .unwrap_or(2152);
129
+
130
+ RerankerModel::Qwen3 {
131
+ model: RefCell::new(model),
132
+ yes_token_id,
133
+ no_token_id,
134
+ }
135
+ }
136
+ "modernbert" => {
137
+ let config: ModernBertConfig = serde_json::from_str(&config_str)?;
138
+ let pad_token_id = config.pad_token_id;
139
+ let model = ModernBert::load(vb.clone(), &config)?;
140
+ // ModernBertHead::load is private, so load the head layers manually
141
+ let head_vb = vb.pp("head");
142
+ let head_dense = candle_nn::linear_no_bias(config.hidden_size, config.hidden_size, head_vb.pp("dense"))?;
143
+ let head_norm = candle_nn::layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, head_vb.pp("norm"))?;
144
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
145
+ RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id }
146
+ }
147
+ _ => {
148
+ let config: BertConfig = serde_json::from_str(&config_str)?;
149
+ let model = BertModel::load(vb.pp("bert"), &config)?;
150
+ let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
151
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
152
+ RerankerModel::Bert { model, pooler, classifier }
153
+ }
154
+ };
155
+
156
+ Ok((model, TokenizerWrapper::new(tokenizer)))
157
+ })();
158
+
159
+ match result {
160
+ Ok((model, tokenizer)) => {
161
+ Ok(Self { model, tokenizer, device, model_id })
162
+ }
163
+ Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
164
+ }
165
+ }
166
+
167
+ /// Extract CLS embeddings from the model output, handling Metal device workarounds
168
+ fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, String> {
169
+ let cls_embeddings = if self.device.is_metal() {
170
+ let (batch_size, seq_len, hidden_size) = embeddings.dims3()
171
+ .map_err(|e| format!("Failed to get dims: {}", e))?;
172
+ let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
173
+ .map_err(|e| format!("Failed to reshape: {}", e))?;
174
+ let mut cls_vecs = Vec::new();
175
+ for i in 0..batch_size {
176
+ let start_idx = i * seq_len;
177
+ let cls_vec = reshaped.narrow(0, start_idx, 1)
178
+ .map_err(|e| format!("Failed to extract CLS: {}", e))?;
179
+ cls_vecs.push(cls_vec);
180
+ }
181
+ Tensor::cat(&cls_vecs, 0)
182
+ .map_err(|e| format!("Failed to cat CLS tokens: {}", e))?
183
+ } else {
184
+ embeddings.i((.., 0))
185
+ .map_err(|e| format!("Failed to extract CLS token: {}", e))?
186
+ };
187
+ cls_embeddings.contiguous()
188
+ .map_err(|e| format!("Failed to make CLS embeddings contiguous: {}", e))
189
+ }
190
+
191
+ pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<RHash, Error> {
192
+ let ruby = Ruby::get().unwrap();
193
+ let runtime_error = ruby.exception_runtime_error();
194
+
195
+ // Create query-document pair for cross-encoder
196
+ let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
197
+
198
+ // Tokenize using the inner tokenizer for detailed info
199
+ let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
200
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
201
+
202
+ // Get token information
203
+ let token_ids = encoding.get_ids().to_vec();
204
+ let token_type_ids = encoding.get_type_ids().to_vec();
205
+ let attention_mask = encoding.get_attention_mask().to_vec();
206
+ let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
207
+
208
+ // Create result hash
209
+ let result = ruby.hash_new();
210
+ result.aset("token_ids", ruby.ary_from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
211
+ result.aset("token_type_ids", ruby.ary_from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
212
+ result.aset("attention_mask", ruby.ary_from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
213
+ result.aset("tokens", ruby.ary_from_vec(tokens))?;
214
+
215
+ Ok(result)
216
+ }
217
+
218
+ pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
219
+ let ruby = Ruby::get().unwrap();
220
+ let runtime_error = ruby.exception_runtime_error();
221
+ let documents: Vec<String> = documents.to_vec()?;
222
+
223
+ // Release the GVL for the entire compute portion (tokenization + inference + scoring).
224
+ // None of this calls Ruby API.
225
+ let ranked_docs = gvl::without_gvl(|| -> std::result::Result<Vec<(String, f32, usize)>, String> {
226
+ self.compute_rerank(&query, &documents, &pooling_method, apply_sigmoid)
227
+ });
228
+
229
+ let ranked_docs = ranked_docs
230
+ .map_err(|e| Error::new(runtime_error, e))?;
231
+
232
+ // Build result array (requires GVL for Ruby object creation)
233
+ let result_array = ruby.ary_new();
234
+ for (doc, score, doc_id) in ranked_docs {
235
+ let tuple = ruby.ary_new();
236
+ tuple.push(doc)?;
237
+ tuple.push(ruby.float_from_f64(score as f64))?;
238
+ tuple.push(doc_id)?;
239
+ result_array.push(tuple)?;
240
+ }
241
+ Ok(result_array)
242
+ }
243
+
244
+ /// Pure compute portion of reranking — no Ruby API calls.
245
+ /// Returns ranked (document, score, original_index) tuples.
246
+ fn compute_rerank(&self, query: &str, documents: &[String], pooling_method: &str, apply_sigmoid: bool) -> std::result::Result<Vec<(String, f32, usize)>, String> {
247
+ // Create query-document pairs for cross-encoder
248
+ let query_and_docs: Vec<EncodeInput> = documents
249
+ .iter()
250
+ .map(|d| (query.to_string(), d.clone()).into())
251
+ .collect();
252
+
253
+ // Tokenize batch
254
+ let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
255
+ .map_err(|e| format!("Tokenization failed: {}", e))?;
256
+
257
+ let token_ids_vec = encodings
258
+ .iter()
259
+ .map(|e| e.get_ids().to_vec())
260
+ .collect::<Vec<_>>();
261
+
262
+ let token_type_ids_vec = encodings
263
+ .iter()
264
+ .map(|e| e.get_type_ids().to_vec())
265
+ .collect::<Vec<_>>();
266
+
267
+ let token_ids = Tensor::new(token_ids_vec, &self.device)
268
+ .map_err(|e| format!("Failed to create tensor: {}", e))?;
269
+ let token_type_ids = Tensor::new(token_type_ids_vec, &self.device)
270
+ .map_err(|e| format!("Failed to create token type ids tensor: {}", e))?;
271
+
272
+ // Compute scores based on model type
273
+ let scores = match &self.model {
274
+ RerankerModel::Bert { model, pooler, classifier } => {
275
+ let attention_mask = token_ids.ne(0u32)
276
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
277
+
278
+ // Forward pass through BERT
279
+ let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
280
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
281
+
282
+ // Apply pooling based on the specified method
283
+ let pooled_embeddings = match pooling_method {
284
+ "pooler" => {
285
+ let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
286
+ let pooled = pooler.forward(&cls_embeddings)
287
+ .map_err(|e| format!("Pooler forward failed: {}", e))?;
288
+ pooled.tanh()
289
+ .map_err(|e| format!("Tanh activation failed: {}", e))?
290
+ },
291
+ "cls" => {
292
+ self.extract_cls_embeddings(&embeddings)?
293
+ },
294
+ "mean" => {
295
+ let (_batch, seq_len, _hidden) = embeddings.dims3()
296
+ .map_err(|e| format!("Failed to get tensor dimensions: {}", e))?;
297
+ let sum = embeddings.sum(1)
298
+ .map_err(|e| format!("Failed to sum embeddings: {}", e))?;
299
+ (sum / (seq_len as f64))
300
+ .map_err(|e| format!("Failed to compute mean: {}", e))?
301
+ },
302
+ _ => return Err(
303
+ format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method))
304
+ };
305
+
306
+ let pooled_embeddings = pooled_embeddings.contiguous()
307
+ .map_err(|e| format!("Failed to make pooled_embeddings contiguous: {}", e))?;
308
+ let logits = classifier.forward(&pooled_embeddings)
309
+ .map_err(|e| format!("Classifier forward failed: {}", e))?;
310
+ logits.squeeze(1)
311
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
312
+ }
313
+ RerankerModel::XLMRoberta { model, pad_token_id } => {
314
+ let attention_mask = token_ids.ne(*pad_token_id)
315
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
316
+
317
+ // XLMRobertaForSequenceClassification returns logits directly
318
+ let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
319
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
320
+ logits.squeeze(1)
321
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
322
+ }
323
+ RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
324
+ let attention_mask = token_ids.ne(*pad_token_id)
325
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
326
+
327
+ // Forward through DeBERTa encoder
328
+ let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
329
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
330
+
331
+ // Pool and classify
332
+ let pooled = pooler.forward(&encoder_output)
333
+ .map_err(|e| format!("Pooler forward failed: {}", e))?;
334
+ let logits = classifier.forward(&pooled)
335
+ .map_err(|e| format!("Classifier forward failed: {}", e))?;
336
+ logits.squeeze(1)
337
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
338
+ }
339
+ RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
340
+ let attention_mask = token_ids.ne(*pad_token_id)
341
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
342
+ let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
343
+ .map_err(|e| format!("Failed to convert attention mask: {}", e))?;
344
+
345
+ // Forward through ModernBERT encoder
346
+ let encoder_output = model.forward(&token_ids, &attention_mask_f32)
347
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
348
+
349
+ // CLS pooling, then head (dense + GELU + norm) + classifier
350
+ let cls = encoder_output.i((.., 0, ..))
351
+ .map_err(|e| format!("Failed to extract CLS: {}", e))?
352
+ .contiguous()
353
+ .map_err(|e| format!("Failed to make contiguous: {}", e))?;
354
+ let hidden = head_dense.forward(&cls)
355
+ .map_err(|e| format!("Head dense failed: {}", e))?;
356
+ let hidden = hidden.gelu_erf()
357
+ .map_err(|e| format!("GELU activation failed: {}", e))?;
358
+ let hidden = head_norm.forward(&hidden)
359
+ .map_err(|e| format!("Head norm failed: {}", e))?;
360
+ let logits = classifier.forward(&hidden)
361
+ .map_err(|e| format!("Classifier forward failed: {}", e))?;
362
+ logits.squeeze(1)
363
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
364
+ }
365
+ RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
366
+ // Qwen3 reranker: decoder-based yes/no scoring
367
+ // Process each document individually (causal LM, not batch encoder)
368
+ let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
369
+ let mut model = model.borrow_mut();
370
+
371
+ for doc in documents.iter() {
372
+ // Build the Qwen3 reranker prompt
373
+ let prompt = format!(
374
+ "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
375
+ query, doc
376
+ );
377
+
378
+ // Tokenize the prompt
379
+ let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
380
+ .map_err(|e| format!("Tokenization failed: {}", e))?;
381
+ let input_ids: Vec<u32> = encoding.get_ids().to_vec();
382
+
383
+ // Clear KV cache for each document
384
+ model.clear_kv_cache();
385
+
386
+ // Forward pass — get logits for the last token position
387
+ let input_tensor = Tensor::new(&input_ids[..], &self.device)
388
+ .map_err(|e| format!("Failed to create tensor: {}", e))?
389
+ .unsqueeze(0)
390
+ .map_err(|e| format!("Failed to unsqueeze: {}", e))?;
391
+
392
+ let logits = model.forward(&input_tensor, 0)
393
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
394
+
395
+ // logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
396
+ let logits = logits.flatten_all()
397
+ .map_err(|e| format!("Failed to flatten: {}", e))?
398
+ .to_dtype(DType::F32)
399
+ .map_err(|e| format!("Failed to convert dtype: {}", e))?;
400
+
401
+ // Extract yes/no logits and compute score
402
+ let yes_logit: f32 = logits.i(*yes_token_id as usize)
403
+ .map_err(|e| format!("Failed to get yes logit: {}", e))?
404
+ .to_scalar()
405
+ .map_err(|e| format!("Failed to convert yes logit: {}", e))?;
406
+ let no_logit: f32 = logits.i(*no_token_id as usize)
407
+ .map_err(|e| format!("Failed to get no logit: {}", e))?
408
+ .to_scalar()
409
+ .map_err(|e| format!("Failed to convert no logit: {}", e))?;
410
+
411
+ // softmax over [yes, no] → P(yes)
412
+ let max_logit = yes_logit.max(no_logit);
413
+ let yes_exp = (yes_logit - max_logit).exp();
414
+ let no_exp = (no_logit - max_logit).exp();
415
+ let score = yes_exp / (yes_exp + no_exp);
416
+
417
+ scores_vec.push(score);
418
+ }
419
+
420
+ // Build scores tensor for uniform handling below
421
+ Tensor::new(scores_vec.as_slice(), &self.device)
422
+ .map_err(|e| format!("Failed to create scores tensor: {}", e))?
423
+ }
424
+ };
425
+
426
+ // Optionally apply sigmoid activation
427
+ let scores = if apply_sigmoid {
428
+ sigmoid(&scores)
429
+ .map_err(|e| format!("Sigmoid failed: {}", e))?
430
+ } else {
431
+ scores
432
+ };
433
+
434
+ let scores_vec: Vec<f32> = scores.to_vec1()
435
+ .map_err(|e| format!("Failed to convert scores to vec: {}", e))?;
436
+
437
+ // Create tuples with document, score, and original index
438
+ let mut ranked_docs: Vec<(String, f32, usize)> = documents
439
+ .iter()
440
+ .cloned()
441
+ .zip(scores_vec)
442
+ .enumerate()
443
+ .map(|(idx, (doc, score))| (doc, score, idx))
444
+ .collect();
445
+
446
+ // Sort documents by relevance score (descending)
447
+ ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
448
+
449
+ Ok(ranked_docs)
450
+ }
451
+
452
+ /// Get the tokenizer used by this model
453
+ pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
454
+ Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
455
+ }
456
+
457
+ /// Get the model_id
458
+ pub fn model_id(&self) -> String {
459
+ self.model_id.clone()
460
+ }
461
+
462
+ /// Get the device
463
+ pub fn device(&self) -> Device {
464
+ Device::from_device(&self.device)
465
+ }
466
+
467
+ /// Get all options as a hash
468
+ pub fn options(&self) -> std::result::Result<RHash, Error> {
469
+ let ruby = Ruby::get().unwrap();
470
+ let hash = ruby.hash_new();
471
+ hash.aset("model_id", self.model_id.clone())?;
472
+ hash.aset("device", self.device().__str__())?;
473
+ Ok(hash)
474
+ }
475
+ }
476
+
477
+ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
478
+ let ruby = Ruby::get().unwrap();
479
+ let c_reranker = rb_candle.define_class("Reranker", ruby.class_object())?;
480
+ c_reranker.define_singleton_method("_create", function!(Reranker::new, 3))?;
481
+ c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
482
+ c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
483
+ c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
484
+ c_reranker.define_method("model_id", method!(Reranker::model_id, 0))?;
485
+ c_reranker.define_method("device", method!(Reranker::device, 0))?;
486
+ c_reranker.define_method("options", method!(Reranker::options, 0))?;
487
+ Ok(())
488
+ }
@@ -0,0 +1,3 @@
1
+ use magnus::Error;
2
+
3
+ pub type Result<T> = std::result::Result<T, Error>;
@@ -0,0 +1,92 @@
1
+ use magnus::{Error, Module, RModule, function, Object, Ruby};
2
+ use std::sync::Arc;
3
+
4
+ use crate::structured::{SchemaProcessor, VocabularyAdapter, Index, Vocabulary};
5
+ use crate::ruby::{Result, tokenizer::Tokenizer};
6
+
7
+ /// Ruby wrapper for structured generation constraints
8
+ #[derive(Clone, Debug)]
9
+ #[magnus::wrap(class = "Candle::StructuredConstraint", mark, free_immediately)]
10
+ pub struct StructuredConstraint {
11
+ pub(crate) index: Arc<Index>,
12
+ }
13
+
14
+ impl StructuredConstraint {
15
+ /// Create a constraint from a JSON schema using a model ID
16
+ /// This uses Vocabulary::from_pretrained which handles tokenizer byte encoding correctly
17
+ pub fn from_schema_with_model(schema: String, model_id: String) -> Result<Self> {
18
+ // Use tokio runtime for async vocabulary loading
19
+ let rt = tokio::runtime::Runtime::new()
20
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
21
+
22
+ let vocabulary = rt.block_on(async {
23
+ Vocabulary::from_pretrained(&model_id, None)
24
+ })
25
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
26
+
27
+ let processor = SchemaProcessor::new();
28
+ let index = processor.process_schema(&schema, &vocabulary)
29
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
30
+
31
+ Ok(Self { index })
32
+ }
33
+
34
+ /// Create a constraint from a regex pattern using a model ID
35
+ pub fn from_regex_with_model(pattern: String, model_id: String) -> Result<Self> {
36
+ // Use tokio runtime for async vocabulary loading
37
+ let rt = tokio::runtime::Runtime::new()
38
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
39
+
40
+ let vocabulary = rt.block_on(async {
41
+ Vocabulary::from_pretrained(&model_id, None)
42
+ })
43
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
44
+
45
+ let processor = SchemaProcessor::new();
46
+ let index = processor.process_regex(&pattern, &vocabulary)
47
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
48
+
49
+ Ok(Self { index })
50
+ }
51
+
52
+ /// Create a constraint from a JSON schema (legacy method using tokenizer directly)
53
+ /// Note: This may not handle all tokenizer byte encodings correctly
54
+ pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
55
+ let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
56
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
57
+
58
+ let processor = SchemaProcessor::new();
59
+ let index = processor.process_schema(&schema, &vocabulary)
60
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
61
+
62
+ Ok(Self { index })
63
+ }
64
+
65
+ /// Create a constraint from a regex pattern (legacy method using tokenizer directly)
66
+ /// Note: This may not handle all tokenizer byte encodings correctly
67
+ pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
68
+ let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
69
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
70
+
71
+ let processor = SchemaProcessor::new();
72
+ let index = processor.process_regex(&pattern, &vocabulary)
73
+ .map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
74
+
75
+ Ok(Self { index })
76
+ }
77
+ }
78
+
79
+ pub fn init_structured(rb_candle: RModule) -> Result<()> {
80
+ let ruby = Ruby::get().unwrap();
81
+ let class = rb_candle.define_class("StructuredConstraint", ruby.class_object())?;
82
+
83
+ // New methods using model_id for proper vocabulary loading
84
+ class.define_singleton_method("from_schema_with_model", function!(StructuredConstraint::from_schema_with_model, 2))?;
85
+ class.define_singleton_method("from_regex_with_model", function!(StructuredConstraint::from_regex_with_model, 2))?;
86
+
87
+ // Legacy methods using tokenizer directly (may have byte encoding issues with some models)
88
+ class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
89
+ class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
90
+
91
+ Ok(())
92
+ }