red-candle 1.1.0 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -1
- data/ext/candle/src/llm/constrained_generation_test.rs +199 -6
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +6 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/phi.rs +21 -5
- data/ext/candle/src/llm/quantized_gguf.rs +18 -5
- data/ext/candle/src/llm/qwen.rs +21 -5
- data/ext/candle/src/llm/text_generation.rs +121 -28
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/llm.rs +81 -55
- data/lib/candle/llm.rb +129 -34
- data/lib/candle/version.rb +1 -1
- metadata +2 -2
@@ -10,39 +10,29 @@ pub struct TextGeneration {
|
|
10
10
|
logits_processor: LogitsProcessor,
|
11
11
|
tokens: Vec<u32>,
|
12
12
|
eos_token_id: Option<u32>,
|
13
|
+
repetition_penalty: f32,
|
14
|
+
repetition_penalty_last_n: usize,
|
13
15
|
constraint: Option<Arc<Index>>,
|
14
16
|
constraint_state: Option<u32>,
|
17
|
+
constraint_completed: bool,
|
18
|
+
tokens_since_constraint_start: usize,
|
15
19
|
}
|
16
20
|
|
17
21
|
impl TextGeneration {
|
18
|
-
pub fn new(
|
19
|
-
seed
|
20
|
-
|
21
|
-
|
22
|
-
_top_k: Option<usize>,
|
23
|
-
_repetition_penalty: f32,
|
24
|
-
_repetition_penalty_last_n: usize,
|
25
|
-
) -> Self {
|
26
|
-
let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
|
27
|
-
|
28
|
-
Self {
|
22
|
+
pub fn new(config: &GenerationConfig) -> Self {
|
23
|
+
let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
|
24
|
+
|
25
|
+
let mut text_gen = Self {
|
29
26
|
logits_processor,
|
30
27
|
tokens: Vec::new(),
|
31
28
|
eos_token_id: None,
|
29
|
+
repetition_penalty: config.repetition_penalty,
|
30
|
+
repetition_penalty_last_n: config.repetition_penalty_last_n,
|
32
31
|
constraint: None,
|
33
32
|
constraint_state: None,
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
pub fn from_config(config: &GenerationConfig) -> Self {
|
38
|
-
let mut text_gen = Self::new(
|
39
|
-
config.seed,
|
40
|
-
Some(config.temperature),
|
41
|
-
config.top_p,
|
42
|
-
config.top_k,
|
43
|
-
config.repetition_penalty,
|
44
|
-
config.repetition_penalty_last_n,
|
45
|
-
);
|
33
|
+
constraint_completed: false,
|
34
|
+
tokens_since_constraint_start: 0,
|
35
|
+
};
|
46
36
|
|
47
37
|
// Set constraint if provided
|
48
38
|
if let Some(ref constraint) = config.constraint {
|
@@ -72,6 +62,8 @@ impl TextGeneration {
|
|
72
62
|
// Initialize with the first state
|
73
63
|
self.constraint_state = Some(constraint.initial_state());
|
74
64
|
self.constraint = Some(constraint);
|
65
|
+
self.constraint_completed = false;
|
66
|
+
self.tokens_since_constraint_start = self.tokens.len();
|
75
67
|
}
|
76
68
|
|
77
69
|
/// Apply constraints to logits by masking disallowed tokens
|
@@ -137,13 +129,12 @@ impl TextGeneration {
|
|
137
129
|
pub fn sample_next_token(
|
138
130
|
&mut self,
|
139
131
|
logits: &Tensor,
|
140
|
-
repetition_penalty: Option<(f32, usize)>,
|
141
132
|
) -> CandleResult<u32> {
|
142
133
|
let mut logits = logits.clone();
|
143
134
|
|
144
|
-
// Apply repetition penalty
|
145
|
-
if
|
146
|
-
self.apply_repetition_penalty(&mut logits,
|
135
|
+
// Apply repetition penalty using stored parameters
|
136
|
+
if self.repetition_penalty != 1.0 {
|
137
|
+
self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
|
147
138
|
}
|
148
139
|
|
149
140
|
// Apply constraints if active
|
@@ -155,12 +146,114 @@ impl TextGeneration {
|
|
155
146
|
|
156
147
|
// Update constraint state if active
|
157
148
|
if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
|
158
|
-
|
149
|
+
// Get the next state
|
150
|
+
let next_state = constraint_index.next_state(¤t_state, &next_token);
|
151
|
+
|
152
|
+
// Check if we're transitioning to a state with no allowed tokens (completion)
|
153
|
+
if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
|
154
|
+
// Check if we've transitioned from a constrained state to an unconstrained state
|
155
|
+
// This happens when the pattern is complete and the FSM allows "anything"
|
156
|
+
|
157
|
+
let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(¤t_state) {
|
158
|
+
// Consider it constrained if we have a limited set of allowed tokens
|
159
|
+
allowed.len() < 1000 // Arbitrary threshold for "constrained"
|
160
|
+
} else {
|
161
|
+
true // No tokens allowed is definitely constrained
|
162
|
+
};
|
163
|
+
|
164
|
+
let next_constrained = if let Some(next_state_val) = next_state {
|
165
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
166
|
+
allowed.is_empty() || allowed.len() < 1000
|
167
|
+
} else {
|
168
|
+
true
|
169
|
+
}
|
170
|
+
} else {
|
171
|
+
true
|
172
|
+
};
|
173
|
+
|
174
|
+
// If we're transitioning from constrained to unconstrained, we've completed the pattern
|
175
|
+
if current_constrained && !next_constrained {
|
176
|
+
self.constraint_completed = true;
|
177
|
+
}
|
178
|
+
|
179
|
+
// Also check if next state has no allowed tokens at all
|
180
|
+
if let Some(next_state_val) = next_state {
|
181
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
182
|
+
if allowed.is_empty() {
|
183
|
+
self.constraint_completed = true;
|
184
|
+
}
|
185
|
+
} else {
|
186
|
+
// None means no tokens allowed - constraint is complete
|
187
|
+
self.constraint_completed = true;
|
188
|
+
}
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
self.constraint_state = next_state;
|
159
193
|
}
|
160
194
|
|
161
195
|
Ok(next_token)
|
162
196
|
}
|
163
197
|
|
198
|
+
/// Check if the constraint is satisfied (reached a valid completion state)
|
199
|
+
pub fn is_constraint_satisfied(&self) -> bool {
|
200
|
+
// If we've explicitly marked the constraint as completed, return true
|
201
|
+
if self.constraint_completed {
|
202
|
+
return true;
|
203
|
+
}
|
204
|
+
|
205
|
+
// Also check the current state
|
206
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
207
|
+
// Check if the constraint has reached a state where it could validly end
|
208
|
+
// This happens when:
|
209
|
+
// 1. We have no more allowed tokens (constraint fully satisfied)
|
210
|
+
// 2. The EOS token is in the allowed tokens (optional ending)
|
211
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
212
|
+
// If no tokens are allowed, the constraint is fully satisfied
|
213
|
+
if allowed.is_empty() {
|
214
|
+
return true;
|
215
|
+
}
|
216
|
+
|
217
|
+
// If EOS token is allowed, we've reached an optional completion point
|
218
|
+
if let Some(eos) = self.eos_token_id {
|
219
|
+
if allowed.contains(&eos) {
|
220
|
+
return true;
|
221
|
+
}
|
222
|
+
}
|
223
|
+
} else {
|
224
|
+
// None means no tokens allowed - constraint is satisfied
|
225
|
+
return true;
|
226
|
+
}
|
227
|
+
}
|
228
|
+
false
|
229
|
+
}
|
230
|
+
|
231
|
+
/// Check if the constraint is satisfied when stop_on_match is true
|
232
|
+
pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
|
233
|
+
// When stop_on_match is true, we stop as soon as the constraint is completed
|
234
|
+
if self.constraint_completed {
|
235
|
+
return true;
|
236
|
+
}
|
237
|
+
|
238
|
+
// Also check if we're currently in a state that could be a valid end
|
239
|
+
// This is important for patterns like phone numbers where after matching
|
240
|
+
// the pattern, the FSM might allow any token (including more numbers)
|
241
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
242
|
+
// Check if we've generated at least one token since constraint start
|
243
|
+
if self.tokens.len() > self.tokens_since_constraint_start {
|
244
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
245
|
+
// If the allowed tokens set is very large (unconstrained),
|
246
|
+
// it means the pattern has been satisfied
|
247
|
+
if allowed.len() > 1000 {
|
248
|
+
return true;
|
249
|
+
}
|
250
|
+
}
|
251
|
+
}
|
252
|
+
}
|
253
|
+
|
254
|
+
false
|
255
|
+
}
|
256
|
+
|
164
257
|
/// Check if we should stop generation
|
165
258
|
pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
|
166
259
|
if self.tokens.len() >= max_length {
|
data/ext/candle/src/ner.rs
CHANGED
@@ -39,13 +39,9 @@ impl NER {
|
|
39
39
|
pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
|
40
40
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
41
41
|
|
42
|
-
|
43
|
-
let device_clone = device.clone();
|
44
|
-
let model_id_clone = model_id.clone();
|
45
|
-
|
46
|
-
let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
42
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
47
43
|
let api = Api::new()?;
|
48
|
-
let repo = api.repo(Repo::new(
|
44
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
49
45
|
|
50
46
|
// Download model files
|
51
47
|
let config_filename = repo.get("config.json")?;
|
@@ -92,7 +88,7 @@ impl NER {
|
|
92
88
|
|
93
89
|
// Load model weights
|
94
90
|
let vb = unsafe {
|
95
|
-
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &
|
91
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
96
92
|
};
|
97
93
|
|
98
94
|
// Load BERT model
|
@@ -106,10 +102,10 @@ impl NER {
|
|
106
102
|
)?;
|
107
103
|
|
108
104
|
Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
|
109
|
-
});
|
105
|
+
})();
|
110
106
|
|
111
|
-
match
|
112
|
-
Ok(
|
107
|
+
match result {
|
108
|
+
Ok((model, tokenizer, classifier, config)) => {
|
113
109
|
Ok(Self {
|
114
110
|
model,
|
115
111
|
tokenizer,
|
@@ -119,28 +115,20 @@ impl NER {
|
|
119
115
|
model_id,
|
120
116
|
})
|
121
117
|
}
|
122
|
-
|
118
|
+
Err(e) => Err(Error::new(
|
123
119
|
magnus::exception::runtime_error(),
|
124
120
|
format!("Failed to load NER model: {}", e)
|
125
121
|
)),
|
126
|
-
Err(_) => Err(Error::new(
|
127
|
-
magnus::exception::runtime_error(),
|
128
|
-
"Thread panicked while loading NER model"
|
129
|
-
)),
|
130
122
|
}
|
131
123
|
}
|
132
124
|
|
133
|
-
///
|
134
|
-
|
135
|
-
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
136
|
-
|
125
|
+
/// Common tokenization and prediction logic
|
126
|
+
fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
|
137
127
|
// Tokenize the text
|
138
|
-
let encoding = self.tokenizer.inner().encode(text
|
128
|
+
let encoding = self.tokenizer.inner().encode(text, true)
|
139
129
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
140
130
|
|
141
131
|
let token_ids = encoding.get_ids();
|
142
|
-
let tokens = encoding.get_tokens();
|
143
|
-
let offsets = encoding.get_offsets();
|
144
132
|
|
145
133
|
// Convert to tensors
|
146
134
|
let input_ids = Tensor::new(token_ids, &self.device)
|
@@ -171,6 +159,19 @@ impl NER {
|
|
171
159
|
.to_vec2()
|
172
160
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
173
161
|
|
162
|
+
Ok((encoding, probs_vec))
|
163
|
+
}
|
164
|
+
|
165
|
+
/// Extract entities from text with confidence scores
|
166
|
+
pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
|
167
|
+
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
168
|
+
|
169
|
+
// Use common tokenization and prediction logic
|
170
|
+
let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
|
171
|
+
|
172
|
+
let tokens = encoding.get_tokens();
|
173
|
+
let offsets = encoding.get_offsets();
|
174
|
+
|
174
175
|
// Extract entities with BIO decoding
|
175
176
|
let entities = self.decode_entities(
|
176
177
|
&text,
|
@@ -199,38 +200,11 @@ impl NER {
|
|
199
200
|
|
200
201
|
/// Get token-level predictions with labels and confidence scores
|
201
202
|
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
202
|
-
//
|
203
|
-
let encoding = self.
|
204
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
203
|
+
// Use common tokenization and prediction logic
|
204
|
+
let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
|
205
205
|
|
206
|
-
let token_ids = encoding.get_ids();
|
207
206
|
let tokens = encoding.get_tokens();
|
208
207
|
|
209
|
-
// Convert to tensors
|
210
|
-
let input_ids = Tensor::new(token_ids, &self.device)
|
211
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
212
|
-
.unsqueeze(0)
|
213
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
214
|
-
|
215
|
-
let attention_mask = Tensor::ones_like(&input_ids)
|
216
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
217
|
-
let token_type_ids = Tensor::zeros_like(&input_ids)
|
218
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
219
|
-
|
220
|
-
// Forward pass
|
221
|
-
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
222
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
223
|
-
let logits = self.classifier.forward(&output)
|
224
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
225
|
-
let probs = candle_nn::ops::softmax(&logits, 2)
|
226
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
227
|
-
|
228
|
-
// Get predictions
|
229
|
-
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
230
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
231
|
-
.to_vec2()
|
232
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
233
|
-
|
234
208
|
// Build result array
|
235
209
|
let result = RArray::new();
|
236
210
|
for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
|
data/ext/candle/src/reranker.rs
CHANGED
@@ -4,7 +4,6 @@ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
4
4
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
5
5
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
6
|
use tokenizers::{EncodeInput, Tokenizer};
|
7
|
-
use std::thread;
|
8
7
|
use crate::ruby::{Device, Result};
|
9
8
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
10
9
|
|
@@ -24,8 +23,7 @@ impl Reranker {
|
|
24
23
|
}
|
25
24
|
|
26
25
|
fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
27
|
-
let
|
28
|
-
let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
26
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
29
27
|
let api = Api::new()?;
|
30
28
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
31
29
|
|
@@ -44,7 +42,7 @@ impl Reranker {
|
|
44
42
|
|
45
43
|
// Load model weights
|
46
44
|
let vb = unsafe {
|
47
|
-
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &
|
45
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
48
46
|
};
|
49
47
|
|
50
48
|
// Load BERT model
|
@@ -57,17 +55,49 @@ impl Reranker {
|
|
57
55
|
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
58
56
|
|
59
57
|
Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
|
60
|
-
});
|
58
|
+
})();
|
61
59
|
|
62
|
-
match
|
63
|
-
Ok(
|
60
|
+
match result {
|
61
|
+
Ok((model, tokenizer, pooler, classifier)) => {
|
64
62
|
Ok(Self { model, tokenizer, pooler, classifier, device })
|
65
63
|
}
|
66
|
-
|
67
|
-
Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
|
64
|
+
Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
|
68
65
|
}
|
69
66
|
}
|
70
67
|
|
68
|
+
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
69
|
+
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
|
70
|
+
let cls_embeddings = if self.device.is_metal() {
|
71
|
+
// Metal has issues with tensor indexing, use a different approach
|
72
|
+
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
73
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
74
|
+
|
75
|
+
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
76
|
+
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
77
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
78
|
+
|
79
|
+
// Extract CLS tokens (first token of each sequence)
|
80
|
+
let mut cls_vecs = Vec::new();
|
81
|
+
for i in 0..batch_size {
|
82
|
+
let start_idx = i * seq_len;
|
83
|
+
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
84
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
85
|
+
cls_vecs.push(cls_vec);
|
86
|
+
}
|
87
|
+
|
88
|
+
// Stack the CLS vectors
|
89
|
+
Tensor::cat(&cls_vecs, 0)
|
90
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
91
|
+
} else {
|
92
|
+
embeddings.i((.., 0))
|
93
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
94
|
+
};
|
95
|
+
|
96
|
+
// Ensure tensor is contiguous for downstream operations
|
97
|
+
cls_embeddings.contiguous()
|
98
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
|
99
|
+
}
|
100
|
+
|
71
101
|
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
|
72
102
|
// Create query-document pair for cross-encoder
|
73
103
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
@@ -131,37 +161,7 @@ impl Reranker {
|
|
131
161
|
let pooled_embeddings = match pooling_method.as_str() {
|
132
162
|
"pooler" => {
|
133
163
|
// Extract [CLS] token and apply pooler (dense + tanh)
|
134
|
-
|
135
|
-
let cls_embeddings = if self.device.is_metal() {
|
136
|
-
// Metal has issues with tensor indexing, use a different approach
|
137
|
-
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
138
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
139
|
-
|
140
|
-
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
141
|
-
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
142
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
143
|
-
|
144
|
-
// Extract CLS tokens (first token of each sequence)
|
145
|
-
let mut cls_vecs = Vec::new();
|
146
|
-
for i in 0..batch_size {
|
147
|
-
let start_idx = i * _seq_len;
|
148
|
-
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
149
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
150
|
-
cls_vecs.push(cls_vec);
|
151
|
-
}
|
152
|
-
|
153
|
-
// Stack the CLS vectors
|
154
|
-
Tensor::cat(&cls_vecs, 0)
|
155
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
156
|
-
.contiguous()
|
157
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
158
|
-
} else {
|
159
|
-
embeddings.i((.., 0))
|
160
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
161
|
-
};
|
162
|
-
// Ensure tensor is contiguous before linear layer
|
163
|
-
let cls_embeddings = cls_embeddings.contiguous()
|
164
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
|
164
|
+
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
165
165
|
let pooled = self.pooler.forward(&cls_embeddings)
|
166
166
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
|
167
167
|
pooled.tanh()
|
@@ -169,34 +169,7 @@ impl Reranker {
|
|
169
169
|
},
|
170
170
|
"cls" => {
|
171
171
|
// Just use the [CLS] token embeddings directly (no pooler layer)
|
172
|
-
|
173
|
-
let cls_embeddings = if self.device.is_metal() {
|
174
|
-
// Use same approach as pooler method
|
175
|
-
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
176
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
177
|
-
|
178
|
-
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
179
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
180
|
-
|
181
|
-
let mut cls_vecs = Vec::new();
|
182
|
-
for i in 0..batch_size {
|
183
|
-
let start_idx = i * _seq_len;
|
184
|
-
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
185
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
186
|
-
cls_vecs.push(cls_vec);
|
187
|
-
}
|
188
|
-
|
189
|
-
Tensor::cat(&cls_vecs, 0)
|
190
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
191
|
-
.contiguous()
|
192
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
193
|
-
} else {
|
194
|
-
embeddings.i((.., 0))
|
195
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
196
|
-
};
|
197
|
-
// Ensure contiguous for classifier
|
198
|
-
cls_embeddings.contiguous()
|
199
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
|
172
|
+
self.extract_cls_embeddings(&embeddings)?
|
200
173
|
},
|
201
174
|
"mean" => {
|
202
175
|
// Mean pooling across all tokens
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -83,6 +83,26 @@ impl ModelType {
|
|
83
83
|
}
|
84
84
|
}
|
85
85
|
|
86
|
+
// Macro to extract parameters from Ruby hash to reduce boilerplate
|
87
|
+
macro_rules! extract_param {
|
88
|
+
// Basic parameter extraction
|
89
|
+
($kwargs:expr, $config:expr, $param:ident) => {
|
90
|
+
if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
|
91
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
92
|
+
$config.$param = v;
|
93
|
+
}
|
94
|
+
}
|
95
|
+
};
|
96
|
+
// Optional parameter extraction (wraps in Some)
|
97
|
+
($kwargs:expr, $config:expr, $param:ident, optional) => {
|
98
|
+
if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
|
99
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
100
|
+
$config.$param = Some(v);
|
101
|
+
}
|
102
|
+
}
|
103
|
+
};
|
104
|
+
}
|
105
|
+
|
86
106
|
#[derive(Clone, Debug)]
|
87
107
|
#[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
|
88
108
|
pub struct GenerationConfig {
|
@@ -93,55 +113,20 @@ impl GenerationConfig {
|
|
93
113
|
pub fn new(kwargs: RHash) -> Result<Self> {
|
94
114
|
let mut config = RustGenerationConfig::default();
|
95
115
|
|
96
|
-
// Extract
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
|
110
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
111
|
-
config.top_p = Some(v);
|
112
|
-
}
|
113
|
-
}
|
114
|
-
|
115
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
|
116
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
117
|
-
config.top_k = Some(v);
|
118
|
-
}
|
119
|
-
}
|
120
|
-
|
121
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
|
122
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
123
|
-
config.repetition_penalty = v;
|
124
|
-
}
|
125
|
-
}
|
126
|
-
|
127
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
|
128
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
129
|
-
config.repetition_penalty_last_n = v;
|
130
|
-
}
|
131
|
-
}
|
132
|
-
|
133
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
|
134
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
135
|
-
config.seed = v;
|
136
|
-
}
|
137
|
-
}
|
138
|
-
|
139
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
|
140
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
141
|
-
config.include_prompt = v;
|
142
|
-
}
|
143
|
-
}
|
116
|
+
// Extract basic parameters using macro
|
117
|
+
extract_param!(kwargs, config, max_length);
|
118
|
+
extract_param!(kwargs, config, temperature);
|
119
|
+
extract_param!(kwargs, config, top_p, optional);
|
120
|
+
extract_param!(kwargs, config, top_k, optional);
|
121
|
+
extract_param!(kwargs, config, repetition_penalty);
|
122
|
+
extract_param!(kwargs, config, repetition_penalty_last_n);
|
123
|
+
extract_param!(kwargs, config, seed);
|
124
|
+
extract_param!(kwargs, config, include_prompt);
|
125
|
+
extract_param!(kwargs, config, debug_tokens);
|
126
|
+
extract_param!(kwargs, config, stop_on_constraint_satisfaction);
|
127
|
+
extract_param!(kwargs, config, stop_on_match);
|
144
128
|
|
129
|
+
// Handle special cases that need custom logic
|
145
130
|
if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
|
146
131
|
if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
|
147
132
|
config.stop_sequences = arr
|
@@ -151,13 +136,6 @@ impl GenerationConfig {
|
|
151
136
|
}
|
152
137
|
}
|
153
138
|
|
154
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
|
155
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
156
|
-
config.debug_tokens = v;
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
// Handle constraint parameter
|
161
139
|
if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
|
162
140
|
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
163
141
|
config.constraint = Some(Arc::clone(&constraint.index));
|
@@ -209,6 +187,15 @@ impl GenerationConfig {
|
|
209
187
|
pub fn debug_tokens(&self) -> bool {
|
210
188
|
self.inner.debug_tokens
|
211
189
|
}
|
190
|
+
|
191
|
+
pub fn stop_on_constraint_satisfaction(&self) -> bool {
|
192
|
+
self.inner.stop_on_constraint_satisfaction
|
193
|
+
}
|
194
|
+
|
195
|
+
pub fn stop_on_match(&self) -> bool {
|
196
|
+
self.inner.stop_on_match
|
197
|
+
}
|
198
|
+
|
212
199
|
pub fn constraint(&self) -> Option<StructuredConstraint> {
|
213
200
|
self.inner.constraint.as_ref().map(|c| StructuredConstraint {
|
214
201
|
index: Arc::clone(c),
|
@@ -372,6 +359,42 @@ impl LLM {
|
|
372
359
|
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
373
360
|
}
|
374
361
|
}
|
362
|
+
|
363
|
+
/// Get the EOS token string for this model
|
364
|
+
pub fn eos_token(&self) -> Result<String> {
|
365
|
+
let (eos_token_id, tokenizer_clone) = {
|
366
|
+
let model = match self.model.lock() {
|
367
|
+
Ok(guard) => guard,
|
368
|
+
Err(poisoned) => poisoned.into_inner(),
|
369
|
+
};
|
370
|
+
let model_ref = model.borrow();
|
371
|
+
|
372
|
+
// Get both EOS token ID and tokenizer clone in one lock scope
|
373
|
+
let eos_id = match &*model_ref {
|
374
|
+
ModelType::Mistral(m) => m.eos_token_id(),
|
375
|
+
ModelType::Llama(m) => m.eos_token_id(),
|
376
|
+
ModelType::Gemma(m) => m.eos_token_id(),
|
377
|
+
ModelType::Qwen(m) => m.eos_token_id(),
|
378
|
+
ModelType::Phi(m) => m.eos_token_id(),
|
379
|
+
ModelType::QuantizedGGUF(m) => m.eos_token_id(),
|
380
|
+
};
|
381
|
+
|
382
|
+
let tokenizer = match &*model_ref {
|
383
|
+
ModelType::Mistral(m) => m.tokenizer().clone(),
|
384
|
+
ModelType::Llama(m) => m.tokenizer().clone(),
|
385
|
+
ModelType::Gemma(m) => m.tokenizer().clone(),
|
386
|
+
ModelType::Qwen(m) => m.tokenizer().clone(),
|
387
|
+
ModelType::Phi(m) => m.tokenizer().clone(),
|
388
|
+
ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
|
389
|
+
};
|
390
|
+
|
391
|
+
(eos_id, tokenizer)
|
392
|
+
}; // Lock is released here
|
393
|
+
|
394
|
+
// Convert ID to string using the tokenizer
|
395
|
+
let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
|
396
|
+
tokenizer_wrapper.id_to_token(eos_token_id as i64)
|
397
|
+
}
|
375
398
|
|
376
399
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
377
400
|
pub fn clear_cache(&self) -> Result<()> {
|
@@ -460,6 +483,8 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
460
483
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
461
484
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
462
485
|
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
486
|
+
rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
|
487
|
+
rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
|
463
488
|
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
464
489
|
|
465
490
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
@@ -469,6 +494,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
469
494
|
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
470
495
|
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
471
496
|
rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
|
497
|
+
rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
|
472
498
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
473
499
|
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
474
500
|
|