red-candle 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,10 @@ enum ModelType {
32
32
  }
33
33
 
34
34
  impl QuantizedGGUF {
35
+ pub fn eos_token_id(&self) -> u32 {
36
+ self.eos_token_id
37
+ }
38
+
35
39
  /// Get the tokenizer
36
40
  pub fn tokenizer(&self) -> &TokenizerWrapper {
37
41
  &self.tokenizer
@@ -316,7 +320,9 @@ impl QuantizedGGUF {
316
320
  // Check model name since Mistral GGUF reports as llama architecture
317
321
  let model_lower = self.model_id.to_lowercase();
318
322
 
319
- if model_lower.contains("mistral") {
323
+ if model_lower.contains("tinyllama") {
324
+ self.apply_chatml_template(messages)
325
+ } else if model_lower.contains("mistral") {
320
326
  self.apply_mistral_template(messages)
321
327
  } else if model_lower.contains("gemma") {
322
328
  // Always use Gemma template for Gemma models, regardless of loader used
@@ -512,6 +518,20 @@ impl QuantizedGGUF {
512
518
  Ok(prompt)
513
519
  }
514
520
 
521
+ fn apply_chatml_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
522
+ let mut prompt = String::new();
523
+
524
+ for message in messages {
525
+ let role = message["role"].as_str().unwrap_or("");
526
+ let content = message["content"].as_str().unwrap_or("");
527
+
528
+ prompt.push_str(&format!("<|{}|>\n{}</s>\n", role, content));
529
+ }
530
+
531
+ prompt.push_str("<|assistant|>");
532
+ Ok(prompt)
533
+ }
534
+
515
535
  fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
516
536
  let mut prompt = String::new();
517
537
 
@@ -538,7 +558,7 @@ impl QuantizedGGUF {
538
558
  config: &GenerationConfig,
539
559
  mut callback: Option<impl FnMut(&str)>,
540
560
  ) -> CandleResult<Vec<u32>> {
541
- let mut text_gen = TextGeneration::from_config(config);
561
+ let mut text_gen = TextGeneration::new(config);
542
562
  text_gen.set_eos_token_id(self.eos_token_id);
543
563
  text_gen.set_tokens(prompt_tokens.clone());
544
564
 
@@ -571,10 +591,7 @@ impl QuantizedGGUF {
571
591
 
572
592
  let logits = logits.to_dtype(DType::F32)?;
573
593
 
574
- let next_token = text_gen.sample_next_token(
575
- &logits,
576
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
577
- )?;
594
+ let next_token = text_gen.sample_next_token(&logits)?;
578
595
 
579
596
  all_tokens.push(next_token);
580
597
 
@@ -596,6 +613,18 @@ impl QuantizedGGUF {
596
613
  break;
597
614
  }
598
615
 
616
+ // Check if constraint is satisfied (early stopping)
617
+ if config.stop_on_constraint_satisfaction {
618
+ let satisfied = if config.stop_on_match {
619
+ text_gen.is_constraint_satisfied_stop_on_match()
620
+ } else {
621
+ text_gen.is_constraint_satisfied()
622
+ };
623
+ if satisfied {
624
+ break;
625
+ }
626
+ }
627
+
599
628
  // Check stop sequences
600
629
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
601
630
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -16,6 +16,10 @@ pub struct Qwen {
16
16
  }
17
17
 
18
18
  impl Qwen {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Get the tokenizer
20
24
  pub fn tokenizer(&self) -> &TokenizerWrapper {
21
25
  &self.tokenizer
@@ -55,6 +59,9 @@ impl Qwen {
55
59
  .unwrap_or(151643); // Default Qwen3 EOS token
56
60
 
57
61
  // Download model weights
62
+ // NOTE: Qwen uses hardcoded shard counts based on model size rather than
63
+ // reading model.safetensors.index.json. This works for official Qwen models
64
+ // but may fail for custom configurations with different shard counts.
58
65
  let mut filenames = vec![];
59
66
  let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
60
67
  else if model_id.contains("14b") || model_id.contains("14B") { 3 }
@@ -124,7 +131,7 @@ impl Qwen {
124
131
  config: &GenerationConfig,
125
132
  mut callback: Option<impl FnMut(&str)>,
126
133
  ) -> CandleResult<Vec<u32>> {
127
- let mut text_gen = TextGeneration::from_config(config);
134
+ let mut text_gen = TextGeneration::new(config);
128
135
  text_gen.set_eos_token_id(self.eos_token_id);
129
136
  text_gen.set_tokens(prompt_tokens.clone());
130
137
 
@@ -150,10 +157,7 @@ impl Qwen {
150
157
 
151
158
  let logits = logits.to_dtype(DType::F32)?;
152
159
 
153
- let next_token = text_gen.sample_next_token(
154
- &logits,
155
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
156
- )?;
160
+ let next_token = text_gen.sample_next_token(&logits)?;
157
161
 
158
162
  all_tokens.push(next_token);
159
163
 
@@ -173,6 +177,18 @@ impl Qwen {
173
177
  break;
174
178
  }
175
179
 
180
+ // Check if constraint is satisfied (early stopping)
181
+ if config.stop_on_constraint_satisfaction {
182
+ let satisfied = if config.stop_on_match {
183
+ text_gen.is_constraint_satisfied_stop_on_match()
184
+ } else {
185
+ text_gen.is_constraint_satisfied()
186
+ };
187
+ if satisfied {
188
+ break;
189
+ }
190
+ }
191
+
176
192
  // Check stop sequences
177
193
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
178
194
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -10,39 +10,29 @@ pub struct TextGeneration {
10
10
  logits_processor: LogitsProcessor,
11
11
  tokens: Vec<u32>,
12
12
  eos_token_id: Option<u32>,
13
+ repetition_penalty: f32,
14
+ repetition_penalty_last_n: usize,
13
15
  constraint: Option<Arc<Index>>,
14
16
  constraint_state: Option<u32>,
17
+ constraint_completed: bool,
18
+ tokens_since_constraint_start: usize,
15
19
  }
16
20
 
17
21
  impl TextGeneration {
18
- pub fn new(
19
- seed: u64,
20
- temperature: Option<f64>,
21
- top_p: Option<f64>,
22
- _top_k: Option<usize>,
23
- _repetition_penalty: f32,
24
- _repetition_penalty_last_n: usize,
25
- ) -> Self {
26
- let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
27
-
28
- Self {
22
+ pub fn new(config: &GenerationConfig) -> Self {
23
+ let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
24
+
25
+ let mut text_gen = Self {
29
26
  logits_processor,
30
27
  tokens: Vec::new(),
31
28
  eos_token_id: None,
29
+ repetition_penalty: config.repetition_penalty,
30
+ repetition_penalty_last_n: config.repetition_penalty_last_n,
32
31
  constraint: None,
33
32
  constraint_state: None,
34
- }
35
- }
36
-
37
- pub fn from_config(config: &GenerationConfig) -> Self {
38
- let mut text_gen = Self::new(
39
- config.seed,
40
- Some(config.temperature),
41
- config.top_p,
42
- config.top_k,
43
- config.repetition_penalty,
44
- config.repetition_penalty_last_n,
45
- );
33
+ constraint_completed: false,
34
+ tokens_since_constraint_start: 0,
35
+ };
46
36
 
47
37
  // Set constraint if provided
48
38
  if let Some(ref constraint) = config.constraint {
@@ -72,6 +62,8 @@ impl TextGeneration {
72
62
  // Initialize with the first state
73
63
  self.constraint_state = Some(constraint.initial_state());
74
64
  self.constraint = Some(constraint);
65
+ self.constraint_completed = false;
66
+ self.tokens_since_constraint_start = self.tokens.len();
75
67
  }
76
68
 
77
69
  /// Apply constraints to logits by masking disallowed tokens
@@ -137,13 +129,12 @@ impl TextGeneration {
137
129
  pub fn sample_next_token(
138
130
  &mut self,
139
131
  logits: &Tensor,
140
- repetition_penalty: Option<(f32, usize)>,
141
132
  ) -> CandleResult<u32> {
142
133
  let mut logits = logits.clone();
143
134
 
144
- // Apply repetition penalty if specified
145
- if let Some((penalty, last_n)) = repetition_penalty {
146
- self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
135
+ // Apply repetition penalty using stored parameters
136
+ if self.repetition_penalty != 1.0 {
137
+ self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
147
138
  }
148
139
 
149
140
  // Apply constraints if active
@@ -155,12 +146,114 @@ impl TextGeneration {
155
146
 
156
147
  // Update constraint state if active
157
148
  if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
158
- self.constraint_state = constraint_index.next_state(&current_state, &next_token);
149
+ // Get the next state
150
+ let next_state = constraint_index.next_state(&current_state, &next_token);
151
+
152
+ // Check if we're transitioning to a state with no allowed tokens (completion)
153
+ if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
154
+ // Check if we've transitioned from a constrained state to an unconstrained state
155
+ // This happens when the pattern is complete and the FSM allows "anything"
156
+
157
+ let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(&current_state) {
158
+ // Consider it constrained if we have a limited set of allowed tokens
159
+ allowed.len() < 1000 // Arbitrary threshold for "constrained"
160
+ } else {
161
+ true // No tokens allowed is definitely constrained
162
+ };
163
+
164
+ let next_constrained = if let Some(next_state_val) = next_state {
165
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
166
+ allowed.is_empty() || allowed.len() < 1000
167
+ } else {
168
+ true
169
+ }
170
+ } else {
171
+ true
172
+ };
173
+
174
+ // If we're transitioning from constrained to unconstrained, we've completed the pattern
175
+ if current_constrained && !next_constrained {
176
+ self.constraint_completed = true;
177
+ }
178
+
179
+ // Also check if next state has no allowed tokens at all
180
+ if let Some(next_state_val) = next_state {
181
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
182
+ if allowed.is_empty() {
183
+ self.constraint_completed = true;
184
+ }
185
+ } else {
186
+ // None means no tokens allowed - constraint is complete
187
+ self.constraint_completed = true;
188
+ }
189
+ }
190
+ }
191
+
192
+ self.constraint_state = next_state;
159
193
  }
160
194
 
161
195
  Ok(next_token)
162
196
  }
163
197
 
198
+ /// Check if the constraint is satisfied (reached a valid completion state)
199
+ pub fn is_constraint_satisfied(&self) -> bool {
200
+ // If we've explicitly marked the constraint as completed, return true
201
+ if self.constraint_completed {
202
+ return true;
203
+ }
204
+
205
+ // Also check the current state
206
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
207
+ // Check if the constraint has reached a state where it could validly end
208
+ // This happens when:
209
+ // 1. We have no more allowed tokens (constraint fully satisfied)
210
+ // 2. The EOS token is in the allowed tokens (optional ending)
211
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
212
+ // If no tokens are allowed, the constraint is fully satisfied
213
+ if allowed.is_empty() {
214
+ return true;
215
+ }
216
+
217
+ // If EOS token is allowed, we've reached an optional completion point
218
+ if let Some(eos) = self.eos_token_id {
219
+ if allowed.contains(&eos) {
220
+ return true;
221
+ }
222
+ }
223
+ } else {
224
+ // None means no tokens allowed - constraint is satisfied
225
+ return true;
226
+ }
227
+ }
228
+ false
229
+ }
230
+
231
+ /// Check if the constraint is satisfied when stop_on_match is true
232
+ pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
233
+ // When stop_on_match is true, we stop as soon as the constraint is completed
234
+ if self.constraint_completed {
235
+ return true;
236
+ }
237
+
238
+ // Also check if we're currently in a state that could be a valid end
239
+ // This is important for patterns like phone numbers where after matching
240
+ // the pattern, the FSM might allow any token (including more numbers)
241
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
242
+ // Check if we've generated at least one token since constraint start
243
+ if self.tokens.len() > self.tokens_since_constraint_start {
244
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
245
+ // If the allowed tokens set is very large (unconstrained),
246
+ // it means the pattern has been satisfied
247
+ if allowed.len() > 1000 {
248
+ return true;
249
+ }
250
+ }
251
+ }
252
+ }
253
+
254
+ false
255
+ }
256
+
164
257
  /// Check if we should stop generation
165
258
  pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
166
259
  if self.tokens.len() >= max_length {
@@ -39,13 +39,9 @@ impl NER {
39
39
  pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
40
40
  let device = device.unwrap_or(Device::Cpu).as_device()?;
41
41
 
42
- // Load model in a separate thread to avoid blocking
43
- let device_clone = device.clone();
44
- let model_id_clone = model_id.clone();
45
-
46
- let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
42
+ let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
47
43
  let api = Api::new()?;
48
- let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
44
+ let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
49
45
 
50
46
  // Download model files
51
47
  let config_filename = repo.get("config.json")?;
@@ -92,7 +88,7 @@ impl NER {
92
88
 
93
89
  // Load model weights
94
90
  let vb = unsafe {
95
- VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
91
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
96
92
  };
97
93
 
98
94
  // Load BERT model
@@ -106,10 +102,10 @@ impl NER {
106
102
  )?;
107
103
 
108
104
  Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
109
- });
105
+ })();
110
106
 
111
- match handle.join() {
112
- Ok(Ok((model, tokenizer, classifier, config))) => {
107
+ match result {
108
+ Ok((model, tokenizer, classifier, config)) => {
113
109
  Ok(Self {
114
110
  model,
115
111
  tokenizer,
@@ -119,28 +115,20 @@ impl NER {
119
115
  model_id,
120
116
  })
121
117
  }
122
- Ok(Err(e)) => Err(Error::new(
118
+ Err(e) => Err(Error::new(
123
119
  magnus::exception::runtime_error(),
124
120
  format!("Failed to load NER model: {}", e)
125
121
  )),
126
- Err(_) => Err(Error::new(
127
- magnus::exception::runtime_error(),
128
- "Thread panicked while loading NER model"
129
- )),
130
122
  }
131
123
  }
132
124
 
133
- /// Extract entities from text with confidence scores
134
- pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
135
- let threshold = confidence_threshold.unwrap_or(0.9) as f32;
136
-
125
+ /// Common tokenization and prediction logic
126
+ fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
137
127
  // Tokenize the text
138
- let encoding = self.tokenizer.inner().encode(text.as_str(), true)
128
+ let encoding = self.tokenizer.inner().encode(text, true)
139
129
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
140
130
 
141
131
  let token_ids = encoding.get_ids();
142
- let tokens = encoding.get_tokens();
143
- let offsets = encoding.get_offsets();
144
132
 
145
133
  // Convert to tensors
146
134
  let input_ids = Tensor::new(token_ids, &self.device)
@@ -171,6 +159,19 @@ impl NER {
171
159
  .to_vec2()
172
160
  .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
173
161
 
162
+ Ok((encoding, probs_vec))
163
+ }
164
+
165
+ /// Extract entities from text with confidence scores
166
+ pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
167
+ let threshold = confidence_threshold.unwrap_or(0.9) as f32;
168
+
169
+ // Use common tokenization and prediction logic
170
+ let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
171
+
172
+ let tokens = encoding.get_tokens();
173
+ let offsets = encoding.get_offsets();
174
+
174
175
  // Extract entities with BIO decoding
175
176
  let entities = self.decode_entities(
176
177
  &text,
@@ -199,38 +200,11 @@ impl NER {
199
200
 
200
201
  /// Get token-level predictions with labels and confidence scores
201
202
  pub fn predict_tokens(&self, text: String) -> Result<RArray> {
202
- // Tokenize the text
203
- let encoding = self.tokenizer.inner().encode(text.as_str(), true)
204
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
203
+ // Use common tokenization and prediction logic
204
+ let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
205
205
 
206
- let token_ids = encoding.get_ids();
207
206
  let tokens = encoding.get_tokens();
208
207
 
209
- // Convert to tensors
210
- let input_ids = Tensor::new(token_ids, &self.device)
211
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
212
- .unsqueeze(0)
213
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
214
-
215
- let attention_mask = Tensor::ones_like(&input_ids)
216
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
217
- let token_type_ids = Tensor::zeros_like(&input_ids)
218
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
219
-
220
- // Forward pass
221
- let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
222
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
223
- let logits = self.classifier.forward(&output)
224
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
225
- let probs = candle_nn::ops::softmax(&logits, 2)
226
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
227
-
228
- // Get predictions
229
- let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
230
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
231
- .to_vec2()
232
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
233
-
234
208
  // Build result array
235
209
  let result = RArray::new();
236
210
  for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
@@ -4,7 +4,6 @@ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
4
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
5
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
6
  use tokenizers::{EncodeInput, Tokenizer};
7
- use std::thread;
8
7
  use crate::ruby::{Device, Result};
9
8
  use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
10
9
 
@@ -24,8 +23,7 @@ impl Reranker {
24
23
  }
25
24
 
26
25
  fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
27
- let device_clone = device.clone();
28
- let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
26
+ let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
29
27
  let api = Api::new()?;
30
28
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
31
29
 
@@ -44,7 +42,7 @@ impl Reranker {
44
42
 
45
43
  // Load model weights
46
44
  let vb = unsafe {
47
- VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
45
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
48
46
  };
49
47
 
50
48
  // Load BERT model
@@ -57,17 +55,49 @@ impl Reranker {
57
55
  let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
58
56
 
59
57
  Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
60
- });
58
+ })();
61
59
 
62
- match handle.join() {
63
- Ok(Ok((model, tokenizer, pooler, classifier))) => {
60
+ match result {
61
+ Ok((model, tokenizer, pooler, classifier)) => {
64
62
  Ok(Self { model, tokenizer, pooler, classifier, device })
65
63
  }
66
- Ok(Err(e)) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
67
- Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
64
+ Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
68
65
  }
69
66
  }
70
67
 
68
+ /// Extract CLS embeddings from the model output, handling Metal device workarounds
69
+ fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
70
+ let cls_embeddings = if self.device.is_metal() {
71
+ // Metal has issues with tensor indexing, use a different approach
72
+ let (batch_size, seq_len, hidden_size) = embeddings.dims3()
73
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
74
+
75
+ // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
76
+ let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
77
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
78
+
79
+ // Extract CLS tokens (first token of each sequence)
80
+ let mut cls_vecs = Vec::new();
81
+ for i in 0..batch_size {
82
+ let start_idx = i * seq_len;
83
+ let cls_vec = reshaped.narrow(0, start_idx, 1)
84
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
85
+ cls_vecs.push(cls_vec);
86
+ }
87
+
88
+ // Stack the CLS vectors
89
+ Tensor::cat(&cls_vecs, 0)
90
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
91
+ } else {
92
+ embeddings.i((.., 0))
93
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
94
+ };
95
+
96
+ // Ensure tensor is contiguous for downstream operations
97
+ cls_embeddings.contiguous()
98
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
99
+ }
100
+
71
101
  pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
72
102
  // Create query-document pair for cross-encoder
73
103
  let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
@@ -131,37 +161,7 @@ impl Reranker {
131
161
  let pooled_embeddings = match pooling_method.as_str() {
132
162
  "pooler" => {
133
163
  // Extract [CLS] token and apply pooler (dense + tanh)
134
- // Work around Metal indexing issue by using narrow instead of i((.., 0))
135
- let cls_embeddings = if self.device.is_metal() {
136
- // Metal has issues with tensor indexing, use a different approach
137
- let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
138
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
139
-
140
- // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
141
- let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
142
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
143
-
144
- // Extract CLS tokens (first token of each sequence)
145
- let mut cls_vecs = Vec::new();
146
- for i in 0..batch_size {
147
- let start_idx = i * _seq_len;
148
- let cls_vec = reshaped.narrow(0, start_idx, 1)
149
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
150
- cls_vecs.push(cls_vec);
151
- }
152
-
153
- // Stack the CLS vectors
154
- Tensor::cat(&cls_vecs, 0)
155
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
156
- .contiguous()
157
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
158
- } else {
159
- embeddings.i((.., 0))
160
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
161
- };
162
- // Ensure tensor is contiguous before linear layer
163
- let cls_embeddings = cls_embeddings.contiguous()
164
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
164
+ let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
165
165
  let pooled = self.pooler.forward(&cls_embeddings)
166
166
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
167
167
  pooled.tanh()
@@ -169,34 +169,7 @@ impl Reranker {
169
169
  },
170
170
  "cls" => {
171
171
  // Just use the [CLS] token embeddings directly (no pooler layer)
172
- // Work around Metal indexing issue
173
- let cls_embeddings = if self.device.is_metal() {
174
- // Use same approach as pooler method
175
- let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
176
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
177
-
178
- let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
179
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
180
-
181
- let mut cls_vecs = Vec::new();
182
- for i in 0..batch_size {
183
- let start_idx = i * _seq_len;
184
- let cls_vec = reshaped.narrow(0, start_idx, 1)
185
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
186
- cls_vecs.push(cls_vec);
187
- }
188
-
189
- Tensor::cat(&cls_vecs, 0)
190
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
191
- .contiguous()
192
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
193
- } else {
194
- embeddings.i((.., 0))
195
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
196
- };
197
- // Ensure contiguous for classifier
198
- cls_embeddings.contiguous()
199
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
172
+ self.extract_cls_embeddings(&embeddings)?
200
173
  },
201
174
  "mean" => {
202
175
  // Mean pooling across all tokens
@@ -199,4 +199,5 @@ pub fn init(rb_candle: RModule) -> Result<()> {
199
199
  rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
200
200
  rb_device.define_method("==", method!(Device::__eq__, 1))?;
201
201
  Ok(())
202
- }
202
+ }
203
+
@@ -35,3 +35,4 @@ pub fn init(rb_candle: RModule) -> Result<()> {
35
35
  rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
36
36
  Ok(())
37
37
  }
38
+
@@ -11,3 +11,4 @@ pub fn wrap_candle_err(err: candle_core::Error) -> Error {
11
11
  pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
12
12
  Error::new(magnus::exception::runtime_error(), err.to_string())
13
13
  }
14
+