red-candle 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,39 +10,29 @@ pub struct TextGeneration {
10
10
  logits_processor: LogitsProcessor,
11
11
  tokens: Vec<u32>,
12
12
  eos_token_id: Option<u32>,
13
+ repetition_penalty: f32,
14
+ repetition_penalty_last_n: usize,
13
15
  constraint: Option<Arc<Index>>,
14
16
  constraint_state: Option<u32>,
17
+ constraint_completed: bool,
18
+ tokens_since_constraint_start: usize,
15
19
  }
16
20
 
17
21
  impl TextGeneration {
18
- pub fn new(
19
- seed: u64,
20
- temperature: Option<f64>,
21
- top_p: Option<f64>,
22
- _top_k: Option<usize>,
23
- _repetition_penalty: f32,
24
- _repetition_penalty_last_n: usize,
25
- ) -> Self {
26
- let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
27
-
28
- Self {
22
+ pub fn new(config: &GenerationConfig) -> Self {
23
+ let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
24
+
25
+ let mut text_gen = Self {
29
26
  logits_processor,
30
27
  tokens: Vec::new(),
31
28
  eos_token_id: None,
29
+ repetition_penalty: config.repetition_penalty,
30
+ repetition_penalty_last_n: config.repetition_penalty_last_n,
32
31
  constraint: None,
33
32
  constraint_state: None,
34
- }
35
- }
36
-
37
- pub fn from_config(config: &GenerationConfig) -> Self {
38
- let mut text_gen = Self::new(
39
- config.seed,
40
- Some(config.temperature),
41
- config.top_p,
42
- config.top_k,
43
- config.repetition_penalty,
44
- config.repetition_penalty_last_n,
45
- );
33
+ constraint_completed: false,
34
+ tokens_since_constraint_start: 0,
35
+ };
46
36
 
47
37
  // Set constraint if provided
48
38
  if let Some(ref constraint) = config.constraint {
@@ -72,6 +62,8 @@ impl TextGeneration {
72
62
  // Initialize with the first state
73
63
  self.constraint_state = Some(constraint.initial_state());
74
64
  self.constraint = Some(constraint);
65
+ self.constraint_completed = false;
66
+ self.tokens_since_constraint_start = self.tokens.len();
75
67
  }
76
68
 
77
69
  /// Apply constraints to logits by masking disallowed tokens
@@ -137,13 +129,12 @@ impl TextGeneration {
137
129
  pub fn sample_next_token(
138
130
  &mut self,
139
131
  logits: &Tensor,
140
- repetition_penalty: Option<(f32, usize)>,
141
132
  ) -> CandleResult<u32> {
142
133
  let mut logits = logits.clone();
143
134
 
144
- // Apply repetition penalty if specified
145
- if let Some((penalty, last_n)) = repetition_penalty {
146
- self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
135
+ // Apply repetition penalty using stored parameters
136
+ if self.repetition_penalty != 1.0 {
137
+ self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
147
138
  }
148
139
 
149
140
  // Apply constraints if active
@@ -155,12 +146,114 @@ impl TextGeneration {
155
146
 
156
147
  // Update constraint state if active
157
148
  if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
158
- self.constraint_state = constraint_index.next_state(&current_state, &next_token);
149
+ // Get the next state
150
+ let next_state = constraint_index.next_state(&current_state, &next_token);
151
+
152
+ // Check if we're transitioning to a state with no allowed tokens (completion)
153
+ if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
154
+ // Check if we've transitioned from a constrained state to an unconstrained state
155
+ // This happens when the pattern is complete and the FSM allows "anything"
156
+
157
+ let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(&current_state) {
158
+ // Consider it constrained if we have a limited set of allowed tokens
159
+ allowed.len() < 1000 // Arbitrary threshold for "constrained"
160
+ } else {
161
+ true // No tokens allowed is definitely constrained
162
+ };
163
+
164
+ let next_constrained = if let Some(next_state_val) = next_state {
165
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
166
+ allowed.is_empty() || allowed.len() < 1000
167
+ } else {
168
+ true
169
+ }
170
+ } else {
171
+ true
172
+ };
173
+
174
+ // If we're transitioning from constrained to unconstrained, we've completed the pattern
175
+ if current_constrained && !next_constrained {
176
+ self.constraint_completed = true;
177
+ }
178
+
179
+ // Also check if next state has no allowed tokens at all
180
+ if let Some(next_state_val) = next_state {
181
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
182
+ if allowed.is_empty() {
183
+ self.constraint_completed = true;
184
+ }
185
+ } else {
186
+ // None means no tokens allowed - constraint is complete
187
+ self.constraint_completed = true;
188
+ }
189
+ }
190
+ }
191
+
192
+ self.constraint_state = next_state;
159
193
  }
160
194
 
161
195
  Ok(next_token)
162
196
  }
163
197
 
198
+ /// Check if the constraint is satisfied (reached a valid completion state)
199
+ pub fn is_constraint_satisfied(&self) -> bool {
200
+ // If we've explicitly marked the constraint as completed, return true
201
+ if self.constraint_completed {
202
+ return true;
203
+ }
204
+
205
+ // Also check the current state
206
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
207
+ // Check if the constraint has reached a state where it could validly end
208
+ // This happens when:
209
+ // 1. We have no more allowed tokens (constraint fully satisfied)
210
+ // 2. The EOS token is in the allowed tokens (optional ending)
211
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
212
+ // If no tokens are allowed, the constraint is fully satisfied
213
+ if allowed.is_empty() {
214
+ return true;
215
+ }
216
+
217
+ // If EOS token is allowed, we've reached an optional completion point
218
+ if let Some(eos) = self.eos_token_id {
219
+ if allowed.contains(&eos) {
220
+ return true;
221
+ }
222
+ }
223
+ } else {
224
+ // None means no tokens allowed - constraint is satisfied
225
+ return true;
226
+ }
227
+ }
228
+ false
229
+ }
230
+
231
+ /// Check if the constraint is satisfied when stop_on_match is true
232
+ pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
233
+ // When stop_on_match is true, we stop as soon as the constraint is completed
234
+ if self.constraint_completed {
235
+ return true;
236
+ }
237
+
238
+ // Also check if we're currently in a state that could be a valid end
239
+ // This is important for patterns like phone numbers where after matching
240
+ // the pattern, the FSM might allow any token (including more numbers)
241
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
242
+ // Check if we've generated at least one token since constraint start
243
+ if self.tokens.len() > self.tokens_since_constraint_start {
244
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
245
+ // If the allowed tokens set is very large (unconstrained),
246
+ // it means the pattern has been satisfied
247
+ if allowed.len() > 1000 {
248
+ return true;
249
+ }
250
+ }
251
+ }
252
+ }
253
+
254
+ false
255
+ }
256
+
164
257
  /// Check if we should stop generation
165
258
  pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
166
259
  if self.tokens.len() >= max_length {
@@ -39,13 +39,9 @@ impl NER {
39
39
  pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
40
40
  let device = device.unwrap_or(Device::Cpu).as_device()?;
41
41
 
42
- // Load model in a separate thread to avoid blocking
43
- let device_clone = device.clone();
44
- let model_id_clone = model_id.clone();
45
-
46
- let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
42
+ let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
47
43
  let api = Api::new()?;
48
- let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
44
+ let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
49
45
 
50
46
  // Download model files
51
47
  let config_filename = repo.get("config.json")?;
@@ -92,7 +88,7 @@ impl NER {
92
88
 
93
89
  // Load model weights
94
90
  let vb = unsafe {
95
- VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
91
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
96
92
  };
97
93
 
98
94
  // Load BERT model
@@ -106,10 +102,10 @@ impl NER {
106
102
  )?;
107
103
 
108
104
  Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
109
- });
105
+ })();
110
106
 
111
- match handle.join() {
112
- Ok(Ok((model, tokenizer, classifier, config))) => {
107
+ match result {
108
+ Ok((model, tokenizer, classifier, config)) => {
113
109
  Ok(Self {
114
110
  model,
115
111
  tokenizer,
@@ -119,28 +115,20 @@ impl NER {
119
115
  model_id,
120
116
  })
121
117
  }
122
- Ok(Err(e)) => Err(Error::new(
118
+ Err(e) => Err(Error::new(
123
119
  magnus::exception::runtime_error(),
124
120
  format!("Failed to load NER model: {}", e)
125
121
  )),
126
- Err(_) => Err(Error::new(
127
- magnus::exception::runtime_error(),
128
- "Thread panicked while loading NER model"
129
- )),
130
122
  }
131
123
  }
132
124
 
133
- /// Extract entities from text with confidence scores
134
- pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
135
- let threshold = confidence_threshold.unwrap_or(0.9) as f32;
136
-
125
+ /// Common tokenization and prediction logic
126
+ fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
137
127
  // Tokenize the text
138
- let encoding = self.tokenizer.inner().encode(text.as_str(), true)
128
+ let encoding = self.tokenizer.inner().encode(text, true)
139
129
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
140
130
 
141
131
  let token_ids = encoding.get_ids();
142
- let tokens = encoding.get_tokens();
143
- let offsets = encoding.get_offsets();
144
132
 
145
133
  // Convert to tensors
146
134
  let input_ids = Tensor::new(token_ids, &self.device)
@@ -171,6 +159,19 @@ impl NER {
171
159
  .to_vec2()
172
160
  .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
173
161
 
162
+ Ok((encoding, probs_vec))
163
+ }
164
+
165
+ /// Extract entities from text with confidence scores
166
+ pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
167
+ let threshold = confidence_threshold.unwrap_or(0.9) as f32;
168
+
169
+ // Use common tokenization and prediction logic
170
+ let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
171
+
172
+ let tokens = encoding.get_tokens();
173
+ let offsets = encoding.get_offsets();
174
+
174
175
  // Extract entities with BIO decoding
175
176
  let entities = self.decode_entities(
176
177
  &text,
@@ -199,38 +200,11 @@ impl NER {
199
200
 
200
201
  /// Get token-level predictions with labels and confidence scores
201
202
  pub fn predict_tokens(&self, text: String) -> Result<RArray> {
202
- // Tokenize the text
203
- let encoding = self.tokenizer.inner().encode(text.as_str(), true)
204
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
203
+ // Use common tokenization and prediction logic
204
+ let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
205
205
 
206
- let token_ids = encoding.get_ids();
207
206
  let tokens = encoding.get_tokens();
208
207
 
209
- // Convert to tensors
210
- let input_ids = Tensor::new(token_ids, &self.device)
211
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
212
- .unsqueeze(0)
213
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
214
-
215
- let attention_mask = Tensor::ones_like(&input_ids)
216
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
217
- let token_type_ids = Tensor::zeros_like(&input_ids)
218
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
219
-
220
- // Forward pass
221
- let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
222
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
223
- let logits = self.classifier.forward(&output)
224
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
225
- let probs = candle_nn::ops::softmax(&logits, 2)
226
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
227
-
228
- // Get predictions
229
- let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
230
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
231
- .to_vec2()
232
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
233
-
234
208
  // Build result array
235
209
  let result = RArray::new();
236
210
  for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
@@ -4,7 +4,6 @@ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
4
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
5
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
6
  use tokenizers::{EncodeInput, Tokenizer};
7
- use std::thread;
8
7
  use crate::ruby::{Device, Result};
9
8
  use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
10
9
 
@@ -24,8 +23,7 @@ impl Reranker {
24
23
  }
25
24
 
26
25
  fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
27
- let device_clone = device.clone();
28
- let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
26
+ let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
29
27
  let api = Api::new()?;
30
28
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
31
29
 
@@ -44,7 +42,7 @@ impl Reranker {
44
42
 
45
43
  // Load model weights
46
44
  let vb = unsafe {
47
- VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
45
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
48
46
  };
49
47
 
50
48
  // Load BERT model
@@ -57,17 +55,49 @@ impl Reranker {
57
55
  let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
58
56
 
59
57
  Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
60
- });
58
+ })();
61
59
 
62
- match handle.join() {
63
- Ok(Ok((model, tokenizer, pooler, classifier))) => {
60
+ match result {
61
+ Ok((model, tokenizer, pooler, classifier)) => {
64
62
  Ok(Self { model, tokenizer, pooler, classifier, device })
65
63
  }
66
- Ok(Err(e)) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
67
- Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
64
+ Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
68
65
  }
69
66
  }
70
67
 
68
+ /// Extract CLS embeddings from the model output, handling Metal device workarounds
69
+ fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
70
+ let cls_embeddings = if self.device.is_metal() {
71
+ // Metal has issues with tensor indexing, use a different approach
72
+ let (batch_size, seq_len, hidden_size) = embeddings.dims3()
73
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
74
+
75
+ // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
76
+ let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
77
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
78
+
79
+ // Extract CLS tokens (first token of each sequence)
80
+ let mut cls_vecs = Vec::new();
81
+ for i in 0..batch_size {
82
+ let start_idx = i * seq_len;
83
+ let cls_vec = reshaped.narrow(0, start_idx, 1)
84
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
85
+ cls_vecs.push(cls_vec);
86
+ }
87
+
88
+ // Stack the CLS vectors
89
+ Tensor::cat(&cls_vecs, 0)
90
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
91
+ } else {
92
+ embeddings.i((.., 0))
93
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
94
+ };
95
+
96
+ // Ensure tensor is contiguous for downstream operations
97
+ cls_embeddings.contiguous()
98
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))
99
+ }
100
+
71
101
  pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
72
102
  // Create query-document pair for cross-encoder
73
103
  let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
@@ -131,37 +161,7 @@ impl Reranker {
131
161
  let pooled_embeddings = match pooling_method.as_str() {
132
162
  "pooler" => {
133
163
  // Extract [CLS] token and apply pooler (dense + tanh)
134
- // Work around Metal indexing issue by using narrow instead of i((.., 0))
135
- let cls_embeddings = if self.device.is_metal() {
136
- // Metal has issues with tensor indexing, use a different approach
137
- let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
138
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
139
-
140
- // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
141
- let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
142
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
143
-
144
- // Extract CLS tokens (first token of each sequence)
145
- let mut cls_vecs = Vec::new();
146
- for i in 0..batch_size {
147
- let start_idx = i * _seq_len;
148
- let cls_vec = reshaped.narrow(0, start_idx, 1)
149
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
150
- cls_vecs.push(cls_vec);
151
- }
152
-
153
- // Stack the CLS vectors
154
- Tensor::cat(&cls_vecs, 0)
155
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
156
- .contiguous()
157
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
158
- } else {
159
- embeddings.i((.., 0))
160
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
161
- };
162
- // Ensure tensor is contiguous before linear layer
163
- let cls_embeddings = cls_embeddings.contiguous()
164
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
164
+ let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
165
165
  let pooled = self.pooler.forward(&cls_embeddings)
166
166
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
167
167
  pooled.tanh()
@@ -169,34 +169,7 @@ impl Reranker {
169
169
  },
170
170
  "cls" => {
171
171
  // Just use the [CLS] token embeddings directly (no pooler layer)
172
- // Work around Metal indexing issue
173
- let cls_embeddings = if self.device.is_metal() {
174
- // Use same approach as pooler method
175
- let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
176
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
177
-
178
- let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
179
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
180
-
181
- let mut cls_vecs = Vec::new();
182
- for i in 0..batch_size {
183
- let start_idx = i * _seq_len;
184
- let cls_vec = reshaped.narrow(0, start_idx, 1)
185
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
186
- cls_vecs.push(cls_vec);
187
- }
188
-
189
- Tensor::cat(&cls_vecs, 0)
190
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
191
- .contiguous()
192
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
193
- } else {
194
- embeddings.i((.., 0))
195
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
196
- };
197
- // Ensure contiguous for classifier
198
- cls_embeddings.contiguous()
199
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
172
+ self.extract_cls_embeddings(&embeddings)?
200
173
  },
201
174
  "mean" => {
202
175
  // Mean pooling across all tokens
@@ -83,6 +83,26 @@ impl ModelType {
83
83
  }
84
84
  }
85
85
 
86
+ // Macro to extract parameters from Ruby hash to reduce boilerplate
87
+ macro_rules! extract_param {
88
+ // Basic parameter extraction
89
+ ($kwargs:expr, $config:expr, $param:ident) => {
90
+ if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
91
+ if let Ok(v) = TryConvert::try_convert(value) {
92
+ $config.$param = v;
93
+ }
94
+ }
95
+ };
96
+ // Optional parameter extraction (wraps in Some)
97
+ ($kwargs:expr, $config:expr, $param:ident, optional) => {
98
+ if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
99
+ if let Ok(v) = TryConvert::try_convert(value) {
100
+ $config.$param = Some(v);
101
+ }
102
+ }
103
+ };
104
+ }
105
+
86
106
  #[derive(Clone, Debug)]
87
107
  #[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
88
108
  pub struct GenerationConfig {
@@ -93,55 +113,20 @@ impl GenerationConfig {
93
113
  pub fn new(kwargs: RHash) -> Result<Self> {
94
114
  let mut config = RustGenerationConfig::default();
95
115
 
96
- // Extract values from kwargs manually
97
- if let Some(value) = kwargs.get(magnus::Symbol::new("max_length")) {
98
- if let Ok(v) = TryConvert::try_convert(value) {
99
- config.max_length = v;
100
- }
101
- }
102
-
103
- if let Some(value) = kwargs.get(magnus::Symbol::new("temperature")) {
104
- if let Ok(v) = TryConvert::try_convert(value) {
105
- config.temperature = v;
106
- }
107
- }
108
-
109
- if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
110
- if let Ok(v) = TryConvert::try_convert(value) {
111
- config.top_p = Some(v);
112
- }
113
- }
114
-
115
- if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
116
- if let Ok(v) = TryConvert::try_convert(value) {
117
- config.top_k = Some(v);
118
- }
119
- }
120
-
121
- if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
122
- if let Ok(v) = TryConvert::try_convert(value) {
123
- config.repetition_penalty = v;
124
- }
125
- }
126
-
127
- if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
128
- if let Ok(v) = TryConvert::try_convert(value) {
129
- config.repetition_penalty_last_n = v;
130
- }
131
- }
132
-
133
- if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
134
- if let Ok(v) = TryConvert::try_convert(value) {
135
- config.seed = v;
136
- }
137
- }
138
-
139
- if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
140
- if let Ok(v) = TryConvert::try_convert(value) {
141
- config.include_prompt = v;
142
- }
143
- }
116
+ // Extract basic parameters using macro
117
+ extract_param!(kwargs, config, max_length);
118
+ extract_param!(kwargs, config, temperature);
119
+ extract_param!(kwargs, config, top_p, optional);
120
+ extract_param!(kwargs, config, top_k, optional);
121
+ extract_param!(kwargs, config, repetition_penalty);
122
+ extract_param!(kwargs, config, repetition_penalty_last_n);
123
+ extract_param!(kwargs, config, seed);
124
+ extract_param!(kwargs, config, include_prompt);
125
+ extract_param!(kwargs, config, debug_tokens);
126
+ extract_param!(kwargs, config, stop_on_constraint_satisfaction);
127
+ extract_param!(kwargs, config, stop_on_match);
144
128
 
129
+ // Handle special cases that need custom logic
145
130
  if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
146
131
  if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
147
132
  config.stop_sequences = arr
@@ -151,13 +136,6 @@ impl GenerationConfig {
151
136
  }
152
137
  }
153
138
 
154
- if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
155
- if let Ok(v) = TryConvert::try_convert(value) {
156
- config.debug_tokens = v;
157
- }
158
- }
159
-
160
- // Handle constraint parameter
161
139
  if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
162
140
  if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
163
141
  config.constraint = Some(Arc::clone(&constraint.index));
@@ -209,6 +187,15 @@ impl GenerationConfig {
209
187
  pub fn debug_tokens(&self) -> bool {
210
188
  self.inner.debug_tokens
211
189
  }
190
+
191
+ pub fn stop_on_constraint_satisfaction(&self) -> bool {
192
+ self.inner.stop_on_constraint_satisfaction
193
+ }
194
+
195
+ pub fn stop_on_match(&self) -> bool {
196
+ self.inner.stop_on_match
197
+ }
198
+
212
199
  pub fn constraint(&self) -> Option<StructuredConstraint> {
213
200
  self.inner.constraint.as_ref().map(|c| StructuredConstraint {
214
201
  index: Arc::clone(c),
@@ -372,6 +359,42 @@ impl LLM {
372
359
  ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
373
360
  }
374
361
  }
362
+
363
+ /// Get the EOS token string for this model
364
+ pub fn eos_token(&self) -> Result<String> {
365
+ let (eos_token_id, tokenizer_clone) = {
366
+ let model = match self.model.lock() {
367
+ Ok(guard) => guard,
368
+ Err(poisoned) => poisoned.into_inner(),
369
+ };
370
+ let model_ref = model.borrow();
371
+
372
+ // Get both EOS token ID and tokenizer clone in one lock scope
373
+ let eos_id = match &*model_ref {
374
+ ModelType::Mistral(m) => m.eos_token_id(),
375
+ ModelType::Llama(m) => m.eos_token_id(),
376
+ ModelType::Gemma(m) => m.eos_token_id(),
377
+ ModelType::Qwen(m) => m.eos_token_id(),
378
+ ModelType::Phi(m) => m.eos_token_id(),
379
+ ModelType::QuantizedGGUF(m) => m.eos_token_id(),
380
+ };
381
+
382
+ let tokenizer = match &*model_ref {
383
+ ModelType::Mistral(m) => m.tokenizer().clone(),
384
+ ModelType::Llama(m) => m.tokenizer().clone(),
385
+ ModelType::Gemma(m) => m.tokenizer().clone(),
386
+ ModelType::Qwen(m) => m.tokenizer().clone(),
387
+ ModelType::Phi(m) => m.tokenizer().clone(),
388
+ ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
389
+ };
390
+
391
+ (eos_id, tokenizer)
392
+ }; // Lock is released here
393
+
394
+ // Convert ID to string using the tokenizer
395
+ let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
396
+ tokenizer_wrapper.id_to_token(eos_token_id as i64)
397
+ }
375
398
 
376
399
  /// Clear the model's cache (e.g., KV cache for transformers)
377
400
  pub fn clear_cache(&self) -> Result<()> {
@@ -460,6 +483,8 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
460
483
  rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
461
484
  rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
462
485
  rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
486
+ rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
487
+ rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
463
488
  rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
464
489
 
465
490
  let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
@@ -469,6 +494,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
469
494
  rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
470
495
  rb_llm.define_method("device", method!(LLM::device, 0))?;
471
496
  rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
497
+ rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
472
498
  rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
473
499
  rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
474
500