red-candle 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,301 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_transformers::models::phi::{Config, Model as PhiModel};
3
+ use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model};
4
+ use hf_hub::api::tokio::Api;
5
+ use tokenizers::Tokenizer;
6
+
7
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ /// Phi model wrapper for text generation
10
+ pub struct Phi {
11
+ model: PhiVariant,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ enum PhiVariant {
19
+ Phi2(PhiModel),
20
+ Phi3(Phi3Model),
21
+ }
22
+
23
+ impl Phi {
24
+ pub fn eos_token_id(&self) -> u32 {
25
+ self.eos_token_id
26
+ }
27
+
28
+ /// Get the tokenizer
29
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
30
+ &self.tokenizer
31
+ }
32
+
33
+ /// Clear the KV cache between generations
34
+ pub fn clear_kv_cache(&mut self) {
35
+ match &mut self.model {
36
+ PhiVariant::Phi2(model) => model.clear_kv_cache(),
37
+ PhiVariant::Phi3(model) => model.clear_kv_cache(),
38
+ }
39
+ }
40
+
41
+ /// Load a Phi model from HuggingFace
42
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
43
+ let api = Api::new()
44
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
45
+
46
+ let repo = api.model(model_id.to_string());
47
+
48
+ // Download configuration
49
+ let config_filename = repo.get("config.json").await
50
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
51
+ let config_str = std::fs::read_to_string(config_filename)?;
52
+
53
+ // Download tokenizer
54
+ let tokenizer_filename = repo.get("tokenizer.json").await
55
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
56
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
57
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
58
+
59
+ // Determine EOS token
60
+ let vocab = tokenizer.get_vocab(true);
61
+ let eos_token_id = vocab.get("<|endoftext|>")
62
+ .or_else(|| vocab.get("<|end|>"))
63
+ .or_else(|| vocab.get("</s>"))
64
+ .copied()
65
+ .unwrap_or(50256); // Default GPT-2 style EOS token
66
+
67
+ // Determine model variant based on model_id or config
68
+ let is_phi3 = model_id.contains("phi-3") || model_id.contains("Phi-3");
69
+
70
+ // Download model weights (handle both single and sharded files)
71
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
72
+ vec![single_file]
73
+ } else {
74
+ // Try to find sharded model files
75
+ // NOTE: This uses a brute-force approach, trying common shard counts.
76
+ // A better approach would be to read model.safetensors.index.json which
77
+ // contains the exact file list, but this works for most models (≤30 shards).
78
+ let mut sharded_files = Vec::new();
79
+ let mut index = 1;
80
+ loop {
81
+ // Try common shard counts
82
+ let mut found = false;
83
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
84
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
85
+ if let Ok(file) = repo.get(&filename).await {
86
+ sharded_files.push(file);
87
+ found = true;
88
+ break;
89
+ }
90
+ }
91
+ if !found {
92
+ break;
93
+ }
94
+ index += 1;
95
+ }
96
+
97
+ if sharded_files.is_empty() {
98
+ return Err(candle_core::Error::Msg(
99
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
100
+ ));
101
+ }
102
+ sharded_files
103
+ };
104
+
105
+ let model = if is_phi3 {
106
+ // Load Phi3 model
107
+ let config: Phi3Config = serde_json::from_str(&config_str)
108
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
109
+
110
+ let vb = unsafe {
111
+ candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
112
+ };
113
+
114
+ let model = Phi3Model::new(&config, vb)?;
115
+ PhiVariant::Phi3(model)
116
+ } else {
117
+ // Load Phi2 model
118
+ let config: Config = serde_json::from_str(&config_str)
119
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi config: {}", e)))?;
120
+
121
+ let vb = unsafe {
122
+ candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
123
+ };
124
+
125
+ let model = PhiModel::new(&config, vb)?;
126
+ PhiVariant::Phi2(model)
127
+ };
128
+
129
+ Ok(Self {
130
+ model,
131
+ tokenizer: TokenizerWrapper::new(tokenizer),
132
+ device,
133
+ model_id: model_id.to_string(),
134
+ eos_token_id,
135
+ })
136
+ }
137
+
138
+ /// Apply Phi chat template to messages
139
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
140
+ let mut prompt = String::new();
141
+
142
+ // Phi-3 uses a specific format
143
+ if matches!(self.model, PhiVariant::Phi3(_)) {
144
+ for message in messages {
145
+ let role = message["role"].as_str().unwrap_or("");
146
+ let content = message["content"].as_str().unwrap_or("");
147
+
148
+ match role {
149
+ "system" => {
150
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
151
+ }
152
+ "user" => {
153
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
154
+ }
155
+ "assistant" => {
156
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
157
+ }
158
+ _ => {}
159
+ }
160
+ }
161
+ prompt.push_str("<|assistant|>\n");
162
+ } else {
163
+ // Phi-2 uses a simpler format
164
+ for message in messages {
165
+ let role = message["role"].as_str().unwrap_or("");
166
+ let content = message["content"].as_str().unwrap_or("");
167
+
168
+ match role {
169
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
170
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
171
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
172
+ _ => {}
173
+ }
174
+ }
175
+ prompt.push_str("Assistant: ");
176
+ }
177
+
178
+ Ok(prompt)
179
+ }
180
+
181
+ fn generate_tokens(
182
+ &mut self,
183
+ prompt_tokens: Vec<u32>,
184
+ config: &GenerationConfig,
185
+ mut callback: Option<impl FnMut(&str)>,
186
+ ) -> CandleResult<Vec<u32>> {
187
+ let mut text_gen = TextGeneration::new(config);
188
+ text_gen.set_eos_token_id(self.eos_token_id);
189
+ text_gen.set_tokens(prompt_tokens.clone());
190
+
191
+ let mut all_tokens = prompt_tokens.clone();
192
+ let start_gen = all_tokens.len();
193
+
194
+ for index in 0..config.max_length {
195
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
196
+ let start_pos = all_tokens.len().saturating_sub(context_size);
197
+ let ctxt = &all_tokens[start_pos..];
198
+
199
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
200
+ let logits = match &mut self.model {
201
+ PhiVariant::Phi2(model) => model.forward(&input)?,
202
+ PhiVariant::Phi3(model) => model.forward(&input, start_pos)?,
203
+ };
204
+ let logits = logits.squeeze(0)?;
205
+
206
+ // Handle different output shapes
207
+ let logits = if logits.dims().len() == 2 {
208
+ let seq_len = logits.dim(0)?;
209
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
210
+ } else {
211
+ logits
212
+ };
213
+
214
+ let logits = logits.to_dtype(DType::F32)?;
215
+
216
+ let next_token = text_gen.sample_next_token(&logits)?;
217
+
218
+ all_tokens.push(next_token);
219
+
220
+ // Stream callback
221
+ if let Some(ref mut cb) = callback {
222
+ if config.debug_tokens {
223
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
224
+ cb(&format!("[{}:{}]", next_token, token_piece));
225
+ } else {
226
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
227
+ cb(&decoded_text);
228
+ }
229
+ }
230
+
231
+ // Check stop conditions
232
+ if text_gen.should_stop(next_token, config.max_length) {
233
+ break;
234
+ }
235
+
236
+ // Check if constraint is satisfied (early stopping)
237
+ if config.stop_on_constraint_satisfaction {
238
+ let satisfied = if config.stop_on_match {
239
+ text_gen.is_constraint_satisfied_stop_on_match()
240
+ } else {
241
+ text_gen.is_constraint_satisfied()
242
+ };
243
+ if satisfied {
244
+ break;
245
+ }
246
+ }
247
+
248
+ // Check stop sequences
249
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
250
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
251
+ break;
252
+ }
253
+ }
254
+
255
+ Ok(if config.include_prompt {
256
+ all_tokens
257
+ } else {
258
+ all_tokens[start_gen..].to_vec()
259
+ })
260
+ }
261
+ }
262
+
263
+ impl TextGenerator for Phi {
264
+ fn generate(
265
+ &mut self,
266
+ prompt: &str,
267
+ config: &GenerationConfig,
268
+ ) -> CandleResult<String> {
269
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
270
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
271
+
272
+ if config.debug_tokens {
273
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
274
+ } else {
275
+ self.tokenizer.decode(&output_tokens, true)
276
+ }
277
+ }
278
+
279
+ fn generate_stream(
280
+ &mut self,
281
+ prompt: &str,
282
+ config: &GenerationConfig,
283
+ mut callback: impl FnMut(&str),
284
+ ) -> CandleResult<String> {
285
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
286
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
287
+ self.tokenizer.decode(&output_tokens, true)
288
+ }
289
+
290
+ fn model_name(&self) -> &str {
291
+ &self.model_id
292
+ }
293
+
294
+ fn device(&self) -> &Device {
295
+ &self.device
296
+ }
297
+
298
+ fn clear_cache(&mut self) {
299
+ self.clear_kv_cache();
300
+ }
301
+ }
@@ -2,6 +2,9 @@ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
2
  use candle_core::quantized::gguf_file;
3
3
  use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
4
4
  use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
5
+ use candle_transformers::models::quantized_qwen2::ModelWeights as QuantizedQwenModel;
6
+ use candle_transformers::models::quantized_phi::ModelWeights as QuantizedPhiModel;
7
+ use candle_transformers::models::quantized_phi3::ModelWeights as QuantizedPhi3Model;
5
8
  use hf_hub::api::tokio::{Api, ApiRepo};
6
9
  use tokenizers::Tokenizer;
7
10
  use std::io::Seek;
@@ -9,7 +12,6 @@ use std::io::Seek;
9
12
  use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
10
13
 
11
14
  /// Unified GGUF model that can load any GGUF file and detect the architecture
12
- #[derive(Debug)]
13
15
  pub struct QuantizedGGUF {
14
16
  model: ModelType,
15
17
  tokenizer: TokenizerWrapper,
@@ -20,14 +22,20 @@ pub struct QuantizedGGUF {
20
22
  _chat_template: Option<String>,
21
23
  }
22
24
 
23
- #[derive(Debug)]
24
25
  enum ModelType {
25
26
  Llama(QuantizedLlamaModel),
26
27
  Gemma(QuantizedGemmaModel),
28
+ Qwen(QuantizedQwenModel),
29
+ Phi(QuantizedPhiModel),
30
+ Phi3(QuantizedPhi3Model),
27
31
  // Mistral uses Llama loader due to tensor naming compatibility
28
32
  }
29
33
 
30
34
  impl QuantizedGGUF {
35
+ pub fn eos_token_id(&self) -> u32 {
36
+ self.eos_token_id
37
+ }
38
+
31
39
  /// Get the tokenizer
32
40
  pub fn tokenizer(&self) -> &TokenizerWrapper {
33
41
  &self.tokenizer
@@ -97,6 +105,34 @@ impl QuantizedGGUF {
97
105
  let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
98
106
  ModelType::Llama(model)
99
107
  }
108
+ "qwen" | "qwen2" | "qwen3" => {
109
+ // Try different loaders based on what metadata is available
110
+ if content.metadata.contains_key("llama.attention.head_count") {
111
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
112
+ ModelType::Llama(model)
113
+ } else if content.metadata.contains_key("qwen2.attention.head_count") {
114
+ let model = QuantizedQwenModel::from_gguf(content, &mut file, &device)?;
115
+ ModelType::Qwen(model)
116
+ } else if content.metadata.contains_key("qwen3.attention.head_count") {
117
+ // Qwen3 GGUF files use a different metadata format
118
+ // The quantized_qwen3 module is not yet in the released version of candle-transformers
119
+ return Err(candle_core::Error::Msg(format!(
120
+ "Qwen3 GGUF format detected but not yet fully supported.\n\n\
121
+ The file contains qwen3.* metadata keys which require candle-transformers > 0.9.1.\n\n\
122
+ Current alternatives:\n\
123
+ 1. Use Qwen2.5 GGUF models which work well:\n\
124
+ - Qwen/Qwen2.5-7B-Instruct-GGUF (recommended)\n\
125
+ - Qwen/Qwen2.5-32B-Instruct-GGUF\n\
126
+ 2. Use non-quantized Qwen models with safetensors\n\
127
+ 3. Wait for candle-transformers update with quantized_qwen3 support\n\n\
128
+ Note: Qwen2.5 models have similar capabilities to Qwen3."
129
+ )));
130
+ } else {
131
+ // Last resort: try llama loader anyway, as it's the most common
132
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
133
+ ModelType::Llama(model)
134
+ }
135
+ }
100
136
  "gemma" | "gemma2" | "gemma3" => {
101
137
  // Try Gemma-specific loader first, fall back to Llama if it fails
102
138
  match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
@@ -112,9 +148,20 @@ impl QuantizedGGUF {
112
148
  Err(e) => return Err(e),
113
149
  }
114
150
  }
151
+ "phi" | "phi2" => {
152
+ let model = QuantizedPhiModel::from_gguf(content, &mut file, &device)?;
153
+ ModelType::Phi(model)
154
+ }
155
+ "phi3" => {
156
+ // QuantizedPhi3Model requires an additional `approx` parameter
157
+ // Setting to false to avoid performance issues without flash-attn
158
+ let approx = false;
159
+ let model = QuantizedPhi3Model::from_gguf(approx, content, &mut file, &device)?;
160
+ ModelType::Phi3(model)
161
+ }
115
162
  _ => {
116
163
  return Err(candle_core::Error::Msg(format!(
117
- "Unsupported architecture: {}. Supported: llama, mistral, gemma",
164
+ "Unsupported architecture: {}. Supported: llama, mistral, gemma, qwen, qwen2, qwen3, phi, phi2, phi3",
118
165
  architecture
119
166
  )));
120
167
  }
@@ -149,6 +196,14 @@ impl QuantizedGGUF {
149
196
  Ok("mistral".to_string())
150
197
  } else if model_lower.contains("gemma") {
151
198
  Ok("gemma".to_string())
199
+ } else if model_lower.contains("qwen") {
200
+ Ok("qwen".to_string())
201
+ } else if model_lower.contains("phi-3") || model_lower.contains("phi3") {
202
+ Ok("phi3".to_string())
203
+ } else if model_lower.contains("phi-2") || model_lower.contains("phi2") {
204
+ Ok("phi2".to_string())
205
+ } else if model_lower.contains("phi") {
206
+ Ok("phi".to_string())
152
207
  } else {
153
208
  Err(candle_core::Error::Msg(
154
209
  "Could not determine model architecture from metadata or name".to_string()
@@ -235,6 +290,20 @@ impl QuantizedGGUF {
235
290
  .copied()
236
291
  .unwrap_or(1)
237
292
  }
293
+ "qwen" | "qwen2" | "qwen3" => {
294
+ vocab.get("<|endoftext|>")
295
+ .or_else(|| vocab.get("<|im_end|>"))
296
+ .or_else(|| vocab.get("</s>"))
297
+ .copied()
298
+ .unwrap_or(151643) // Default Qwen3 EOS token
299
+ }
300
+ "phi" | "phi2" | "phi3" => {
301
+ vocab.get("<|endoftext|>")
302
+ .or_else(|| vocab.get("<|end|>"))
303
+ .or_else(|| vocab.get("</s>"))
304
+ .copied()
305
+ .unwrap_or(50256) // Default GPT-2 style EOS token
306
+ }
238
307
  _ => 2, // Default
239
308
  }
240
309
  }
@@ -256,6 +325,10 @@ impl QuantizedGGUF {
256
325
  } else if model_lower.contains("gemma") {
257
326
  // Always use Gemma template for Gemma models, regardless of loader used
258
327
  self.apply_gemma_template(messages)
328
+ } else if model_lower.contains("qwen") {
329
+ self.apply_qwen_template(messages)
330
+ } else if model_lower.contains("phi") {
331
+ self.apply_phi_template(messages)
259
332
  } else {
260
333
  match self.architecture.as_str() {
261
334
  "llama" => {
@@ -268,6 +341,12 @@ impl QuantizedGGUF {
268
341
  "gemma" => {
269
342
  self.apply_gemma_template(messages)
270
343
  }
344
+ "qwen" | "qwen2" | "qwen3" => {
345
+ self.apply_qwen_template(messages)
346
+ }
347
+ "phi" | "phi2" | "phi3" => {
348
+ self.apply_phi_template(messages)
349
+ }
271
350
  _ => Ok(self.apply_generic_template(messages))
272
351
  }
273
352
  }
@@ -366,6 +445,77 @@ impl QuantizedGGUF {
366
445
  Ok(prompt)
367
446
  }
368
447
 
448
+ fn apply_qwen_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
449
+ let mut prompt = String::new();
450
+
451
+ for message in messages {
452
+ let role = message["role"].as_str().unwrap_or("");
453
+ let content = message["content"].as_str().unwrap_or("");
454
+
455
+ match role {
456
+ "system" => {
457
+ prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
458
+ }
459
+ "user" => {
460
+ prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
461
+ }
462
+ "assistant" => {
463
+ prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
464
+ }
465
+ _ => {}
466
+ }
467
+ }
468
+
469
+ // Add generation prompt
470
+ prompt.push_str("<|im_start|>assistant\n");
471
+ Ok(prompt)
472
+ }
473
+
474
+ fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
475
+ let mut prompt = String::new();
476
+
477
+ // Check if it's Phi-3 (newer format) or Phi-2/Phi (simpler format)
478
+ let is_phi3 = self.model_id.contains("phi-3") || self.model_id.contains("Phi-3") || self.architecture == "phi3";
479
+
480
+ if is_phi3 {
481
+ // Phi-3 format
482
+ for message in messages {
483
+ let role = message["role"].as_str().unwrap_or("");
484
+ let content = message["content"].as_str().unwrap_or("");
485
+
486
+ match role {
487
+ "system" => {
488
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
489
+ }
490
+ "user" => {
491
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
492
+ }
493
+ "assistant" => {
494
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
495
+ }
496
+ _ => {}
497
+ }
498
+ }
499
+ prompt.push_str("<|assistant|>\n");
500
+ } else {
501
+ // Phi-2 format
502
+ for message in messages {
503
+ let role = message["role"].as_str().unwrap_or("");
504
+ let content = message["content"].as_str().unwrap_or("");
505
+
506
+ match role {
507
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
508
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
509
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
510
+ _ => {}
511
+ }
512
+ }
513
+ prompt.push_str("Assistant: ");
514
+ }
515
+
516
+ Ok(prompt)
517
+ }
518
+
369
519
  fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
370
520
  let mut prompt = String::new();
371
521
 
@@ -381,7 +531,9 @@ impl QuantizedGGUF {
381
531
 
382
532
  /// Clear the KV cache between generations
383
533
  pub fn clear_kv_cache(&mut self) {
384
- // Quantized models manage cache internally
534
+ // Quantized models don't expose cache clearing methods
535
+ // Phi3 GGUF models have a known issue where the KV cache
536
+ // cannot be cleared, leading to errors on subsequent generations
385
537
  }
386
538
 
387
539
  fn generate_tokens(
@@ -390,7 +542,7 @@ impl QuantizedGGUF {
390
542
  config: &GenerationConfig,
391
543
  mut callback: Option<impl FnMut(&str)>,
392
544
  ) -> CandleResult<Vec<u32>> {
393
- let mut text_gen = TextGeneration::from_config(config);
545
+ let mut text_gen = TextGeneration::new(config);
394
546
  text_gen.set_eos_token_id(self.eos_token_id);
395
547
  text_gen.set_tokens(prompt_tokens.clone());
396
548
 
@@ -408,6 +560,9 @@ impl QuantizedGGUF {
408
560
  let logits = match &mut self.model {
409
561
  ModelType::Llama(model) => model.forward(&input, start_pos)?,
410
562
  ModelType::Gemma(model) => model.forward(&input, start_pos)?,
563
+ ModelType::Qwen(model) => model.forward(&input, start_pos)?,
564
+ ModelType::Phi(model) => model.forward(&input, start_pos)?,
565
+ ModelType::Phi3(model) => model.forward(&input, start_pos)?,
411
566
  };
412
567
 
413
568
  let logits = logits.squeeze(0)?;
@@ -420,10 +575,7 @@ impl QuantizedGGUF {
420
575
 
421
576
  let logits = logits.to_dtype(DType::F32)?;
422
577
 
423
- let next_token = text_gen.sample_next_token(
424
- &logits,
425
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
426
- )?;
578
+ let next_token = text_gen.sample_next_token(&logits)?;
427
579
 
428
580
  all_tokens.push(next_token);
429
581
 
@@ -445,6 +597,18 @@ impl QuantizedGGUF {
445
597
  break;
446
598
  }
447
599
 
600
+ // Check if constraint is satisfied (early stopping)
601
+ if config.stop_on_constraint_satisfaction {
602
+ let satisfied = if config.stop_on_match {
603
+ text_gen.is_constraint_satisfied_stop_on_match()
604
+ } else {
605
+ text_gen.is_constraint_satisfied()
606
+ };
607
+ if satisfied {
608
+ break;
609
+ }
610
+ }
611
+
448
612
  // Check stop sequences
449
613
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
450
614
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {