red-candle 1.8.0.pre2-x86_64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5193 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +33 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
#[cfg(test)]
|
|
2
|
+
mod constrained_generation_tests {
|
|
3
|
+
use super::super::*;
|
|
4
|
+
use crate::structured::{VocabularyAdapter, SchemaProcessor};
|
|
5
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
6
|
+
|
|
7
|
+
#[tokio::test]
|
|
8
|
+
async fn test_constrained_vs_unconstrained_generation() {
|
|
9
|
+
// This test demonstrates the difference between constrained and unconstrained generation
|
|
10
|
+
|
|
11
|
+
// Load a tokenizer for testing
|
|
12
|
+
if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
|
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
|
14
|
+
|
|
15
|
+
// Create vocabulary adapter
|
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
|
17
|
+
.expect("Should create vocabulary");
|
|
18
|
+
|
|
19
|
+
// Create schema processor
|
|
20
|
+
let processor = SchemaProcessor::new();
|
|
21
|
+
|
|
22
|
+
// Define a simple JSON schema for a yes/no response
|
|
23
|
+
let schema = r#"{
|
|
24
|
+
"type": "object",
|
|
25
|
+
"properties": {
|
|
26
|
+
"answer": {
|
|
27
|
+
"type": "string",
|
|
28
|
+
"enum": ["yes", "no"]
|
|
29
|
+
}
|
|
30
|
+
},
|
|
31
|
+
"required": ["answer"]
|
|
32
|
+
}"#;
|
|
33
|
+
|
|
34
|
+
// Process schema into Index
|
|
35
|
+
let index = processor.process_schema(schema, &vocabulary)
|
|
36
|
+
.expect("Should process schema");
|
|
37
|
+
|
|
38
|
+
// Test configuration with constraint
|
|
39
|
+
let mut config_with_constraint = GenerationConfig::default();
|
|
40
|
+
config_with_constraint.constraint = Some(index.clone());
|
|
41
|
+
config_with_constraint.max_length = 50;
|
|
42
|
+
|
|
43
|
+
// Test configuration without constraint
|
|
44
|
+
let config_without_constraint = GenerationConfig::default();
|
|
45
|
+
|
|
46
|
+
// Create text generation instances
|
|
47
|
+
let mut gen_constrained = TextGeneration::new(&config_with_constraint);
|
|
48
|
+
let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
|
|
49
|
+
|
|
50
|
+
// Set EOS token
|
|
51
|
+
gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
|
|
52
|
+
gen_unconstrained.set_eos_token_id(102);
|
|
53
|
+
|
|
54
|
+
// Constraints are set internally - we can't directly verify them
|
|
55
|
+
// but we can test their effects in actual generation
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
#[test]
|
|
60
|
+
fn test_constraint_configuration() {
|
|
61
|
+
// Test that we can create a TextGeneration with constraints
|
|
62
|
+
let config = GenerationConfig::default();
|
|
63
|
+
let _text_gen = TextGeneration::new(&config);
|
|
64
|
+
|
|
65
|
+
// Test that we can create a TextGeneration with config
|
|
66
|
+
// Constraints are private implementation details
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
#[test]
|
|
70
|
+
fn test_repetition_penalty() {
|
|
71
|
+
use candle_core::{Tensor, Device};
|
|
72
|
+
|
|
73
|
+
let device = Device::Cpu;
|
|
74
|
+
let vocab_size = 10;
|
|
75
|
+
|
|
76
|
+
// Create logits with some positive and negative values
|
|
77
|
+
let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
|
|
78
|
+
let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
|
79
|
+
|
|
80
|
+
// Create text generation with some tokens
|
|
81
|
+
let mut config = GenerationConfig::default();
|
|
82
|
+
config.seed = 42;
|
|
83
|
+
config.temperature = 1.0;
|
|
84
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
85
|
+
text_gen.push_token(0); // Token that had logit 1.0
|
|
86
|
+
text_gen.push_token(2); // Token that had logit 2.0
|
|
87
|
+
text_gen.push_token(5); // Token that had logit 3.0
|
|
88
|
+
|
|
89
|
+
// Apply repetition penalty
|
|
90
|
+
text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
|
|
91
|
+
|
|
92
|
+
let penalized = logits.to_vec1::<f32>().unwrap();
|
|
93
|
+
|
|
94
|
+
// Check that tokens in context were penalized
|
|
95
|
+
assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
|
|
96
|
+
assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
|
|
97
|
+
assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
|
|
98
|
+
|
|
99
|
+
// Check that other tokens remain unchanged
|
|
100
|
+
assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
|
|
101
|
+
assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
#[test]
|
|
105
|
+
fn test_stop_conditions() {
|
|
106
|
+
let mut config = GenerationConfig::default();
|
|
107
|
+
config.seed = 42;
|
|
108
|
+
config.temperature = 1.0;
|
|
109
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
110
|
+
text_gen.set_eos_token_id(50256); // Common EOS token
|
|
111
|
+
|
|
112
|
+
// Test max length stop
|
|
113
|
+
for i in 0..10 {
|
|
114
|
+
text_gen.push_token(i);
|
|
115
|
+
}
|
|
116
|
+
assert!(text_gen.should_stop(100, 10), "Should stop at max length");
|
|
117
|
+
assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
|
|
118
|
+
|
|
119
|
+
// Test EOS token stop
|
|
120
|
+
assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
|
|
121
|
+
assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
|
|
122
|
+
|
|
123
|
+
// Test stop sequences
|
|
124
|
+
let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
|
|
125
|
+
assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
|
|
126
|
+
assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
|
|
127
|
+
assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
#[test]
|
|
131
|
+
fn test_sample_next_token_uses_repetition_penalty() {
|
|
132
|
+
use candle_core::{Tensor, Device};
|
|
133
|
+
|
|
134
|
+
let device = Device::Cpu;
|
|
135
|
+
let vocab_size = 10;
|
|
136
|
+
|
|
137
|
+
// Create initial logits
|
|
138
|
+
let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
139
|
+
let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
|
140
|
+
|
|
141
|
+
// Test 1: Create TextGeneration with repetition penalty and add some tokens to history
|
|
142
|
+
let mut config_with_penalty = GenerationConfig::default();
|
|
143
|
+
config_with_penalty.seed = 42;
|
|
144
|
+
config_with_penalty.temperature = 0.1;
|
|
145
|
+
config_with_penalty.repetition_penalty = 1.5;
|
|
146
|
+
config_with_penalty.repetition_penalty_last_n = 10;
|
|
147
|
+
let mut text_gen = TextGeneration::new(&config_with_penalty);
|
|
148
|
+
text_gen.push_token(2); // Token with logit 3.0
|
|
149
|
+
text_gen.push_token(5); // Token with logit 6.0
|
|
150
|
+
text_gen.push_token(9); // Token with logit 10.0
|
|
151
|
+
|
|
152
|
+
// Sample with repetition penalty (now uses stored penalty)
|
|
153
|
+
let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
|
|
154
|
+
|
|
155
|
+
// Test 2: Same setup but without penalty
|
|
156
|
+
let mut config_no_penalty = GenerationConfig::default();
|
|
157
|
+
config_no_penalty.seed = 42;
|
|
158
|
+
config_no_penalty.temperature = 0.1;
|
|
159
|
+
config_no_penalty.repetition_penalty = 1.0; // No penalty
|
|
160
|
+
let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
|
|
161
|
+
text_gen_no_penalty.push_token(2);
|
|
162
|
+
text_gen_no_penalty.push_token(5);
|
|
163
|
+
text_gen_no_penalty.push_token(9);
|
|
164
|
+
|
|
165
|
+
let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
|
|
166
|
+
|
|
167
|
+
// With low temperature and penalty, should avoid previously used high-logit tokens
|
|
168
|
+
// Without penalty, should prefer high-logit tokens
|
|
169
|
+
// This is probabilistic, but with temp=0.1 it should be fairly deterministic
|
|
170
|
+
|
|
171
|
+
// Test 3: Verify penalty is applied correctly by checking modified logits
|
|
172
|
+
let mut config_verify = GenerationConfig::default();
|
|
173
|
+
config_verify.seed = 42;
|
|
174
|
+
config_verify.temperature = 0.1;
|
|
175
|
+
config_verify.repetition_penalty = 2.0;
|
|
176
|
+
config_verify.repetition_penalty_last_n = 10;
|
|
177
|
+
let mut text_gen_verify = TextGeneration::new(&config_verify);
|
|
178
|
+
text_gen_verify.push_token(9); // Highest logit token
|
|
179
|
+
|
|
180
|
+
// Clone logits to check modification
|
|
181
|
+
let mut logits_for_penalty = logits.clone();
|
|
182
|
+
text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
|
|
183
|
+
|
|
184
|
+
let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
|
|
185
|
+
assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
|
|
186
|
+
assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
#[test]
|
|
190
|
+
fn test_text_generation_from_config_parameters() {
|
|
191
|
+
|
|
192
|
+
// Create a config with specific values
|
|
193
|
+
let mut config = GenerationConfig::default();
|
|
194
|
+
config.seed = 12345;
|
|
195
|
+
config.temperature = 0.5;
|
|
196
|
+
config.top_p = Some(0.9);
|
|
197
|
+
config.top_k = Some(40); // Currently unused but should be accepted
|
|
198
|
+
config.repetition_penalty = 1.2;
|
|
199
|
+
config.repetition_penalty_last_n = 50;
|
|
200
|
+
|
|
201
|
+
// Create TextGeneration from config
|
|
202
|
+
let text_gen = TextGeneration::new(&config);
|
|
203
|
+
|
|
204
|
+
// We can't directly inspect private fields, but we can test behavior
|
|
205
|
+
// Test that it creates successfully (no panic)
|
|
206
|
+
assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
|
|
207
|
+
|
|
208
|
+
// Test with constraint
|
|
209
|
+
let config_with_constraint = GenerationConfig::default();
|
|
210
|
+
// In real usage, this would be a real constraint
|
|
211
|
+
// For testing, we just verify it accepts the config
|
|
212
|
+
let text_gen_constrained = TextGeneration::new(&config_with_constraint);
|
|
213
|
+
assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
#[test]
|
|
217
|
+
fn test_generation_with_different_penalties() {
|
|
218
|
+
use candle_core::{Tensor, Device, DType};
|
|
219
|
+
|
|
220
|
+
let device = Device::Cpu;
|
|
221
|
+
let vocab_size = 50;
|
|
222
|
+
|
|
223
|
+
// Create logits with clear preferences
|
|
224
|
+
let mut logits_vec = vec![0.0; vocab_size];
|
|
225
|
+
logits_vec[10] = 10.0; // Strong preference
|
|
226
|
+
logits_vec[20] = 8.0; // Second preference
|
|
227
|
+
logits_vec[30] = 6.0; // Third preference
|
|
228
|
+
|
|
229
|
+
// Test different penalty configurations
|
|
230
|
+
let configs = vec![
|
|
231
|
+
(1.0, 64), // No penalty (1.0 = neutral)
|
|
232
|
+
(1.5, 64), // Moderate penalty
|
|
233
|
+
(2.0, 64), // Strong penalty
|
|
234
|
+
(1.2, 10), // Penalty with limited range
|
|
235
|
+
];
|
|
236
|
+
|
|
237
|
+
for (penalty, last_n) in configs {
|
|
238
|
+
let mut config = GenerationConfig::default();
|
|
239
|
+
config.seed = 42; // Fixed seed for reproducibility
|
|
240
|
+
config.temperature = 0.1; // Low temperature for more deterministic behavior
|
|
241
|
+
config.repetition_penalty = penalty;
|
|
242
|
+
config.repetition_penalty_last_n = last_n;
|
|
243
|
+
|
|
244
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
245
|
+
|
|
246
|
+
// Generate a sequence of tokens
|
|
247
|
+
let mut generated = Vec::new();
|
|
248
|
+
for _i in 0..5 {
|
|
249
|
+
let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
|
|
250
|
+
|
|
251
|
+
let token = text_gen.sample_next_token(&logits).unwrap();
|
|
252
|
+
|
|
253
|
+
generated.push(token);
|
|
254
|
+
|
|
255
|
+
// Verify the token is in valid range
|
|
256
|
+
assert!(token < vocab_size as u32, "Token should be within vocabulary");
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// With higher penalties, we should see more diversity (less repetition)
|
|
260
|
+
let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
|
|
261
|
+
if penalty > 1.5 {
|
|
262
|
+
assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
#[test]
|
|
268
|
+
fn test_sample_next_token_integration() {
|
|
269
|
+
use candle_core::{Tensor, Device, DType};
|
|
270
|
+
|
|
271
|
+
let device = Device::Cpu;
|
|
272
|
+
|
|
273
|
+
// Test the full integration of sample_next_token
|
|
274
|
+
let mut config = GenerationConfig::default();
|
|
275
|
+
config.seed = 999;
|
|
276
|
+
config.temperature = 0.7;
|
|
277
|
+
config.max_length = 10;
|
|
278
|
+
config.repetition_penalty = 1.3;
|
|
279
|
+
config.repetition_penalty_last_n = 5;
|
|
280
|
+
|
|
281
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
282
|
+
text_gen.set_eos_token_id(50256);
|
|
283
|
+
|
|
284
|
+
// Simulate a generation loop
|
|
285
|
+
let vocab_size = 100;
|
|
286
|
+
let mut all_tokens = Vec::new();
|
|
287
|
+
|
|
288
|
+
for step in 0..8 {
|
|
289
|
+
// Create varying logits to simulate model output
|
|
290
|
+
let mut logits_vec = vec![0.0; vocab_size];
|
|
291
|
+
// Make different tokens attractive at different steps
|
|
292
|
+
let preferred_token = (step * 13) % vocab_size;
|
|
293
|
+
logits_vec[preferred_token] = 5.0;
|
|
294
|
+
logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
|
|
295
|
+
logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
|
|
296
|
+
|
|
297
|
+
let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
|
|
298
|
+
|
|
299
|
+
let token = text_gen.sample_next_token(&logits).unwrap();
|
|
300
|
+
|
|
301
|
+
all_tokens.push(token);
|
|
302
|
+
|
|
303
|
+
// Check if we should stop
|
|
304
|
+
if text_gen.should_stop(token, config.max_length) {
|
|
305
|
+
break;
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
// Verify generation worked
|
|
310
|
+
assert!(!all_tokens.is_empty(), "Should generate some tokens");
|
|
311
|
+
assert!(all_tokens.len() <= config.max_length, "Should respect max length");
|
|
312
|
+
|
|
313
|
+
// Verify tokens are being tracked
|
|
314
|
+
assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
#[test]
|
|
318
|
+
fn test_constraint_satisfied_not_triggered_by_large_allowed_set() {
|
|
319
|
+
// This test verifies the fix for the bug where is_constraint_satisfied_stop_on_match
|
|
320
|
+
// would incorrectly return true when many tokens are allowed (e.g., inside a JSON string).
|
|
321
|
+
// The old buggy code had: if allowed.len() > 1000 { return true; }
|
|
322
|
+
// This caused early termination when inside strings with many valid characters.
|
|
323
|
+
|
|
324
|
+
let config = GenerationConfig::default();
|
|
325
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
326
|
+
text_gen.set_eos_token_id(50256);
|
|
327
|
+
|
|
328
|
+
// Without a constraint, should not be satisfied
|
|
329
|
+
assert!(!text_gen.is_constraint_satisfied(),
|
|
330
|
+
"Without constraint, should not be satisfied");
|
|
331
|
+
assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
|
|
332
|
+
"Without constraint, stop_on_match should not be satisfied");
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
#[test]
|
|
336
|
+
fn test_constraint_satisfied_only_when_empty_or_eos_only() {
|
|
337
|
+
// Test that constraint satisfaction only triggers when:
|
|
338
|
+
// 1. No tokens are allowed (empty set)
|
|
339
|
+
// 2. Only EOS token is allowed
|
|
340
|
+
// NOT when many tokens are allowed (like inside a JSON string)
|
|
341
|
+
|
|
342
|
+
let config = GenerationConfig::default();
|
|
343
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
344
|
+
text_gen.set_eos_token_id(100); // Set EOS token
|
|
345
|
+
|
|
346
|
+
// Without constraint, should not be satisfied
|
|
347
|
+
assert!(!text_gen.is_constraint_satisfied());
|
|
348
|
+
assert!(!text_gen.is_constraint_satisfied_stop_on_match());
|
|
349
|
+
|
|
350
|
+
// The key insight: constraint satisfaction should NOT be triggered
|
|
351
|
+
// just because there are many allowed tokens. It should only trigger
|
|
352
|
+
// when the constraint is definitively complete (empty allowed set or only EOS).
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
#[tokio::test]
|
|
356
|
+
async fn test_constraint_with_json_schema_not_early_termination() {
|
|
357
|
+
// Integration test: Create a real JSON schema constraint and verify
|
|
358
|
+
// that being inside a string (many allowed tokens) doesn't trigger completion.
|
|
359
|
+
|
|
360
|
+
if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
|
|
361
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
|
362
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
|
363
|
+
.expect("Should create vocabulary");
|
|
364
|
+
|
|
365
|
+
let processor = SchemaProcessor::new();
|
|
366
|
+
|
|
367
|
+
// Schema with a string field - when generating content inside the string,
|
|
368
|
+
// many characters are valid, but the constraint is NOT complete
|
|
369
|
+
let schema = r#"{
|
|
370
|
+
"type": "object",
|
|
371
|
+
"properties": {
|
|
372
|
+
"name": { "type": "string" }
|
|
373
|
+
},
|
|
374
|
+
"required": ["name"]
|
|
375
|
+
}"#;
|
|
376
|
+
|
|
377
|
+
let index = processor.process_schema(schema, &vocabulary)
|
|
378
|
+
.expect("Should process schema");
|
|
379
|
+
|
|
380
|
+
let mut config = GenerationConfig::default();
|
|
381
|
+
config.constraint = Some(index);
|
|
382
|
+
config.max_length = 100;
|
|
383
|
+
|
|
384
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
385
|
+
text_gen.set_eos_token_id(102); // BERT's [SEP]
|
|
386
|
+
|
|
387
|
+
// At the initial state, the constraint should NOT be satisfied
|
|
388
|
+
// (we haven't generated a complete JSON object yet)
|
|
389
|
+
assert!(!text_gen.is_constraint_satisfied(),
|
|
390
|
+
"Initial state should not be satisfied - JSON not yet generated");
|
|
391
|
+
assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
|
|
392
|
+
"Initial state should not trigger stop_on_match");
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
}
|