red-candle 1.6.0 → 1.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +21 -7
- data/ext/candle/src/ruby/reranker.rs +263 -59
- data/lib/candle/version.rb +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 00c0870d599db76ba556ab880f69219bc32d72698c3570cd308df3923b1332f9
|
|
4
|
+
data.tar.gz: bffbd02fa1fe11813a304ed16771e6a72ee01f85ccb445eb67c5da6bf696c670
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: c9ab39f01e4777c3d9906592cf62b8b1d2d03fd04be4ba82d8ddee837678ac477be88b1a0a48fa898e443f7469681e26f25297d4d310c39fb5222137d3b17d87
|
|
7
|
+
data.tar.gz: 22fe87f0fb64754a0a19bcc632dff72b0cc43838d73797aa991c961b65c9ecb7d9bd7aacbde51baf4063011081f4111a64f53726b51c87aca3e48214304e0119
|
data/README.md
CHANGED
|
@@ -434,12 +434,26 @@ embedding = model.embedding("Hi there!")
|
|
|
434
434
|
|
|
435
435
|
Red-Candle includes support for cross-encoder reranking models, which can be used to reorder documents by relevance to a query. This is particularly useful for improving search results or implementing retrieval-augmented generation (RAG) systems.
|
|
436
436
|
|
|
437
|
+
### Supported Models
|
|
438
|
+
|
|
439
|
+
Red-Candle supports both **BERT** and **XLM-RoBERTa** reranker architectures. The model type is auto-detected from `config.json`.
|
|
440
|
+
|
|
441
|
+
| Model | Architecture | Params | Notes |
|
|
442
|
+
|-------|-------------|--------|-------|
|
|
443
|
+
| `BAAI/bge-reranker-base` | XLM-RoBERTa | 278M | Recommended — strong quality, multilingual |
|
|
444
|
+
| `BAAI/bge-reranker-large` | XLM-RoBERTa | 560M | Best quality, higher resource usage |
|
|
445
|
+
| `BAAI/bge-reranker-v2-m3` | XLM-RoBERTa | 278M | Multilingual, very strong |
|
|
446
|
+
| `cross-encoder/ms-marco-MiniLM-L-12-v2` | BERT | 33M | Lightweight, English only |
|
|
447
|
+
|
|
437
448
|
### Basic Usage
|
|
438
449
|
|
|
439
450
|
```ruby
|
|
440
451
|
require 'candle'
|
|
441
452
|
|
|
442
|
-
# Initialize the reranker
|
|
453
|
+
# Initialize the reranker (BGE reranker recommended for quality)
|
|
454
|
+
reranker = Candle::Reranker.from_pretrained("BAAI/bge-reranker-base")
|
|
455
|
+
|
|
456
|
+
# Or use the lighter BERT-based model
|
|
443
457
|
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
|
444
458
|
|
|
445
459
|
# Or with custom max_length for truncation (default is 512)
|
|
@@ -528,7 +542,7 @@ results = reranker.rerank_with_pooling(query, documents, "cls")
|
|
|
528
542
|
results = reranker.rerank_sigmoid_with_pooling(query, documents, "cls")
|
|
529
543
|
```
|
|
530
544
|
|
|
531
|
-
Note:
|
|
545
|
+
Note: Pooling methods only apply to BERT-based models. XLM-RoBERTa models (e.g., BGE rerankers) have a built-in classification head and ignore the `pooling_method` parameter. For BERT models, the default "pooler" method is recommended as it matches how cross-encoder models are trained.
|
|
532
546
|
|
|
533
547
|
### CUDA Support
|
|
534
548
|
|
|
@@ -546,11 +560,11 @@ Cross-encoder reranking models differ from bi-encoder embedding models:
|
|
|
546
560
|
- **Bi-encoders** (like the embedding models above) encode queries and documents separately into dense vectors
|
|
547
561
|
- **Cross-encoders** process the query and document together, allowing for more nuanced relevance scoring
|
|
548
562
|
|
|
549
|
-
The reranker
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
563
|
+
The reranker concatenates the query and document with special tokens and processes them jointly through transformer layers to produce a single relevance score.
|
|
564
|
+
|
|
565
|
+
**BERT models** (e.g., MiniLM): Use a pooler layer (dense + tanh) on the [CLS] token, then a classifier layer. Pooling method is configurable (`pooler`, `cls`, `mean`).
|
|
566
|
+
|
|
567
|
+
**XLM-RoBERTa models** (e.g., BGE rerankers): Use a built-in classification head that returns logits directly. The `pooling_method` parameter is ignored — the model handles its own pooling internally.
|
|
554
568
|
|
|
555
569
|
This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
|
|
556
570
|
|
|
@@ -1,18 +1,59 @@
|
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
|
|
2
|
-
use candle_transformers::models::bert::{BertModel, Config};
|
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
|
|
3
|
+
use candle_transformers::models::xlm_roberta::{
|
|
4
|
+
XLMRobertaForSequenceClassification, Config as XLMRobertaConfig,
|
|
5
|
+
};
|
|
6
|
+
use candle_transformers::models::debertav2::{
|
|
7
|
+
DebertaV2Model, DebertaV2ContextPooler, Config as DebertaV2Config,
|
|
8
|
+
};
|
|
9
|
+
use candle_transformers::models::modernbert::{
|
|
10
|
+
ModernBert, Config as ModernBertConfig,
|
|
11
|
+
};
|
|
12
|
+
use candle_transformers::models::qwen3::{
|
|
13
|
+
ModelForCausalLM as Qwen3Model, Config as Qwen3Config,
|
|
14
|
+
};
|
|
3
15
|
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
4
16
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
|
5
17
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
6
18
|
use tokenizers::{EncodeInput, Tokenizer};
|
|
19
|
+
use std::cell::RefCell;
|
|
7
20
|
use crate::ruby::{Device, Result};
|
|
8
21
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
9
22
|
|
|
23
|
+
enum RerankerModel {
|
|
24
|
+
Bert {
|
|
25
|
+
model: BertModel,
|
|
26
|
+
pooler: Linear,
|
|
27
|
+
classifier: Linear,
|
|
28
|
+
},
|
|
29
|
+
XLMRoberta {
|
|
30
|
+
model: XLMRobertaForSequenceClassification,
|
|
31
|
+
pad_token_id: u32,
|
|
32
|
+
},
|
|
33
|
+
DeBERTa {
|
|
34
|
+
model: DebertaV2Model,
|
|
35
|
+
pooler: DebertaV2ContextPooler,
|
|
36
|
+
classifier: Linear,
|
|
37
|
+
pad_token_id: u32,
|
|
38
|
+
},
|
|
39
|
+
ModernBert {
|
|
40
|
+
model: ModernBert,
|
|
41
|
+
head_dense: Linear,
|
|
42
|
+
head_norm: candle_nn::LayerNorm,
|
|
43
|
+
classifier: Linear,
|
|
44
|
+
pad_token_id: u32,
|
|
45
|
+
},
|
|
46
|
+
Qwen3 {
|
|
47
|
+
model: RefCell<Qwen3Model>,
|
|
48
|
+
yes_token_id: u32,
|
|
49
|
+
no_token_id: u32,
|
|
50
|
+
},
|
|
51
|
+
}
|
|
52
|
+
|
|
10
53
|
#[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
|
|
11
54
|
pub struct Reranker {
|
|
12
|
-
model:
|
|
55
|
+
model: RerankerModel,
|
|
13
56
|
tokenizer: TokenizerWrapper,
|
|
14
|
-
pooler: Linear,
|
|
15
|
-
classifier: Linear,
|
|
16
57
|
device: CoreDevice,
|
|
17
58
|
model_id: String,
|
|
18
59
|
}
|
|
@@ -28,7 +69,7 @@ impl Reranker {
|
|
|
28
69
|
let ruby = Ruby::get().unwrap();
|
|
29
70
|
let runtime_error = ruby.exception_runtime_error();
|
|
30
71
|
|
|
31
|
-
let result = (|| -> std::result::Result<(
|
|
72
|
+
let result = (|| -> std::result::Result<(RerankerModel, TokenizerWrapper), Box<dyn std::error::Error + Send + Sync>> {
|
|
32
73
|
let api = Api::new()?;
|
|
33
74
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
|
34
75
|
|
|
@@ -37,9 +78,10 @@ impl Reranker {
|
|
|
37
78
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
|
38
79
|
let weights_filename = repo.get("model.safetensors")?;
|
|
39
80
|
|
|
40
|
-
//
|
|
41
|
-
let
|
|
42
|
-
let
|
|
81
|
+
// Read raw config to detect model type
|
|
82
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
|
83
|
+
let raw_config: serde_json::Value = serde_json::from_str(&config_str)?;
|
|
84
|
+
let model_type = raw_config["model_type"].as_str().unwrap_or("bert");
|
|
43
85
|
|
|
44
86
|
// Setup tokenizer with padding AND truncation
|
|
45
87
|
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
|
@@ -51,21 +93,71 @@ impl Reranker {
|
|
|
51
93
|
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
|
52
94
|
};
|
|
53
95
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
96
|
+
let model = match model_type {
|
|
97
|
+
"xlm-roberta" => {
|
|
98
|
+
let config: XLMRobertaConfig = serde_json::from_str(&config_str)?;
|
|
99
|
+
let pad_token_id = config.pad_token_id;
|
|
100
|
+
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
|
|
101
|
+
RerankerModel::XLMRoberta { model, pad_token_id }
|
|
102
|
+
}
|
|
103
|
+
"deberta-v2" => {
|
|
104
|
+
let config: DebertaV2Config = serde_json::from_str(&config_str)?;
|
|
105
|
+
let pad_token_id = config.pad_token_id.unwrap_or(0) as u32;
|
|
106
|
+
let model = DebertaV2Model::load(vb.pp("deberta"), &config)?;
|
|
107
|
+
let pooler = DebertaV2ContextPooler::load(vb.clone(), &config)?;
|
|
108
|
+
let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
|
|
109
|
+
let num_labels = config.id2label.as_ref().map_or(1, |m| m.len());
|
|
110
|
+
let classifier = candle_nn::linear(pooler_hidden_size, num_labels, vb.pp("classifier"))?;
|
|
111
|
+
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id }
|
|
112
|
+
}
|
|
113
|
+
"qwen3" => {
|
|
114
|
+
let config: Qwen3Config = serde_json::from_str(&config_str)?;
|
|
115
|
+
let model = Qwen3Model::new(&config, vb)?;
|
|
116
|
+
|
|
117
|
+
// Look up "yes" and "no" token IDs from the tokenizer
|
|
118
|
+
let yes_token_id: u32 = tokenizer
|
|
119
|
+
.encode("yes", false)
|
|
120
|
+
.ok()
|
|
121
|
+
.and_then(|enc| enc.get_ids().first().copied())
|
|
122
|
+
.unwrap_or(9693);
|
|
123
|
+
let no_token_id: u32 = tokenizer
|
|
124
|
+
.encode("no", false)
|
|
125
|
+
.ok()
|
|
126
|
+
.and_then(|enc| enc.get_ids().first().copied())
|
|
127
|
+
.unwrap_or(2152);
|
|
128
|
+
|
|
129
|
+
RerankerModel::Qwen3 {
|
|
130
|
+
model: RefCell::new(model),
|
|
131
|
+
yes_token_id,
|
|
132
|
+
no_token_id,
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
"modernbert" => {
|
|
136
|
+
let config: ModernBertConfig = serde_json::from_str(&config_str)?;
|
|
137
|
+
let pad_token_id = config.pad_token_id;
|
|
138
|
+
let model = ModernBert::load(vb.clone(), &config)?;
|
|
139
|
+
// ModernBertHead::load is private, so load the head layers manually
|
|
140
|
+
let head_vb = vb.pp("head");
|
|
141
|
+
let head_dense = candle_nn::linear_no_bias(config.hidden_size, config.hidden_size, head_vb.pp("dense"))?;
|
|
142
|
+
let head_norm = candle_nn::layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, head_vb.pp("norm"))?;
|
|
143
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
|
144
|
+
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id }
|
|
145
|
+
}
|
|
146
|
+
_ => {
|
|
147
|
+
let config: BertConfig = serde_json::from_str(&config_str)?;
|
|
148
|
+
let model = BertModel::load(vb.pp("bert"), &config)?;
|
|
149
|
+
let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
|
|
150
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
|
151
|
+
RerankerModel::Bert { model, pooler, classifier }
|
|
152
|
+
}
|
|
153
|
+
};
|
|
62
154
|
|
|
63
|
-
Ok((model, TokenizerWrapper::new(tokenizer)
|
|
155
|
+
Ok((model, TokenizerWrapper::new(tokenizer)))
|
|
64
156
|
})();
|
|
65
157
|
|
|
66
158
|
match result {
|
|
67
|
-
Ok((model, tokenizer
|
|
68
|
-
Ok(Self { model, tokenizer,
|
|
159
|
+
Ok((model, tokenizer)) => {
|
|
160
|
+
Ok(Self { model, tokenizer, device, model_id })
|
|
69
161
|
}
|
|
70
162
|
Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
|
|
71
163
|
}
|
|
@@ -164,49 +256,161 @@ impl Reranker {
|
|
|
164
256
|
.map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
|
|
165
257
|
let token_type_ids = Tensor::new(token_type_ids, &self.device)
|
|
166
258
|
.map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
let
|
|
259
|
+
|
|
260
|
+
// Compute scores based on model type
|
|
261
|
+
let scores = match &self.model {
|
|
262
|
+
RerankerModel::Bert { model, pooler, classifier } => {
|
|
263
|
+
let attention_mask = token_ids.ne(0u32)
|
|
264
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
265
|
+
|
|
266
|
+
// Forward pass through BERT
|
|
267
|
+
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
|
268
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
269
|
+
|
|
270
|
+
// Apply pooling based on the specified method
|
|
271
|
+
let pooled_embeddings = match pooling_method.as_str() {
|
|
272
|
+
"pooler" => {
|
|
273
|
+
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
|
274
|
+
let pooled = pooler.forward(&cls_embeddings)
|
|
275
|
+
.map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
|
|
276
|
+
pooled.tanh()
|
|
277
|
+
.map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
|
|
278
|
+
},
|
|
279
|
+
"cls" => {
|
|
280
|
+
self.extract_cls_embeddings(&embeddings)?
|
|
281
|
+
},
|
|
282
|
+
"mean" => {
|
|
283
|
+
let (_batch, seq_len, _hidden) = embeddings.dims3()
|
|
284
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
|
|
285
|
+
let sum = embeddings.sum(1)
|
|
286
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
|
|
287
|
+
(sum / (seq_len as f64))
|
|
288
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
|
|
289
|
+
},
|
|
290
|
+
_ => return Err(Error::new(runtime_error,
|
|
291
|
+
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
|
|
292
|
+
};
|
|
293
|
+
|
|
294
|
+
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
295
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
|
|
296
|
+
let logits = classifier.forward(&pooled_embeddings)
|
|
297
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
298
|
+
logits.squeeze(1)
|
|
299
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
300
|
+
}
|
|
301
|
+
RerankerModel::XLMRoberta { model, pad_token_id } => {
|
|
302
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
303
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
304
|
+
|
|
305
|
+
// XLMRobertaForSequenceClassification returns logits directly
|
|
306
|
+
let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
|
|
307
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
308
|
+
logits.squeeze(1)
|
|
309
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
310
|
+
}
|
|
311
|
+
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
|
|
312
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
313
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
314
|
+
|
|
315
|
+
// Forward through DeBERTa encoder
|
|
316
|
+
let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
|
|
317
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
318
|
+
|
|
319
|
+
// Pool and classify
|
|
320
|
+
let pooled = pooler.forward(&encoder_output)
|
|
180
321
|
.map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
|
|
181
|
-
|
|
182
|
-
.map_err(|e| Error::new(runtime_error, format!("
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
322
|
+
let logits = classifier.forward(&pooled)
|
|
323
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
324
|
+
logits.squeeze(1)
|
|
325
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
326
|
+
}
|
|
327
|
+
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
|
|
328
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
329
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
330
|
+
let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
|
|
331
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert attention mask: {}", e)))?;
|
|
332
|
+
|
|
333
|
+
// Forward through ModernBERT encoder
|
|
334
|
+
let encoder_output = model.forward(&token_ids, &attention_mask_f32)
|
|
335
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
336
|
+
|
|
337
|
+
// CLS pooling, then head (dense + GELU + norm) + classifier
|
|
338
|
+
let cls = encoder_output.i((.., 0, ..))
|
|
339
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?
|
|
340
|
+
.contiguous()
|
|
341
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to make contiguous: {}", e)))?;
|
|
342
|
+
let hidden = head_dense.forward(&cls)
|
|
343
|
+
.map_err(|e| Error::new(runtime_error, format!("Head dense failed: {}", e)))?;
|
|
344
|
+
let hidden = hidden.gelu_erf()
|
|
345
|
+
.map_err(|e| Error::new(runtime_error, format!("GELU activation failed: {}", e)))?;
|
|
346
|
+
let hidden = head_norm.forward(&hidden)
|
|
347
|
+
.map_err(|e| Error::new(runtime_error, format!("Head norm failed: {}", e)))?;
|
|
348
|
+
let logits = classifier.forward(&hidden)
|
|
349
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
350
|
+
logits.squeeze(1)
|
|
351
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
352
|
+
}
|
|
353
|
+
RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
|
|
354
|
+
// Qwen3 reranker: decoder-based yes/no scoring
|
|
355
|
+
// Process each document individually (causal LM, not batch encoder)
|
|
356
|
+
let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
|
|
357
|
+
let mut model = model.borrow_mut();
|
|
358
|
+
|
|
359
|
+
for doc in &documents {
|
|
360
|
+
// Build the Qwen3 reranker prompt
|
|
361
|
+
let prompt = format!(
|
|
362
|
+
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
|
363
|
+
query, doc
|
|
364
|
+
);
|
|
365
|
+
|
|
366
|
+
// Tokenize the prompt
|
|
367
|
+
let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
|
|
368
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
369
|
+
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
370
|
+
|
|
371
|
+
// Clear KV cache for each document
|
|
372
|
+
model.clear_kv_cache();
|
|
373
|
+
|
|
374
|
+
// Forward pass — get logits for the last token position
|
|
375
|
+
let input_tensor = Tensor::new(&input_ids[..], &self.device)
|
|
376
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?
|
|
377
|
+
.unsqueeze(0)
|
|
378
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
|
|
379
|
+
|
|
380
|
+
let logits = model.forward(&input_tensor, 0)
|
|
381
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
382
|
+
|
|
383
|
+
// logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
|
|
384
|
+
let logits = logits.flatten_all()
|
|
385
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?
|
|
386
|
+
.to_dtype(DType::F32)
|
|
387
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
|
|
388
|
+
|
|
389
|
+
// Extract yes/no logits and compute score
|
|
390
|
+
let yes_logit: f32 = logits.i(*yes_token_id as usize)
|
|
391
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get yes logit: {}", e)))?
|
|
392
|
+
.to_scalar()
|
|
393
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert yes logit: {}", e)))?;
|
|
394
|
+
let no_logit: f32 = logits.i(*no_token_id as usize)
|
|
395
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get no logit: {}", e)))?
|
|
396
|
+
.to_scalar()
|
|
397
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert no logit: {}", e)))?;
|
|
398
|
+
|
|
399
|
+
// softmax over [yes, no] → P(yes)
|
|
400
|
+
let max_logit = yes_logit.max(no_logit);
|
|
401
|
+
let yes_exp = (yes_logit - max_logit).exp();
|
|
402
|
+
let no_exp = (no_logit - max_logit).exp();
|
|
403
|
+
let score = yes_exp / (yes_exp + no_exp);
|
|
404
|
+
|
|
405
|
+
scores_vec.push(score);
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
// Build scores tensor for uniform handling below
|
|
409
|
+
Tensor::new(scores_vec.as_slice(), &self.device)
|
|
410
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create scores tensor: {}", e)))?
|
|
411
|
+
}
|
|
199
412
|
};
|
|
200
413
|
|
|
201
|
-
// Apply classifier to get relevance scores (raw logits)
|
|
202
|
-
// Ensure tensor is contiguous before linear layer
|
|
203
|
-
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
204
|
-
.map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
|
|
205
|
-
let logits = self.classifier.forward(&pooled_embeddings)
|
|
206
|
-
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
207
|
-
let scores = logits.squeeze(1)
|
|
208
|
-
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?;
|
|
209
|
-
|
|
210
414
|
// Optionally apply sigmoid activation
|
|
211
415
|
let scores = if apply_sigmoid {
|
|
212
416
|
sigmoid(&scores)
|
data/lib/candle/version.rb
CHANGED