red-candle 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0e67b3106c236e7f579cf3bdfdf56f0b5ae99e8b33bb9bf35cfe361c19bb0514
4
- data.tar.gz: 573ae31fda25ebd4f22e7338a6a94fd48ce9d869959957acbc8d77e1ed83052c
3
+ metadata.gz: 00c0870d599db76ba556ab880f69219bc32d72698c3570cd308df3923b1332f9
4
+ data.tar.gz: bffbd02fa1fe11813a304ed16771e6a72ee01f85ccb445eb67c5da6bf696c670
5
5
  SHA512:
6
- metadata.gz: 382e0ab840efe730184a7baa9f7438f43b10898d94fbea2648d92fbcf30a39ae9dcf48626dbf5a63e64263c2300d88555815c95820c4a6060eaa7b5f370deb9d
7
- data.tar.gz: 00b15a30260ce226ef97c4e47a193ed557147377835472df552fde8de5f60c319ac4b8d7a1ef4f6a0408232c46776503b6f4b9a30bb46209df750ce3f47dd7e7
6
+ metadata.gz: c9ab39f01e4777c3d9906592cf62b8b1d2d03fd04be4ba82d8ddee837678ac477be88b1a0a48fa898e443f7469681e26f25297d4d310c39fb5222137d3b17d87
7
+ data.tar.gz: 22fe87f0fb64754a0a19bcc632dff72b0cc43838d73797aa991c961b65c9ecb7d9bd7aacbde51baf4063011081f4111a64f53726b51c87aca3e48214304e0119
data/README.md CHANGED
@@ -434,12 +434,26 @@ embedding = model.embedding("Hi there!")
434
434
 
435
435
  Red-Candle includes support for cross-encoder reranking models, which can be used to reorder documents by relevance to a query. This is particularly useful for improving search results or implementing retrieval-augmented generation (RAG) systems.
436
436
 
437
+ ### Supported Models
438
+
439
+ Red-Candle supports both **BERT** and **XLM-RoBERTa** reranker architectures. The model type is auto-detected from `config.json`.
440
+
441
+ | Model | Architecture | Params | Notes |
442
+ |-------|-------------|--------|-------|
443
+ | `BAAI/bge-reranker-base` | XLM-RoBERTa | 278M | Recommended — strong quality, multilingual |
444
+ | `BAAI/bge-reranker-large` | XLM-RoBERTa | 560M | Best quality, higher resource usage |
445
+ | `BAAI/bge-reranker-v2-m3` | XLM-RoBERTa | 278M | Multilingual, very strong |
446
+ | `cross-encoder/ms-marco-MiniLM-L-12-v2` | BERT | 33M | Lightweight, English only |
447
+
437
448
  ### Basic Usage
438
449
 
439
450
  ```ruby
440
451
  require 'candle'
441
452
 
442
- # Initialize the reranker with a cross-encoder model
453
+ # Initialize the reranker (BGE reranker recommended for quality)
454
+ reranker = Candle::Reranker.from_pretrained("BAAI/bge-reranker-base")
455
+
456
+ # Or use the lighter BERT-based model
443
457
  reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
444
458
 
445
459
  # Or with custom max_length for truncation (default is 512)
@@ -528,7 +542,7 @@ results = reranker.rerank_with_pooling(query, documents, "cls")
528
542
  results = reranker.rerank_sigmoid_with_pooling(query, documents, "cls")
529
543
  ```
530
544
 
531
- Note: The default "pooler" method is recommended as it matches how cross-encoder models are trained. Other pooling methods may produce different ranking results.
545
+ Note: Pooling methods only apply to BERT-based models. XLM-RoBERTa models (e.g., BGE rerankers) have a built-in classification head and ignore the `pooling_method` parameter. For BERT models, the default "pooler" method is recommended as it matches how cross-encoder models are trained.
532
546
 
533
547
  ### CUDA Support
534
548
 
@@ -546,11 +560,11 @@ Cross-encoder reranking models differ from bi-encoder embedding models:
546
560
  - **Bi-encoders** (like the embedding models above) encode queries and documents separately into dense vectors
547
561
  - **Cross-encoders** process the query and document together, allowing for more nuanced relevance scoring
548
562
 
549
- The reranker uses a BERT-based architecture that:
550
- 1. Concatenates the query and document with special tokens: `[CLS] query [SEP] document [SEP]`
551
- 2. Processes them jointly through BERT layers
552
- 3. Applies a pooler layer (dense + tanh) to the [CLS] token
553
- 4. Uses a classifier layer to produce a single relevance score
563
+ The reranker concatenates the query and document with special tokens and processes them jointly through transformer layers to produce a single relevance score.
564
+
565
+ **BERT models** (e.g., MiniLM): Use a pooler layer (dense + tanh) on the [CLS] token, then a classifier layer. Pooling method is configurable (`pooler`, `cls`, `mean`).
566
+
567
+ **XLM-RoBERTa models** (e.g., BGE rerankers): Use a built-in classification head that returns logits directly. The `pooling_method` parameter is ignored — the model handles its own pooling internally.
554
568
 
555
569
  This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
556
570
 
@@ -1,18 +1,59 @@
1
1
  use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
2
- use candle_transformers::models::bert::{BertModel, Config};
2
+ use candle_transformers::models::bert::{BertModel, Config as BertConfig};
3
+ use candle_transformers::models::xlm_roberta::{
4
+ XLMRobertaForSequenceClassification, Config as XLMRobertaConfig,
5
+ };
6
+ use candle_transformers::models::debertav2::{
7
+ DebertaV2Model, DebertaV2ContextPooler, Config as DebertaV2Config,
8
+ };
9
+ use candle_transformers::models::modernbert::{
10
+ ModernBert, Config as ModernBertConfig,
11
+ };
12
+ use candle_transformers::models::qwen3::{
13
+ ModelForCausalLM as Qwen3Model, Config as Qwen3Config,
14
+ };
3
15
  use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
16
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
17
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
18
  use tokenizers::{EncodeInput, Tokenizer};
19
+ use std::cell::RefCell;
7
20
  use crate::ruby::{Device, Result};
8
21
  use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
9
22
 
23
+ enum RerankerModel {
24
+ Bert {
25
+ model: BertModel,
26
+ pooler: Linear,
27
+ classifier: Linear,
28
+ },
29
+ XLMRoberta {
30
+ model: XLMRobertaForSequenceClassification,
31
+ pad_token_id: u32,
32
+ },
33
+ DeBERTa {
34
+ model: DebertaV2Model,
35
+ pooler: DebertaV2ContextPooler,
36
+ classifier: Linear,
37
+ pad_token_id: u32,
38
+ },
39
+ ModernBert {
40
+ model: ModernBert,
41
+ head_dense: Linear,
42
+ head_norm: candle_nn::LayerNorm,
43
+ classifier: Linear,
44
+ pad_token_id: u32,
45
+ },
46
+ Qwen3 {
47
+ model: RefCell<Qwen3Model>,
48
+ yes_token_id: u32,
49
+ no_token_id: u32,
50
+ },
51
+ }
52
+
10
53
  #[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
11
54
  pub struct Reranker {
12
- model: BertModel,
55
+ model: RerankerModel,
13
56
  tokenizer: TokenizerWrapper,
14
- pooler: Linear,
15
- classifier: Linear,
16
57
  device: CoreDevice,
17
58
  model_id: String,
18
59
  }
@@ -28,7 +69,7 @@ impl Reranker {
28
69
  let ruby = Ruby::get().unwrap();
29
70
  let runtime_error = ruby.exception_runtime_error();
30
71
 
31
- let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
72
+ let result = (|| -> std::result::Result<(RerankerModel, TokenizerWrapper), Box<dyn std::error::Error + Send + Sync>> {
32
73
  let api = Api::new()?;
33
74
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
34
75
 
@@ -37,9 +78,10 @@ impl Reranker {
37
78
  let tokenizer_filename = repo.get("tokenizer.json")?;
38
79
  let weights_filename = repo.get("model.safetensors")?;
39
80
 
40
- // Load config
41
- let config = std::fs::read_to_string(config_filename)?;
42
- let config: Config = serde_json::from_str(&config)?;
81
+ // Read raw config to detect model type
82
+ let config_str = std::fs::read_to_string(&config_filename)?;
83
+ let raw_config: serde_json::Value = serde_json::from_str(&config_str)?;
84
+ let model_type = raw_config["model_type"].as_str().unwrap_or("bert");
43
85
 
44
86
  // Setup tokenizer with padding AND truncation
45
87
  let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
@@ -51,21 +93,71 @@ impl Reranker {
51
93
  VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
52
94
  };
53
95
 
54
- // Load BERT model
55
- let model = BertModel::load(vb.pp("bert"), &config)?;
56
-
57
- // Load pooler layer (dense + tanh activation)
58
- let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
59
-
60
- // Load classifier layer for cross-encoder (single output score)
61
- let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
96
+ let model = match model_type {
97
+ "xlm-roberta" => {
98
+ let config: XLMRobertaConfig = serde_json::from_str(&config_str)?;
99
+ let pad_token_id = config.pad_token_id;
100
+ let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
101
+ RerankerModel::XLMRoberta { model, pad_token_id }
102
+ }
103
+ "deberta-v2" => {
104
+ let config: DebertaV2Config = serde_json::from_str(&config_str)?;
105
+ let pad_token_id = config.pad_token_id.unwrap_or(0) as u32;
106
+ let model = DebertaV2Model::load(vb.pp("deberta"), &config)?;
107
+ let pooler = DebertaV2ContextPooler::load(vb.clone(), &config)?;
108
+ let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
109
+ let num_labels = config.id2label.as_ref().map_or(1, |m| m.len());
110
+ let classifier = candle_nn::linear(pooler_hidden_size, num_labels, vb.pp("classifier"))?;
111
+ RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id }
112
+ }
113
+ "qwen3" => {
114
+ let config: Qwen3Config = serde_json::from_str(&config_str)?;
115
+ let model = Qwen3Model::new(&config, vb)?;
116
+
117
+ // Look up "yes" and "no" token IDs from the tokenizer
118
+ let yes_token_id: u32 = tokenizer
119
+ .encode("yes", false)
120
+ .ok()
121
+ .and_then(|enc| enc.get_ids().first().copied())
122
+ .unwrap_or(9693);
123
+ let no_token_id: u32 = tokenizer
124
+ .encode("no", false)
125
+ .ok()
126
+ .and_then(|enc| enc.get_ids().first().copied())
127
+ .unwrap_or(2152);
128
+
129
+ RerankerModel::Qwen3 {
130
+ model: RefCell::new(model),
131
+ yes_token_id,
132
+ no_token_id,
133
+ }
134
+ }
135
+ "modernbert" => {
136
+ let config: ModernBertConfig = serde_json::from_str(&config_str)?;
137
+ let pad_token_id = config.pad_token_id;
138
+ let model = ModernBert::load(vb.clone(), &config)?;
139
+ // ModernBertHead::load is private, so load the head layers manually
140
+ let head_vb = vb.pp("head");
141
+ let head_dense = candle_nn::linear_no_bias(config.hidden_size, config.hidden_size, head_vb.pp("dense"))?;
142
+ let head_norm = candle_nn::layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, head_vb.pp("norm"))?;
143
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
144
+ RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id }
145
+ }
146
+ _ => {
147
+ let config: BertConfig = serde_json::from_str(&config_str)?;
148
+ let model = BertModel::load(vb.pp("bert"), &config)?;
149
+ let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
150
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
151
+ RerankerModel::Bert { model, pooler, classifier }
152
+ }
153
+ };
62
154
 
63
- Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
155
+ Ok((model, TokenizerWrapper::new(tokenizer)))
64
156
  })();
65
157
 
66
158
  match result {
67
- Ok((model, tokenizer, pooler, classifier)) => {
68
- Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
159
+ Ok((model, tokenizer)) => {
160
+ Ok(Self { model, tokenizer, device, model_id })
69
161
  }
70
162
  Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
71
163
  }
@@ -164,49 +256,161 @@ impl Reranker {
164
256
  .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
165
257
  let token_type_ids = Tensor::new(token_type_ids, &self.device)
166
258
  .map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
167
- let attention_mask = token_ids.ne(0u32)
168
- .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
169
-
170
- // Forward pass through BERT
171
- let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
172
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
173
-
174
- // Apply pooling based on the specified method
175
- let pooled_embeddings = match pooling_method.as_str() {
176
- "pooler" => {
177
- // Extract [CLS] token and apply pooler (dense + tanh)
178
- let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
179
- let pooled = self.pooler.forward(&cls_embeddings)
259
+
260
+ // Compute scores based on model type
261
+ let scores = match &self.model {
262
+ RerankerModel::Bert { model, pooler, classifier } => {
263
+ let attention_mask = token_ids.ne(0u32)
264
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
265
+
266
+ // Forward pass through BERT
267
+ let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
268
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
269
+
270
+ // Apply pooling based on the specified method
271
+ let pooled_embeddings = match pooling_method.as_str() {
272
+ "pooler" => {
273
+ let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
274
+ let pooled = pooler.forward(&cls_embeddings)
275
+ .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
276
+ pooled.tanh()
277
+ .map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
278
+ },
279
+ "cls" => {
280
+ self.extract_cls_embeddings(&embeddings)?
281
+ },
282
+ "mean" => {
283
+ let (_batch, seq_len, _hidden) = embeddings.dims3()
284
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
285
+ let sum = embeddings.sum(1)
286
+ .map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
287
+ (sum / (seq_len as f64))
288
+ .map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
289
+ },
290
+ _ => return Err(Error::new(runtime_error,
291
+ format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
292
+ };
293
+
294
+ let pooled_embeddings = pooled_embeddings.contiguous()
295
+ .map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
296
+ let logits = classifier.forward(&pooled_embeddings)
297
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
298
+ logits.squeeze(1)
299
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
300
+ }
301
+ RerankerModel::XLMRoberta { model, pad_token_id } => {
302
+ let attention_mask = token_ids.ne(*pad_token_id)
303
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
304
+
305
+ // XLMRobertaForSequenceClassification returns logits directly
306
+ let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
307
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
308
+ logits.squeeze(1)
309
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
310
+ }
311
+ RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
312
+ let attention_mask = token_ids.ne(*pad_token_id)
313
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
314
+
315
+ // Forward through DeBERTa encoder
316
+ let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
317
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
318
+
319
+ // Pool and classify
320
+ let pooled = pooler.forward(&encoder_output)
180
321
  .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
181
- pooled.tanh()
182
- .map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
183
- },
184
- "cls" => {
185
- // Just use the [CLS] token embeddings directly (no pooler layer)
186
- self.extract_cls_embeddings(&embeddings)?
187
- },
188
- "mean" => {
189
- // Mean pooling across all tokens
190
- let (_batch, seq_len, _hidden) = embeddings.dims3()
191
- .map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
192
- let sum = embeddings.sum(1)
193
- .map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
194
- (sum / (seq_len as f64))
195
- .map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
196
- },
197
- _ => return Err(Error::new(runtime_error,
198
- format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
322
+ let logits = classifier.forward(&pooled)
323
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
324
+ logits.squeeze(1)
325
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
326
+ }
327
+ RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
328
+ let attention_mask = token_ids.ne(*pad_token_id)
329
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
330
+ let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
331
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert attention mask: {}", e)))?;
332
+
333
+ // Forward through ModernBERT encoder
334
+ let encoder_output = model.forward(&token_ids, &attention_mask_f32)
335
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
336
+
337
+ // CLS pooling, then head (dense + GELU + norm) + classifier
338
+ let cls = encoder_output.i((.., 0, ..))
339
+ .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?
340
+ .contiguous()
341
+ .map_err(|e| Error::new(runtime_error, format!("Failed to make contiguous: {}", e)))?;
342
+ let hidden = head_dense.forward(&cls)
343
+ .map_err(|e| Error::new(runtime_error, format!("Head dense failed: {}", e)))?;
344
+ let hidden = hidden.gelu_erf()
345
+ .map_err(|e| Error::new(runtime_error, format!("GELU activation failed: {}", e)))?;
346
+ let hidden = head_norm.forward(&hidden)
347
+ .map_err(|e| Error::new(runtime_error, format!("Head norm failed: {}", e)))?;
348
+ let logits = classifier.forward(&hidden)
349
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
350
+ logits.squeeze(1)
351
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
352
+ }
353
+ RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
354
+ // Qwen3 reranker: decoder-based yes/no scoring
355
+ // Process each document individually (causal LM, not batch encoder)
356
+ let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
357
+ let mut model = model.borrow_mut();
358
+
359
+ for doc in &documents {
360
+ // Build the Qwen3 reranker prompt
361
+ let prompt = format!(
362
+ "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
363
+ query, doc
364
+ );
365
+
366
+ // Tokenize the prompt
367
+ let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
368
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
369
+ let input_ids: Vec<u32> = encoding.get_ids().to_vec();
370
+
371
+ // Clear KV cache for each document
372
+ model.clear_kv_cache();
373
+
374
+ // Forward pass — get logits for the last token position
375
+ let input_tensor = Tensor::new(&input_ids[..], &self.device)
376
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?
377
+ .unsqueeze(0)
378
+ .map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
379
+
380
+ let logits = model.forward(&input_tensor, 0)
381
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
382
+
383
+ // logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
384
+ let logits = logits.flatten_all()
385
+ .map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?
386
+ .to_dtype(DType::F32)
387
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
388
+
389
+ // Extract yes/no logits and compute score
390
+ let yes_logit: f32 = logits.i(*yes_token_id as usize)
391
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get yes logit: {}", e)))?
392
+ .to_scalar()
393
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert yes logit: {}", e)))?;
394
+ let no_logit: f32 = logits.i(*no_token_id as usize)
395
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get no logit: {}", e)))?
396
+ .to_scalar()
397
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert no logit: {}", e)))?;
398
+
399
+ // softmax over [yes, no] → P(yes)
400
+ let max_logit = yes_logit.max(no_logit);
401
+ let yes_exp = (yes_logit - max_logit).exp();
402
+ let no_exp = (no_logit - max_logit).exp();
403
+ let score = yes_exp / (yes_exp + no_exp);
404
+
405
+ scores_vec.push(score);
406
+ }
407
+
408
+ // Build scores tensor for uniform handling below
409
+ Tensor::new(scores_vec.as_slice(), &self.device)
410
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create scores tensor: {}", e)))?
411
+ }
199
412
  };
200
413
 
201
- // Apply classifier to get relevance scores (raw logits)
202
- // Ensure tensor is contiguous before linear layer
203
- let pooled_embeddings = pooled_embeddings.contiguous()
204
- .map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
205
- let logits = self.classifier.forward(&pooled_embeddings)
206
- .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
207
- let scores = logits.squeeze(1)
208
- .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?;
209
-
210
414
  // Optionally apply sigmoid activation
211
415
  let scores = if apply_sigmoid {
212
416
  sigmoid(&scores)
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.6.0"
3
+ VERSION = "1.6.1"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.6.0
4
+ version: 1.6.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen