red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
|
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
|
|
3
|
+
use candle_transformers::models::xlm_roberta::{
|
|
4
|
+
XLMRobertaForSequenceClassification, Config as XLMRobertaConfig,
|
|
5
|
+
};
|
|
6
|
+
use candle_transformers::models::debertav2::{
|
|
7
|
+
DebertaV2Model, DebertaV2ContextPooler, Config as DebertaV2Config,
|
|
8
|
+
};
|
|
9
|
+
use candle_transformers::models::modernbert::{
|
|
10
|
+
ModernBert, Config as ModernBertConfig,
|
|
11
|
+
};
|
|
12
|
+
use candle_transformers::models::qwen3::{
|
|
13
|
+
ModelForCausalLM as Qwen3Model, Config as Qwen3Config,
|
|
14
|
+
};
|
|
15
|
+
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
16
|
+
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
|
17
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
18
|
+
use tokenizers::{EncodeInput, Tokenizer};
|
|
19
|
+
use std::cell::RefCell;
|
|
20
|
+
use crate::ruby::{Device, Result};
|
|
21
|
+
use crate::gvl;
|
|
22
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
23
|
+
|
|
24
|
+
enum RerankerModel {
|
|
25
|
+
Bert {
|
|
26
|
+
model: BertModel,
|
|
27
|
+
pooler: Linear,
|
|
28
|
+
classifier: Linear,
|
|
29
|
+
},
|
|
30
|
+
XLMRoberta {
|
|
31
|
+
model: XLMRobertaForSequenceClassification,
|
|
32
|
+
pad_token_id: u32,
|
|
33
|
+
},
|
|
34
|
+
DeBERTa {
|
|
35
|
+
model: DebertaV2Model,
|
|
36
|
+
pooler: DebertaV2ContextPooler,
|
|
37
|
+
classifier: Linear,
|
|
38
|
+
pad_token_id: u32,
|
|
39
|
+
},
|
|
40
|
+
ModernBert {
|
|
41
|
+
model: ModernBert,
|
|
42
|
+
head_dense: Linear,
|
|
43
|
+
head_norm: candle_nn::LayerNorm,
|
|
44
|
+
classifier: Linear,
|
|
45
|
+
pad_token_id: u32,
|
|
46
|
+
},
|
|
47
|
+
Qwen3 {
|
|
48
|
+
model: RefCell<Qwen3Model>,
|
|
49
|
+
yes_token_id: u32,
|
|
50
|
+
no_token_id: u32,
|
|
51
|
+
},
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
#[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
|
|
55
|
+
pub struct Reranker {
|
|
56
|
+
model: RerankerModel,
|
|
57
|
+
tokenizer: TokenizerWrapper,
|
|
58
|
+
device: CoreDevice,
|
|
59
|
+
model_id: String,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
impl Reranker {
|
|
63
|
+
pub fn new(model_id: String, device: Option<Device>, max_length: Option<usize>) -> Result<Self> {
|
|
64
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
|
65
|
+
let max_length = max_length.unwrap_or(512); // Default to 512
|
|
66
|
+
Self::new_with_core_device(model_id, device, max_length)
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
fn new_with_core_device(model_id: String, device: CoreDevice, max_length: usize) -> std::result::Result<Self, Error> {
|
|
70
|
+
let ruby = Ruby::get().unwrap();
|
|
71
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
72
|
+
|
|
73
|
+
let result = (|| -> std::result::Result<(RerankerModel, TokenizerWrapper), Box<dyn std::error::Error + Send + Sync>> {
|
|
74
|
+
let api = Api::new()?;
|
|
75
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
|
76
|
+
|
|
77
|
+
// Download model files
|
|
78
|
+
let config_filename = repo.get("config.json")?;
|
|
79
|
+
let tokenizer_filename = repo.get("tokenizer.json")?;
|
|
80
|
+
let weights_filename = repo.get("model.safetensors")?;
|
|
81
|
+
|
|
82
|
+
// Read raw config to detect model type
|
|
83
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
|
84
|
+
let raw_config: serde_json::Value = serde_json::from_str(&config_str)?;
|
|
85
|
+
let model_type = raw_config["model_type"].as_str().unwrap_or("bert");
|
|
86
|
+
|
|
87
|
+
// Setup tokenizer with padding AND truncation
|
|
88
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
|
89
|
+
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
|
90
|
+
let tokenizer = TokenizerLoader::with_truncation(tokenizer, max_length);
|
|
91
|
+
|
|
92
|
+
// Load model weights
|
|
93
|
+
let vb = unsafe {
|
|
94
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
|
95
|
+
};
|
|
96
|
+
|
|
97
|
+
let model = match model_type {
|
|
98
|
+
"xlm-roberta" => {
|
|
99
|
+
let config: XLMRobertaConfig = serde_json::from_str(&config_str)?;
|
|
100
|
+
let pad_token_id = config.pad_token_id;
|
|
101
|
+
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
|
|
102
|
+
RerankerModel::XLMRoberta { model, pad_token_id }
|
|
103
|
+
}
|
|
104
|
+
"deberta-v2" => {
|
|
105
|
+
let config: DebertaV2Config = serde_json::from_str(&config_str)?;
|
|
106
|
+
let pad_token_id = config.pad_token_id.unwrap_or(0) as u32;
|
|
107
|
+
let model = DebertaV2Model::load(vb.pp("deberta"), &config)?;
|
|
108
|
+
let pooler = DebertaV2ContextPooler::load(vb.clone(), &config)?;
|
|
109
|
+
let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
|
|
110
|
+
let num_labels = config.id2label.as_ref().map_or(1, |m| m.len());
|
|
111
|
+
let classifier = candle_nn::linear(pooler_hidden_size, num_labels, vb.pp("classifier"))?;
|
|
112
|
+
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id }
|
|
113
|
+
}
|
|
114
|
+
"qwen3" => {
|
|
115
|
+
let config: Qwen3Config = serde_json::from_str(&config_str)?;
|
|
116
|
+
let model = Qwen3Model::new(&config, vb)?;
|
|
117
|
+
|
|
118
|
+
// Look up "yes" and "no" token IDs from the tokenizer
|
|
119
|
+
let yes_token_id: u32 = tokenizer
|
|
120
|
+
.encode("yes", false)
|
|
121
|
+
.ok()
|
|
122
|
+
.and_then(|enc| enc.get_ids().first().copied())
|
|
123
|
+
.unwrap_or(9693);
|
|
124
|
+
let no_token_id: u32 = tokenizer
|
|
125
|
+
.encode("no", false)
|
|
126
|
+
.ok()
|
|
127
|
+
.and_then(|enc| enc.get_ids().first().copied())
|
|
128
|
+
.unwrap_or(2152);
|
|
129
|
+
|
|
130
|
+
RerankerModel::Qwen3 {
|
|
131
|
+
model: RefCell::new(model),
|
|
132
|
+
yes_token_id,
|
|
133
|
+
no_token_id,
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
"modernbert" => {
|
|
137
|
+
let config: ModernBertConfig = serde_json::from_str(&config_str)?;
|
|
138
|
+
let pad_token_id = config.pad_token_id;
|
|
139
|
+
let model = ModernBert::load(vb.clone(), &config)?;
|
|
140
|
+
// ModernBertHead::load is private, so load the head layers manually
|
|
141
|
+
let head_vb = vb.pp("head");
|
|
142
|
+
let head_dense = candle_nn::linear_no_bias(config.hidden_size, config.hidden_size, head_vb.pp("dense"))?;
|
|
143
|
+
let head_norm = candle_nn::layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, head_vb.pp("norm"))?;
|
|
144
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
|
145
|
+
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id }
|
|
146
|
+
}
|
|
147
|
+
_ => {
|
|
148
|
+
let config: BertConfig = serde_json::from_str(&config_str)?;
|
|
149
|
+
let model = BertModel::load(vb.pp("bert"), &config)?;
|
|
150
|
+
let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
|
|
151
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
|
152
|
+
RerankerModel::Bert { model, pooler, classifier }
|
|
153
|
+
}
|
|
154
|
+
};
|
|
155
|
+
|
|
156
|
+
Ok((model, TokenizerWrapper::new(tokenizer)))
|
|
157
|
+
})();
|
|
158
|
+
|
|
159
|
+
match result {
|
|
160
|
+
Ok((model, tokenizer)) => {
|
|
161
|
+
Ok(Self { model, tokenizer, device, model_id })
|
|
162
|
+
}
|
|
163
|
+
Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
|
168
|
+
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, String> {
|
|
169
|
+
let cls_embeddings = if self.device.is_metal() {
|
|
170
|
+
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
|
171
|
+
.map_err(|e| format!("Failed to get dims: {}", e))?;
|
|
172
|
+
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
|
173
|
+
.map_err(|e| format!("Failed to reshape: {}", e))?;
|
|
174
|
+
let mut cls_vecs = Vec::new();
|
|
175
|
+
for i in 0..batch_size {
|
|
176
|
+
let start_idx = i * seq_len;
|
|
177
|
+
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
|
178
|
+
.map_err(|e| format!("Failed to extract CLS: {}", e))?;
|
|
179
|
+
cls_vecs.push(cls_vec);
|
|
180
|
+
}
|
|
181
|
+
Tensor::cat(&cls_vecs, 0)
|
|
182
|
+
.map_err(|e| format!("Failed to cat CLS tokens: {}", e))?
|
|
183
|
+
} else {
|
|
184
|
+
embeddings.i((.., 0))
|
|
185
|
+
.map_err(|e| format!("Failed to extract CLS token: {}", e))?
|
|
186
|
+
};
|
|
187
|
+
cls_embeddings.contiguous()
|
|
188
|
+
.map_err(|e| format!("Failed to make CLS embeddings contiguous: {}", e))
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<RHash, Error> {
|
|
192
|
+
let ruby = Ruby::get().unwrap();
|
|
193
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
194
|
+
|
|
195
|
+
// Create query-document pair for cross-encoder
|
|
196
|
+
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
|
197
|
+
|
|
198
|
+
// Tokenize using the inner tokenizer for detailed info
|
|
199
|
+
let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
|
|
200
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
201
|
+
|
|
202
|
+
// Get token information
|
|
203
|
+
let token_ids = encoding.get_ids().to_vec();
|
|
204
|
+
let token_type_ids = encoding.get_type_ids().to_vec();
|
|
205
|
+
let attention_mask = encoding.get_attention_mask().to_vec();
|
|
206
|
+
let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
|
|
207
|
+
|
|
208
|
+
// Create result hash
|
|
209
|
+
let result = ruby.hash_new();
|
|
210
|
+
result.aset("token_ids", ruby.ary_from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
|
211
|
+
result.aset("token_type_ids", ruby.ary_from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
|
212
|
+
result.aset("attention_mask", ruby.ary_from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
|
|
213
|
+
result.aset("tokens", ruby.ary_from_vec(tokens))?;
|
|
214
|
+
|
|
215
|
+
Ok(result)
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
|
|
219
|
+
let ruby = Ruby::get().unwrap();
|
|
220
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
221
|
+
let documents: Vec<String> = documents.to_vec()?;
|
|
222
|
+
|
|
223
|
+
// Release the GVL for the entire compute portion (tokenization + inference + scoring).
|
|
224
|
+
// None of this calls Ruby API.
|
|
225
|
+
let ranked_docs = gvl::without_gvl(|| -> std::result::Result<Vec<(String, f32, usize)>, String> {
|
|
226
|
+
self.compute_rerank(&query, &documents, &pooling_method, apply_sigmoid)
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
let ranked_docs = ranked_docs
|
|
230
|
+
.map_err(|e| Error::new(runtime_error, e))?;
|
|
231
|
+
|
|
232
|
+
// Build result array (requires GVL for Ruby object creation)
|
|
233
|
+
let result_array = ruby.ary_new();
|
|
234
|
+
for (doc, score, doc_id) in ranked_docs {
|
|
235
|
+
let tuple = ruby.ary_new();
|
|
236
|
+
tuple.push(doc)?;
|
|
237
|
+
tuple.push(ruby.float_from_f64(score as f64))?;
|
|
238
|
+
tuple.push(doc_id)?;
|
|
239
|
+
result_array.push(tuple)?;
|
|
240
|
+
}
|
|
241
|
+
Ok(result_array)
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
/// Pure compute portion of reranking — no Ruby API calls.
|
|
245
|
+
/// Returns ranked (document, score, original_index) tuples.
|
|
246
|
+
fn compute_rerank(&self, query: &str, documents: &[String], pooling_method: &str, apply_sigmoid: bool) -> std::result::Result<Vec<(String, f32, usize)>, String> {
|
|
247
|
+
// Create query-document pairs for cross-encoder
|
|
248
|
+
let query_and_docs: Vec<EncodeInput> = documents
|
|
249
|
+
.iter()
|
|
250
|
+
.map(|d| (query.to_string(), d.clone()).into())
|
|
251
|
+
.collect();
|
|
252
|
+
|
|
253
|
+
// Tokenize batch
|
|
254
|
+
let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
|
|
255
|
+
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
|
256
|
+
|
|
257
|
+
let token_ids_vec = encodings
|
|
258
|
+
.iter()
|
|
259
|
+
.map(|e| e.get_ids().to_vec())
|
|
260
|
+
.collect::<Vec<_>>();
|
|
261
|
+
|
|
262
|
+
let token_type_ids_vec = encodings
|
|
263
|
+
.iter()
|
|
264
|
+
.map(|e| e.get_type_ids().to_vec())
|
|
265
|
+
.collect::<Vec<_>>();
|
|
266
|
+
|
|
267
|
+
let token_ids = Tensor::new(token_ids_vec, &self.device)
|
|
268
|
+
.map_err(|e| format!("Failed to create tensor: {}", e))?;
|
|
269
|
+
let token_type_ids = Tensor::new(token_type_ids_vec, &self.device)
|
|
270
|
+
.map_err(|e| format!("Failed to create token type ids tensor: {}", e))?;
|
|
271
|
+
|
|
272
|
+
// Compute scores based on model type
|
|
273
|
+
let scores = match &self.model {
|
|
274
|
+
RerankerModel::Bert { model, pooler, classifier } => {
|
|
275
|
+
let attention_mask = token_ids.ne(0u32)
|
|
276
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
277
|
+
|
|
278
|
+
// Forward pass through BERT
|
|
279
|
+
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
|
280
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
281
|
+
|
|
282
|
+
// Apply pooling based on the specified method
|
|
283
|
+
let pooled_embeddings = match pooling_method {
|
|
284
|
+
"pooler" => {
|
|
285
|
+
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
|
286
|
+
let pooled = pooler.forward(&cls_embeddings)
|
|
287
|
+
.map_err(|e| format!("Pooler forward failed: {}", e))?;
|
|
288
|
+
pooled.tanh()
|
|
289
|
+
.map_err(|e| format!("Tanh activation failed: {}", e))?
|
|
290
|
+
},
|
|
291
|
+
"cls" => {
|
|
292
|
+
self.extract_cls_embeddings(&embeddings)?
|
|
293
|
+
},
|
|
294
|
+
"mean" => {
|
|
295
|
+
let (_batch, seq_len, _hidden) = embeddings.dims3()
|
|
296
|
+
.map_err(|e| format!("Failed to get tensor dimensions: {}", e))?;
|
|
297
|
+
let sum = embeddings.sum(1)
|
|
298
|
+
.map_err(|e| format!("Failed to sum embeddings: {}", e))?;
|
|
299
|
+
(sum / (seq_len as f64))
|
|
300
|
+
.map_err(|e| format!("Failed to compute mean: {}", e))?
|
|
301
|
+
},
|
|
302
|
+
_ => return Err(
|
|
303
|
+
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method))
|
|
304
|
+
};
|
|
305
|
+
|
|
306
|
+
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
307
|
+
.map_err(|e| format!("Failed to make pooled_embeddings contiguous: {}", e))?;
|
|
308
|
+
let logits = classifier.forward(&pooled_embeddings)
|
|
309
|
+
.map_err(|e| format!("Classifier forward failed: {}", e))?;
|
|
310
|
+
logits.squeeze(1)
|
|
311
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
312
|
+
}
|
|
313
|
+
RerankerModel::XLMRoberta { model, pad_token_id } => {
|
|
314
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
315
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
316
|
+
|
|
317
|
+
// XLMRobertaForSequenceClassification returns logits directly
|
|
318
|
+
let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
|
|
319
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
320
|
+
logits.squeeze(1)
|
|
321
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
322
|
+
}
|
|
323
|
+
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
|
|
324
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
325
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
326
|
+
|
|
327
|
+
// Forward through DeBERTa encoder
|
|
328
|
+
let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
|
|
329
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
330
|
+
|
|
331
|
+
// Pool and classify
|
|
332
|
+
let pooled = pooler.forward(&encoder_output)
|
|
333
|
+
.map_err(|e| format!("Pooler forward failed: {}", e))?;
|
|
334
|
+
let logits = classifier.forward(&pooled)
|
|
335
|
+
.map_err(|e| format!("Classifier forward failed: {}", e))?;
|
|
336
|
+
logits.squeeze(1)
|
|
337
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
338
|
+
}
|
|
339
|
+
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
|
|
340
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
341
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
342
|
+
let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
|
|
343
|
+
.map_err(|e| format!("Failed to convert attention mask: {}", e))?;
|
|
344
|
+
|
|
345
|
+
// Forward through ModernBERT encoder
|
|
346
|
+
let encoder_output = model.forward(&token_ids, &attention_mask_f32)
|
|
347
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
348
|
+
|
|
349
|
+
// CLS pooling, then head (dense + GELU + norm) + classifier
|
|
350
|
+
let cls = encoder_output.i((.., 0, ..))
|
|
351
|
+
.map_err(|e| format!("Failed to extract CLS: {}", e))?
|
|
352
|
+
.contiguous()
|
|
353
|
+
.map_err(|e| format!("Failed to make contiguous: {}", e))?;
|
|
354
|
+
let hidden = head_dense.forward(&cls)
|
|
355
|
+
.map_err(|e| format!("Head dense failed: {}", e))?;
|
|
356
|
+
let hidden = hidden.gelu_erf()
|
|
357
|
+
.map_err(|e| format!("GELU activation failed: {}", e))?;
|
|
358
|
+
let hidden = head_norm.forward(&hidden)
|
|
359
|
+
.map_err(|e| format!("Head norm failed: {}", e))?;
|
|
360
|
+
let logits = classifier.forward(&hidden)
|
|
361
|
+
.map_err(|e| format!("Classifier forward failed: {}", e))?;
|
|
362
|
+
logits.squeeze(1)
|
|
363
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
364
|
+
}
|
|
365
|
+
RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
|
|
366
|
+
// Qwen3 reranker: decoder-based yes/no scoring
|
|
367
|
+
// Process each document individually (causal LM, not batch encoder)
|
|
368
|
+
let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
|
|
369
|
+
let mut model = model.borrow_mut();
|
|
370
|
+
|
|
371
|
+
for doc in documents.iter() {
|
|
372
|
+
// Build the Qwen3 reranker prompt
|
|
373
|
+
let prompt = format!(
|
|
374
|
+
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
|
375
|
+
query, doc
|
|
376
|
+
);
|
|
377
|
+
|
|
378
|
+
// Tokenize the prompt
|
|
379
|
+
let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
|
|
380
|
+
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
|
381
|
+
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
382
|
+
|
|
383
|
+
// Clear KV cache for each document
|
|
384
|
+
model.clear_kv_cache();
|
|
385
|
+
|
|
386
|
+
// Forward pass — get logits for the last token position
|
|
387
|
+
let input_tensor = Tensor::new(&input_ids[..], &self.device)
|
|
388
|
+
.map_err(|e| format!("Failed to create tensor: {}", e))?
|
|
389
|
+
.unsqueeze(0)
|
|
390
|
+
.map_err(|e| format!("Failed to unsqueeze: {}", e))?;
|
|
391
|
+
|
|
392
|
+
let logits = model.forward(&input_tensor, 0)
|
|
393
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
394
|
+
|
|
395
|
+
// logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
|
|
396
|
+
let logits = logits.flatten_all()
|
|
397
|
+
.map_err(|e| format!("Failed to flatten: {}", e))?
|
|
398
|
+
.to_dtype(DType::F32)
|
|
399
|
+
.map_err(|e| format!("Failed to convert dtype: {}", e))?;
|
|
400
|
+
|
|
401
|
+
// Extract yes/no logits and compute score
|
|
402
|
+
let yes_logit: f32 = logits.i(*yes_token_id as usize)
|
|
403
|
+
.map_err(|e| format!("Failed to get yes logit: {}", e))?
|
|
404
|
+
.to_scalar()
|
|
405
|
+
.map_err(|e| format!("Failed to convert yes logit: {}", e))?;
|
|
406
|
+
let no_logit: f32 = logits.i(*no_token_id as usize)
|
|
407
|
+
.map_err(|e| format!("Failed to get no logit: {}", e))?
|
|
408
|
+
.to_scalar()
|
|
409
|
+
.map_err(|e| format!("Failed to convert no logit: {}", e))?;
|
|
410
|
+
|
|
411
|
+
// softmax over [yes, no] → P(yes)
|
|
412
|
+
let max_logit = yes_logit.max(no_logit);
|
|
413
|
+
let yes_exp = (yes_logit - max_logit).exp();
|
|
414
|
+
let no_exp = (no_logit - max_logit).exp();
|
|
415
|
+
let score = yes_exp / (yes_exp + no_exp);
|
|
416
|
+
|
|
417
|
+
scores_vec.push(score);
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
// Build scores tensor for uniform handling below
|
|
421
|
+
Tensor::new(scores_vec.as_slice(), &self.device)
|
|
422
|
+
.map_err(|e| format!("Failed to create scores tensor: {}", e))?
|
|
423
|
+
}
|
|
424
|
+
};
|
|
425
|
+
|
|
426
|
+
// Optionally apply sigmoid activation
|
|
427
|
+
let scores = if apply_sigmoid {
|
|
428
|
+
sigmoid(&scores)
|
|
429
|
+
.map_err(|e| format!("Sigmoid failed: {}", e))?
|
|
430
|
+
} else {
|
|
431
|
+
scores
|
|
432
|
+
};
|
|
433
|
+
|
|
434
|
+
let scores_vec: Vec<f32> = scores.to_vec1()
|
|
435
|
+
.map_err(|e| format!("Failed to convert scores to vec: {}", e))?;
|
|
436
|
+
|
|
437
|
+
// Create tuples with document, score, and original index
|
|
438
|
+
let mut ranked_docs: Vec<(String, f32, usize)> = documents
|
|
439
|
+
.iter()
|
|
440
|
+
.cloned()
|
|
441
|
+
.zip(scores_vec)
|
|
442
|
+
.enumerate()
|
|
443
|
+
.map(|(idx, (doc, score))| (doc, score, idx))
|
|
444
|
+
.collect();
|
|
445
|
+
|
|
446
|
+
// Sort documents by relevance score (descending)
|
|
447
|
+
ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
448
|
+
|
|
449
|
+
Ok(ranked_docs)
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
/// Get the tokenizer used by this model
|
|
453
|
+
pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
|
|
454
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
/// Get the model_id
|
|
458
|
+
pub fn model_id(&self) -> String {
|
|
459
|
+
self.model_id.clone()
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
/// Get the device
|
|
463
|
+
pub fn device(&self) -> Device {
|
|
464
|
+
Device::from_device(&self.device)
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
/// Get all options as a hash
|
|
468
|
+
pub fn options(&self) -> std::result::Result<RHash, Error> {
|
|
469
|
+
let ruby = Ruby::get().unwrap();
|
|
470
|
+
let hash = ruby.hash_new();
|
|
471
|
+
hash.aset("model_id", self.model_id.clone())?;
|
|
472
|
+
hash.aset("device", self.device().__str__())?;
|
|
473
|
+
Ok(hash)
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
|
478
|
+
let ruby = Ruby::get().unwrap();
|
|
479
|
+
let c_reranker = rb_candle.define_class("Reranker", ruby.class_object())?;
|
|
480
|
+
c_reranker.define_singleton_method("_create", function!(Reranker::new, 3))?;
|
|
481
|
+
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
|
482
|
+
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
|
483
|
+
c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
|
|
484
|
+
c_reranker.define_method("model_id", method!(Reranker::model_id, 0))?;
|
|
485
|
+
c_reranker.define_method("device", method!(Reranker::device, 0))?;
|
|
486
|
+
c_reranker.define_method("options", method!(Reranker::options, 0))?;
|
|
487
|
+
Ok(())
|
|
488
|
+
}
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
use magnus::{Error, Module, RModule, function, Object, Ruby};
|
|
2
|
+
use std::sync::Arc;
|
|
3
|
+
|
|
4
|
+
use crate::structured::{SchemaProcessor, VocabularyAdapter, Index, Vocabulary};
|
|
5
|
+
use crate::ruby::{Result, tokenizer::Tokenizer};
|
|
6
|
+
|
|
7
|
+
/// Ruby wrapper for structured generation constraints
|
|
8
|
+
#[derive(Clone, Debug)]
|
|
9
|
+
#[magnus::wrap(class = "Candle::StructuredConstraint", mark, free_immediately)]
|
|
10
|
+
pub struct StructuredConstraint {
|
|
11
|
+
pub(crate) index: Arc<Index>,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
impl StructuredConstraint {
|
|
15
|
+
/// Create a constraint from a JSON schema using a model ID
|
|
16
|
+
/// This uses Vocabulary::from_pretrained which handles tokenizer byte encoding correctly
|
|
17
|
+
pub fn from_schema_with_model(schema: String, model_id: String) -> Result<Self> {
|
|
18
|
+
// Use tokio runtime for async vocabulary loading
|
|
19
|
+
let rt = tokio::runtime::Runtime::new()
|
|
20
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
|
21
|
+
|
|
22
|
+
let vocabulary = rt.block_on(async {
|
|
23
|
+
Vocabulary::from_pretrained(&model_id, None)
|
|
24
|
+
})
|
|
25
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
|
|
26
|
+
|
|
27
|
+
let processor = SchemaProcessor::new();
|
|
28
|
+
let index = processor.process_schema(&schema, &vocabulary)
|
|
29
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
|
|
30
|
+
|
|
31
|
+
Ok(Self { index })
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
/// Create a constraint from a regex pattern using a model ID
|
|
35
|
+
pub fn from_regex_with_model(pattern: String, model_id: String) -> Result<Self> {
|
|
36
|
+
// Use tokio runtime for async vocabulary loading
|
|
37
|
+
let rt = tokio::runtime::Runtime::new()
|
|
38
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
|
39
|
+
|
|
40
|
+
let vocabulary = rt.block_on(async {
|
|
41
|
+
Vocabulary::from_pretrained(&model_id, None)
|
|
42
|
+
})
|
|
43
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
|
|
44
|
+
|
|
45
|
+
let processor = SchemaProcessor::new();
|
|
46
|
+
let index = processor.process_regex(&pattern, &vocabulary)
|
|
47
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
|
|
48
|
+
|
|
49
|
+
Ok(Self { index })
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/// Create a constraint from a JSON schema (legacy method using tokenizer directly)
|
|
53
|
+
/// Note: This may not handle all tokenizer byte encodings correctly
|
|
54
|
+
pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
|
|
55
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
|
56
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
|
57
|
+
|
|
58
|
+
let processor = SchemaProcessor::new();
|
|
59
|
+
let index = processor.process_schema(&schema, &vocabulary)
|
|
60
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
|
|
61
|
+
|
|
62
|
+
Ok(Self { index })
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/// Create a constraint from a regex pattern (legacy method using tokenizer directly)
|
|
66
|
+
/// Note: This may not handle all tokenizer byte encodings correctly
|
|
67
|
+
pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
|
|
68
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
|
69
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
|
70
|
+
|
|
71
|
+
let processor = SchemaProcessor::new();
|
|
72
|
+
let index = processor.process_regex(&pattern, &vocabulary)
|
|
73
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
|
|
74
|
+
|
|
75
|
+
Ok(Self { index })
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
pub fn init_structured(rb_candle: RModule) -> Result<()> {
|
|
80
|
+
let ruby = Ruby::get().unwrap();
|
|
81
|
+
let class = rb_candle.define_class("StructuredConstraint", ruby.class_object())?;
|
|
82
|
+
|
|
83
|
+
// New methods using model_id for proper vocabulary loading
|
|
84
|
+
class.define_singleton_method("from_schema_with_model", function!(StructuredConstraint::from_schema_with_model, 2))?;
|
|
85
|
+
class.define_singleton_method("from_regex_with_model", function!(StructuredConstraint::from_regex_with_model, 2))?;
|
|
86
|
+
|
|
87
|
+
// Legacy methods using tokenizer directly (may have byte encoding issues with some models)
|
|
88
|
+
class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
|
|
89
|
+
class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
|
|
90
|
+
|
|
91
|
+
Ok(())
|
|
92
|
+
}
|