red-candle 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5a2b4fac15c2c261d8fea90d34973605209043e2b2a222f82414bf94bc5c47e8
4
- data.tar.gz: daff4398f34170e20744ab2ee8b1abb5b977ffec15a129d678efced7c9649495
3
+ metadata.gz: a3678037fbb196c621c8e9df6a213b0d3dffbdb1b8b3dfd73eee4a7ea2feafca
4
+ data.tar.gz: ada97ef81af854439622bdc12b796442be9e0f31e7c7d8a5df374c7bfb07ff2e
5
5
  SHA512:
6
- metadata.gz: 4391a7fb4072d9ac174bcecdb366c975c5e487f43fcc4b0db75533aa44c94822d38103a633eba0b098ac26cf31313e9dd7da77255ed83692584e6936cca86271
7
- data.tar.gz: da7a15a86fea349069079537b8c6f6696079842d205168aa908644a59528d3ecf5922548d8b3d153494f0beb3fe6221b30110368f1f1289f0ac1ff4856f6b243
6
+ metadata.gz: d353177318c4599fa30974a676350087a8e5fd070fe3d317344a4e1b3ae022cb69adf742d62063c2da09dbab7e971cbfae1e53a87527ce7f1c18afd1223797e8
7
+ data.tar.gz: df4b2f43f6fb1aa623053fd09d6e48eba0d8c2615f51dc2accdc4dc292fb3fb7d665553b04cae3747e001e04cd4b9cdbe5c022c3efd077daddf97e074a1e9e5c
data/README.md CHANGED
@@ -58,7 +58,8 @@ end
58
58
  - **EmbeddingModel**: Generate embeddings for text
59
59
  - **Reranker**: Rerank documents based on relevance
60
60
  - **NER**: Named Entity Recognition directly from Ruby
61
- - **LLM**: Chat with Large Language Models (e.g., Llama, Mistral, Gemma)
61
+ - **LLM**: Chat with Large Language Models (e.g., Llama, Mistral, Gemma, Qwen, Phi)
62
+ - **Structured Generation**: Generate JSON from a schema or match a regular expression
62
63
 
63
64
  ## Model Storage
64
65
 
@@ -44,8 +44,8 @@ mod constrained_generation_tests {
44
44
  let config_without_constraint = GenerationConfig::default();
45
45
 
46
46
  // Create text generation instances
47
- let mut gen_constrained = TextGeneration::from_config(&config_with_constraint);
48
- let mut gen_unconstrained = TextGeneration::from_config(&config_without_constraint);
47
+ let mut gen_constrained = TextGeneration::new(&config_with_constraint);
48
+ let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
49
49
 
50
50
  // Set EOS token
51
51
  gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
@@ -60,9 +60,9 @@ mod constrained_generation_tests {
60
60
  fn test_constraint_configuration() {
61
61
  // Test that we can create a TextGeneration with constraints
62
62
  let config = GenerationConfig::default();
63
- let _text_gen = TextGeneration::from_config(&config);
63
+ let _text_gen = TextGeneration::new(&config);
64
64
 
65
- // Test that we can create a TextGeneration from config
65
+ // Test that we can create a TextGeneration with config
66
66
  // Constraints are private implementation details
67
67
  }
68
68
 
@@ -78,7 +78,10 @@ mod constrained_generation_tests {
78
78
  let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
79
79
 
80
80
  // Create text generation with some tokens
81
- let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
81
+ let mut config = GenerationConfig::default();
82
+ config.seed = 42;
83
+ config.temperature = 1.0;
84
+ let mut text_gen = TextGeneration::new(&config);
82
85
  text_gen.push_token(0); // Token that had logit 1.0
83
86
  text_gen.push_token(2); // Token that had logit 2.0
84
87
  text_gen.push_token(5); // Token that had logit 3.0
@@ -100,7 +103,10 @@ mod constrained_generation_tests {
100
103
 
101
104
  #[test]
102
105
  fn test_stop_conditions() {
103
- let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
106
+ let mut config = GenerationConfig::default();
107
+ config.seed = 42;
108
+ config.temperature = 1.0;
109
+ let mut text_gen = TextGeneration::new(&config);
104
110
  text_gen.set_eos_token_id(50256); // Common EOS token
105
111
 
106
112
  // Test max length stop
@@ -120,4 +126,191 @@ mod constrained_generation_tests {
120
126
  assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
121
127
  assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
122
128
  }
129
+
130
+ #[test]
131
+ fn test_sample_next_token_uses_repetition_penalty() {
132
+ use candle_core::{Tensor, Device};
133
+
134
+ let device = Device::Cpu;
135
+ let vocab_size = 10;
136
+
137
+ // Create initial logits
138
+ let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
139
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
140
+
141
+ // Test 1: Create TextGeneration with repetition penalty and add some tokens to history
142
+ let mut config_with_penalty = GenerationConfig::default();
143
+ config_with_penalty.seed = 42;
144
+ config_with_penalty.temperature = 0.1;
145
+ config_with_penalty.repetition_penalty = 1.5;
146
+ config_with_penalty.repetition_penalty_last_n = 10;
147
+ let mut text_gen = TextGeneration::new(&config_with_penalty);
148
+ text_gen.push_token(2); // Token with logit 3.0
149
+ text_gen.push_token(5); // Token with logit 6.0
150
+ text_gen.push_token(9); // Token with logit 10.0
151
+
152
+ // Sample with repetition penalty (now uses stored penalty)
153
+ let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
154
+
155
+ // Test 2: Same setup but without penalty
156
+ let mut config_no_penalty = GenerationConfig::default();
157
+ config_no_penalty.seed = 42;
158
+ config_no_penalty.temperature = 0.1;
159
+ config_no_penalty.repetition_penalty = 1.0; // No penalty
160
+ let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
161
+ text_gen_no_penalty.push_token(2);
162
+ text_gen_no_penalty.push_token(5);
163
+ text_gen_no_penalty.push_token(9);
164
+
165
+ let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
166
+
167
+ // With low temperature and penalty, should avoid previously used high-logit tokens
168
+ // Without penalty, should prefer high-logit tokens
169
+ // This is probabilistic, but with temp=0.1 it should be fairly deterministic
170
+
171
+ // Test 3: Verify penalty is applied correctly by checking modified logits
172
+ let mut config_verify = GenerationConfig::default();
173
+ config_verify.seed = 42;
174
+ config_verify.temperature = 0.1;
175
+ config_verify.repetition_penalty = 2.0;
176
+ config_verify.repetition_penalty_last_n = 10;
177
+ let mut text_gen_verify = TextGeneration::new(&config_verify);
178
+ text_gen_verify.push_token(9); // Highest logit token
179
+
180
+ // Clone logits to check modification
181
+ let mut logits_for_penalty = logits.clone();
182
+ text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
183
+
184
+ let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
185
+ assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
186
+ assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
187
+ }
188
+
189
+ #[test]
190
+ fn test_text_generation_from_config_parameters() {
191
+
192
+ // Create a config with specific values
193
+ let mut config = GenerationConfig::default();
194
+ config.seed = 12345;
195
+ config.temperature = 0.5;
196
+ config.top_p = Some(0.9);
197
+ config.top_k = Some(40); // Currently unused but should be accepted
198
+ config.repetition_penalty = 1.2;
199
+ config.repetition_penalty_last_n = 50;
200
+
201
+ // Create TextGeneration from config
202
+ let text_gen = TextGeneration::new(&config);
203
+
204
+ // We can't directly inspect private fields, but we can test behavior
205
+ // Test that it creates successfully (no panic)
206
+ assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
207
+
208
+ // Test with constraint
209
+ let config_with_constraint = GenerationConfig::default();
210
+ // In real usage, this would be a real constraint
211
+ // For testing, we just verify it accepts the config
212
+ let text_gen_constrained = TextGeneration::new(&config_with_constraint);
213
+ assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
214
+ }
215
+
216
+ #[test]
217
+ fn test_generation_with_different_penalties() {
218
+ use candle_core::{Tensor, Device, DType};
219
+
220
+ let device = Device::Cpu;
221
+ let vocab_size = 50;
222
+
223
+ // Create logits with clear preferences
224
+ let mut logits_vec = vec![0.0; vocab_size];
225
+ logits_vec[10] = 10.0; // Strong preference
226
+ logits_vec[20] = 8.0; // Second preference
227
+ logits_vec[30] = 6.0; // Third preference
228
+
229
+ // Test different penalty configurations
230
+ let configs = vec![
231
+ (1.0, 64), // No penalty (1.0 = neutral)
232
+ (1.5, 64), // Moderate penalty
233
+ (2.0, 64), // Strong penalty
234
+ (1.2, 10), // Penalty with limited range
235
+ ];
236
+
237
+ for (penalty, last_n) in configs {
238
+ let mut config = GenerationConfig::default();
239
+ config.seed = 42; // Fixed seed for reproducibility
240
+ config.temperature = 0.1; // Low temperature for more deterministic behavior
241
+ config.repetition_penalty = penalty;
242
+ config.repetition_penalty_last_n = last_n;
243
+
244
+ let mut text_gen = TextGeneration::new(&config);
245
+
246
+ // Generate a sequence of tokens
247
+ let mut generated = Vec::new();
248
+ for _i in 0..5 {
249
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
250
+
251
+ let token = text_gen.sample_next_token(&logits).unwrap();
252
+
253
+ generated.push(token);
254
+
255
+ // Verify the token is in valid range
256
+ assert!(token < vocab_size as u32, "Token should be within vocabulary");
257
+ }
258
+
259
+ // With higher penalties, we should see more diversity (less repetition)
260
+ let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
261
+ if penalty > 1.5 {
262
+ assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
263
+ }
264
+ }
265
+ }
266
+
267
+ #[test]
268
+ fn test_sample_next_token_integration() {
269
+ use candle_core::{Tensor, Device, DType};
270
+
271
+ let device = Device::Cpu;
272
+
273
+ // Test the full integration of sample_next_token
274
+ let mut config = GenerationConfig::default();
275
+ config.seed = 999;
276
+ config.temperature = 0.7;
277
+ config.max_length = 10;
278
+ config.repetition_penalty = 1.3;
279
+ config.repetition_penalty_last_n = 5;
280
+
281
+ let mut text_gen = TextGeneration::new(&config);
282
+ text_gen.set_eos_token_id(50256);
283
+
284
+ // Simulate a generation loop
285
+ let vocab_size = 100;
286
+ let mut all_tokens = Vec::new();
287
+
288
+ for step in 0..8 {
289
+ // Create varying logits to simulate model output
290
+ let mut logits_vec = vec![0.0; vocab_size];
291
+ // Make different tokens attractive at different steps
292
+ let preferred_token = (step * 13) % vocab_size;
293
+ logits_vec[preferred_token] = 5.0;
294
+ logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
295
+ logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
296
+
297
+ let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
298
+
299
+ let token = text_gen.sample_next_token(&logits).unwrap();
300
+
301
+ all_tokens.push(token);
302
+
303
+ // Check if we should stop
304
+ if text_gen.should_stop(token, config.max_length) {
305
+ break;
306
+ }
307
+ }
308
+
309
+ // Verify generation worked
310
+ assert!(!all_tokens.is_empty(), "Should generate some tokens");
311
+ assert!(all_tokens.len() <= config.max_length, "Should respect max length");
312
+
313
+ // Verify tokens are being tracked
314
+ assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
315
+ }
123
316
  }
@@ -16,6 +16,10 @@ pub struct Gemma {
16
16
  }
17
17
 
18
18
  impl Gemma {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Clear the KV cache between generations
20
24
  pub fn clear_kv_cache(&mut self) {
21
25
  self.model.clear_kv_cache();
@@ -49,6 +53,9 @@ impl Gemma {
49
53
  vec![single_file]
50
54
  } else {
51
55
  // Try to find sharded model files
56
+ // NOTE: This uses a brute-force approach, trying common shard counts.
57
+ // A better approach would be to read model.safetensors.index.json which
58
+ // contains the exact file list, but this works for most models (≤12 shards).
52
59
  let mut sharded_files = Vec::new();
53
60
  let mut index = 1;
54
61
  loop {
@@ -139,7 +146,7 @@ impl Gemma {
139
146
  config: &GenerationConfig,
140
147
  mut callback: Option<impl FnMut(&str)>,
141
148
  ) -> CandleResult<Vec<u32>> {
142
- let mut text_gen = TextGeneration::from_config(config);
149
+ let mut text_gen = TextGeneration::new(config);
143
150
  text_gen.set_eos_token_id(self.eos_token_id);
144
151
  text_gen.set_tokens(prompt_tokens.clone());
145
152
 
@@ -165,10 +172,7 @@ impl Gemma {
165
172
 
166
173
  let logits = logits.to_dtype(DType::F32)?;
167
174
 
168
- let next_token = text_gen.sample_next_token(
169
- &logits,
170
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
171
- )?;
175
+ let next_token = text_gen.sample_next_token(&logits)?;
172
176
 
173
177
  all_tokens.push(next_token);
174
178
 
@@ -190,6 +194,18 @@ impl Gemma {
190
194
  break;
191
195
  }
192
196
 
197
+ // Check if constraint is satisfied (early stopping)
198
+ if config.stop_on_constraint_satisfaction {
199
+ let satisfied = if config.stop_on_match {
200
+ text_gen.is_constraint_satisfied_stop_on_match()
201
+ } else {
202
+ text_gen.is_constraint_satisfied()
203
+ };
204
+ if satisfied {
205
+ break;
206
+ }
207
+ }
208
+
193
209
  // Check stop sequences
194
210
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
195
211
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -27,6 +27,10 @@ pub struct GenerationConfig {
27
27
  pub debug_tokens: bool,
28
28
  /// Optional constraint index for structured generation
29
29
  pub constraint: Option<Arc<Index>>,
30
+ /// Stop immediately when constraint is satisfied
31
+ pub stop_on_constraint_satisfaction: bool,
32
+ /// Whether to stop immediately when pattern is matched (vs allowing continuation)
33
+ pub stop_on_match: bool,
30
34
  }
31
35
 
32
36
  /// Generate a random seed based on current time
@@ -51,6 +55,8 @@ impl Default for GenerationConfig {
51
55
  include_prompt: false,
52
56
  debug_tokens: false,
53
57
  constraint: None,
58
+ stop_on_constraint_satisfaction: true,
59
+ stop_on_match: true,
54
60
  }
55
61
  }
56
62
  }
@@ -18,6 +18,10 @@ pub struct Llama {
18
18
  }
19
19
 
20
20
  impl Llama {
21
+ pub fn eos_token_id(&self) -> u32 {
22
+ self.eos_token_id
23
+ }
24
+
21
25
  /// Clear the KV cache between generations
22
26
  pub fn clear_kv_cache(&mut self) {
23
27
  // Since Cache doesn't expose a reset method and kvs is private,
@@ -58,6 +62,9 @@ impl Llama {
58
62
  vec![consolidated_file]
59
63
  } else {
60
64
  // Try to find sharded model files
65
+ // NOTE: This uses a brute-force approach, trying common shard counts.
66
+ // A better approach would be to read model.safetensors.index.json which
67
+ // contains the exact file list, but this works for most models (≤30 shards).
61
68
  let mut sharded_files = Vec::new();
62
69
  let mut index = 1;
63
70
  loop {
@@ -175,7 +182,7 @@ impl Llama {
175
182
  config: &GenerationConfig,
176
183
  mut callback: Option<impl FnMut(&str)>,
177
184
  ) -> CandleResult<Vec<u32>> {
178
- let mut text_gen = TextGeneration::from_config(config);
185
+ let mut text_gen = TextGeneration::new(config);
179
186
  text_gen.set_eos_token_id(self.eos_token_id);
180
187
  text_gen.set_tokens(prompt_tokens.clone());
181
188
 
@@ -201,10 +208,7 @@ impl Llama {
201
208
 
202
209
  let logits = logits.to_dtype(DType::F32)?;
203
210
 
204
- let next_token = text_gen.sample_next_token(
205
- &logits,
206
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
207
- )?;
211
+ let next_token = text_gen.sample_next_token(&logits)?;
208
212
 
209
213
  all_tokens.push(next_token);
210
214
 
@@ -226,6 +230,18 @@ impl Llama {
226
230
  break;
227
231
  }
228
232
 
233
+ // Check if constraint is satisfied (early stopping)
234
+ if config.stop_on_constraint_satisfaction {
235
+ let satisfied = if config.stop_on_match {
236
+ text_gen.is_constraint_satisfied_stop_on_match()
237
+ } else {
238
+ text_gen.is_constraint_satisfied()
239
+ };
240
+ if satisfied {
241
+ break;
242
+ }
243
+ }
244
+
229
245
  // Check stop sequences
230
246
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
231
247
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -16,6 +16,10 @@ pub struct Mistral {
16
16
  }
17
17
 
18
18
  impl Mistral {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Clear the KV cache between generations
20
24
  pub fn clear_kv_cache(&mut self) {
21
25
  self.model.clear_kv_cache();
@@ -52,6 +56,9 @@ impl Mistral {
52
56
  vec![consolidated_file]
53
57
  } else {
54
58
  // Try to find sharded model files
59
+ // NOTE: This uses a brute-force approach, trying common shard counts.
60
+ // A better approach would be to read model.safetensors.index.json which
61
+ // contains the exact file list, but this works for most models (≤8 shards).
55
62
  let mut sharded_files = Vec::new();
56
63
  let mut index = 1;
57
64
  loop {
@@ -144,7 +151,7 @@ impl Mistral {
144
151
  config: &GenerationConfig,
145
152
  mut callback: Option<impl FnMut(&str)>,
146
153
  ) -> CandleResult<Vec<u32>> {
147
- let mut text_gen = TextGeneration::from_config(config);
154
+ let mut text_gen = TextGeneration::new(config);
148
155
  text_gen.set_eos_token_id(self.eos_token_id);
149
156
  text_gen.set_tokens(prompt_tokens.clone());
150
157
 
@@ -176,10 +183,7 @@ impl Mistral {
176
183
  // Convert to F32 for sampling if needed
177
184
  let logits = logits.to_dtype(DType::F32)?;
178
185
 
179
- let next_token = text_gen.sample_next_token(
180
- &logits,
181
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
182
- )?;
186
+ let next_token = text_gen.sample_next_token(&logits)?;
183
187
 
184
188
  all_tokens.push(next_token);
185
189
 
@@ -201,6 +205,18 @@ impl Mistral {
201
205
  break;
202
206
  }
203
207
 
208
+ // Check if constraint is satisfied (early stopping)
209
+ if config.stop_on_constraint_satisfaction {
210
+ let satisfied = if config.stop_on_match {
211
+ text_gen.is_constraint_satisfied_stop_on_match()
212
+ } else {
213
+ text_gen.is_constraint_satisfied()
214
+ };
215
+ if satisfied {
216
+ break;
217
+ }
218
+ }
219
+
204
220
  // Check stop sequences
205
221
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
206
222
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -21,6 +21,10 @@ enum PhiVariant {
21
21
  }
22
22
 
23
23
  impl Phi {
24
+ pub fn eos_token_id(&self) -> u32 {
25
+ self.eos_token_id
26
+ }
27
+
24
28
  /// Get the tokenizer
25
29
  pub fn tokenizer(&self) -> &TokenizerWrapper {
26
30
  &self.tokenizer
@@ -68,6 +72,9 @@ impl Phi {
68
72
  vec![single_file]
69
73
  } else {
70
74
  // Try to find sharded model files
75
+ // NOTE: This uses a brute-force approach, trying common shard counts.
76
+ // A better approach would be to read model.safetensors.index.json which
77
+ // contains the exact file list, but this works for most models (≤30 shards).
71
78
  let mut sharded_files = Vec::new();
72
79
  let mut index = 1;
73
80
  loop {
@@ -177,7 +184,7 @@ impl Phi {
177
184
  config: &GenerationConfig,
178
185
  mut callback: Option<impl FnMut(&str)>,
179
186
  ) -> CandleResult<Vec<u32>> {
180
- let mut text_gen = TextGeneration::from_config(config);
187
+ let mut text_gen = TextGeneration::new(config);
181
188
  text_gen.set_eos_token_id(self.eos_token_id);
182
189
  text_gen.set_tokens(prompt_tokens.clone());
183
190
 
@@ -206,10 +213,7 @@ impl Phi {
206
213
 
207
214
  let logits = logits.to_dtype(DType::F32)?;
208
215
 
209
- let next_token = text_gen.sample_next_token(
210
- &logits,
211
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
212
- )?;
216
+ let next_token = text_gen.sample_next_token(&logits)?;
213
217
 
214
218
  all_tokens.push(next_token);
215
219
 
@@ -229,6 +233,18 @@ impl Phi {
229
233
  break;
230
234
  }
231
235
 
236
+ // Check if constraint is satisfied (early stopping)
237
+ if config.stop_on_constraint_satisfaction {
238
+ let satisfied = if config.stop_on_match {
239
+ text_gen.is_constraint_satisfied_stop_on_match()
240
+ } else {
241
+ text_gen.is_constraint_satisfied()
242
+ };
243
+ if satisfied {
244
+ break;
245
+ }
246
+ }
247
+
232
248
  // Check stop sequences
233
249
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
234
250
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -32,6 +32,10 @@ enum ModelType {
32
32
  }
33
33
 
34
34
  impl QuantizedGGUF {
35
+ pub fn eos_token_id(&self) -> u32 {
36
+ self.eos_token_id
37
+ }
38
+
35
39
  /// Get the tokenizer
36
40
  pub fn tokenizer(&self) -> &TokenizerWrapper {
37
41
  &self.tokenizer
@@ -538,7 +542,7 @@ impl QuantizedGGUF {
538
542
  config: &GenerationConfig,
539
543
  mut callback: Option<impl FnMut(&str)>,
540
544
  ) -> CandleResult<Vec<u32>> {
541
- let mut text_gen = TextGeneration::from_config(config);
545
+ let mut text_gen = TextGeneration::new(config);
542
546
  text_gen.set_eos_token_id(self.eos_token_id);
543
547
  text_gen.set_tokens(prompt_tokens.clone());
544
548
 
@@ -571,10 +575,7 @@ impl QuantizedGGUF {
571
575
 
572
576
  let logits = logits.to_dtype(DType::F32)?;
573
577
 
574
- let next_token = text_gen.sample_next_token(
575
- &logits,
576
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
577
- )?;
578
+ let next_token = text_gen.sample_next_token(&logits)?;
578
579
 
579
580
  all_tokens.push(next_token);
580
581
 
@@ -596,6 +597,18 @@ impl QuantizedGGUF {
596
597
  break;
597
598
  }
598
599
 
600
+ // Check if constraint is satisfied (early stopping)
601
+ if config.stop_on_constraint_satisfaction {
602
+ let satisfied = if config.stop_on_match {
603
+ text_gen.is_constraint_satisfied_stop_on_match()
604
+ } else {
605
+ text_gen.is_constraint_satisfied()
606
+ };
607
+ if satisfied {
608
+ break;
609
+ }
610
+ }
611
+
599
612
  // Check stop sequences
600
613
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
601
614
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -16,6 +16,10 @@ pub struct Qwen {
16
16
  }
17
17
 
18
18
  impl Qwen {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Get the tokenizer
20
24
  pub fn tokenizer(&self) -> &TokenizerWrapper {
21
25
  &self.tokenizer
@@ -55,6 +59,9 @@ impl Qwen {
55
59
  .unwrap_or(151643); // Default Qwen3 EOS token
56
60
 
57
61
  // Download model weights
62
+ // NOTE: Qwen uses hardcoded shard counts based on model size rather than
63
+ // reading model.safetensors.index.json. This works for official Qwen models
64
+ // but may fail for custom configurations with different shard counts.
58
65
  let mut filenames = vec![];
59
66
  let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
60
67
  else if model_id.contains("14b") || model_id.contains("14B") { 3 }
@@ -124,7 +131,7 @@ impl Qwen {
124
131
  config: &GenerationConfig,
125
132
  mut callback: Option<impl FnMut(&str)>,
126
133
  ) -> CandleResult<Vec<u32>> {
127
- let mut text_gen = TextGeneration::from_config(config);
134
+ let mut text_gen = TextGeneration::new(config);
128
135
  text_gen.set_eos_token_id(self.eos_token_id);
129
136
  text_gen.set_tokens(prompt_tokens.clone());
130
137
 
@@ -150,10 +157,7 @@ impl Qwen {
150
157
 
151
158
  let logits = logits.to_dtype(DType::F32)?;
152
159
 
153
- let next_token = text_gen.sample_next_token(
154
- &logits,
155
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
156
- )?;
160
+ let next_token = text_gen.sample_next_token(&logits)?;
157
161
 
158
162
  all_tokens.push(next_token);
159
163
 
@@ -173,6 +177,18 @@ impl Qwen {
173
177
  break;
174
178
  }
175
179
 
180
+ // Check if constraint is satisfied (early stopping)
181
+ if config.stop_on_constraint_satisfaction {
182
+ let satisfied = if config.stop_on_match {
183
+ text_gen.is_constraint_satisfied_stop_on_match()
184
+ } else {
185
+ text_gen.is_constraint_satisfied()
186
+ };
187
+ if satisfied {
188
+ break;
189
+ }
190
+ }
191
+
176
192
  // Check stop sequences
177
193
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
178
194
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {