red-candle 1.2.1 → 1.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +75 -0
- data/ext/candle/src/ruby/reranker.rs +40 -38
- data/lib/candle/reranker.rb +6 -4
- data/lib/candle/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d101e47090a800dddabfa1ee41971ca3f214f9a422e9408c0970b24c03aab8c7
|
4
|
+
data.tar.gz: 7136f0922548fd4bead034930b3c6d1553c498c23fc17802a6c94f7f4a0c7ce5
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 594900e14d64fc335e53e2599a6e2d81ba4fdb1a1f3d5680536a47538bdf05d12332e959b33b67bfa04cca0dd321835a737d2ddf9cdd33c42b937370dde7a2a1
|
7
|
+
data.tar.gz: 4005d21ddccce9182d8600af5b3fd3f34ee9fabdf1b8235eb37eb1be7ede74b644c23d4ea322f6c0e05808e3e98e0c2066251111f538e9e8a934e7ec069ae4c0
|
data/README.md
CHANGED
@@ -363,6 +363,12 @@ require 'candle'
|
|
363
363
|
# Initialize the reranker with a cross-encoder model
|
364
364
|
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
365
365
|
|
366
|
+
# Or with custom max_length for truncation (default is 512)
|
367
|
+
reranker = Candle::Reranker.from_pretrained(
|
368
|
+
"cross-encoder/ms-marco-MiniLM-L-12-v2",
|
369
|
+
max_length: 256 # Faster processing with less context
|
370
|
+
)
|
371
|
+
|
366
372
|
# Define your query and candidate documents
|
367
373
|
query = "How many people live in London?"
|
368
374
|
documents = [
|
@@ -469,6 +475,75 @@ The reranker uses a BERT-based architecture that:
|
|
469
475
|
|
470
476
|
This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
|
471
477
|
|
478
|
+
### Performance Considerations
|
479
|
+
|
480
|
+
**Important**: The Reranker automatically truncates documents to ensure stable performance. The default maximum is 512 tokens, but this is configurable.
|
481
|
+
|
482
|
+
#### Configurable Truncation
|
483
|
+
|
484
|
+
You can adjust the `max_length` parameter to balance performance and context:
|
485
|
+
|
486
|
+
```ruby
|
487
|
+
# Default: 512 tokens (maximum context, ~300ms per doc on CPU)
|
488
|
+
reranker = Candle::Reranker.from_pretrained(model_id)
|
489
|
+
|
490
|
+
# Faster: 256 tokens (~60% faster, ~120ms per doc on CPU)
|
491
|
+
reranker = Candle::Reranker.from_pretrained(model_id, max_length: 256)
|
492
|
+
|
493
|
+
# Fastest: 128 tokens (~80% faster, ~60ms per doc on CPU)
|
494
|
+
reranker = Candle::Reranker.from_pretrained(model_id, max_length: 128)
|
495
|
+
```
|
496
|
+
|
497
|
+
Choose based on your needs:
|
498
|
+
- **512 tokens**: Maximum context for complex queries (default)
|
499
|
+
- **256 tokens**: Good balance of speed and context
|
500
|
+
- **128 tokens**: Fast processing for simple matching
|
501
|
+
|
502
|
+
#### Performance Guidelines
|
503
|
+
|
504
|
+
1. **Document Length**: Documents longer than ~400 words will be truncated
|
505
|
+
- The first 512 tokens (roughly 300-400 words) are used
|
506
|
+
- Consider splitting very long documents into chunks if full coverage is needed
|
507
|
+
|
508
|
+
2. **Batch Size**: Process multiple documents in one call for efficiency
|
509
|
+
```ruby
|
510
|
+
# Good: Single call with multiple documents
|
511
|
+
results = reranker.rerank(query, documents)
|
512
|
+
|
513
|
+
# Less efficient: Multiple calls
|
514
|
+
documents.map { |doc| reranker.rerank(query, [doc]) }
|
515
|
+
```
|
516
|
+
|
517
|
+
3. **Expected Performance**:
|
518
|
+
- **CPU**: ~0.3-0.5s per query-document pair
|
519
|
+
- **GPU (Metal/CUDA)**: ~0.05-0.1s per query-document pair
|
520
|
+
- Performance is consistent regardless of document length due to truncation
|
521
|
+
|
522
|
+
4. **Chunking Strategy** for long documents:
|
523
|
+
```ruby
|
524
|
+
def rerank_long_document(query, long_text, chunk_size: 300)
|
525
|
+
# Split into overlapping chunks
|
526
|
+
words = long_text.split
|
527
|
+
chunks = []
|
528
|
+
|
529
|
+
(0...words.length).step(chunk_size - 50) do |i|
|
530
|
+
chunk = words[i...(i + chunk_size)].join(" ")
|
531
|
+
chunks << chunk
|
532
|
+
end
|
533
|
+
|
534
|
+
# Rerank chunks
|
535
|
+
results = reranker.rerank(query, chunks)
|
536
|
+
|
537
|
+
# Return best chunk
|
538
|
+
results.max_by { |r| r[:score] }
|
539
|
+
end
|
540
|
+
```
|
541
|
+
|
542
|
+
5. **Memory Usage**:
|
543
|
+
- Model size: ~125MB
|
544
|
+
- Each batch processes all documents simultaneously
|
545
|
+
- Consider batching if you have many documents
|
546
|
+
|
472
547
|
## Tokenizer
|
473
548
|
|
474
549
|
Red-Candle provides direct access to tokenizers for text preprocessing and analysis. This is useful for understanding how models process text, debugging issues, and building custom NLP pipelines.
|
@@ -18,46 +18,48 @@ pub struct Reranker {
|
|
18
18
|
}
|
19
19
|
|
20
20
|
impl Reranker {
|
21
|
-
pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
|
21
|
+
pub fn new(model_id: String, device: Option<Device>, max_length: Option<usize>) -> Result<Self> {
|
22
22
|
let device = device.unwrap_or(Device::best()).as_device()?;
|
23
|
-
|
23
|
+
let max_length = max_length.unwrap_or(512); // Default to 512
|
24
|
+
Self::new_with_core_device(model_id, device, max_length)
|
24
25
|
}
|
25
|
-
|
26
|
-
fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
26
|
+
|
27
|
+
fn new_with_core_device(model_id: String, device: CoreDevice, max_length: usize) -> std::result::Result<Self, Error> {
|
27
28
|
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
28
29
|
let api = Api::new()?;
|
29
30
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
30
|
-
|
31
|
+
|
31
32
|
// Download model files
|
32
33
|
let config_filename = repo.get("config.json")?;
|
33
34
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
34
35
|
let weights_filename = repo.get("model.safetensors")?;
|
35
|
-
|
36
|
+
|
36
37
|
// Load config
|
37
38
|
let config = std::fs::read_to_string(config_filename)?;
|
38
39
|
let config: Config = serde_json::from_str(&config)?;
|
39
40
|
|
40
|
-
// Setup tokenizer with padding
|
41
|
+
// Setup tokenizer with padding AND truncation
|
41
42
|
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
42
43
|
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
43
|
-
|
44
|
+
let tokenizer = TokenizerLoader::with_truncation(tokenizer, max_length);
|
45
|
+
|
44
46
|
// Load model weights
|
45
47
|
let vb = unsafe {
|
46
48
|
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
47
49
|
};
|
48
|
-
|
50
|
+
|
49
51
|
// Load BERT model
|
50
52
|
let model = BertModel::load(vb.pp("bert"), &config)?;
|
51
|
-
|
53
|
+
|
52
54
|
// Load pooler layer (dense + tanh activation)
|
53
55
|
let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
|
54
|
-
|
56
|
+
|
55
57
|
// Load classifier layer for cross-encoder (single output score)
|
56
58
|
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
57
|
-
|
59
|
+
|
58
60
|
Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
|
59
61
|
})();
|
60
|
-
|
62
|
+
|
61
63
|
match result {
|
62
64
|
Ok((model, tokenizer, pooler, classifier)) => {
|
63
65
|
Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
|
@@ -65,18 +67,18 @@ impl Reranker {
|
|
65
67
|
Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
|
66
68
|
}
|
67
69
|
}
|
68
|
-
|
70
|
+
|
69
71
|
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
70
72
|
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
|
71
73
|
let cls_embeddings = if self.device.is_metal() {
|
72
74
|
// Metal has issues with tensor indexing, use a different approach
|
73
75
|
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
74
76
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
75
|
-
|
77
|
+
|
76
78
|
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
77
79
|
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
78
80
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
79
|
-
|
81
|
+
|
80
82
|
// Extract CLS tokens (first token of each sequence)
|
81
83
|
let mut cls_vecs = Vec::new();
|
82
84
|
for i in 0..batch_size {
|
@@ -85,7 +87,7 @@ impl Reranker {
|
|
85
87
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
86
88
|
cls_vecs.push(cls_vec);
|
87
89
|
}
|
88
|
-
|
90
|
+
|
89
91
|
// Stack the CLS vectors
|
90
92
|
Tensor::cat(&cls_vecs, 0)
|
91
93
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
@@ -93,39 +95,39 @@ impl Reranker {
|
|
93
95
|
embeddings.i((.., 0))
|
94
96
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
95
97
|
};
|
96
|
-
|
98
|
+
|
97
99
|
// Ensure tensor is contiguous for downstream operations
|
98
100
|
cls_embeddings.contiguous()
|
99
101
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
|
100
102
|
}
|
101
|
-
|
103
|
+
|
102
104
|
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
|
103
105
|
// Create query-document pair for cross-encoder
|
104
106
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
105
|
-
|
107
|
+
|
106
108
|
// Tokenize using the inner tokenizer for detailed info
|
107
109
|
let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
|
108
110
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
109
|
-
|
111
|
+
|
110
112
|
// Get token information
|
111
113
|
let token_ids = encoding.get_ids().to_vec();
|
112
114
|
let token_type_ids = encoding.get_type_ids().to_vec();
|
113
115
|
let attention_mask = encoding.get_attention_mask().to_vec();
|
114
116
|
let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
|
115
|
-
|
117
|
+
|
116
118
|
// Create result hash
|
117
119
|
let result = magnus::RHash::new();
|
118
120
|
result.aset("token_ids", RArray::from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
119
121
|
result.aset("token_type_ids", RArray::from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
120
122
|
result.aset("attention_mask", RArray::from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
|
121
123
|
result.aset("tokens", RArray::from_vec(tokens))?;
|
122
|
-
|
124
|
+
|
123
125
|
Ok(result)
|
124
126
|
}
|
125
|
-
|
127
|
+
|
126
128
|
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
|
127
129
|
let documents: Vec<String> = documents.to_vec()?;
|
128
|
-
|
130
|
+
|
129
131
|
// Create query-document pairs for cross-encoder
|
130
132
|
let query_and_docs: Vec<EncodeInput> = documents
|
131
133
|
.iter()
|
@@ -135,13 +137,13 @@ impl Reranker {
|
|
135
137
|
// Tokenize batch using inner tokenizer for access to token type IDs
|
136
138
|
let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
|
137
139
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
138
|
-
|
140
|
+
|
139
141
|
// Convert to tensors
|
140
142
|
let token_ids = encodings
|
141
143
|
.iter()
|
142
144
|
.map(|e| e.get_ids().to_vec())
|
143
145
|
.collect::<Vec<_>>();
|
144
|
-
|
146
|
+
|
145
147
|
let token_type_ids = encodings
|
146
148
|
.iter()
|
147
149
|
.map(|e| e.get_type_ids().to_vec())
|
@@ -153,11 +155,11 @@ impl Reranker {
|
|
153
155
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create token type ids tensor: {}", e)))?;
|
154
156
|
let attention_mask = token_ids.ne(0u32)
|
155
157
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create attention mask: {}", e)))?;
|
156
|
-
|
158
|
+
|
157
159
|
// Forward pass through BERT
|
158
160
|
let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
159
161
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Model forward pass failed: {}", e)))?;
|
160
|
-
|
162
|
+
|
161
163
|
// Apply pooling based on the specified method
|
162
164
|
let pooled_embeddings = match pooling_method.as_str() {
|
163
165
|
"pooler" => {
|
@@ -181,10 +183,10 @@ impl Reranker {
|
|
181
183
|
(sum / (seq_len as f64))
|
182
184
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to compute mean: {}", e)))?
|
183
185
|
},
|
184
|
-
_ => return Err(Error::new(magnus::exception::runtime_error(),
|
186
|
+
_ => return Err(Error::new(magnus::exception::runtime_error(),
|
185
187
|
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
|
186
188
|
};
|
187
|
-
|
189
|
+
|
188
190
|
// Apply classifier to get relevance scores (raw logits)
|
189
191
|
// Ensure tensor is contiguous before linear layer
|
190
192
|
let pooled_embeddings = pooled_embeddings.contiguous()
|
@@ -193,7 +195,7 @@ impl Reranker {
|
|
193
195
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Classifier forward failed: {}", e)))?;
|
194
196
|
let scores = logits.squeeze(1)
|
195
197
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to squeeze tensor: {}", e)))?;
|
196
|
-
|
198
|
+
|
197
199
|
// Optionally apply sigmoid activation
|
198
200
|
let scores = if apply_sigmoid {
|
199
201
|
sigmoid(&scores)
|
@@ -201,7 +203,7 @@ impl Reranker {
|
|
201
203
|
} else {
|
202
204
|
scores
|
203
205
|
};
|
204
|
-
|
206
|
+
|
205
207
|
let scores_vec: Vec<f32> = scores.to_vec1()
|
206
208
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to convert scores to vec: {}", e)))?;
|
207
209
|
|
@@ -212,7 +214,7 @@ impl Reranker {
|
|
212
214
|
.enumerate()
|
213
215
|
.map(|(idx, (doc, score))| (doc, score, idx))
|
214
216
|
.collect();
|
215
|
-
|
217
|
+
|
216
218
|
// Sort documents by relevance score (descending)
|
217
219
|
ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
218
220
|
|
@@ -232,17 +234,17 @@ impl Reranker {
|
|
232
234
|
pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
|
233
235
|
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
234
236
|
}
|
235
|
-
|
237
|
+
|
236
238
|
/// Get the model_id
|
237
239
|
pub fn model_id(&self) -> String {
|
238
240
|
self.model_id.clone()
|
239
241
|
}
|
240
|
-
|
242
|
+
|
241
243
|
/// Get the device
|
242
244
|
pub fn device(&self) -> Device {
|
243
245
|
Device::from_device(&self.device)
|
244
246
|
}
|
245
|
-
|
247
|
+
|
246
248
|
/// Get all options as a hash
|
247
249
|
pub fn options(&self) -> std::result::Result<magnus::RHash, Error> {
|
248
250
|
let hash = magnus::RHash::new();
|
@@ -254,7 +256,7 @@ impl Reranker {
|
|
254
256
|
|
255
257
|
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
256
258
|
let c_reranker = rb_candle.define_class("Reranker", class::object())?;
|
257
|
-
c_reranker.define_singleton_method("_create", function!(Reranker::new,
|
259
|
+
c_reranker.define_singleton_method("_create", function!(Reranker::new, 3))?;
|
258
260
|
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
259
261
|
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
260
262
|
c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
|
data/lib/candle/reranker.rb
CHANGED
@@ -6,18 +6,20 @@ module Candle
|
|
6
6
|
# Load a pre-trained reranker model from HuggingFace
|
7
7
|
# @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
|
8
8
|
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
9
|
+
# @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
|
9
10
|
# @return [Reranker] A new Reranker instance
|
10
|
-
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
|
11
|
-
_create(model_id, device)
|
11
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
|
12
|
+
_create(model_id, device, max_length)
|
12
13
|
end
|
13
14
|
|
14
15
|
# Constructor for creating a new Reranker with optional parameters
|
15
16
|
# @deprecated Use {.from_pretrained} instead
|
16
17
|
# @param model_path [String, nil] The path to the model on Hugging Face
|
17
18
|
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
18
|
-
|
19
|
+
# @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
|
20
|
+
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
|
19
21
|
$stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
|
20
|
-
_create(model_path, device)
|
22
|
+
_create(model_path, device, max_length)
|
21
23
|
end
|
22
24
|
|
23
25
|
# Returns documents ranked by relevance using the specified pooling method.
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -9,7 +9,7 @@ authors:
|
|
9
9
|
autorequire:
|
10
10
|
bindir: bin
|
11
11
|
cert_chain: []
|
12
|
-
date: 2025-08-
|
12
|
+
date: 2025-08-13 00:00:00.000000000 Z
|
13
13
|
dependencies:
|
14
14
|
- !ruby/object:Gem::Dependency
|
15
15
|
name: rb_sys
|