red-candle 1.1.0 → 1.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +65 -1
- data/Rakefile +40 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +199 -6
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +6 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/phi.rs +21 -5
- data/ext/candle/src/llm/quantized_gguf.rs +35 -6
- data/ext/candle/src/llm/qwen.rs +21 -5
- data/ext/candle/src/llm/text_generation.rs +121 -28
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/device.rs +2 -1
- data/ext/candle/src/ruby/dtype.rs +1 -0
- data/ext/candle/src/ruby/errors.rs +1 -0
- data/ext/candle/src/ruby/llm.rs +81 -55
- data/ext/candle/src/ruby/tensor.rs +2 -1
- data/ext/candle/src/tokenizer/mod.rs +2 -1
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/llm.rb +129 -34
- data/lib/candle/version.rb +1 -1
- metadata +4 -2
@@ -32,6 +32,10 @@ enum ModelType {
|
|
32
32
|
}
|
33
33
|
|
34
34
|
impl QuantizedGGUF {
|
35
|
+
pub fn eos_token_id(&self) -> u32 {
|
36
|
+
self.eos_token_id
|
37
|
+
}
|
38
|
+
|
35
39
|
/// Get the tokenizer
|
36
40
|
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
37
41
|
&self.tokenizer
|
@@ -316,7 +320,9 @@ impl QuantizedGGUF {
|
|
316
320
|
// Check model name since Mistral GGUF reports as llama architecture
|
317
321
|
let model_lower = self.model_id.to_lowercase();
|
318
322
|
|
319
|
-
if model_lower.contains("
|
323
|
+
if model_lower.contains("tinyllama") {
|
324
|
+
self.apply_chatml_template(messages)
|
325
|
+
} else if model_lower.contains("mistral") {
|
320
326
|
self.apply_mistral_template(messages)
|
321
327
|
} else if model_lower.contains("gemma") {
|
322
328
|
// Always use Gemma template for Gemma models, regardless of loader used
|
@@ -512,6 +518,20 @@ impl QuantizedGGUF {
|
|
512
518
|
Ok(prompt)
|
513
519
|
}
|
514
520
|
|
521
|
+
fn apply_chatml_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
522
|
+
let mut prompt = String::new();
|
523
|
+
|
524
|
+
for message in messages {
|
525
|
+
let role = message["role"].as_str().unwrap_or("");
|
526
|
+
let content = message["content"].as_str().unwrap_or("");
|
527
|
+
|
528
|
+
prompt.push_str(&format!("<|{}|>\n{}</s>\n", role, content));
|
529
|
+
}
|
530
|
+
|
531
|
+
prompt.push_str("<|assistant|>");
|
532
|
+
Ok(prompt)
|
533
|
+
}
|
534
|
+
|
515
535
|
fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
|
516
536
|
let mut prompt = String::new();
|
517
537
|
|
@@ -538,7 +558,7 @@ impl QuantizedGGUF {
|
|
538
558
|
config: &GenerationConfig,
|
539
559
|
mut callback: Option<impl FnMut(&str)>,
|
540
560
|
) -> CandleResult<Vec<u32>> {
|
541
|
-
let mut text_gen = TextGeneration::
|
561
|
+
let mut text_gen = TextGeneration::new(config);
|
542
562
|
text_gen.set_eos_token_id(self.eos_token_id);
|
543
563
|
text_gen.set_tokens(prompt_tokens.clone());
|
544
564
|
|
@@ -571,10 +591,7 @@ impl QuantizedGGUF {
|
|
571
591
|
|
572
592
|
let logits = logits.to_dtype(DType::F32)?;
|
573
593
|
|
574
|
-
let next_token = text_gen.sample_next_token(
|
575
|
-
&logits,
|
576
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
577
|
-
)?;
|
594
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
578
595
|
|
579
596
|
all_tokens.push(next_token);
|
580
597
|
|
@@ -596,6 +613,18 @@ impl QuantizedGGUF {
|
|
596
613
|
break;
|
597
614
|
}
|
598
615
|
|
616
|
+
// Check if constraint is satisfied (early stopping)
|
617
|
+
if config.stop_on_constraint_satisfaction {
|
618
|
+
let satisfied = if config.stop_on_match {
|
619
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
620
|
+
} else {
|
621
|
+
text_gen.is_constraint_satisfied()
|
622
|
+
};
|
623
|
+
if satisfied {
|
624
|
+
break;
|
625
|
+
}
|
626
|
+
}
|
627
|
+
|
599
628
|
// Check stop sequences
|
600
629
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
601
630
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
data/ext/candle/src/llm/qwen.rs
CHANGED
@@ -16,6 +16,10 @@ pub struct Qwen {
|
|
16
16
|
}
|
17
17
|
|
18
18
|
impl Qwen {
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
20
|
+
self.eos_token_id
|
21
|
+
}
|
22
|
+
|
19
23
|
/// Get the tokenizer
|
20
24
|
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
21
25
|
&self.tokenizer
|
@@ -55,6 +59,9 @@ impl Qwen {
|
|
55
59
|
.unwrap_or(151643); // Default Qwen3 EOS token
|
56
60
|
|
57
61
|
// Download model weights
|
62
|
+
// NOTE: Qwen uses hardcoded shard counts based on model size rather than
|
63
|
+
// reading model.safetensors.index.json. This works for official Qwen models
|
64
|
+
// but may fail for custom configurations with different shard counts.
|
58
65
|
let mut filenames = vec![];
|
59
66
|
let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
|
60
67
|
else if model_id.contains("14b") || model_id.contains("14B") { 3 }
|
@@ -124,7 +131,7 @@ impl Qwen {
|
|
124
131
|
config: &GenerationConfig,
|
125
132
|
mut callback: Option<impl FnMut(&str)>,
|
126
133
|
) -> CandleResult<Vec<u32>> {
|
127
|
-
let mut text_gen = TextGeneration::
|
134
|
+
let mut text_gen = TextGeneration::new(config);
|
128
135
|
text_gen.set_eos_token_id(self.eos_token_id);
|
129
136
|
text_gen.set_tokens(prompt_tokens.clone());
|
130
137
|
|
@@ -150,10 +157,7 @@ impl Qwen {
|
|
150
157
|
|
151
158
|
let logits = logits.to_dtype(DType::F32)?;
|
152
159
|
|
153
|
-
let next_token = text_gen.sample_next_token(
|
154
|
-
&logits,
|
155
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
156
|
-
)?;
|
160
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
157
161
|
|
158
162
|
all_tokens.push(next_token);
|
159
163
|
|
@@ -173,6 +177,18 @@ impl Qwen {
|
|
173
177
|
break;
|
174
178
|
}
|
175
179
|
|
180
|
+
// Check if constraint is satisfied (early stopping)
|
181
|
+
if config.stop_on_constraint_satisfaction {
|
182
|
+
let satisfied = if config.stop_on_match {
|
183
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
184
|
+
} else {
|
185
|
+
text_gen.is_constraint_satisfied()
|
186
|
+
};
|
187
|
+
if satisfied {
|
188
|
+
break;
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
176
192
|
// Check stop sequences
|
177
193
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
178
194
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
@@ -10,39 +10,29 @@ pub struct TextGeneration {
|
|
10
10
|
logits_processor: LogitsProcessor,
|
11
11
|
tokens: Vec<u32>,
|
12
12
|
eos_token_id: Option<u32>,
|
13
|
+
repetition_penalty: f32,
|
14
|
+
repetition_penalty_last_n: usize,
|
13
15
|
constraint: Option<Arc<Index>>,
|
14
16
|
constraint_state: Option<u32>,
|
17
|
+
constraint_completed: bool,
|
18
|
+
tokens_since_constraint_start: usize,
|
15
19
|
}
|
16
20
|
|
17
21
|
impl TextGeneration {
|
18
|
-
pub fn new(
|
19
|
-
seed
|
20
|
-
|
21
|
-
|
22
|
-
_top_k: Option<usize>,
|
23
|
-
_repetition_penalty: f32,
|
24
|
-
_repetition_penalty_last_n: usize,
|
25
|
-
) -> Self {
|
26
|
-
let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
|
27
|
-
|
28
|
-
Self {
|
22
|
+
pub fn new(config: &GenerationConfig) -> Self {
|
23
|
+
let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
|
24
|
+
|
25
|
+
let mut text_gen = Self {
|
29
26
|
logits_processor,
|
30
27
|
tokens: Vec::new(),
|
31
28
|
eos_token_id: None,
|
29
|
+
repetition_penalty: config.repetition_penalty,
|
30
|
+
repetition_penalty_last_n: config.repetition_penalty_last_n,
|
32
31
|
constraint: None,
|
33
32
|
constraint_state: None,
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
pub fn from_config(config: &GenerationConfig) -> Self {
|
38
|
-
let mut text_gen = Self::new(
|
39
|
-
config.seed,
|
40
|
-
Some(config.temperature),
|
41
|
-
config.top_p,
|
42
|
-
config.top_k,
|
43
|
-
config.repetition_penalty,
|
44
|
-
config.repetition_penalty_last_n,
|
45
|
-
);
|
33
|
+
constraint_completed: false,
|
34
|
+
tokens_since_constraint_start: 0,
|
35
|
+
};
|
46
36
|
|
47
37
|
// Set constraint if provided
|
48
38
|
if let Some(ref constraint) = config.constraint {
|
@@ -72,6 +62,8 @@ impl TextGeneration {
|
|
72
62
|
// Initialize with the first state
|
73
63
|
self.constraint_state = Some(constraint.initial_state());
|
74
64
|
self.constraint = Some(constraint);
|
65
|
+
self.constraint_completed = false;
|
66
|
+
self.tokens_since_constraint_start = self.tokens.len();
|
75
67
|
}
|
76
68
|
|
77
69
|
/// Apply constraints to logits by masking disallowed tokens
|
@@ -137,13 +129,12 @@ impl TextGeneration {
|
|
137
129
|
pub fn sample_next_token(
|
138
130
|
&mut self,
|
139
131
|
logits: &Tensor,
|
140
|
-
repetition_penalty: Option<(f32, usize)>,
|
141
132
|
) -> CandleResult<u32> {
|
142
133
|
let mut logits = logits.clone();
|
143
134
|
|
144
|
-
// Apply repetition penalty
|
145
|
-
if
|
146
|
-
self.apply_repetition_penalty(&mut logits,
|
135
|
+
// Apply repetition penalty using stored parameters
|
136
|
+
if self.repetition_penalty != 1.0 {
|
137
|
+
self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
|
147
138
|
}
|
148
139
|
|
149
140
|
// Apply constraints if active
|
@@ -155,12 +146,114 @@ impl TextGeneration {
|
|
155
146
|
|
156
147
|
// Update constraint state if active
|
157
148
|
if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
|
158
|
-
|
149
|
+
// Get the next state
|
150
|
+
let next_state = constraint_index.next_state(¤t_state, &next_token);
|
151
|
+
|
152
|
+
// Check if we're transitioning to a state with no allowed tokens (completion)
|
153
|
+
if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
|
154
|
+
// Check if we've transitioned from a constrained state to an unconstrained state
|
155
|
+
// This happens when the pattern is complete and the FSM allows "anything"
|
156
|
+
|
157
|
+
let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(¤t_state) {
|
158
|
+
// Consider it constrained if we have a limited set of allowed tokens
|
159
|
+
allowed.len() < 1000 // Arbitrary threshold for "constrained"
|
160
|
+
} else {
|
161
|
+
true // No tokens allowed is definitely constrained
|
162
|
+
};
|
163
|
+
|
164
|
+
let next_constrained = if let Some(next_state_val) = next_state {
|
165
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
166
|
+
allowed.is_empty() || allowed.len() < 1000
|
167
|
+
} else {
|
168
|
+
true
|
169
|
+
}
|
170
|
+
} else {
|
171
|
+
true
|
172
|
+
};
|
173
|
+
|
174
|
+
// If we're transitioning from constrained to unconstrained, we've completed the pattern
|
175
|
+
if current_constrained && !next_constrained {
|
176
|
+
self.constraint_completed = true;
|
177
|
+
}
|
178
|
+
|
179
|
+
// Also check if next state has no allowed tokens at all
|
180
|
+
if let Some(next_state_val) = next_state {
|
181
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
182
|
+
if allowed.is_empty() {
|
183
|
+
self.constraint_completed = true;
|
184
|
+
}
|
185
|
+
} else {
|
186
|
+
// None means no tokens allowed - constraint is complete
|
187
|
+
self.constraint_completed = true;
|
188
|
+
}
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
self.constraint_state = next_state;
|
159
193
|
}
|
160
194
|
|
161
195
|
Ok(next_token)
|
162
196
|
}
|
163
197
|
|
198
|
+
/// Check if the constraint is satisfied (reached a valid completion state)
|
199
|
+
pub fn is_constraint_satisfied(&self) -> bool {
|
200
|
+
// If we've explicitly marked the constraint as completed, return true
|
201
|
+
if self.constraint_completed {
|
202
|
+
return true;
|
203
|
+
}
|
204
|
+
|
205
|
+
// Also check the current state
|
206
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
207
|
+
// Check if the constraint has reached a state where it could validly end
|
208
|
+
// This happens when:
|
209
|
+
// 1. We have no more allowed tokens (constraint fully satisfied)
|
210
|
+
// 2. The EOS token is in the allowed tokens (optional ending)
|
211
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
212
|
+
// If no tokens are allowed, the constraint is fully satisfied
|
213
|
+
if allowed.is_empty() {
|
214
|
+
return true;
|
215
|
+
}
|
216
|
+
|
217
|
+
// If EOS token is allowed, we've reached an optional completion point
|
218
|
+
if let Some(eos) = self.eos_token_id {
|
219
|
+
if allowed.contains(&eos) {
|
220
|
+
return true;
|
221
|
+
}
|
222
|
+
}
|
223
|
+
} else {
|
224
|
+
// None means no tokens allowed - constraint is satisfied
|
225
|
+
return true;
|
226
|
+
}
|
227
|
+
}
|
228
|
+
false
|
229
|
+
}
|
230
|
+
|
231
|
+
/// Check if the constraint is satisfied when stop_on_match is true
|
232
|
+
pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
|
233
|
+
// When stop_on_match is true, we stop as soon as the constraint is completed
|
234
|
+
if self.constraint_completed {
|
235
|
+
return true;
|
236
|
+
}
|
237
|
+
|
238
|
+
// Also check if we're currently in a state that could be a valid end
|
239
|
+
// This is important for patterns like phone numbers where after matching
|
240
|
+
// the pattern, the FSM might allow any token (including more numbers)
|
241
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
242
|
+
// Check if we've generated at least one token since constraint start
|
243
|
+
if self.tokens.len() > self.tokens_since_constraint_start {
|
244
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
245
|
+
// If the allowed tokens set is very large (unconstrained),
|
246
|
+
// it means the pattern has been satisfied
|
247
|
+
if allowed.len() > 1000 {
|
248
|
+
return true;
|
249
|
+
}
|
250
|
+
}
|
251
|
+
}
|
252
|
+
}
|
253
|
+
|
254
|
+
false
|
255
|
+
}
|
256
|
+
|
164
257
|
/// Check if we should stop generation
|
165
258
|
pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
|
166
259
|
if self.tokens.len() >= max_length {
|
data/ext/candle/src/ner.rs
CHANGED
@@ -39,13 +39,9 @@ impl NER {
|
|
39
39
|
pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
|
40
40
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
41
41
|
|
42
|
-
|
43
|
-
let device_clone = device.clone();
|
44
|
-
let model_id_clone = model_id.clone();
|
45
|
-
|
46
|
-
let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
42
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
47
43
|
let api = Api::new()?;
|
48
|
-
let repo = api.repo(Repo::new(
|
44
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
49
45
|
|
50
46
|
// Download model files
|
51
47
|
let config_filename = repo.get("config.json")?;
|
@@ -92,7 +88,7 @@ impl NER {
|
|
92
88
|
|
93
89
|
// Load model weights
|
94
90
|
let vb = unsafe {
|
95
|
-
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &
|
91
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
96
92
|
};
|
97
93
|
|
98
94
|
// Load BERT model
|
@@ -106,10 +102,10 @@ impl NER {
|
|
106
102
|
)?;
|
107
103
|
|
108
104
|
Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
|
109
|
-
});
|
105
|
+
})();
|
110
106
|
|
111
|
-
match
|
112
|
-
Ok(
|
107
|
+
match result {
|
108
|
+
Ok((model, tokenizer, classifier, config)) => {
|
113
109
|
Ok(Self {
|
114
110
|
model,
|
115
111
|
tokenizer,
|
@@ -119,28 +115,20 @@ impl NER {
|
|
119
115
|
model_id,
|
120
116
|
})
|
121
117
|
}
|
122
|
-
|
118
|
+
Err(e) => Err(Error::new(
|
123
119
|
magnus::exception::runtime_error(),
|
124
120
|
format!("Failed to load NER model: {}", e)
|
125
121
|
)),
|
126
|
-
Err(_) => Err(Error::new(
|
127
|
-
magnus::exception::runtime_error(),
|
128
|
-
"Thread panicked while loading NER model"
|
129
|
-
)),
|
130
122
|
}
|
131
123
|
}
|
132
124
|
|
133
|
-
///
|
134
|
-
|
135
|
-
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
136
|
-
|
125
|
+
/// Common tokenization and prediction logic
|
126
|
+
fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
|
137
127
|
// Tokenize the text
|
138
|
-
let encoding = self.tokenizer.inner().encode(text
|
128
|
+
let encoding = self.tokenizer.inner().encode(text, true)
|
139
129
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
140
130
|
|
141
131
|
let token_ids = encoding.get_ids();
|
142
|
-
let tokens = encoding.get_tokens();
|
143
|
-
let offsets = encoding.get_offsets();
|
144
132
|
|
145
133
|
// Convert to tensors
|
146
134
|
let input_ids = Tensor::new(token_ids, &self.device)
|
@@ -171,6 +159,19 @@ impl NER {
|
|
171
159
|
.to_vec2()
|
172
160
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
173
161
|
|
162
|
+
Ok((encoding, probs_vec))
|
163
|
+
}
|
164
|
+
|
165
|
+
/// Extract entities from text with confidence scores
|
166
|
+
pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
|
167
|
+
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
168
|
+
|
169
|
+
// Use common tokenization and prediction logic
|
170
|
+
let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
|
171
|
+
|
172
|
+
let tokens = encoding.get_tokens();
|
173
|
+
let offsets = encoding.get_offsets();
|
174
|
+
|
174
175
|
// Extract entities with BIO decoding
|
175
176
|
let entities = self.decode_entities(
|
176
177
|
&text,
|
@@ -199,38 +200,11 @@ impl NER {
|
|
199
200
|
|
200
201
|
/// Get token-level predictions with labels and confidence scores
|
201
202
|
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
202
|
-
//
|
203
|
-
let encoding = self.
|
204
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
203
|
+
// Use common tokenization and prediction logic
|
204
|
+
let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
|
205
205
|
|
206
|
-
let token_ids = encoding.get_ids();
|
207
206
|
let tokens = encoding.get_tokens();
|
208
207
|
|
209
|
-
// Convert to tensors
|
210
|
-
let input_ids = Tensor::new(token_ids, &self.device)
|
211
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
212
|
-
.unsqueeze(0)
|
213
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
214
|
-
|
215
|
-
let attention_mask = Tensor::ones_like(&input_ids)
|
216
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
217
|
-
let token_type_ids = Tensor::zeros_like(&input_ids)
|
218
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
219
|
-
|
220
|
-
// Forward pass
|
221
|
-
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
222
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
223
|
-
let logits = self.classifier.forward(&output)
|
224
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
225
|
-
let probs = candle_nn::ops::softmax(&logits, 2)
|
226
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
227
|
-
|
228
|
-
// Get predictions
|
229
|
-
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
230
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
231
|
-
.to_vec2()
|
232
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
233
|
-
|
234
208
|
// Build result array
|
235
209
|
let result = RArray::new();
|
236
210
|
for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
|
data/ext/candle/src/reranker.rs
CHANGED
@@ -4,7 +4,6 @@ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
4
4
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
5
5
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
6
|
use tokenizers::{EncodeInput, Tokenizer};
|
7
|
-
use std::thread;
|
8
7
|
use crate::ruby::{Device, Result};
|
9
8
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
10
9
|
|
@@ -24,8 +23,7 @@ impl Reranker {
|
|
24
23
|
}
|
25
24
|
|
26
25
|
fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
27
|
-
let
|
28
|
-
let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
26
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
29
27
|
let api = Api::new()?;
|
30
28
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
31
29
|
|
@@ -44,7 +42,7 @@ impl Reranker {
|
|
44
42
|
|
45
43
|
// Load model weights
|
46
44
|
let vb = unsafe {
|
47
|
-
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &
|
45
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
48
46
|
};
|
49
47
|
|
50
48
|
// Load BERT model
|
@@ -57,17 +55,49 @@ impl Reranker {
|
|
57
55
|
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
58
56
|
|
59
57
|
Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
|
60
|
-
});
|
58
|
+
})();
|
61
59
|
|
62
|
-
match
|
63
|
-
Ok(
|
60
|
+
match result {
|
61
|
+
Ok((model, tokenizer, pooler, classifier)) => {
|
64
62
|
Ok(Self { model, tokenizer, pooler, classifier, device })
|
65
63
|
}
|
66
|
-
|
67
|
-
Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
|
64
|
+
Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
|
68
65
|
}
|
69
66
|
}
|
70
67
|
|
68
|
+
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
69
|
+
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
|
70
|
+
let cls_embeddings = if self.device.is_metal() {
|
71
|
+
// Metal has issues with tensor indexing, use a different approach
|
72
|
+
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
73
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
74
|
+
|
75
|
+
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
76
|
+
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
77
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
78
|
+
|
79
|
+
// Extract CLS tokens (first token of each sequence)
|
80
|
+
let mut cls_vecs = Vec::new();
|
81
|
+
for i in 0..batch_size {
|
82
|
+
let start_idx = i * seq_len;
|
83
|
+
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
84
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
85
|
+
cls_vecs.push(cls_vec);
|
86
|
+
}
|
87
|
+
|
88
|
+
// Stack the CLS vectors
|
89
|
+
Tensor::cat(&cls_vecs, 0)
|
90
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
91
|
+
} else {
|
92
|
+
embeddings.i((.., 0))
|
93
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
94
|
+
};
|
95
|
+
|
96
|
+
// Ensure tensor is contiguous for downstream operations
|
97
|
+
cls_embeddings.contiguous()
|
98
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
|
99
|
+
}
|
100
|
+
|
71
101
|
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
|
72
102
|
// Create query-document pair for cross-encoder
|
73
103
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
@@ -131,37 +161,7 @@ impl Reranker {
|
|
131
161
|
let pooled_embeddings = match pooling_method.as_str() {
|
132
162
|
"pooler" => {
|
133
163
|
// Extract [CLS] token and apply pooler (dense + tanh)
|
134
|
-
|
135
|
-
let cls_embeddings = if self.device.is_metal() {
|
136
|
-
// Metal has issues with tensor indexing, use a different approach
|
137
|
-
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
138
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
139
|
-
|
140
|
-
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
141
|
-
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
142
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
143
|
-
|
144
|
-
// Extract CLS tokens (first token of each sequence)
|
145
|
-
let mut cls_vecs = Vec::new();
|
146
|
-
for i in 0..batch_size {
|
147
|
-
let start_idx = i * _seq_len;
|
148
|
-
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
149
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
150
|
-
cls_vecs.push(cls_vec);
|
151
|
-
}
|
152
|
-
|
153
|
-
// Stack the CLS vectors
|
154
|
-
Tensor::cat(&cls_vecs, 0)
|
155
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
156
|
-
.contiguous()
|
157
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
158
|
-
} else {
|
159
|
-
embeddings.i((.., 0))
|
160
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
161
|
-
};
|
162
|
-
// Ensure tensor is contiguous before linear layer
|
163
|
-
let cls_embeddings = cls_embeddings.contiguous()
|
164
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
|
164
|
+
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
165
165
|
let pooled = self.pooler.forward(&cls_embeddings)
|
166
166
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
|
167
167
|
pooled.tanh()
|
@@ -169,34 +169,7 @@ impl Reranker {
|
|
169
169
|
},
|
170
170
|
"cls" => {
|
171
171
|
// Just use the [CLS] token embeddings directly (no pooler layer)
|
172
|
-
|
173
|
-
let cls_embeddings = if self.device.is_metal() {
|
174
|
-
// Use same approach as pooler method
|
175
|
-
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
176
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
177
|
-
|
178
|
-
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
179
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
180
|
-
|
181
|
-
let mut cls_vecs = Vec::new();
|
182
|
-
for i in 0..batch_size {
|
183
|
-
let start_idx = i * _seq_len;
|
184
|
-
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
185
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
186
|
-
cls_vecs.push(cls_vec);
|
187
|
-
}
|
188
|
-
|
189
|
-
Tensor::cat(&cls_vecs, 0)
|
190
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
191
|
-
.contiguous()
|
192
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
193
|
-
} else {
|
194
|
-
embeddings.i((.., 0))
|
195
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
196
|
-
};
|
197
|
-
// Ensure contiguous for classifier
|
198
|
-
cls_embeddings.contiguous()
|
199
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
|
172
|
+
self.extract_cls_embeddings(&embeddings)?
|
200
173
|
},
|
201
174
|
"mean" => {
|
202
175
|
// Mean pooling across all tokens
|