red-candle 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,316 @@
1
+ #[cfg(test)]
2
+ mod constrained_generation_tests {
3
+ use super::super::*;
4
+ use crate::structured::{VocabularyAdapter, SchemaProcessor};
5
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
6
+
7
+ #[tokio::test]
8
+ async fn test_constrained_vs_unconstrained_generation() {
9
+ // This test demonstrates the difference between constrained and unconstrained generation
10
+
11
+ // Load a tokenizer for testing
12
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
13
+ let wrapper = TokenizerWrapper::new(tokenizer);
14
+
15
+ // Create vocabulary adapter
16
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
17
+ .expect("Should create vocabulary");
18
+
19
+ // Create schema processor
20
+ let processor = SchemaProcessor::new();
21
+
22
+ // Define a simple JSON schema for a yes/no response
23
+ let schema = r#"{
24
+ "type": "object",
25
+ "properties": {
26
+ "answer": {
27
+ "type": "string",
28
+ "enum": ["yes", "no"]
29
+ }
30
+ },
31
+ "required": ["answer"]
32
+ }"#;
33
+
34
+ // Process schema into Index
35
+ let index = processor.process_schema(schema, &vocabulary)
36
+ .expect("Should process schema");
37
+
38
+ // Test configuration with constraint
39
+ let mut config_with_constraint = GenerationConfig::default();
40
+ config_with_constraint.constraint = Some(index.clone());
41
+ config_with_constraint.max_length = 50;
42
+
43
+ // Test configuration without constraint
44
+ let config_without_constraint = GenerationConfig::default();
45
+
46
+ // Create text generation instances
47
+ let mut gen_constrained = TextGeneration::new(&config_with_constraint);
48
+ let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
49
+
50
+ // Set EOS token
51
+ gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
52
+ gen_unconstrained.set_eos_token_id(102);
53
+
54
+ // Constraints are set internally - we can't directly verify them
55
+ // but we can test their effects in actual generation
56
+ }
57
+ }
58
+
59
+ #[test]
60
+ fn test_constraint_configuration() {
61
+ // Test that we can create a TextGeneration with constraints
62
+ let config = GenerationConfig::default();
63
+ let _text_gen = TextGeneration::new(&config);
64
+
65
+ // Test that we can create a TextGeneration with config
66
+ // Constraints are private implementation details
67
+ }
68
+
69
+ #[test]
70
+ fn test_repetition_penalty() {
71
+ use candle_core::{Tensor, Device};
72
+
73
+ let device = Device::Cpu;
74
+ let vocab_size = 10;
75
+
76
+ // Create logits with some positive and negative values
77
+ let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
78
+ let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
79
+
80
+ // Create text generation with some tokens
81
+ let mut config = GenerationConfig::default();
82
+ config.seed = 42;
83
+ config.temperature = 1.0;
84
+ let mut text_gen = TextGeneration::new(&config);
85
+ text_gen.push_token(0); // Token that had logit 1.0
86
+ text_gen.push_token(2); // Token that had logit 2.0
87
+ text_gen.push_token(5); // Token that had logit 3.0
88
+
89
+ // Apply repetition penalty
90
+ text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
91
+
92
+ let penalized = logits.to_vec1::<f32>().unwrap();
93
+
94
+ // Check that tokens in context were penalized
95
+ assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
96
+ assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
97
+ assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
98
+
99
+ // Check that other tokens remain unchanged
100
+ assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
101
+ assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
102
+ }
103
+
104
+ #[test]
105
+ fn test_stop_conditions() {
106
+ let mut config = GenerationConfig::default();
107
+ config.seed = 42;
108
+ config.temperature = 1.0;
109
+ let mut text_gen = TextGeneration::new(&config);
110
+ text_gen.set_eos_token_id(50256); // Common EOS token
111
+
112
+ // Test max length stop
113
+ for i in 0..10 {
114
+ text_gen.push_token(i);
115
+ }
116
+ assert!(text_gen.should_stop(100, 10), "Should stop at max length");
117
+ assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
118
+
119
+ // Test EOS token stop
120
+ assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
121
+ assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
122
+
123
+ // Test stop sequences
124
+ let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
125
+ assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
126
+ assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
127
+ assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
128
+ }
129
+
130
+ #[test]
131
+ fn test_sample_next_token_uses_repetition_penalty() {
132
+ use candle_core::{Tensor, Device};
133
+
134
+ let device = Device::Cpu;
135
+ let vocab_size = 10;
136
+
137
+ // Create initial logits
138
+ let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
139
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
140
+
141
+ // Test 1: Create TextGeneration with repetition penalty and add some tokens to history
142
+ let mut config_with_penalty = GenerationConfig::default();
143
+ config_with_penalty.seed = 42;
144
+ config_with_penalty.temperature = 0.1;
145
+ config_with_penalty.repetition_penalty = 1.5;
146
+ config_with_penalty.repetition_penalty_last_n = 10;
147
+ let mut text_gen = TextGeneration::new(&config_with_penalty);
148
+ text_gen.push_token(2); // Token with logit 3.0
149
+ text_gen.push_token(5); // Token with logit 6.0
150
+ text_gen.push_token(9); // Token with logit 10.0
151
+
152
+ // Sample with repetition penalty (now uses stored penalty)
153
+ let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
154
+
155
+ // Test 2: Same setup but without penalty
156
+ let mut config_no_penalty = GenerationConfig::default();
157
+ config_no_penalty.seed = 42;
158
+ config_no_penalty.temperature = 0.1;
159
+ config_no_penalty.repetition_penalty = 1.0; // No penalty
160
+ let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
161
+ text_gen_no_penalty.push_token(2);
162
+ text_gen_no_penalty.push_token(5);
163
+ text_gen_no_penalty.push_token(9);
164
+
165
+ let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
166
+
167
+ // With low temperature and penalty, should avoid previously used high-logit tokens
168
+ // Without penalty, should prefer high-logit tokens
169
+ // This is probabilistic, but with temp=0.1 it should be fairly deterministic
170
+
171
+ // Test 3: Verify penalty is applied correctly by checking modified logits
172
+ let mut config_verify = GenerationConfig::default();
173
+ config_verify.seed = 42;
174
+ config_verify.temperature = 0.1;
175
+ config_verify.repetition_penalty = 2.0;
176
+ config_verify.repetition_penalty_last_n = 10;
177
+ let mut text_gen_verify = TextGeneration::new(&config_verify);
178
+ text_gen_verify.push_token(9); // Highest logit token
179
+
180
+ // Clone logits to check modification
181
+ let mut logits_for_penalty = logits.clone();
182
+ text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
183
+
184
+ let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
185
+ assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
186
+ assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
187
+ }
188
+
189
+ #[test]
190
+ fn test_text_generation_from_config_parameters() {
191
+
192
+ // Create a config with specific values
193
+ let mut config = GenerationConfig::default();
194
+ config.seed = 12345;
195
+ config.temperature = 0.5;
196
+ config.top_p = Some(0.9);
197
+ config.top_k = Some(40); // Currently unused but should be accepted
198
+ config.repetition_penalty = 1.2;
199
+ config.repetition_penalty_last_n = 50;
200
+
201
+ // Create TextGeneration from config
202
+ let text_gen = TextGeneration::new(&config);
203
+
204
+ // We can't directly inspect private fields, but we can test behavior
205
+ // Test that it creates successfully (no panic)
206
+ assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
207
+
208
+ // Test with constraint
209
+ let config_with_constraint = GenerationConfig::default();
210
+ // In real usage, this would be a real constraint
211
+ // For testing, we just verify it accepts the config
212
+ let text_gen_constrained = TextGeneration::new(&config_with_constraint);
213
+ assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
214
+ }
215
+
216
+ #[test]
217
+ fn test_generation_with_different_penalties() {
218
+ use candle_core::{Tensor, Device, DType};
219
+
220
+ let device = Device::Cpu;
221
+ let vocab_size = 50;
222
+
223
+ // Create logits with clear preferences
224
+ let mut logits_vec = vec![0.0; vocab_size];
225
+ logits_vec[10] = 10.0; // Strong preference
226
+ logits_vec[20] = 8.0; // Second preference
227
+ logits_vec[30] = 6.0; // Third preference
228
+
229
+ // Test different penalty configurations
230
+ let configs = vec![
231
+ (1.0, 64), // No penalty (1.0 = neutral)
232
+ (1.5, 64), // Moderate penalty
233
+ (2.0, 64), // Strong penalty
234
+ (1.2, 10), // Penalty with limited range
235
+ ];
236
+
237
+ for (penalty, last_n) in configs {
238
+ let mut config = GenerationConfig::default();
239
+ config.seed = 42; // Fixed seed for reproducibility
240
+ config.temperature = 0.1; // Low temperature for more deterministic behavior
241
+ config.repetition_penalty = penalty;
242
+ config.repetition_penalty_last_n = last_n;
243
+
244
+ let mut text_gen = TextGeneration::new(&config);
245
+
246
+ // Generate a sequence of tokens
247
+ let mut generated = Vec::new();
248
+ for _i in 0..5 {
249
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
250
+
251
+ let token = text_gen.sample_next_token(&logits).unwrap();
252
+
253
+ generated.push(token);
254
+
255
+ // Verify the token is in valid range
256
+ assert!(token < vocab_size as u32, "Token should be within vocabulary");
257
+ }
258
+
259
+ // With higher penalties, we should see more diversity (less repetition)
260
+ let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
261
+ if penalty > 1.5 {
262
+ assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
263
+ }
264
+ }
265
+ }
266
+
267
+ #[test]
268
+ fn test_sample_next_token_integration() {
269
+ use candle_core::{Tensor, Device, DType};
270
+
271
+ let device = Device::Cpu;
272
+
273
+ // Test the full integration of sample_next_token
274
+ let mut config = GenerationConfig::default();
275
+ config.seed = 999;
276
+ config.temperature = 0.7;
277
+ config.max_length = 10;
278
+ config.repetition_penalty = 1.3;
279
+ config.repetition_penalty_last_n = 5;
280
+
281
+ let mut text_gen = TextGeneration::new(&config);
282
+ text_gen.set_eos_token_id(50256);
283
+
284
+ // Simulate a generation loop
285
+ let vocab_size = 100;
286
+ let mut all_tokens = Vec::new();
287
+
288
+ for step in 0..8 {
289
+ // Create varying logits to simulate model output
290
+ let mut logits_vec = vec![0.0; vocab_size];
291
+ // Make different tokens attractive at different steps
292
+ let preferred_token = (step * 13) % vocab_size;
293
+ logits_vec[preferred_token] = 5.0;
294
+ logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
295
+ logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
296
+
297
+ let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
298
+
299
+ let token = text_gen.sample_next_token(&logits).unwrap();
300
+
301
+ all_tokens.push(token);
302
+
303
+ // Check if we should stop
304
+ if text_gen.should_stop(token, config.max_length) {
305
+ break;
306
+ }
307
+ }
308
+
309
+ // Verify generation worked
310
+ assert!(!all_tokens.is_empty(), "Should generate some tokens");
311
+ assert!(all_tokens.len() <= config.max_length, "Should respect max length");
312
+
313
+ // Verify tokens are being tracked
314
+ assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
315
+ }
316
+ }
@@ -16,6 +16,10 @@ pub struct Gemma {
16
16
  }
17
17
 
18
18
  impl Gemma {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Clear the KV cache between generations
20
24
  pub fn clear_kv_cache(&mut self) {
21
25
  self.model.clear_kv_cache();
@@ -49,6 +53,9 @@ impl Gemma {
49
53
  vec![single_file]
50
54
  } else {
51
55
  // Try to find sharded model files
56
+ // NOTE: This uses a brute-force approach, trying common shard counts.
57
+ // A better approach would be to read model.safetensors.index.json which
58
+ // contains the exact file list, but this works for most models (≤12 shards).
52
59
  let mut sharded_files = Vec::new();
53
60
  let mut index = 1;
54
61
  loop {
@@ -139,7 +146,7 @@ impl Gemma {
139
146
  config: &GenerationConfig,
140
147
  mut callback: Option<impl FnMut(&str)>,
141
148
  ) -> CandleResult<Vec<u32>> {
142
- let mut text_gen = TextGeneration::from_config(config);
149
+ let mut text_gen = TextGeneration::new(config);
143
150
  text_gen.set_eos_token_id(self.eos_token_id);
144
151
  text_gen.set_tokens(prompt_tokens.clone());
145
152
 
@@ -165,10 +172,7 @@ impl Gemma {
165
172
 
166
173
  let logits = logits.to_dtype(DType::F32)?;
167
174
 
168
- let next_token = text_gen.sample_next_token(
169
- &logits,
170
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
171
- )?;
175
+ let next_token = text_gen.sample_next_token(&logits)?;
172
176
 
173
177
  all_tokens.push(next_token);
174
178
 
@@ -190,6 +194,18 @@ impl Gemma {
190
194
  break;
191
195
  }
192
196
 
197
+ // Check if constraint is satisfied (early stopping)
198
+ if config.stop_on_constraint_satisfaction {
199
+ let satisfied = if config.stop_on_match {
200
+ text_gen.is_constraint_satisfied_stop_on_match()
201
+ } else {
202
+ text_gen.is_constraint_satisfied()
203
+ };
204
+ if satisfied {
205
+ break;
206
+ }
207
+ }
208
+
193
209
  // Check stop sequences
194
210
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
195
211
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -1,4 +1,6 @@
1
1
  use std::time::{SystemTime, UNIX_EPOCH};
2
+ use std::sync::Arc;
3
+ use crate::structured::Index;
2
4
 
3
5
  /// Configuration for text generation
4
6
  #[derive(Debug, Clone)]
@@ -23,6 +25,12 @@ pub struct GenerationConfig {
23
25
  pub include_prompt: bool,
24
26
  /// Whether to show raw tokens during generation (for debugging)
25
27
  pub debug_tokens: bool,
28
+ /// Optional constraint index for structured generation
29
+ pub constraint: Option<Arc<Index>>,
30
+ /// Stop immediately when constraint is satisfied
31
+ pub stop_on_constraint_satisfaction: bool,
32
+ /// Whether to stop immediately when pattern is matched (vs allowing continuation)
33
+ pub stop_on_match: bool,
26
34
  }
27
35
 
28
36
  /// Generate a random seed based on current time
@@ -46,6 +54,9 @@ impl Default for GenerationConfig {
46
54
  stop_sequences: vec![],
47
55
  include_prompt: false,
48
56
  debug_tokens: false,
57
+ constraint: None,
58
+ stop_on_constraint_satisfaction: true,
59
+ stop_on_match: true,
49
60
  }
50
61
  }
51
62
  }
@@ -18,6 +18,10 @@ pub struct Llama {
18
18
  }
19
19
 
20
20
  impl Llama {
21
+ pub fn eos_token_id(&self) -> u32 {
22
+ self.eos_token_id
23
+ }
24
+
21
25
  /// Clear the KV cache between generations
22
26
  pub fn clear_kv_cache(&mut self) {
23
27
  // Since Cache doesn't expose a reset method and kvs is private,
@@ -58,6 +62,9 @@ impl Llama {
58
62
  vec![consolidated_file]
59
63
  } else {
60
64
  // Try to find sharded model files
65
+ // NOTE: This uses a brute-force approach, trying common shard counts.
66
+ // A better approach would be to read model.safetensors.index.json which
67
+ // contains the exact file list, but this works for most models (≤30 shards).
61
68
  let mut sharded_files = Vec::new();
62
69
  let mut index = 1;
63
70
  loop {
@@ -175,7 +182,7 @@ impl Llama {
175
182
  config: &GenerationConfig,
176
183
  mut callback: Option<impl FnMut(&str)>,
177
184
  ) -> CandleResult<Vec<u32>> {
178
- let mut text_gen = TextGeneration::from_config(config);
185
+ let mut text_gen = TextGeneration::new(config);
179
186
  text_gen.set_eos_token_id(self.eos_token_id);
180
187
  text_gen.set_tokens(prompt_tokens.clone());
181
188
 
@@ -201,10 +208,7 @@ impl Llama {
201
208
 
202
209
  let logits = logits.to_dtype(DType::F32)?;
203
210
 
204
- let next_token = text_gen.sample_next_token(
205
- &logits,
206
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
207
- )?;
211
+ let next_token = text_gen.sample_next_token(&logits)?;
208
212
 
209
213
  all_tokens.push(next_token);
210
214
 
@@ -226,6 +230,18 @@ impl Llama {
226
230
  break;
227
231
  }
228
232
 
233
+ // Check if constraint is satisfied (early stopping)
234
+ if config.stop_on_constraint_satisfaction {
235
+ let satisfied = if config.stop_on_match {
236
+ text_gen.is_constraint_satisfied_stop_on_match()
237
+ } else {
238
+ text_gen.is_constraint_satisfied()
239
+ };
240
+ if satisfied {
241
+ break;
242
+ }
243
+ }
244
+
229
245
  // Check stop sequences
230
246
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
231
247
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -16,6 +16,10 @@ pub struct Mistral {
16
16
  }
17
17
 
18
18
  impl Mistral {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
19
23
  /// Clear the KV cache between generations
20
24
  pub fn clear_kv_cache(&mut self) {
21
25
  self.model.clear_kv_cache();
@@ -52,6 +56,9 @@ impl Mistral {
52
56
  vec![consolidated_file]
53
57
  } else {
54
58
  // Try to find sharded model files
59
+ // NOTE: This uses a brute-force approach, trying common shard counts.
60
+ // A better approach would be to read model.safetensors.index.json which
61
+ // contains the exact file list, but this works for most models (≤8 shards).
55
62
  let mut sharded_files = Vec::new();
56
63
  let mut index = 1;
57
64
  loop {
@@ -144,7 +151,7 @@ impl Mistral {
144
151
  config: &GenerationConfig,
145
152
  mut callback: Option<impl FnMut(&str)>,
146
153
  ) -> CandleResult<Vec<u32>> {
147
- let mut text_gen = TextGeneration::from_config(config);
154
+ let mut text_gen = TextGeneration::new(config);
148
155
  text_gen.set_eos_token_id(self.eos_token_id);
149
156
  text_gen.set_tokens(prompt_tokens.clone());
150
157
 
@@ -176,10 +183,7 @@ impl Mistral {
176
183
  // Convert to F32 for sampling if needed
177
184
  let logits = logits.to_dtype(DType::F32)?;
178
185
 
179
- let next_token = text_gen.sample_next_token(
180
- &logits,
181
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
182
- )?;
186
+ let next_token = text_gen.sample_next_token(&logits)?;
183
187
 
184
188
  all_tokens.push(next_token);
185
189
 
@@ -201,6 +205,18 @@ impl Mistral {
201
205
  break;
202
206
  }
203
207
 
208
+ // Check if constraint is satisfied (early stopping)
209
+ if config.stop_on_constraint_satisfaction {
210
+ let satisfied = if config.stop_on_match {
211
+ text_gen.is_constraint_satisfied_stop_on_match()
212
+ } else {
213
+ text_gen.is_constraint_satisfied()
214
+ };
215
+ if satisfied {
216
+ break;
217
+ }
218
+ }
219
+
204
220
  // Check stop sequences
205
221
  let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
206
222
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
@@ -3,6 +3,8 @@ use candle_core::{Device, Result as CandleResult};
3
3
  pub mod mistral;
4
4
  pub mod llama;
5
5
  pub mod gemma;
6
+ pub mod qwen;
7
+ pub mod phi;
6
8
  pub mod generation_config;
7
9
  pub mod text_generation;
8
10
  pub mod quantized_gguf;
@@ -12,6 +14,9 @@ pub use text_generation::TextGeneration;
12
14
  pub use quantized_gguf::QuantizedGGUF;
13
15
  pub use crate::tokenizer::TokenizerWrapper;
14
16
 
17
+ #[cfg(test)]
18
+ mod constrained_generation_test;
19
+
15
20
  /// Trait for text generation models
16
21
  pub trait TextGenerator: Send + Sync {
17
22
  /// Generate text from a prompt