red-candle 1.7.0 → 1.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +1 -0
- data/ext/candle/src/ruby/embedding_model.rs +16 -14
- data/ext/candle/src/ruby/llm.rs +9 -3
- data/ext/candle/src/ruby/ner.rs +8 -4
- data/ext/candle/src/ruby/reranker.rs +89 -85
- data/lib/candle/version.rb +1 -1
- metadata +2 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 2a4feba5be08ae2cb7a70b4e61671a5cd434c65a1fa29d98b77841187caf2893
|
|
4
|
+
data.tar.gz: 4c453e5cbf48837fa3b68af52f1f409ddaf039fe3f0b49e13a3c66693614f18f
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 9c5b31367d0efd7d99ec866ab44615d94cbbd98cd401edb3d8393b892fdac0d1d3657a2903c02f60b20b2f8e3649a3eeac29d9f6348f3a0a51ec1fd97ad60757
|
|
7
|
+
data.tar.gz: 60da5b3c14e101b44e53321276c17e3a80074fb56e5e927469908ffb1deb17443b215c317405c526bb79a2d22df9e8837a32a3f7d486851b89e51a9cdc2ccf77
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
/// GVL (Global VM Lock) release support for Ruby.
|
|
2
|
+
///
|
|
3
|
+
/// Ruby's GVL prevents other Ruby threads from running while native code
|
|
4
|
+
/// executes. For long-running operations (LLM inference, reranking, embedding),
|
|
5
|
+
/// we release the GVL so other threads (TUI render loops, HTTP servers, etc.)
|
|
6
|
+
/// can run concurrently.
|
|
7
|
+
///
|
|
8
|
+
/// SAFETY: Code running without the GVL must NOT call any Ruby API.
|
|
9
|
+
|
|
10
|
+
use std::os::raw::c_void;
|
|
11
|
+
|
|
12
|
+
type UnblockFn = unsafe extern "C" fn(*mut c_void);
|
|
13
|
+
|
|
14
|
+
extern "C" {
|
|
15
|
+
fn rb_thread_call_without_gvl(
|
|
16
|
+
func: unsafe extern "C" fn(*mut c_void) -> *mut c_void,
|
|
17
|
+
data1: *mut c_void,
|
|
18
|
+
ubf: Option<UnblockFn>,
|
|
19
|
+
data2: *mut c_void,
|
|
20
|
+
) -> *mut c_void;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
/// Run a closure without the GVL. The closure must not call any Ruby API.
|
|
24
|
+
pub fn without_gvl<F, R>(f: F) -> R
|
|
25
|
+
where
|
|
26
|
+
F: FnOnce() -> R,
|
|
27
|
+
{
|
|
28
|
+
struct CallData<F, R> {
|
|
29
|
+
func: Option<F>,
|
|
30
|
+
result: Option<R>,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
unsafe extern "C" fn call_func<F, R>(data: *mut c_void) -> *mut c_void
|
|
34
|
+
where
|
|
35
|
+
F: FnOnce() -> R,
|
|
36
|
+
{
|
|
37
|
+
let data = &mut *(data as *mut CallData<F, R>);
|
|
38
|
+
let func = data.func.take().unwrap();
|
|
39
|
+
data.result = Some(func());
|
|
40
|
+
std::ptr::null_mut()
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
let mut data = CallData {
|
|
44
|
+
func: Some(f),
|
|
45
|
+
result: None,
|
|
46
|
+
};
|
|
47
|
+
|
|
48
|
+
unsafe {
|
|
49
|
+
rb_thread_call_without_gvl(
|
|
50
|
+
call_func::<F, R>,
|
|
51
|
+
&mut data as *mut _ as *mut c_void,
|
|
52
|
+
None,
|
|
53
|
+
std::ptr::null_mut(),
|
|
54
|
+
);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
data.result.unwrap()
|
|
58
|
+
}
|
data/ext/candle/src/lib.rs
CHANGED
|
@@ -104,14 +104,15 @@ impl EmbeddingModel {
|
|
|
104
104
|
/// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
|
|
105
105
|
pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
|
|
106
106
|
let ruby = Ruby::get().unwrap();
|
|
107
|
-
match &self.0.model {
|
|
108
|
-
Some(model) => {
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
107
|
+
match (&self.0.model, &self.0.tokenizer) {
|
|
108
|
+
(Some(model), Some(tokenizer)) => {
|
|
109
|
+
let result = crate::gvl::without_gvl(|| {
|
|
110
|
+
self.compute_embedding(input, model, tokenizer, &pooling_method)
|
|
111
|
+
});
|
|
112
|
+
Ok(Tensor(result?))
|
|
113
113
|
}
|
|
114
|
-
None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
|
|
114
|
+
(None, _) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found")),
|
|
115
|
+
(_, None) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found")),
|
|
115
116
|
}
|
|
116
117
|
}
|
|
117
118
|
|
|
@@ -119,14 +120,15 @@ impl EmbeddingModel {
|
|
|
119
120
|
/// &RETURNS&: Tensor
|
|
120
121
|
pub fn embeddings(&self, input: String) -> Result<Tensor> {
|
|
121
122
|
let ruby = Ruby::get().unwrap();
|
|
122
|
-
match &self.0.model {
|
|
123
|
-
Some(model) => {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
123
|
+
match (&self.0.model, &self.0.tokenizer) {
|
|
124
|
+
(Some(model), Some(tokenizer)) => {
|
|
125
|
+
let result = crate::gvl::without_gvl(|| {
|
|
126
|
+
self.compute_embeddings(input, model, tokenizer)
|
|
127
|
+
});
|
|
128
|
+
Ok(Tensor(result?))
|
|
128
129
|
}
|
|
129
|
-
None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
|
|
130
|
+
(None, _) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found")),
|
|
131
|
+
(_, None) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found")),
|
|
130
132
|
}
|
|
131
133
|
}
|
|
132
134
|
|
data/ext/candle/src/ruby/llm.rs
CHANGED
|
@@ -5,6 +5,7 @@ use std::sync::Arc;
|
|
|
5
5
|
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, qwen3::Qwen3 as RustQwen3, phi::Phi as RustPhi, granite::Granite as RustGranite, granitemoehybrid::GraniteMoeHybrid as RustGraniteMoeHybrid, glm4::Glm4 as RustGlm4, QuantizedGGUF as RustQuantizedGGUF};
|
|
6
6
|
use crate::ruby::{Result, Device};
|
|
7
7
|
use crate::ruby::structured::StructuredConstraint;
|
|
8
|
+
use crate::gvl;
|
|
8
9
|
|
|
9
10
|
// Use an enum to handle different model types instead of trait objects
|
|
10
11
|
enum ModelType {
|
|
@@ -422,7 +423,7 @@ impl LLM {
|
|
|
422
423
|
})
|
|
423
424
|
}
|
|
424
425
|
|
|
425
|
-
/// Generate text from a prompt
|
|
426
|
+
/// Generate text from a prompt (releases GVL during inference)
|
|
426
427
|
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
|
427
428
|
let ruby = Ruby::get().unwrap();
|
|
428
429
|
let config = config
|
|
@@ -435,8 +436,13 @@ impl LLM {
|
|
|
435
436
|
};
|
|
436
437
|
let mut model_ref = model.borrow_mut();
|
|
437
438
|
|
|
438
|
-
|
|
439
|
-
|
|
439
|
+
// Release the GVL during inference so other Ruby threads can run
|
|
440
|
+
// (e.g., TUI render loops, HTTP servers, etc.)
|
|
441
|
+
let result = gvl::without_gvl(|| {
|
|
442
|
+
model_ref.generate(&prompt, &config)
|
|
443
|
+
});
|
|
444
|
+
|
|
445
|
+
result.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
|
|
440
446
|
}
|
|
441
447
|
|
|
442
448
|
/// Generate text with streaming output
|
data/ext/candle/src/ruby/ner.rs
CHANGED
|
@@ -173,8 +173,10 @@ impl NER {
|
|
|
173
173
|
let ruby = Ruby::get().unwrap();
|
|
174
174
|
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
|
175
175
|
|
|
176
|
-
//
|
|
177
|
-
let (encoding, probs_vec) =
|
|
176
|
+
// Release GVL during tokenization + model forward pass
|
|
177
|
+
let (encoding, probs_vec) = crate::gvl::without_gvl(|| {
|
|
178
|
+
self.tokenize_and_predict(&text)
|
|
179
|
+
})?;
|
|
178
180
|
|
|
179
181
|
let tokens = encoding.get_tokens();
|
|
180
182
|
let offsets = encoding.get_offsets();
|
|
@@ -208,8 +210,10 @@ impl NER {
|
|
|
208
210
|
/// Get token-level predictions with labels and confidence scores
|
|
209
211
|
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
|
210
212
|
let ruby = Ruby::get().unwrap();
|
|
211
|
-
//
|
|
212
|
-
let (encoding, probs_vec) =
|
|
213
|
+
// Release GVL during tokenization + model forward pass
|
|
214
|
+
let (encoding, probs_vec) = crate::gvl::without_gvl(|| {
|
|
215
|
+
self.tokenize_and_predict(&text)
|
|
216
|
+
})?;
|
|
213
217
|
|
|
214
218
|
let tokens = encoding.get_tokens();
|
|
215
219
|
|
|
@@ -18,6 +18,7 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
|
18
18
|
use tokenizers::{EncodeInput, Tokenizer};
|
|
19
19
|
use std::cell::RefCell;
|
|
20
20
|
use crate::ruby::{Device, Result};
|
|
21
|
+
use crate::gvl;
|
|
21
22
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
22
23
|
|
|
23
24
|
enum RerankerModel {
|
|
@@ -164,39 +165,27 @@ impl Reranker {
|
|
|
164
165
|
}
|
|
165
166
|
|
|
166
167
|
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
|
167
|
-
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor,
|
|
168
|
-
let ruby = Ruby::get().unwrap();
|
|
169
|
-
let runtime_error = ruby.exception_runtime_error();
|
|
170
|
-
|
|
168
|
+
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, String> {
|
|
171
169
|
let cls_embeddings = if self.device.is_metal() {
|
|
172
|
-
// Metal has issues with tensor indexing, use a different approach
|
|
173
170
|
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
|
174
|
-
.map_err(|e|
|
|
175
|
-
|
|
176
|
-
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
|
171
|
+
.map_err(|e| format!("Failed to get dims: {}", e))?;
|
|
177
172
|
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
|
178
|
-
.map_err(|e|
|
|
179
|
-
|
|
180
|
-
// Extract CLS tokens (first token of each sequence)
|
|
173
|
+
.map_err(|e| format!("Failed to reshape: {}", e))?;
|
|
181
174
|
let mut cls_vecs = Vec::new();
|
|
182
175
|
for i in 0..batch_size {
|
|
183
176
|
let start_idx = i * seq_len;
|
|
184
177
|
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
|
185
|
-
.map_err(|e|
|
|
178
|
+
.map_err(|e| format!("Failed to extract CLS: {}", e))?;
|
|
186
179
|
cls_vecs.push(cls_vec);
|
|
187
180
|
}
|
|
188
|
-
|
|
189
|
-
// Stack the CLS vectors
|
|
190
181
|
Tensor::cat(&cls_vecs, 0)
|
|
191
|
-
.map_err(|e|
|
|
182
|
+
.map_err(|e| format!("Failed to cat CLS tokens: {}", e))?
|
|
192
183
|
} else {
|
|
193
184
|
embeddings.i((.., 0))
|
|
194
|
-
.map_err(|e|
|
|
185
|
+
.map_err(|e| format!("Failed to extract CLS token: {}", e))?
|
|
195
186
|
};
|
|
196
|
-
|
|
197
|
-
// Ensure tensor is contiguous for downstream operations
|
|
198
187
|
cls_embeddings.contiguous()
|
|
199
|
-
.map_err(|e|
|
|
188
|
+
.map_err(|e| format!("Failed to make CLS embeddings contiguous: {}", e))
|
|
200
189
|
}
|
|
201
190
|
|
|
202
191
|
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<RHash, Error> {
|
|
@@ -231,124 +220,147 @@ impl Reranker {
|
|
|
231
220
|
let runtime_error = ruby.exception_runtime_error();
|
|
232
221
|
let documents: Vec<String> = documents.to_vec()?;
|
|
233
222
|
|
|
223
|
+
// Release the GVL for the entire compute portion (tokenization + inference + scoring).
|
|
224
|
+
// None of this calls Ruby API.
|
|
225
|
+
let ranked_docs = gvl::without_gvl(|| -> std::result::Result<Vec<(String, f32, usize)>, String> {
|
|
226
|
+
self.compute_rerank(&query, &documents, &pooling_method, apply_sigmoid)
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
let ranked_docs = ranked_docs
|
|
230
|
+
.map_err(|e| Error::new(runtime_error, e))?;
|
|
231
|
+
|
|
232
|
+
// Build result array (requires GVL for Ruby object creation)
|
|
233
|
+
let result_array = ruby.ary_new();
|
|
234
|
+
for (doc, score, doc_id) in ranked_docs {
|
|
235
|
+
let tuple = ruby.ary_new();
|
|
236
|
+
tuple.push(doc)?;
|
|
237
|
+
tuple.push(ruby.float_from_f64(score as f64))?;
|
|
238
|
+
tuple.push(doc_id)?;
|
|
239
|
+
result_array.push(tuple)?;
|
|
240
|
+
}
|
|
241
|
+
Ok(result_array)
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
/// Pure compute portion of reranking — no Ruby API calls.
|
|
245
|
+
/// Returns ranked (document, score, original_index) tuples.
|
|
246
|
+
fn compute_rerank(&self, query: &str, documents: &[String], pooling_method: &str, apply_sigmoid: bool) -> std::result::Result<Vec<(String, f32, usize)>, String> {
|
|
234
247
|
// Create query-document pairs for cross-encoder
|
|
235
248
|
let query_and_docs: Vec<EncodeInput> = documents
|
|
236
249
|
.iter()
|
|
237
|
-
.map(|d| (query.
|
|
250
|
+
.map(|d| (query.to_string(), d.clone()).into())
|
|
238
251
|
.collect();
|
|
239
252
|
|
|
240
|
-
// Tokenize batch
|
|
253
|
+
// Tokenize batch
|
|
241
254
|
let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
|
|
242
|
-
.map_err(|e|
|
|
255
|
+
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
|
243
256
|
|
|
244
|
-
|
|
245
|
-
let token_ids = encodings
|
|
257
|
+
let token_ids_vec = encodings
|
|
246
258
|
.iter()
|
|
247
259
|
.map(|e| e.get_ids().to_vec())
|
|
248
260
|
.collect::<Vec<_>>();
|
|
249
261
|
|
|
250
|
-
let
|
|
262
|
+
let token_type_ids_vec = encodings
|
|
251
263
|
.iter()
|
|
252
264
|
.map(|e| e.get_type_ids().to_vec())
|
|
253
265
|
.collect::<Vec<_>>();
|
|
254
266
|
|
|
255
|
-
let token_ids = Tensor::new(
|
|
256
|
-
.map_err(|e|
|
|
257
|
-
let token_type_ids = Tensor::new(
|
|
258
|
-
.map_err(|e|
|
|
267
|
+
let token_ids = Tensor::new(token_ids_vec, &self.device)
|
|
268
|
+
.map_err(|e| format!("Failed to create tensor: {}", e))?;
|
|
269
|
+
let token_type_ids = Tensor::new(token_type_ids_vec, &self.device)
|
|
270
|
+
.map_err(|e| format!("Failed to create token type ids tensor: {}", e))?;
|
|
259
271
|
|
|
260
272
|
// Compute scores based on model type
|
|
261
273
|
let scores = match &self.model {
|
|
262
274
|
RerankerModel::Bert { model, pooler, classifier } => {
|
|
263
275
|
let attention_mask = token_ids.ne(0u32)
|
|
264
|
-
.map_err(|e|
|
|
276
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
265
277
|
|
|
266
278
|
// Forward pass through BERT
|
|
267
279
|
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
|
268
|
-
.map_err(|e|
|
|
280
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
269
281
|
|
|
270
282
|
// Apply pooling based on the specified method
|
|
271
|
-
let pooled_embeddings = match pooling_method
|
|
283
|
+
let pooled_embeddings = match pooling_method {
|
|
272
284
|
"pooler" => {
|
|
273
285
|
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
|
274
286
|
let pooled = pooler.forward(&cls_embeddings)
|
|
275
|
-
.map_err(|e|
|
|
287
|
+
.map_err(|e| format!("Pooler forward failed: {}", e))?;
|
|
276
288
|
pooled.tanh()
|
|
277
|
-
.map_err(|e|
|
|
289
|
+
.map_err(|e| format!("Tanh activation failed: {}", e))?
|
|
278
290
|
},
|
|
279
291
|
"cls" => {
|
|
280
292
|
self.extract_cls_embeddings(&embeddings)?
|
|
281
293
|
},
|
|
282
294
|
"mean" => {
|
|
283
295
|
let (_batch, seq_len, _hidden) = embeddings.dims3()
|
|
284
|
-
.map_err(|e|
|
|
296
|
+
.map_err(|e| format!("Failed to get tensor dimensions: {}", e))?;
|
|
285
297
|
let sum = embeddings.sum(1)
|
|
286
|
-
.map_err(|e|
|
|
298
|
+
.map_err(|e| format!("Failed to sum embeddings: {}", e))?;
|
|
287
299
|
(sum / (seq_len as f64))
|
|
288
|
-
.map_err(|e|
|
|
300
|
+
.map_err(|e| format!("Failed to compute mean: {}", e))?
|
|
289
301
|
},
|
|
290
|
-
_ => return Err(
|
|
291
|
-
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method))
|
|
302
|
+
_ => return Err(
|
|
303
|
+
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method))
|
|
292
304
|
};
|
|
293
305
|
|
|
294
306
|
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
295
|
-
.map_err(|e|
|
|
307
|
+
.map_err(|e| format!("Failed to make pooled_embeddings contiguous: {}", e))?;
|
|
296
308
|
let logits = classifier.forward(&pooled_embeddings)
|
|
297
|
-
.map_err(|e|
|
|
309
|
+
.map_err(|e| format!("Classifier forward failed: {}", e))?;
|
|
298
310
|
logits.squeeze(1)
|
|
299
|
-
.map_err(|e|
|
|
311
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
300
312
|
}
|
|
301
313
|
RerankerModel::XLMRoberta { model, pad_token_id } => {
|
|
302
314
|
let attention_mask = token_ids.ne(*pad_token_id)
|
|
303
|
-
.map_err(|e|
|
|
315
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
304
316
|
|
|
305
317
|
// XLMRobertaForSequenceClassification returns logits directly
|
|
306
318
|
let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
|
|
307
|
-
.map_err(|e|
|
|
319
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
308
320
|
logits.squeeze(1)
|
|
309
|
-
.map_err(|e|
|
|
321
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
310
322
|
}
|
|
311
323
|
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
|
|
312
324
|
let attention_mask = token_ids.ne(*pad_token_id)
|
|
313
|
-
.map_err(|e|
|
|
325
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
314
326
|
|
|
315
327
|
// Forward through DeBERTa encoder
|
|
316
328
|
let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
|
|
317
|
-
.map_err(|e|
|
|
329
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
318
330
|
|
|
319
331
|
// Pool and classify
|
|
320
332
|
let pooled = pooler.forward(&encoder_output)
|
|
321
|
-
.map_err(|e|
|
|
333
|
+
.map_err(|e| format!("Pooler forward failed: {}", e))?;
|
|
322
334
|
let logits = classifier.forward(&pooled)
|
|
323
|
-
.map_err(|e|
|
|
335
|
+
.map_err(|e| format!("Classifier forward failed: {}", e))?;
|
|
324
336
|
logits.squeeze(1)
|
|
325
|
-
.map_err(|e|
|
|
337
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
326
338
|
}
|
|
327
339
|
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
|
|
328
340
|
let attention_mask = token_ids.ne(*pad_token_id)
|
|
329
|
-
.map_err(|e|
|
|
341
|
+
.map_err(|e| format!("Failed to create attention mask: {}", e))?;
|
|
330
342
|
let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
|
|
331
|
-
.map_err(|e|
|
|
343
|
+
.map_err(|e| format!("Failed to convert attention mask: {}", e))?;
|
|
332
344
|
|
|
333
345
|
// Forward through ModernBERT encoder
|
|
334
346
|
let encoder_output = model.forward(&token_ids, &attention_mask_f32)
|
|
335
|
-
.map_err(|e|
|
|
347
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
336
348
|
|
|
337
349
|
// CLS pooling, then head (dense + GELU + norm) + classifier
|
|
338
350
|
let cls = encoder_output.i((.., 0, ..))
|
|
339
|
-
.map_err(|e|
|
|
351
|
+
.map_err(|e| format!("Failed to extract CLS: {}", e))?
|
|
340
352
|
.contiguous()
|
|
341
|
-
.map_err(|e|
|
|
353
|
+
.map_err(|e| format!("Failed to make contiguous: {}", e))?;
|
|
342
354
|
let hidden = head_dense.forward(&cls)
|
|
343
|
-
.map_err(|e|
|
|
355
|
+
.map_err(|e| format!("Head dense failed: {}", e))?;
|
|
344
356
|
let hidden = hidden.gelu_erf()
|
|
345
|
-
.map_err(|e|
|
|
357
|
+
.map_err(|e| format!("GELU activation failed: {}", e))?;
|
|
346
358
|
let hidden = head_norm.forward(&hidden)
|
|
347
|
-
.map_err(|e|
|
|
359
|
+
.map_err(|e| format!("Head norm failed: {}", e))?;
|
|
348
360
|
let logits = classifier.forward(&hidden)
|
|
349
|
-
.map_err(|e|
|
|
361
|
+
.map_err(|e| format!("Classifier forward failed: {}", e))?;
|
|
350
362
|
logits.squeeze(1)
|
|
351
|
-
.map_err(|e|
|
|
363
|
+
.map_err(|e| format!("Failed to squeeze tensor: {}", e))?
|
|
352
364
|
}
|
|
353
365
|
RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
|
|
354
366
|
// Qwen3 reranker: decoder-based yes/no scoring
|
|
@@ -356,7 +368,7 @@ impl Reranker {
|
|
|
356
368
|
let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
|
|
357
369
|
let mut model = model.borrow_mut();
|
|
358
370
|
|
|
359
|
-
for doc in
|
|
371
|
+
for doc in documents.iter() {
|
|
360
372
|
// Build the Qwen3 reranker prompt
|
|
361
373
|
let prompt = format!(
|
|
362
374
|
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
|
@@ -365,7 +377,7 @@ impl Reranker {
|
|
|
365
377
|
|
|
366
378
|
// Tokenize the prompt
|
|
367
379
|
let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
|
|
368
|
-
.map_err(|e|
|
|
380
|
+
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
|
369
381
|
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
370
382
|
|
|
371
383
|
// Clear KV cache for each document
|
|
@@ -373,28 +385,28 @@ impl Reranker {
|
|
|
373
385
|
|
|
374
386
|
// Forward pass — get logits for the last token position
|
|
375
387
|
let input_tensor = Tensor::new(&input_ids[..], &self.device)
|
|
376
|
-
.map_err(|e|
|
|
388
|
+
.map_err(|e| format!("Failed to create tensor: {}", e))?
|
|
377
389
|
.unsqueeze(0)
|
|
378
|
-
.map_err(|e|
|
|
390
|
+
.map_err(|e| format!("Failed to unsqueeze: {}", e))?;
|
|
379
391
|
|
|
380
392
|
let logits = model.forward(&input_tensor, 0)
|
|
381
|
-
.map_err(|e|
|
|
393
|
+
.map_err(|e| format!("Model forward pass failed: {}", e))?;
|
|
382
394
|
|
|
383
395
|
// logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
|
|
384
396
|
let logits = logits.flatten_all()
|
|
385
|
-
.map_err(|e|
|
|
397
|
+
.map_err(|e| format!("Failed to flatten: {}", e))?
|
|
386
398
|
.to_dtype(DType::F32)
|
|
387
|
-
.map_err(|e|
|
|
399
|
+
.map_err(|e| format!("Failed to convert dtype: {}", e))?;
|
|
388
400
|
|
|
389
401
|
// Extract yes/no logits and compute score
|
|
390
402
|
let yes_logit: f32 = logits.i(*yes_token_id as usize)
|
|
391
|
-
.map_err(|e|
|
|
403
|
+
.map_err(|e| format!("Failed to get yes logit: {}", e))?
|
|
392
404
|
.to_scalar()
|
|
393
|
-
.map_err(|e|
|
|
405
|
+
.map_err(|e| format!("Failed to convert yes logit: {}", e))?;
|
|
394
406
|
let no_logit: f32 = logits.i(*no_token_id as usize)
|
|
395
|
-
.map_err(|e|
|
|
407
|
+
.map_err(|e| format!("Failed to get no logit: {}", e))?
|
|
396
408
|
.to_scalar()
|
|
397
|
-
.map_err(|e|
|
|
409
|
+
.map_err(|e| format!("Failed to convert no logit: {}", e))?;
|
|
398
410
|
|
|
399
411
|
// softmax over [yes, no] → P(yes)
|
|
400
412
|
let max_logit = yes_logit.max(no_logit);
|
|
@@ -407,24 +419,25 @@ impl Reranker {
|
|
|
407
419
|
|
|
408
420
|
// Build scores tensor for uniform handling below
|
|
409
421
|
Tensor::new(scores_vec.as_slice(), &self.device)
|
|
410
|
-
.map_err(|e|
|
|
422
|
+
.map_err(|e| format!("Failed to create scores tensor: {}", e))?
|
|
411
423
|
}
|
|
412
424
|
};
|
|
413
425
|
|
|
414
426
|
// Optionally apply sigmoid activation
|
|
415
427
|
let scores = if apply_sigmoid {
|
|
416
428
|
sigmoid(&scores)
|
|
417
|
-
.map_err(|e|
|
|
429
|
+
.map_err(|e| format!("Sigmoid failed: {}", e))?
|
|
418
430
|
} else {
|
|
419
431
|
scores
|
|
420
432
|
};
|
|
421
433
|
|
|
422
434
|
let scores_vec: Vec<f32> = scores.to_vec1()
|
|
423
|
-
.map_err(|e|
|
|
435
|
+
.map_err(|e| format!("Failed to convert scores to vec: {}", e))?;
|
|
424
436
|
|
|
425
437
|
// Create tuples with document, score, and original index
|
|
426
438
|
let mut ranked_docs: Vec<(String, f32, usize)> = documents
|
|
427
|
-
.
|
|
439
|
+
.iter()
|
|
440
|
+
.cloned()
|
|
428
441
|
.zip(scores_vec)
|
|
429
442
|
.enumerate()
|
|
430
443
|
.map(|(idx, (doc, score))| (doc, score, idx))
|
|
@@ -433,16 +446,7 @@ impl Reranker {
|
|
|
433
446
|
// Sort documents by relevance score (descending)
|
|
434
447
|
ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
435
448
|
|
|
436
|
-
|
|
437
|
-
let result_array = ruby.ary_new();
|
|
438
|
-
for (doc, score, doc_id) in ranked_docs {
|
|
439
|
-
let tuple = ruby.ary_new();
|
|
440
|
-
tuple.push(doc)?;
|
|
441
|
-
tuple.push(ruby.float_from_f64(score as f64))?;
|
|
442
|
-
tuple.push(doc_id)?;
|
|
443
|
-
result_array.push(tuple)?;
|
|
444
|
-
}
|
|
445
|
-
Ok(result_array)
|
|
449
|
+
Ok(ranked_docs)
|
|
446
450
|
}
|
|
447
451
|
|
|
448
452
|
/// Get the tokenizer used by this model
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: red-candle
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 1.7.
|
|
4
|
+
version: 1.7.1
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Christopher Petersen
|
|
@@ -215,6 +215,7 @@ files:
|
|
|
215
215
|
- ext/candle/build.rs
|
|
216
216
|
- ext/candle/extconf.rb
|
|
217
217
|
- ext/candle/rustfmt.toml
|
|
218
|
+
- ext/candle/src/gvl.rs
|
|
218
219
|
- ext/candle/src/lib.rs
|
|
219
220
|
- ext/candle/src/llm/constrained_generation_test.rs
|
|
220
221
|
- ext/candle/src/llm/gemma.rs
|