red-candle 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5a2b4fac15c2c261d8fea90d34973605209043e2b2a222f82414bf94bc5c47e8
4
- data.tar.gz: daff4398f34170e20744ab2ee8b1abb5b977ffec15a129d678efced7c9649495
3
+ metadata.gz: cd566594df4f0d3ec8ecd592b1d71610ef0ba2091cd91ea53549405a8c5c9b18
4
+ data.tar.gz: cfaab403935e927371fbd2f605ea60da3f44effd64950bb132651a5e6437e30c
5
5
  SHA512:
6
- metadata.gz: 4391a7fb4072d9ac174bcecdb366c975c5e487f43fcc4b0db75533aa44c94822d38103a633eba0b098ac26cf31313e9dd7da77255ed83692584e6936cca86271
7
- data.tar.gz: da7a15a86fea349069079537b8c6f6696079842d205168aa908644a59528d3ecf5922548d8b3d153494f0beb3fe6221b30110368f1f1289f0ac1ff4856f6b243
6
+ metadata.gz: b18f384766db5a3de2f794764a8a47d3d69261ae4bf5c93992df7228fe9a5817eb1862b516a9ee3e63b5d349936f46534d50f1378b7932c7ac482550270a8bea
7
+ data.tar.gz: ce74834cde884bf03e3d163f6d081e87f95db0fba8db24300b8c6aa5b0f1f0162542c836bebb414010534778e394ca36faaec807ff1cc652902c3b318503def5
data/README.md CHANGED
@@ -58,7 +58,8 @@ end
58
58
  - **EmbeddingModel**: Generate embeddings for text
59
59
  - **Reranker**: Rerank documents based on relevance
60
60
  - **NER**: Named Entity Recognition directly from Ruby
61
- - **LLM**: Chat with Large Language Models (e.g., Llama, Mistral, Gemma)
61
+ - **LLM**: Chat with Large Language Models (e.g., Llama, Mistral, Gemma, Qwen, Phi)
62
+ - **Structured Generation**: Generate JSON from a schema or match a regular expression
62
63
 
63
64
  ## Model Storage
64
65
 
@@ -887,6 +888,69 @@ bundle exec rake compile
887
888
 
888
889
  Pull requests are welcome.
889
890
 
891
+ ## Testing
892
+
893
+ Red Candle has comprehensive tests at both the Ruby and Rust levels:
894
+
895
+ ### Ruby Tests
896
+ ```bash
897
+ # Run all Ruby tests
898
+ bundle exec rake test
899
+
900
+ # Run specific test suites
901
+ bundle exec rake test:device # Device compatibility tests
902
+ bundle exec rake test:benchmark # Benchmark tests
903
+ bundle exec rake test:llm:mistral # Model-specific tests
904
+ ```
905
+
906
+ ### Rust Tests
907
+ ```bash
908
+ # Run Rust unit and integration tests
909
+ cd ext/candle && cargo test
910
+
911
+ # Or use the Rake task
912
+ bundle exec rake rust:test
913
+ ```
914
+
915
+ The Rust tests include:
916
+ - Unit tests within source files (using `#[cfg(test)]` modules)
917
+ - Integration tests for external dependencies (candle_core operations)
918
+ - Tests for structured generation, tokenization, and text generation
919
+
920
+ ### Code Coverage
921
+
922
+ #### Rust Code Coverage
923
+ Red Candle uses `cargo-llvm-cov` for Rust code coverage analysis:
924
+
925
+ ```bash
926
+ # Generate HTML coverage report (opens in target/llvm-cov/html/index.html)
927
+ bundle exec rake rust:coverage:html
928
+
929
+ # Show coverage summary in terminal
930
+ bundle exec rake rust:coverage:summary
931
+
932
+ # Generate detailed coverage report
933
+ bundle exec rake rust:coverage:report
934
+
935
+ # Generate LCOV format for CI integration
936
+ bundle exec rake rust:coverage:lcov
937
+
938
+ # Clean coverage data
939
+ bundle exec rake rust:coverage:clean
940
+ ```
941
+
942
+ **Note**: Overall Rust coverage shows ~17% because most code consists of Ruby FFI bindings that are tested through Ruby tests. The testable Rust components have high coverage:
943
+ - Constrained generation: 99.59%
944
+ - Schema processing: 90.99%
945
+ - Integration tests: 97.12%
946
+
947
+ #### Ruby Code Coverage
948
+ Ruby test coverage is generated automatically when running tests:
949
+ ```bash
950
+ bundle exec rake test
951
+ # Coverage report generated in coverage/index.html
952
+ ```
953
+
890
954
  ## Release
891
955
 
892
956
  1. Update version number in `lib/candle/version.rb` and commit.
data/Rakefile CHANGED
@@ -134,3 +134,43 @@ namespace :doc do
134
134
  end
135
135
 
136
136
  task doc: "doc:default"
137
+
138
+ namespace :rust do
139
+ desc "Run Rust tests with code coverage"
140
+ namespace :coverage do
141
+ desc "Generate HTML coverage report"
142
+ task :html do
143
+ sh "cd ext/candle && cargo llvm-cov --html"
144
+ puts "Coverage report generated in target/llvm-cov/html/index.html"
145
+ end
146
+
147
+ desc "Generate coverage report in terminal"
148
+ task :report do
149
+ sh "cd ext/candle && cargo llvm-cov"
150
+ end
151
+
152
+ desc "Show coverage summary"
153
+ task :summary do
154
+ sh "cd ext/candle && cargo llvm-cov --summary-only"
155
+ end
156
+
157
+ desc "Generate lcov format coverage report"
158
+ task :lcov do
159
+ sh "cd ext/candle && cargo llvm-cov --lcov --output-path ../../coverage/lcov.info"
160
+ puts "LCOV report generated in coverage/lcov.info"
161
+ end
162
+
163
+ desc "Clean coverage data"
164
+ task :clean do
165
+ sh "cd ext/candle && cargo llvm-cov clean"
166
+ end
167
+ end
168
+
169
+ desc "Run Rust tests"
170
+ task :test do
171
+ sh "cd ext/candle && cargo test"
172
+ end
173
+ end
174
+
175
+ desc "Run Rust tests with coverage (alias)"
176
+ task "coverage:rust" => "rust:coverage:html"
@@ -44,8 +44,8 @@ mod constrained_generation_tests {
44
44
  let config_without_constraint = GenerationConfig::default();
45
45
 
46
46
  // Create text generation instances
47
- let mut gen_constrained = TextGeneration::from_config(&config_with_constraint);
48
- let mut gen_unconstrained = TextGeneration::from_config(&config_without_constraint);
47
+ let mut gen_constrained = TextGeneration::new(&config_with_constraint);
48
+ let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
49
49
 
50
50
  // Set EOS token
51
51
  gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
@@ -60,9 +60,9 @@ mod constrained_generation_tests {
60
60
  fn test_constraint_configuration() {
61
61
  // Test that we can create a TextGeneration with constraints
62
62
  let config = GenerationConfig::default();
63
- let _text_gen = TextGeneration::from_config(&config);
63
+ let _text_gen = TextGeneration::new(&config);
64
64
 
65
- // Test that we can create a TextGeneration from config
65
+ // Test that we can create a TextGeneration with config
66
66
  // Constraints are private implementation details
67
67
  }
68
68
 
@@ -78,7 +78,10 @@ mod constrained_generation_tests {
78
78
  let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
79
79
 
80
80
  // Create text generation with some tokens
81
- let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
81
+ let mut config = GenerationConfig::default();
82
+ config.seed = 42;
83
+ config.temperature = 1.0;
84
+ let mut text_gen = TextGeneration::new(&config);
82
85
  text_gen.push_token(0); // Token that had logit 1.0
83
86
  text_gen.push_token(2); // Token that had logit 2.0
84
87
  text_gen.push_token(5); // Token that had logit 3.0
@@ -100,7 +103,10 @@ mod constrained_generation_tests {
100
103
 
101
104
  #[test]
102
105
  fn test_stop_conditions() {
103
- let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
106
+ let mut config = GenerationConfig::default();
107
+ config.seed = 42;
108
+ config.temperature = 1.0;
109
+ let mut text_gen = TextGeneration::new(&config);
104
110
  text_gen.set_eos_token_id(50256); // Common EOS token
105
111
 
106
112
  // Test max length stop
@@ -120,4 +126,191 @@ mod constrained_generation_tests {
120
126
  assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
121
127
  assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
122
128
  }
129
+
130
+ #[test]
131
+ fn test_sample_next_token_uses_repetition_penalty() {
132
+ use candle_core::{Tensor, Device};
133
+
134
+ let device = Device::Cpu;
135
+ let vocab_size = 10;
136
+
137
+ // Create initial logits
138
+ let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
139
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
140
+
141
+ // Test 1: Create TextGeneration with repetition penalty and add some tokens to history
142
+ let mut config_with_penalty = GenerationConfig::default();
143
+ config_with_penalty.seed = 42;
144
+ config_with_penalty.temperature = 0.1;
145
+ config_with_penalty.repetition_penalty = 1.5;
146
+ config_with_penalty.repetition_penalty_last_n = 10;
147
+ let mut text_gen = TextGeneration::new(&config_with_penalty);
148
+ text_gen.push_token(2); // Token with logit 3.0
149
+ text_gen.push_token(5); // Token with logit 6.0
150
+ text_gen.push_token(9); // Token with logit 10.0
151
+
152
+ // Sample with repetition penalty (now uses stored penalty)
153
+ let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
154
+
155
+ // Test 2: Same setup but without penalty
156
+ let mut config_no_penalty = GenerationConfig::default();
157
+ config_no_penalty.seed = 42;
158
+ config_no_penalty.temperature = 0.1;
159
+ config_no_penalty.repetition_penalty = 1.0; // No penalty
160
+ let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
161
+ text_gen_no_penalty.push_token(2);
162
+ text_gen_no_penalty.push_token(5);
163
+ text_gen_no_penalty.push_token(9);
164
+
165
+ let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
166
+
167
+ // With low temperature and penalty, should avoid previously used high-logit tokens
168
+ // Without penalty, should prefer high-logit tokens
169
+ // This is probabilistic, but with temp=0.1 it should be fairly deterministic
170
+
171
+ // Test 3: Verify penalty is applied correctly by checking modified logits
172
+ let mut config_verify = GenerationConfig::default();
173
+ config_verify.seed = 42;
174
+ config_verify.temperature = 0.1;
175
+ config_verify.repetition_penalty = 2.0;
176
+ config_verify.repetition_penalty_last_n = 10;
177
+ let mut text_gen_verify = TextGeneration::new(&config_verify);
178
+ text_gen_verify.push_token(9); // Highest logit token
179
+
180
+ // Clone logits to check modification
181
+ let mut logits_for_penalty = logits.clone();
182
+ text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
183
+
184
+ let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
185
+ assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
186
+ assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
187
+ }
188
+
189
+ #[test]
190
+ fn test_text_generation_from_config_parameters() {
191
+
192
+ // Create a config with specific values
193
+ let mut config = GenerationConfig::default();
194
+ config.seed = 12345;
195
+ config.temperature = 0.5;
196
+ config.top_p = Some(0.9);
197
+ config.top_k = Some(40); // Currently unused but should be accepted
198
+ config.repetition_penalty = 1.2;
199
+ config.repetition_penalty_last_n = 50;
200
+
201
+ // Create TextGeneration from config
202
+ let text_gen = TextGeneration::new(&config);
203
+
204
+ // We can't directly inspect private fields, but we can test behavior
205
+ // Test that it creates successfully (no panic)
206
+ assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
207
+
208
+ // Test with constraint
209
+ let config_with_constraint = GenerationConfig::default();
210
+ // In real usage, this would be a real constraint
211
+ // For testing, we just verify it accepts the config
212
+ let text_gen_constrained = TextGeneration::new(&config_with_constraint);
213
+ assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
214
+ }
215
+
216
+ #[test]
217
+ fn test_generation_with_different_penalties() {
218
+ use candle_core::{Tensor, Device, DType};
219
+
220
+ let device = Device::Cpu;
221
+ let vocab_size = 50;
222
+
223
+ // Create logits with clear preferences
224
+ let mut logits_vec = vec![0.0; vocab_size];
225
+ logits_vec[10] = 10.0; // Strong preference
226
+ logits_vec[20] = 8.0; // Second preference
227
+ logits_vec[30] = 6.0; // Third preference
228
+
229
+ // Test different penalty configurations
230
+ let configs = vec![
231
+ (1.0, 64), // No penalty (1.0 = neutral)
232
+ (1.5, 64), // Moderate penalty
233
+ (2.0, 64), // Strong penalty
234
+ (1.2, 10), // Penalty with limited range
235
+ ];
236
+
237
+ for (penalty, last_n) in configs {
238
+ let mut config = GenerationConfig::default();
239
+ config.seed = 42; // Fixed seed for reproducibility
240
+ config.temperature = 0.1; // Low temperature for more deterministic behavior
241
+ config.repetition_penalty = penalty;
242
+ config.repetition_penalty_last_n = last_n;
243
+
244
+ let mut text_gen = TextGeneration::new(&config);
245
+
246
+ // Generate a sequence of tokens
247
+ let mut generated = Vec::new();
248
+ for _i in 0..5 {
249
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
250
+
251
+ let token = text_gen.sample_next_token(&logits).unwrap();
252
+
253
+ generated.push(token);
254
+
255
+ // Verify the token is in valid range
256
+ assert!(token < vocab_size as u32, "Token should be within vocabulary");
257
+ }
258
+
259
+ // With higher penalties, we should see more diversity (less repetition)
260
+ let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
261
+ if penalty > 1.5 {
262
+ assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
263
+ }
264
+ }
265
+ }
266
+
267
+ #[test]
268
+ fn test_sample_next_token_integration() {
269
+ use candle_core::{Tensor, Device, DType};
270
+
271
+ let device = Device::Cpu;
272
+
273
+ // Test the full integration of sample_next_token
274
+ let mut config = GenerationConfig::default();
275
+ config.seed = 999;
276
+ config.temperature = 0.7;
277
+ config.max_length = 10;
278
+ config.repetition_penalty = 1.3;
279
+ config.repetition_penalty_last_n = 5;
280
+
281
+ let mut text_gen = TextGeneration::new(&config);
282
+ text_gen.set_eos_token_id(50256);
283
+
284
+ // Simulate a generation loop
285
+ let vocab_size = 100;
286
+ let mut all_tokens = Vec::new();
287
+
288
+ for step in 0..8 {
289
+ // Create varying logits to simulate model output
290
+ let mut logits_vec = vec![0.0; vocab_size];
291
+ // Make different tokens attractive at different steps
292
+ let preferred_token = (step * 13) % vocab_size;
293
+ logits_vec[preferred_token] = 5.0;
294
+ logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
295
+ logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
296
+
297
+ let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
298
+
299
+ let token = text_gen.sample_next_token(&logits).unwrap();
300
+
301
+ all_tokens.push(token);
302
+
303
+ // Check if we should stop
304
+ if text_gen.should_stop(token, config.max_length) {
305
+ break;
306
+ }
307
+ }
308
+
309
+ // Verify generation worked
310
+ assert!(!all_tokens.is_empty(), "Should generate some tokens");
311
+ assert!(all_tokens.len() <= config.max_length, "Should respect max length");
312
+
313
+ // Verify tokens are being tracked
314
+ assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
315
+ }
123
316
  }
@@ -16,6 +16,10 @@ pub struct Gemma {
16
16
  }
17
17
 
18
18
  impl Gemma {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Clear the KV cache between generations
20
24
  pub fn clear_kv_cache(&mut self) {
21
25
  self.model.clear_kv_cache();
@@ -49,6 +53,9 @@ impl Gemma {
49
53
  vec![single_file]
50
54
  } else {
51
55
  // Try to find sharded model files
56
+ // NOTE: This uses a brute-force approach, trying common shard counts.
57
+ // A better approach would be to read model.safetensors.index.json which
58
+ // contains the exact file list, but this works for most models (≤12 shards).
52
59
  let mut sharded_files = Vec::new();
53
60
  let mut index = 1;
54
61
  loop {
@@ -139,7 +146,7 @@ impl Gemma {
139
146
  config: &GenerationConfig,
140
147
  mut callback: Option<impl FnMut(&str)>,
141
148
  ) -> CandleResult<Vec<u32>> {
142
- let mut text_gen = TextGeneration::from_config(config);
149
+ let mut text_gen = TextGeneration::new(config);
143
150
  text_gen.set_eos_token_id(self.eos_token_id);
144
151
  text_gen.set_tokens(prompt_tokens.clone());
145
152
 
@@ -165,10 +172,7 @@ impl Gemma {
165
172
 
166
173
  let logits = logits.to_dtype(DType::F32)?;
167
174
 
168
- let next_token = text_gen.sample_next_token(
169
- &logits,
170
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
171
- )?;
175
+ let next_token = text_gen.sample_next_token(&logits)?;
172
176
 
173
177
  all_tokens.push(next_token);
174
178
 
@@ -190,6 +194,18 @@ impl Gemma {
190
194
  break;
191
195
  }
192
196
 
197
+ // Check if constraint is satisfied (early stopping)
198
+ if config.stop_on_constraint_satisfaction {
199
+ let satisfied = if config.stop_on_match {
200
+ text_gen.is_constraint_satisfied_stop_on_match()
201
+ } else {
202
+ text_gen.is_constraint_satisfied()
203
+ };
204
+ if satisfied {
205
+ break;
206
+ }
207
+ }
208
+
193
209
  // Check stop sequences
194
210
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
195
211
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -27,6 +27,10 @@ pub struct GenerationConfig {
27
27
  pub debug_tokens: bool,
28
28
  /// Optional constraint index for structured generation
29
29
  pub constraint: Option<Arc<Index>>,
30
+ /// Stop immediately when constraint is satisfied
31
+ pub stop_on_constraint_satisfaction: bool,
32
+ /// Whether to stop immediately when pattern is matched (vs allowing continuation)
33
+ pub stop_on_match: bool,
30
34
  }
31
35
 
32
36
  /// Generate a random seed based on current time
@@ -51,6 +55,8 @@ impl Default for GenerationConfig {
51
55
  include_prompt: false,
52
56
  debug_tokens: false,
53
57
  constraint: None,
58
+ stop_on_constraint_satisfaction: true,
59
+ stop_on_match: true,
54
60
  }
55
61
  }
56
62
  }
@@ -18,6 +18,10 @@ pub struct Llama {
18
18
  }
19
19
 
20
20
  impl Llama {
21
+ pub fn eos_token_id(&self) -> u32 {
22
+ self.eos_token_id
23
+ }
24
+
21
25
  /// Clear the KV cache between generations
22
26
  pub fn clear_kv_cache(&mut self) {
23
27
  // Since Cache doesn't expose a reset method and kvs is private,
@@ -58,6 +62,9 @@ impl Llama {
58
62
  vec![consolidated_file]
59
63
  } else {
60
64
  // Try to find sharded model files
65
+ // NOTE: This uses a brute-force approach, trying common shard counts.
66
+ // A better approach would be to read model.safetensors.index.json which
67
+ // contains the exact file list, but this works for most models (≤30 shards).
61
68
  let mut sharded_files = Vec::new();
62
69
  let mut index = 1;
63
70
  loop {
@@ -175,7 +182,7 @@ impl Llama {
175
182
  config: &GenerationConfig,
176
183
  mut callback: Option<impl FnMut(&str)>,
177
184
  ) -> CandleResult<Vec<u32>> {
178
- let mut text_gen = TextGeneration::from_config(config);
185
+ let mut text_gen = TextGeneration::new(config);
179
186
  text_gen.set_eos_token_id(self.eos_token_id);
180
187
  text_gen.set_tokens(prompt_tokens.clone());
181
188
 
@@ -201,10 +208,7 @@ impl Llama {
201
208
 
202
209
  let logits = logits.to_dtype(DType::F32)?;
203
210
 
204
- let next_token = text_gen.sample_next_token(
205
- &logits,
206
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
207
- )?;
211
+ let next_token = text_gen.sample_next_token(&logits)?;
208
212
 
209
213
  all_tokens.push(next_token);
210
214
 
@@ -226,6 +230,18 @@ impl Llama {
226
230
  break;
227
231
  }
228
232
 
233
+ // Check if constraint is satisfied (early stopping)
234
+ if config.stop_on_constraint_satisfaction {
235
+ let satisfied = if config.stop_on_match {
236
+ text_gen.is_constraint_satisfied_stop_on_match()
237
+ } else {
238
+ text_gen.is_constraint_satisfied()
239
+ };
240
+ if satisfied {
241
+ break;
242
+ }
243
+ }
244
+
229
245
  // Check stop sequences
230
246
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
231
247
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -16,6 +16,10 @@ pub struct Mistral {
16
16
  }
17
17
 
18
18
  impl Mistral {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Clear the KV cache between generations
20
24
  pub fn clear_kv_cache(&mut self) {
21
25
  self.model.clear_kv_cache();
@@ -52,6 +56,9 @@ impl Mistral {
52
56
  vec![consolidated_file]
53
57
  } else {
54
58
  // Try to find sharded model files
59
+ // NOTE: This uses a brute-force approach, trying common shard counts.
60
+ // A better approach would be to read model.safetensors.index.json which
61
+ // contains the exact file list, but this works for most models (≤8 shards).
55
62
  let mut sharded_files = Vec::new();
56
63
  let mut index = 1;
57
64
  loop {
@@ -144,7 +151,7 @@ impl Mistral {
144
151
  config: &GenerationConfig,
145
152
  mut callback: Option<impl FnMut(&str)>,
146
153
  ) -> CandleResult<Vec<u32>> {
147
- let mut text_gen = TextGeneration::from_config(config);
154
+ let mut text_gen = TextGeneration::new(config);
148
155
  text_gen.set_eos_token_id(self.eos_token_id);
149
156
  text_gen.set_tokens(prompt_tokens.clone());
150
157
 
@@ -176,10 +183,7 @@ impl Mistral {
176
183
  // Convert to F32 for sampling if needed
177
184
  let logits = logits.to_dtype(DType::F32)?;
178
185
 
179
- let next_token = text_gen.sample_next_token(
180
- &logits,
181
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
182
- )?;
186
+ let next_token = text_gen.sample_next_token(&logits)?;
183
187
 
184
188
  all_tokens.push(next_token);
185
189
 
@@ -201,6 +205,18 @@ impl Mistral {
201
205
  break;
202
206
  }
203
207
 
208
+ // Check if constraint is satisfied (early stopping)
209
+ if config.stop_on_constraint_satisfaction {
210
+ let satisfied = if config.stop_on_match {
211
+ text_gen.is_constraint_satisfied_stop_on_match()
212
+ } else {
213
+ text_gen.is_constraint_satisfied()
214
+ };
215
+ if satisfied {
216
+ break;
217
+ }
218
+ }
219
+
204
220
  // Check stop sequences
205
221
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
206
222
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -21,6 +21,10 @@ enum PhiVariant {
21
21
  }
22
22
 
23
23
  impl Phi {
24
+ pub fn eos_token_id(&self) -> u32 {
25
+ self.eos_token_id
26
+ }
27
+
24
28
  /// Get the tokenizer
25
29
  pub fn tokenizer(&self) -> &TokenizerWrapper {
26
30
  &self.tokenizer
@@ -68,6 +72,9 @@ impl Phi {
68
72
  vec![single_file]
69
73
  } else {
70
74
  // Try to find sharded model files
75
+ // NOTE: This uses a brute-force approach, trying common shard counts.
76
+ // A better approach would be to read model.safetensors.index.json which
77
+ // contains the exact file list, but this works for most models (≤30 shards).
71
78
  let mut sharded_files = Vec::new();
72
79
  let mut index = 1;
73
80
  loop {
@@ -177,7 +184,7 @@ impl Phi {
177
184
  config: &GenerationConfig,
178
185
  mut callback: Option<impl FnMut(&str)>,
179
186
  ) -> CandleResult<Vec<u32>> {
180
- let mut text_gen = TextGeneration::from_config(config);
187
+ let mut text_gen = TextGeneration::new(config);
181
188
  text_gen.set_eos_token_id(self.eos_token_id);
182
189
  text_gen.set_tokens(prompt_tokens.clone());
183
190
 
@@ -206,10 +213,7 @@ impl Phi {
206
213
 
207
214
  let logits = logits.to_dtype(DType::F32)?;
208
215
 
209
- let next_token = text_gen.sample_next_token(
210
- &logits,
211
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
212
- )?;
216
+ let next_token = text_gen.sample_next_token(&logits)?;
213
217
 
214
218
  all_tokens.push(next_token);
215
219
 
@@ -229,6 +233,18 @@ impl Phi {
229
233
  break;
230
234
  }
231
235
 
236
+ // Check if constraint is satisfied (early stopping)
237
+ if config.stop_on_constraint_satisfaction {
238
+ let satisfied = if config.stop_on_match {
239
+ text_gen.is_constraint_satisfied_stop_on_match()
240
+ } else {
241
+ text_gen.is_constraint_satisfied()
242
+ };
243
+ if satisfied {
244
+ break;
245
+ }
246
+ }
247
+
232
248
  // Check stop sequences
233
249
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
234
250
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {