red-candle 1.0.1 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,9 @@ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
2
  use candle_core::quantized::gguf_file;
3
3
  use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
4
4
  use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
5
+ use candle_transformers::models::quantized_qwen2::ModelWeights as QuantizedQwenModel;
6
+ use candle_transformers::models::quantized_phi::ModelWeights as QuantizedPhiModel;
7
+ use candle_transformers::models::quantized_phi3::ModelWeights as QuantizedPhi3Model;
5
8
  use hf_hub::api::tokio::{Api, ApiRepo};
6
9
  use tokenizers::Tokenizer;
7
10
  use std::io::Seek;
@@ -9,7 +12,6 @@ use std::io::Seek;
9
12
  use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
10
13
 
11
14
  /// Unified GGUF model that can load any GGUF file and detect the architecture
12
- #[derive(Debug)]
13
15
  pub struct QuantizedGGUF {
14
16
  model: ModelType,
15
17
  tokenizer: TokenizerWrapper,
@@ -20,10 +22,12 @@ pub struct QuantizedGGUF {
20
22
  _chat_template: Option<String>,
21
23
  }
22
24
 
23
- #[derive(Debug)]
24
25
  enum ModelType {
25
26
  Llama(QuantizedLlamaModel),
26
27
  Gemma(QuantizedGemmaModel),
28
+ Qwen(QuantizedQwenModel),
29
+ Phi(QuantizedPhiModel),
30
+ Phi3(QuantizedPhi3Model),
27
31
  // Mistral uses Llama loader due to tensor naming compatibility
28
32
  }
29
33
 
@@ -97,6 +101,34 @@ impl QuantizedGGUF {
97
101
  let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
98
102
  ModelType::Llama(model)
99
103
  }
104
+ "qwen" | "qwen2" | "qwen3" => {
105
+ // Try different loaders based on what metadata is available
106
+ if content.metadata.contains_key("llama.attention.head_count") {
107
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
108
+ ModelType::Llama(model)
109
+ } else if content.metadata.contains_key("qwen2.attention.head_count") {
110
+ let model = QuantizedQwenModel::from_gguf(content, &mut file, &device)?;
111
+ ModelType::Qwen(model)
112
+ } else if content.metadata.contains_key("qwen3.attention.head_count") {
113
+ // Qwen3 GGUF files use a different metadata format
114
+ // The quantized_qwen3 module is not yet in the released version of candle-transformers
115
+ return Err(candle_core::Error::Msg(format!(
116
+ "Qwen3 GGUF format detected but not yet fully supported.\n\n\
117
+ The file contains qwen3.* metadata keys which require candle-transformers > 0.9.1.\n\n\
118
+ Current alternatives:\n\
119
+ 1. Use Qwen2.5 GGUF models which work well:\n\
120
+ - Qwen/Qwen2.5-7B-Instruct-GGUF (recommended)\n\
121
+ - Qwen/Qwen2.5-32B-Instruct-GGUF\n\
122
+ 2. Use non-quantized Qwen models with safetensors\n\
123
+ 3. Wait for candle-transformers update with quantized_qwen3 support\n\n\
124
+ Note: Qwen2.5 models have similar capabilities to Qwen3."
125
+ )));
126
+ } else {
127
+ // Last resort: try llama loader anyway, as it's the most common
128
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
129
+ ModelType::Llama(model)
130
+ }
131
+ }
100
132
  "gemma" | "gemma2" | "gemma3" => {
101
133
  // Try Gemma-specific loader first, fall back to Llama if it fails
102
134
  match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
@@ -112,9 +144,20 @@ impl QuantizedGGUF {
112
144
  Err(e) => return Err(e),
113
145
  }
114
146
  }
147
+ "phi" | "phi2" => {
148
+ let model = QuantizedPhiModel::from_gguf(content, &mut file, &device)?;
149
+ ModelType::Phi(model)
150
+ }
151
+ "phi3" => {
152
+ // QuantizedPhi3Model requires an additional `approx` parameter
153
+ // Setting to false to avoid performance issues without flash-attn
154
+ let approx = false;
155
+ let model = QuantizedPhi3Model::from_gguf(approx, content, &mut file, &device)?;
156
+ ModelType::Phi3(model)
157
+ }
115
158
  _ => {
116
159
  return Err(candle_core::Error::Msg(format!(
117
- "Unsupported architecture: {}. Supported: llama, mistral, gemma",
160
+ "Unsupported architecture: {}. Supported: llama, mistral, gemma, qwen, qwen2, qwen3, phi, phi2, phi3",
118
161
  architecture
119
162
  )));
120
163
  }
@@ -149,6 +192,14 @@ impl QuantizedGGUF {
149
192
  Ok("mistral".to_string())
150
193
  } else if model_lower.contains("gemma") {
151
194
  Ok("gemma".to_string())
195
+ } else if model_lower.contains("qwen") {
196
+ Ok("qwen".to_string())
197
+ } else if model_lower.contains("phi-3") || model_lower.contains("phi3") {
198
+ Ok("phi3".to_string())
199
+ } else if model_lower.contains("phi-2") || model_lower.contains("phi2") {
200
+ Ok("phi2".to_string())
201
+ } else if model_lower.contains("phi") {
202
+ Ok("phi".to_string())
152
203
  } else {
153
204
  Err(candle_core::Error::Msg(
154
205
  "Could not determine model architecture from metadata or name".to_string()
@@ -235,6 +286,20 @@ impl QuantizedGGUF {
235
286
  .copied()
236
287
  .unwrap_or(1)
237
288
  }
289
+ "qwen" | "qwen2" | "qwen3" => {
290
+ vocab.get("<|endoftext|>")
291
+ .or_else(|| vocab.get("<|im_end|>"))
292
+ .or_else(|| vocab.get("</s>"))
293
+ .copied()
294
+ .unwrap_or(151643) // Default Qwen3 EOS token
295
+ }
296
+ "phi" | "phi2" | "phi3" => {
297
+ vocab.get("<|endoftext|>")
298
+ .or_else(|| vocab.get("<|end|>"))
299
+ .or_else(|| vocab.get("</s>"))
300
+ .copied()
301
+ .unwrap_or(50256) // Default GPT-2 style EOS token
302
+ }
238
303
  _ => 2, // Default
239
304
  }
240
305
  }
@@ -256,6 +321,10 @@ impl QuantizedGGUF {
256
321
  } else if model_lower.contains("gemma") {
257
322
  // Always use Gemma template for Gemma models, regardless of loader used
258
323
  self.apply_gemma_template(messages)
324
+ } else if model_lower.contains("qwen") {
325
+ self.apply_qwen_template(messages)
326
+ } else if model_lower.contains("phi") {
327
+ self.apply_phi_template(messages)
259
328
  } else {
260
329
  match self.architecture.as_str() {
261
330
  "llama" => {
@@ -268,6 +337,12 @@ impl QuantizedGGUF {
268
337
  "gemma" => {
269
338
  self.apply_gemma_template(messages)
270
339
  }
340
+ "qwen" | "qwen2" | "qwen3" => {
341
+ self.apply_qwen_template(messages)
342
+ }
343
+ "phi" | "phi2" | "phi3" => {
344
+ self.apply_phi_template(messages)
345
+ }
271
346
  _ => Ok(self.apply_generic_template(messages))
272
347
  }
273
348
  }
@@ -366,6 +441,77 @@ impl QuantizedGGUF {
366
441
  Ok(prompt)
367
442
  }
368
443
 
444
+ fn apply_qwen_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
445
+ let mut prompt = String::new();
446
+
447
+ for message in messages {
448
+ let role = message["role"].as_str().unwrap_or("");
449
+ let content = message["content"].as_str().unwrap_or("");
450
+
451
+ match role {
452
+ "system" => {
453
+ prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
454
+ }
455
+ "user" => {
456
+ prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
457
+ }
458
+ "assistant" => {
459
+ prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
460
+ }
461
+ _ => {}
462
+ }
463
+ }
464
+
465
+ // Add generation prompt
466
+ prompt.push_str("<|im_start|>assistant\n");
467
+ Ok(prompt)
468
+ }
469
+
470
+ fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
471
+ let mut prompt = String::new();
472
+
473
+ // Check if it's Phi-3 (newer format) or Phi-2/Phi (simpler format)
474
+ let is_phi3 = self.model_id.contains("phi-3") || self.model_id.contains("Phi-3") || self.architecture == "phi3";
475
+
476
+ if is_phi3 {
477
+ // Phi-3 format
478
+ for message in messages {
479
+ let role = message["role"].as_str().unwrap_or("");
480
+ let content = message["content"].as_str().unwrap_or("");
481
+
482
+ match role {
483
+ "system" => {
484
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
485
+ }
486
+ "user" => {
487
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
488
+ }
489
+ "assistant" => {
490
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
491
+ }
492
+ _ => {}
493
+ }
494
+ }
495
+ prompt.push_str("<|assistant|>\n");
496
+ } else {
497
+ // Phi-2 format
498
+ for message in messages {
499
+ let role = message["role"].as_str().unwrap_or("");
500
+ let content = message["content"].as_str().unwrap_or("");
501
+
502
+ match role {
503
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
504
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
505
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
506
+ _ => {}
507
+ }
508
+ }
509
+ prompt.push_str("Assistant: ");
510
+ }
511
+
512
+ Ok(prompt)
513
+ }
514
+
369
515
  fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
370
516
  let mut prompt = String::new();
371
517
 
@@ -381,7 +527,9 @@ impl QuantizedGGUF {
381
527
 
382
528
  /// Clear the KV cache between generations
383
529
  pub fn clear_kv_cache(&mut self) {
384
- // Quantized models manage cache internally
530
+ // Quantized models don't expose cache clearing methods
531
+ // Phi3 GGUF models have a known issue where the KV cache
532
+ // cannot be cleared, leading to errors on subsequent generations
385
533
  }
386
534
 
387
535
  fn generate_tokens(
@@ -408,6 +556,9 @@ impl QuantizedGGUF {
408
556
  let logits = match &mut self.model {
409
557
  ModelType::Llama(model) => model.forward(&input, start_pos)?,
410
558
  ModelType::Gemma(model) => model.forward(&input, start_pos)?,
559
+ ModelType::Qwen(model) => model.forward(&input, start_pos)?,
560
+ ModelType::Phi(model) => model.forward(&input, start_pos)?,
561
+ ModelType::Phi3(model) => model.forward(&input, start_pos)?,
411
562
  };
412
563
 
413
564
  let logits = logits.squeeze(0)?;
@@ -0,0 +1,229 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_transformers::models::qwen2::{Config, Model as QwenModel};
3
+ use hf_hub::api::tokio::Api;
4
+ use tokenizers::Tokenizer;
5
+
6
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
7
+
8
+ /// Qwen model wrapper for text generation
9
+ #[derive(Debug)]
10
+ pub struct Qwen {
11
+ model: QwenModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Qwen {
19
+ /// Get the tokenizer
20
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
21
+ &self.tokenizer
22
+ }
23
+
24
+ /// Clear the KV cache between generations
25
+ pub fn clear_kv_cache(&mut self) {
26
+ self.model.clear_kv_cache();
27
+ }
28
+
29
+ /// Load a Qwen model from HuggingFace
30
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
31
+ let api = Api::new()
32
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
33
+
34
+ let repo = api.model(model_id.to_string());
35
+
36
+ // Download configuration
37
+ let config_filename = repo.get("config.json").await
38
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
39
+ let config_str = std::fs::read_to_string(config_filename)?;
40
+ let config: Config = serde_json::from_str(&config_str)
41
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
42
+
43
+ // Download tokenizer
44
+ let tokenizer_filename = repo.get("tokenizer.json").await
45
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
46
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
47
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
48
+
49
+ // Determine EOS token
50
+ let vocab = tokenizer.get_vocab(true);
51
+ let eos_token_id = vocab.get("<|endoftext|>")
52
+ .or_else(|| vocab.get("<|im_end|>"))
53
+ .or_else(|| vocab.get("</s>"))
54
+ .copied()
55
+ .unwrap_or(151643); // Default Qwen3 EOS token
56
+
57
+ // Download model weights
58
+ let mut filenames = vec![];
59
+ let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
60
+ else if model_id.contains("14b") || model_id.contains("14B") { 3 }
61
+ else { 1 };
62
+
63
+ if num_shards == 1 {
64
+ // Single file model
65
+ let filename = repo.get("model.safetensors").await
66
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
67
+ filenames.push(filename);
68
+ } else {
69
+ // Sharded model
70
+ for shard_idx in 1..=num_shards {
71
+ let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
72
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
73
+ filenames.push(filename);
74
+ }
75
+ }
76
+
77
+ // Load the model
78
+ let vb = unsafe {
79
+ candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
80
+ };
81
+
82
+ let model = QwenModel::new(&config, vb)?;
83
+
84
+ Ok(Self {
85
+ model,
86
+ tokenizer: TokenizerWrapper::new(tokenizer),
87
+ device,
88
+ model_id: model_id.to_string(),
89
+ eos_token_id,
90
+ })
91
+ }
92
+
93
+ /// Apply Qwen chat template to messages
94
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
95
+ let mut prompt = String::new();
96
+
97
+ for message in messages {
98
+ let role = message["role"].as_str().unwrap_or("");
99
+ let content = message["content"].as_str().unwrap_or("");
100
+
101
+ match role {
102
+ "system" => {
103
+ prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
104
+ }
105
+ "user" => {
106
+ prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
107
+ }
108
+ "assistant" => {
109
+ prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
110
+ }
111
+ _ => {}
112
+ }
113
+ }
114
+
115
+ // Add generation prompt
116
+ prompt.push_str("<|im_start|>assistant\n");
117
+
118
+ Ok(prompt)
119
+ }
120
+
121
+ fn generate_tokens(
122
+ &mut self,
123
+ prompt_tokens: Vec<u32>,
124
+ config: &GenerationConfig,
125
+ mut callback: Option<impl FnMut(&str)>,
126
+ ) -> CandleResult<Vec<u32>> {
127
+ let mut text_gen = TextGeneration::from_config(config);
128
+ text_gen.set_eos_token_id(self.eos_token_id);
129
+ text_gen.set_tokens(prompt_tokens.clone());
130
+
131
+ let mut all_tokens = prompt_tokens.clone();
132
+ let start_gen = all_tokens.len();
133
+
134
+ for index in 0..config.max_length {
135
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
136
+ let start_pos = all_tokens.len().saturating_sub(context_size);
137
+ let ctxt = &all_tokens[start_pos..];
138
+
139
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
140
+ let logits = self.model.forward(&input, start_pos, None)?;
141
+ let logits = logits.squeeze(0)?;
142
+
143
+ // Handle different output shapes
144
+ let logits = if logits.dims().len() == 2 {
145
+ let seq_len = logits.dim(0)?;
146
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
147
+ } else {
148
+ logits
149
+ };
150
+
151
+ let logits = logits.to_dtype(DType::F32)?;
152
+
153
+ let next_token = text_gen.sample_next_token(
154
+ &logits,
155
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
156
+ )?;
157
+
158
+ all_tokens.push(next_token);
159
+
160
+ // Stream callback
161
+ if let Some(ref mut cb) = callback {
162
+ if config.debug_tokens {
163
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
164
+ cb(&format!("[{}:{}]", next_token, token_piece));
165
+ } else {
166
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
167
+ cb(&decoded_text);
168
+ }
169
+ }
170
+
171
+ // Check stop conditions
172
+ if text_gen.should_stop(next_token, config.max_length) {
173
+ break;
174
+ }
175
+
176
+ // Check stop sequences
177
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
178
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
179
+ break;
180
+ }
181
+ }
182
+
183
+ Ok(if config.include_prompt {
184
+ all_tokens
185
+ } else {
186
+ all_tokens[start_gen..].to_vec()
187
+ })
188
+ }
189
+ }
190
+
191
+ impl TextGenerator for Qwen {
192
+ fn generate(
193
+ &mut self,
194
+ prompt: &str,
195
+ config: &GenerationConfig,
196
+ ) -> CandleResult<String> {
197
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
198
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
199
+
200
+ if config.debug_tokens {
201
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
202
+ } else {
203
+ self.tokenizer.decode(&output_tokens, true)
204
+ }
205
+ }
206
+
207
+ fn generate_stream(
208
+ &mut self,
209
+ prompt: &str,
210
+ config: &GenerationConfig,
211
+ mut callback: impl FnMut(&str),
212
+ ) -> CandleResult<String> {
213
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
214
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
215
+ self.tokenizer.decode(&output_tokens, true)
216
+ }
217
+
218
+ fn model_name(&self) -> &str {
219
+ &self.model_id
220
+ }
221
+
222
+ fn device(&self) -> &Device {
223
+ &self.device
224
+ }
225
+
226
+ fn clear_cache(&mut self) {
227
+ self.clear_kv_cache();
228
+ }
229
+ }
@@ -1,13 +1,17 @@
1
1
  use candle_core::{Result as CandleResult, Tensor};
2
2
  use candle_transformers::generation::LogitsProcessor;
3
+ use std::sync::Arc;
3
4
 
4
5
  use super::GenerationConfig;
6
+ use crate::structured::Index;
5
7
 
6
8
  /// Helper struct for text generation process
7
9
  pub struct TextGeneration {
8
10
  logits_processor: LogitsProcessor,
9
11
  tokens: Vec<u32>,
10
12
  eos_token_id: Option<u32>,
13
+ constraint: Option<Arc<Index>>,
14
+ constraint_state: Option<u32>,
11
15
  }
12
16
 
13
17
  impl TextGeneration {
@@ -25,18 +29,27 @@ impl TextGeneration {
25
29
  logits_processor,
26
30
  tokens: Vec::new(),
27
31
  eos_token_id: None,
32
+ constraint: None,
33
+ constraint_state: None,
28
34
  }
29
35
  }
30
36
 
31
37
  pub fn from_config(config: &GenerationConfig) -> Self {
32
- Self::new(
38
+ let mut text_gen = Self::new(
33
39
  config.seed,
34
40
  Some(config.temperature),
35
41
  config.top_p,
36
42
  config.top_k,
37
43
  config.repetition_penalty,
38
44
  config.repetition_penalty_last_n,
39
- )
45
+ );
46
+
47
+ // Set constraint if provided
48
+ if let Some(ref constraint) = config.constraint {
49
+ text_gen.set_constraint(Arc::clone(constraint));
50
+ }
51
+
52
+ text_gen
40
53
  }
41
54
 
42
55
  pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
@@ -55,6 +68,36 @@ impl TextGeneration {
55
68
  self.tokens.push(token);
56
69
  }
57
70
 
71
+ pub fn set_constraint(&mut self, constraint: Arc<Index>) {
72
+ // Initialize with the first state
73
+ self.constraint_state = Some(constraint.initial_state());
74
+ self.constraint = Some(constraint);
75
+ }
76
+
77
+ /// Apply constraints to logits by masking disallowed tokens
78
+ fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
79
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
80
+ let device = logits.device();
81
+ let vocab_size = logits.dims1()?;
82
+
83
+ // Get allowed tokens from the constraint index for current state
84
+ if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
85
+ // Create a mask where allowed tokens have value 0 and others have -inf
86
+ let mut mask = vec![f32::NEG_INFINITY; vocab_size];
87
+ for &token_id in &allowed_tokens {
88
+ if (token_id as usize) < vocab_size {
89
+ mask[token_id as usize] = 0.0;
90
+ }
91
+ }
92
+
93
+ // Apply mask to logits
94
+ let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
95
+ *logits = logits.add(&mask_tensor)?;
96
+ }
97
+ }
98
+ Ok(())
99
+ }
100
+
58
101
  /// Apply repetition penalty to logits
59
102
  pub fn apply_repetition_penalty(
60
103
  &self,
@@ -103,10 +146,18 @@ impl TextGeneration {
103
146
  self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
104
147
  }
105
148
 
149
+ // Apply constraints if active
150
+ self.apply_constraints(&mut logits)?;
151
+
106
152
  // Sample token
107
153
  let next_token = self.logits_processor.sample(&logits)?;
108
154
  self.tokens.push(next_token);
109
155
 
156
+ // Update constraint state if active
157
+ if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
158
+ self.constraint_state = constraint_index.next_state(&current_state, &next_token);
159
+ }
160
+
110
161
  Ok(next_token)
111
162
  }
112
163
 
@@ -122,6 +173,19 @@ impl TextGeneration {
122
173
  }
123
174
  }
124
175
 
176
+ // Check if we've reached a final state in constraint
177
+ // A state is considered final if it has no allowed tokens
178
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
179
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
180
+ if allowed.is_empty() {
181
+ return true;
182
+ }
183
+ } else {
184
+ // None means no tokens allowed - we're done
185
+ return true;
186
+ }
187
+ }
188
+
125
189
  false
126
190
  }
127
191
 
@@ -162,6 +162,10 @@ impl Device {
162
162
  pub fn __str__(&self) -> String {
163
163
  self.__repr__()
164
164
  }
165
+
166
+ pub fn __eq__(&self, other: &Device) -> bool {
167
+ self == other
168
+ }
165
169
  }
166
170
 
167
171
  impl magnus::TryConvert for Device {
@@ -193,5 +197,6 @@ pub fn init(rb_candle: RModule) -> Result<()> {
193
197
  rb_device.define_singleton_method("default", function!(default_device, 0))?;
194
198
  rb_device.define_method("to_s", method!(Device::__str__, 0))?;
195
199
  rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
200
+ rb_device.define_method("==", method!(Device::__eq__, 1))?;
196
201
  Ok(())
197
202
  }