red-candle 1.2.0 → 1.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3f2d005688d7b0253060d087a9800ea8c1d2c7bbb6ff6c92cca3ebc238d99be3
4
- data.tar.gz: 6296db628c2d13a39ef035fe45c41be39de2404e6ab72d9735109ad18879f65c
3
+ metadata.gz: d101e47090a800dddabfa1ee41971ca3f214f9a422e9408c0970b24c03aab8c7
4
+ data.tar.gz: 7136f0922548fd4bead034930b3c6d1553c498c23fc17802a6c94f7f4a0c7ce5
5
5
  SHA512:
6
- metadata.gz: 4511b6f96d1356101e10547f740702479c894fb6d1e6f8cb04213b49b624ac8dc73d83b8b91bbab396e095fd31bbb4dd019ca967ce7f790447a4b77dd25d3356
7
- data.tar.gz: 5a1ef095e2bbd9967317e0c416fb018da2673e18d76e00c98177cdba5dd2f9c5723fc5404b0d4221800227aab8b74e5560d420aff79f94e216365e6d3cee6f1e
6
+ metadata.gz: 594900e14d64fc335e53e2599a6e2d81ba4fdb1a1f3d5680536a47538bdf05d12332e959b33b67bfa04cca0dd321835a737d2ddf9cdd33c42b937370dde7a2a1
7
+ data.tar.gz: 4005d21ddccce9182d8600af5b3fd3f34ee9fabdf1b8235eb37eb1be7ede74b644c23d4ea322f6c0e05808e3e98e0c2066251111f538e9e8a934e7ec069ae4c0
data/README.md CHANGED
@@ -363,6 +363,12 @@ require 'candle'
363
363
  # Initialize the reranker with a cross-encoder model
364
364
  reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
365
365
 
366
+ # Or with custom max_length for truncation (default is 512)
367
+ reranker = Candle::Reranker.from_pretrained(
368
+ "cross-encoder/ms-marco-MiniLM-L-12-v2",
369
+ max_length: 256 # Faster processing with less context
370
+ )
371
+
366
372
  # Define your query and candidate documents
367
373
  query = "How many people live in London?"
368
374
  documents = [
@@ -469,6 +475,75 @@ The reranker uses a BERT-based architecture that:
469
475
 
470
476
  This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
471
477
 
478
+ ### Performance Considerations
479
+
480
+ **Important**: The Reranker automatically truncates documents to ensure stable performance. The default maximum is 512 tokens, but this is configurable.
481
+
482
+ #### Configurable Truncation
483
+
484
+ You can adjust the `max_length` parameter to balance performance and context:
485
+
486
+ ```ruby
487
+ # Default: 512 tokens (maximum context, ~300ms per doc on CPU)
488
+ reranker = Candle::Reranker.from_pretrained(model_id)
489
+
490
+ # Faster: 256 tokens (~60% faster, ~120ms per doc on CPU)
491
+ reranker = Candle::Reranker.from_pretrained(model_id, max_length: 256)
492
+
493
+ # Fastest: 128 tokens (~80% faster, ~60ms per doc on CPU)
494
+ reranker = Candle::Reranker.from_pretrained(model_id, max_length: 128)
495
+ ```
496
+
497
+ Choose based on your needs:
498
+ - **512 tokens**: Maximum context for complex queries (default)
499
+ - **256 tokens**: Good balance of speed and context
500
+ - **128 tokens**: Fast processing for simple matching
501
+
502
+ #### Performance Guidelines
503
+
504
+ 1. **Document Length**: Documents longer than ~400 words will be truncated
505
+ - The first 512 tokens (roughly 300-400 words) are used
506
+ - Consider splitting very long documents into chunks if full coverage is needed
507
+
508
+ 2. **Batch Size**: Process multiple documents in one call for efficiency
509
+ ```ruby
510
+ # Good: Single call with multiple documents
511
+ results = reranker.rerank(query, documents)
512
+
513
+ # Less efficient: Multiple calls
514
+ documents.map { |doc| reranker.rerank(query, [doc]) }
515
+ ```
516
+
517
+ 3. **Expected Performance**:
518
+ - **CPU**: ~0.3-0.5s per query-document pair
519
+ - **GPU (Metal/CUDA)**: ~0.05-0.1s per query-document pair
520
+ - Performance is consistent regardless of document length due to truncation
521
+
522
+ 4. **Chunking Strategy** for long documents:
523
+ ```ruby
524
+ def rerank_long_document(query, long_text, chunk_size: 300)
525
+ # Split into overlapping chunks
526
+ words = long_text.split
527
+ chunks = []
528
+
529
+ (0...words.length).step(chunk_size - 50) do |i|
530
+ chunk = words[i...(i + chunk_size)].join(" ")
531
+ chunks << chunk
532
+ end
533
+
534
+ # Rerank chunks
535
+ results = reranker.rerank(query, chunks)
536
+
537
+ # Return best chunk
538
+ results.max_by { |r| r[:score] }
539
+ end
540
+ ```
541
+
542
+ 5. **Memory Usage**:
543
+ - Model size: ~125MB
544
+ - Each batch processes all documents simultaneously
545
+ - Consider batching if you have many documents
546
+
472
547
  ## Tokenizer
473
548
 
474
549
  Red-Candle provides direct access to tokenizers for text preprocessing and analysis. This is useful for understanding how models process text, debugging issues, and building custom NLP pipelines.
@@ -699,7 +774,7 @@ Perfect for specialized domains:
699
774
  ```ruby
700
775
  # Biomedical entities
701
776
  gene_patterns = [
702
- /\b[A-Z][A-Z0-9]{2,}\b/, # TP53, BRCA1, EGFR
777
+ /\b[A-Z][A-Z0-9]{2,10}\b/, # TP53, BRCA1, EGFR (bounded for safety)
703
778
  /\bCD\d+\b/, # CD4, CD8, CD34
704
779
  /\b[A-Z]+\d[A-Z]\d*\b/ # RAD51C, PALB2
705
780
  ]
@@ -18,46 +18,48 @@ pub struct Reranker {
18
18
  }
19
19
 
20
20
  impl Reranker {
21
- pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
21
+ pub fn new(model_id: String, device: Option<Device>, max_length: Option<usize>) -> Result<Self> {
22
22
  let device = device.unwrap_or(Device::best()).as_device()?;
23
- Self::new_with_core_device(model_id, device)
23
+ let max_length = max_length.unwrap_or(512); // Default to 512
24
+ Self::new_with_core_device(model_id, device, max_length)
24
25
  }
25
-
26
- fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
26
+
27
+ fn new_with_core_device(model_id: String, device: CoreDevice, max_length: usize) -> std::result::Result<Self, Error> {
27
28
  let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
28
29
  let api = Api::new()?;
29
30
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
30
-
31
+
31
32
  // Download model files
32
33
  let config_filename = repo.get("config.json")?;
33
34
  let tokenizer_filename = repo.get("tokenizer.json")?;
34
35
  let weights_filename = repo.get("model.safetensors")?;
35
-
36
+
36
37
  // Load config
37
38
  let config = std::fs::read_to_string(config_filename)?;
38
39
  let config: Config = serde_json::from_str(&config)?;
39
40
 
40
- // Setup tokenizer with padding
41
+ // Setup tokenizer with padding AND truncation
41
42
  let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
42
43
  let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
43
-
44
+ let tokenizer = TokenizerLoader::with_truncation(tokenizer, max_length);
45
+
44
46
  // Load model weights
45
47
  let vb = unsafe {
46
48
  VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
47
49
  };
48
-
50
+
49
51
  // Load BERT model
50
52
  let model = BertModel::load(vb.pp("bert"), &config)?;
51
-
53
+
52
54
  // Load pooler layer (dense + tanh activation)
53
55
  let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
54
-
56
+
55
57
  // Load classifier layer for cross-encoder (single output score)
56
58
  let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
57
-
59
+
58
60
  Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
59
61
  })();
60
-
62
+
61
63
  match result {
62
64
  Ok((model, tokenizer, pooler, classifier)) => {
63
65
  Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
@@ -65,18 +67,18 @@ impl Reranker {
65
67
  Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
66
68
  }
67
69
  }
68
-
70
+
69
71
  /// Extract CLS embeddings from the model output, handling Metal device workarounds
70
72
  fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
71
73
  let cls_embeddings = if self.device.is_metal() {
72
74
  // Metal has issues with tensor indexing, use a different approach
73
75
  let (batch_size, seq_len, hidden_size) = embeddings.dims3()
74
76
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
75
-
77
+
76
78
  // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
77
79
  let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
78
80
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
79
-
81
+
80
82
  // Extract CLS tokens (first token of each sequence)
81
83
  let mut cls_vecs = Vec::new();
82
84
  for i in 0..batch_size {
@@ -85,7 +87,7 @@ impl Reranker {
85
87
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
86
88
  cls_vecs.push(cls_vec);
87
89
  }
88
-
90
+
89
91
  // Stack the CLS vectors
90
92
  Tensor::cat(&cls_vecs, 0)
91
93
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
@@ -93,39 +95,39 @@ impl Reranker {
93
95
  embeddings.i((.., 0))
94
96
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
95
97
  };
96
-
98
+
97
99
  // Ensure tensor is contiguous for downstream operations
98
100
  cls_embeddings.contiguous()
99
101
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
100
102
  }
101
-
103
+
102
104
  pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
103
105
  // Create query-document pair for cross-encoder
104
106
  let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
105
-
107
+
106
108
  // Tokenize using the inner tokenizer for detailed info
107
109
  let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
108
110
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
109
-
111
+
110
112
  // Get token information
111
113
  let token_ids = encoding.get_ids().to_vec();
112
114
  let token_type_ids = encoding.get_type_ids().to_vec();
113
115
  let attention_mask = encoding.get_attention_mask().to_vec();
114
116
  let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
115
-
117
+
116
118
  // Create result hash
117
119
  let result = magnus::RHash::new();
118
120
  result.aset("token_ids", RArray::from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
119
121
  result.aset("token_type_ids", RArray::from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
120
122
  result.aset("attention_mask", RArray::from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
121
123
  result.aset("tokens", RArray::from_vec(tokens))?;
122
-
124
+
123
125
  Ok(result)
124
126
  }
125
-
127
+
126
128
  pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
127
129
  let documents: Vec<String> = documents.to_vec()?;
128
-
130
+
129
131
  // Create query-document pairs for cross-encoder
130
132
  let query_and_docs: Vec<EncodeInput> = documents
131
133
  .iter()
@@ -135,13 +137,13 @@ impl Reranker {
135
137
  // Tokenize batch using inner tokenizer for access to token type IDs
136
138
  let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
137
139
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
138
-
140
+
139
141
  // Convert to tensors
140
142
  let token_ids = encodings
141
143
  .iter()
142
144
  .map(|e| e.get_ids().to_vec())
143
145
  .collect::<Vec<_>>();
144
-
146
+
145
147
  let token_type_ids = encodings
146
148
  .iter()
147
149
  .map(|e| e.get_type_ids().to_vec())
@@ -153,11 +155,11 @@ impl Reranker {
153
155
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create token type ids tensor: {}", e)))?;
154
156
  let attention_mask = token_ids.ne(0u32)
155
157
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create attention mask: {}", e)))?;
156
-
158
+
157
159
  // Forward pass through BERT
158
160
  let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
159
161
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Model forward pass failed: {}", e)))?;
160
-
162
+
161
163
  // Apply pooling based on the specified method
162
164
  let pooled_embeddings = match pooling_method.as_str() {
163
165
  "pooler" => {
@@ -181,10 +183,10 @@ impl Reranker {
181
183
  (sum / (seq_len as f64))
182
184
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to compute mean: {}", e)))?
183
185
  },
184
- _ => return Err(Error::new(magnus::exception::runtime_error(),
186
+ _ => return Err(Error::new(magnus::exception::runtime_error(),
185
187
  format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
186
188
  };
187
-
189
+
188
190
  // Apply classifier to get relevance scores (raw logits)
189
191
  // Ensure tensor is contiguous before linear layer
190
192
  let pooled_embeddings = pooled_embeddings.contiguous()
@@ -193,7 +195,7 @@ impl Reranker {
193
195
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Classifier forward failed: {}", e)))?;
194
196
  let scores = logits.squeeze(1)
195
197
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to squeeze tensor: {}", e)))?;
196
-
198
+
197
199
  // Optionally apply sigmoid activation
198
200
  let scores = if apply_sigmoid {
199
201
  sigmoid(&scores)
@@ -201,7 +203,7 @@ impl Reranker {
201
203
  } else {
202
204
  scores
203
205
  };
204
-
206
+
205
207
  let scores_vec: Vec<f32> = scores.to_vec1()
206
208
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to convert scores to vec: {}", e)))?;
207
209
 
@@ -212,7 +214,7 @@ impl Reranker {
212
214
  .enumerate()
213
215
  .map(|(idx, (doc, score))| (doc, score, idx))
214
216
  .collect();
215
-
217
+
216
218
  // Sort documents by relevance score (descending)
217
219
  ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
218
220
 
@@ -232,17 +234,17 @@ impl Reranker {
232
234
  pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
233
235
  Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
234
236
  }
235
-
237
+
236
238
  /// Get the model_id
237
239
  pub fn model_id(&self) -> String {
238
240
  self.model_id.clone()
239
241
  }
240
-
242
+
241
243
  /// Get the device
242
244
  pub fn device(&self) -> Device {
243
245
  Device::from_device(&self.device)
244
246
  }
245
-
247
+
246
248
  /// Get all options as a hash
247
249
  pub fn options(&self) -> std::result::Result<magnus::RHash, Error> {
248
250
  let hash = magnus::RHash::new();
@@ -254,7 +256,7 @@ impl Reranker {
254
256
 
255
257
  pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
256
258
  let c_reranker = rb_candle.define_class("Reranker", class::object())?;
257
- c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
259
+ c_reranker.define_singleton_method("_create", function!(Reranker::new, 3))?;
258
260
  c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
259
261
  c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
260
262
  c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
data/lib/candle/ner.rb CHANGED
@@ -1,5 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ # Pattern validation available but not forced
4
+ # require_relative 'pattern_validator' # Uncomment if needed
5
+
3
6
  module Candle
4
7
  # Named Entity Recognition (NER) for token classification
5
8
  #
@@ -189,6 +192,14 @@ module Candle
189
192
  def recognize(text, tokenizer = nil)
190
193
  entities = []
191
194
 
195
+ # Limit text length to prevent ReDoS on very long strings
196
+ # This is especially important for Ruby < 3.2
197
+ max_length = 1_000_000 # 1MB of text
198
+ if text.length > max_length
199
+ warn "PatternEntityRecognizer: Text truncated from #{text.length} to #{max_length} chars for safety"
200
+ text = text[0...max_length]
201
+ end
202
+
192
203
  @patterns.each do |pattern|
193
204
  regex = pattern.is_a?(Regexp) ? pattern : Regexp.new(pattern)
194
205
 
@@ -6,18 +6,20 @@ module Candle
6
6
  # Load a pre-trained reranker model from HuggingFace
7
7
  # @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
8
8
  # @param device [Candle::Device] The device to use for computation (defaults to best available)
9
+ # @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
9
10
  # @return [Reranker] A new Reranker instance
10
- def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
11
- _create(model_id, device)
11
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
12
+ _create(model_id, device, max_length)
12
13
  end
13
14
 
14
15
  # Constructor for creating a new Reranker with optional parameters
15
16
  # @deprecated Use {.from_pretrained} instead
16
17
  # @param model_path [String, nil] The path to the model on Hugging Face
17
18
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
18
- def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best)
19
+ # @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
20
+ def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
19
21
  $stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
20
- _create(model_path, device)
22
+ _create(model_path, device, max_length)
21
23
  end
22
24
 
23
25
  # Returns documents ranked by relevance using the specified pooling method.
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.2.0"
3
+ VERSION = "1.2.2"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.0
4
+ version: 1.2.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-08-10 00:00:00.000000000 Z
12
+ date: 2025-08-13 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys