red-candle 1.0.2 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +38 -3
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +316 -0
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +11 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +301 -0
- data/ext/candle/src/llm/quantized_gguf.rs +173 -9
- data/ext/candle/src/llm/qwen.rs +245 -0
- data/ext/candle/src/llm/text_generation.rs +183 -26
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +119 -55
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/llm.rb +203 -2
- data/lib/candle/version.rb +1 -1
- metadata +14 -4
@@ -0,0 +1,316 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod constrained_generation_tests {
|
3
|
+
use super::super::*;
|
4
|
+
use crate::structured::{VocabularyAdapter, SchemaProcessor};
|
5
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
6
|
+
|
7
|
+
#[tokio::test]
|
8
|
+
async fn test_constrained_vs_unconstrained_generation() {
|
9
|
+
// This test demonstrates the difference between constrained and unconstrained generation
|
10
|
+
|
11
|
+
// Load a tokenizer for testing
|
12
|
+
if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
14
|
+
|
15
|
+
// Create vocabulary adapter
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
17
|
+
.expect("Should create vocabulary");
|
18
|
+
|
19
|
+
// Create schema processor
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
|
22
|
+
// Define a simple JSON schema for a yes/no response
|
23
|
+
let schema = r#"{
|
24
|
+
"type": "object",
|
25
|
+
"properties": {
|
26
|
+
"answer": {
|
27
|
+
"type": "string",
|
28
|
+
"enum": ["yes", "no"]
|
29
|
+
}
|
30
|
+
},
|
31
|
+
"required": ["answer"]
|
32
|
+
}"#;
|
33
|
+
|
34
|
+
// Process schema into Index
|
35
|
+
let index = processor.process_schema(schema, &vocabulary)
|
36
|
+
.expect("Should process schema");
|
37
|
+
|
38
|
+
// Test configuration with constraint
|
39
|
+
let mut config_with_constraint = GenerationConfig::default();
|
40
|
+
config_with_constraint.constraint = Some(index.clone());
|
41
|
+
config_with_constraint.max_length = 50;
|
42
|
+
|
43
|
+
// Test configuration without constraint
|
44
|
+
let config_without_constraint = GenerationConfig::default();
|
45
|
+
|
46
|
+
// Create text generation instances
|
47
|
+
let mut gen_constrained = TextGeneration::new(&config_with_constraint);
|
48
|
+
let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
|
49
|
+
|
50
|
+
// Set EOS token
|
51
|
+
gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
|
52
|
+
gen_unconstrained.set_eos_token_id(102);
|
53
|
+
|
54
|
+
// Constraints are set internally - we can't directly verify them
|
55
|
+
// but we can test their effects in actual generation
|
56
|
+
}
|
57
|
+
}
|
58
|
+
|
59
|
+
#[test]
|
60
|
+
fn test_constraint_configuration() {
|
61
|
+
// Test that we can create a TextGeneration with constraints
|
62
|
+
let config = GenerationConfig::default();
|
63
|
+
let _text_gen = TextGeneration::new(&config);
|
64
|
+
|
65
|
+
// Test that we can create a TextGeneration with config
|
66
|
+
// Constraints are private implementation details
|
67
|
+
}
|
68
|
+
|
69
|
+
#[test]
|
70
|
+
fn test_repetition_penalty() {
|
71
|
+
use candle_core::{Tensor, Device};
|
72
|
+
|
73
|
+
let device = Device::Cpu;
|
74
|
+
let vocab_size = 10;
|
75
|
+
|
76
|
+
// Create logits with some positive and negative values
|
77
|
+
let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
|
78
|
+
let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
79
|
+
|
80
|
+
// Create text generation with some tokens
|
81
|
+
let mut config = GenerationConfig::default();
|
82
|
+
config.seed = 42;
|
83
|
+
config.temperature = 1.0;
|
84
|
+
let mut text_gen = TextGeneration::new(&config);
|
85
|
+
text_gen.push_token(0); // Token that had logit 1.0
|
86
|
+
text_gen.push_token(2); // Token that had logit 2.0
|
87
|
+
text_gen.push_token(5); // Token that had logit 3.0
|
88
|
+
|
89
|
+
// Apply repetition penalty
|
90
|
+
text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
|
91
|
+
|
92
|
+
let penalized = logits.to_vec1::<f32>().unwrap();
|
93
|
+
|
94
|
+
// Check that tokens in context were penalized
|
95
|
+
assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
|
96
|
+
assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
|
97
|
+
assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
|
98
|
+
|
99
|
+
// Check that other tokens remain unchanged
|
100
|
+
assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
|
101
|
+
assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
|
102
|
+
}
|
103
|
+
|
104
|
+
#[test]
|
105
|
+
fn test_stop_conditions() {
|
106
|
+
let mut config = GenerationConfig::default();
|
107
|
+
config.seed = 42;
|
108
|
+
config.temperature = 1.0;
|
109
|
+
let mut text_gen = TextGeneration::new(&config);
|
110
|
+
text_gen.set_eos_token_id(50256); // Common EOS token
|
111
|
+
|
112
|
+
// Test max length stop
|
113
|
+
for i in 0..10 {
|
114
|
+
text_gen.push_token(i);
|
115
|
+
}
|
116
|
+
assert!(text_gen.should_stop(100, 10), "Should stop at max length");
|
117
|
+
assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
|
118
|
+
|
119
|
+
// Test EOS token stop
|
120
|
+
assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
|
121
|
+
assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
|
122
|
+
|
123
|
+
// Test stop sequences
|
124
|
+
let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
|
125
|
+
assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
|
126
|
+
assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
|
127
|
+
assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
|
128
|
+
}
|
129
|
+
|
130
|
+
#[test]
|
131
|
+
fn test_sample_next_token_uses_repetition_penalty() {
|
132
|
+
use candle_core::{Tensor, Device};
|
133
|
+
|
134
|
+
let device = Device::Cpu;
|
135
|
+
let vocab_size = 10;
|
136
|
+
|
137
|
+
// Create initial logits
|
138
|
+
let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
139
|
+
let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
140
|
+
|
141
|
+
// Test 1: Create TextGeneration with repetition penalty and add some tokens to history
|
142
|
+
let mut config_with_penalty = GenerationConfig::default();
|
143
|
+
config_with_penalty.seed = 42;
|
144
|
+
config_with_penalty.temperature = 0.1;
|
145
|
+
config_with_penalty.repetition_penalty = 1.5;
|
146
|
+
config_with_penalty.repetition_penalty_last_n = 10;
|
147
|
+
let mut text_gen = TextGeneration::new(&config_with_penalty);
|
148
|
+
text_gen.push_token(2); // Token with logit 3.0
|
149
|
+
text_gen.push_token(5); // Token with logit 6.0
|
150
|
+
text_gen.push_token(9); // Token with logit 10.0
|
151
|
+
|
152
|
+
// Sample with repetition penalty (now uses stored penalty)
|
153
|
+
let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
|
154
|
+
|
155
|
+
// Test 2: Same setup but without penalty
|
156
|
+
let mut config_no_penalty = GenerationConfig::default();
|
157
|
+
config_no_penalty.seed = 42;
|
158
|
+
config_no_penalty.temperature = 0.1;
|
159
|
+
config_no_penalty.repetition_penalty = 1.0; // No penalty
|
160
|
+
let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
|
161
|
+
text_gen_no_penalty.push_token(2);
|
162
|
+
text_gen_no_penalty.push_token(5);
|
163
|
+
text_gen_no_penalty.push_token(9);
|
164
|
+
|
165
|
+
let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
|
166
|
+
|
167
|
+
// With low temperature and penalty, should avoid previously used high-logit tokens
|
168
|
+
// Without penalty, should prefer high-logit tokens
|
169
|
+
// This is probabilistic, but with temp=0.1 it should be fairly deterministic
|
170
|
+
|
171
|
+
// Test 3: Verify penalty is applied correctly by checking modified logits
|
172
|
+
let mut config_verify = GenerationConfig::default();
|
173
|
+
config_verify.seed = 42;
|
174
|
+
config_verify.temperature = 0.1;
|
175
|
+
config_verify.repetition_penalty = 2.0;
|
176
|
+
config_verify.repetition_penalty_last_n = 10;
|
177
|
+
let mut text_gen_verify = TextGeneration::new(&config_verify);
|
178
|
+
text_gen_verify.push_token(9); // Highest logit token
|
179
|
+
|
180
|
+
// Clone logits to check modification
|
181
|
+
let mut logits_for_penalty = logits.clone();
|
182
|
+
text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
|
183
|
+
|
184
|
+
let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
|
185
|
+
assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
|
186
|
+
assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
|
187
|
+
}
|
188
|
+
|
189
|
+
#[test]
|
190
|
+
fn test_text_generation_from_config_parameters() {
|
191
|
+
|
192
|
+
// Create a config with specific values
|
193
|
+
let mut config = GenerationConfig::default();
|
194
|
+
config.seed = 12345;
|
195
|
+
config.temperature = 0.5;
|
196
|
+
config.top_p = Some(0.9);
|
197
|
+
config.top_k = Some(40); // Currently unused but should be accepted
|
198
|
+
config.repetition_penalty = 1.2;
|
199
|
+
config.repetition_penalty_last_n = 50;
|
200
|
+
|
201
|
+
// Create TextGeneration from config
|
202
|
+
let text_gen = TextGeneration::new(&config);
|
203
|
+
|
204
|
+
// We can't directly inspect private fields, but we can test behavior
|
205
|
+
// Test that it creates successfully (no panic)
|
206
|
+
assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
|
207
|
+
|
208
|
+
// Test with constraint
|
209
|
+
let config_with_constraint = GenerationConfig::default();
|
210
|
+
// In real usage, this would be a real constraint
|
211
|
+
// For testing, we just verify it accepts the config
|
212
|
+
let text_gen_constrained = TextGeneration::new(&config_with_constraint);
|
213
|
+
assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
|
214
|
+
}
|
215
|
+
|
216
|
+
#[test]
|
217
|
+
fn test_generation_with_different_penalties() {
|
218
|
+
use candle_core::{Tensor, Device, DType};
|
219
|
+
|
220
|
+
let device = Device::Cpu;
|
221
|
+
let vocab_size = 50;
|
222
|
+
|
223
|
+
// Create logits with clear preferences
|
224
|
+
let mut logits_vec = vec![0.0; vocab_size];
|
225
|
+
logits_vec[10] = 10.0; // Strong preference
|
226
|
+
logits_vec[20] = 8.0; // Second preference
|
227
|
+
logits_vec[30] = 6.0; // Third preference
|
228
|
+
|
229
|
+
// Test different penalty configurations
|
230
|
+
let configs = vec![
|
231
|
+
(1.0, 64), // No penalty (1.0 = neutral)
|
232
|
+
(1.5, 64), // Moderate penalty
|
233
|
+
(2.0, 64), // Strong penalty
|
234
|
+
(1.2, 10), // Penalty with limited range
|
235
|
+
];
|
236
|
+
|
237
|
+
for (penalty, last_n) in configs {
|
238
|
+
let mut config = GenerationConfig::default();
|
239
|
+
config.seed = 42; // Fixed seed for reproducibility
|
240
|
+
config.temperature = 0.1; // Low temperature for more deterministic behavior
|
241
|
+
config.repetition_penalty = penalty;
|
242
|
+
config.repetition_penalty_last_n = last_n;
|
243
|
+
|
244
|
+
let mut text_gen = TextGeneration::new(&config);
|
245
|
+
|
246
|
+
// Generate a sequence of tokens
|
247
|
+
let mut generated = Vec::new();
|
248
|
+
for _i in 0..5 {
|
249
|
+
let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
|
250
|
+
|
251
|
+
let token = text_gen.sample_next_token(&logits).unwrap();
|
252
|
+
|
253
|
+
generated.push(token);
|
254
|
+
|
255
|
+
// Verify the token is in valid range
|
256
|
+
assert!(token < vocab_size as u32, "Token should be within vocabulary");
|
257
|
+
}
|
258
|
+
|
259
|
+
// With higher penalties, we should see more diversity (less repetition)
|
260
|
+
let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
|
261
|
+
if penalty > 1.5 {
|
262
|
+
assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
|
263
|
+
}
|
264
|
+
}
|
265
|
+
}
|
266
|
+
|
267
|
+
#[test]
|
268
|
+
fn test_sample_next_token_integration() {
|
269
|
+
use candle_core::{Tensor, Device, DType};
|
270
|
+
|
271
|
+
let device = Device::Cpu;
|
272
|
+
|
273
|
+
// Test the full integration of sample_next_token
|
274
|
+
let mut config = GenerationConfig::default();
|
275
|
+
config.seed = 999;
|
276
|
+
config.temperature = 0.7;
|
277
|
+
config.max_length = 10;
|
278
|
+
config.repetition_penalty = 1.3;
|
279
|
+
config.repetition_penalty_last_n = 5;
|
280
|
+
|
281
|
+
let mut text_gen = TextGeneration::new(&config);
|
282
|
+
text_gen.set_eos_token_id(50256);
|
283
|
+
|
284
|
+
// Simulate a generation loop
|
285
|
+
let vocab_size = 100;
|
286
|
+
let mut all_tokens = Vec::new();
|
287
|
+
|
288
|
+
for step in 0..8 {
|
289
|
+
// Create varying logits to simulate model output
|
290
|
+
let mut logits_vec = vec![0.0; vocab_size];
|
291
|
+
// Make different tokens attractive at different steps
|
292
|
+
let preferred_token = (step * 13) % vocab_size;
|
293
|
+
logits_vec[preferred_token] = 5.0;
|
294
|
+
logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
|
295
|
+
logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
|
296
|
+
|
297
|
+
let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
|
298
|
+
|
299
|
+
let token = text_gen.sample_next_token(&logits).unwrap();
|
300
|
+
|
301
|
+
all_tokens.push(token);
|
302
|
+
|
303
|
+
// Check if we should stop
|
304
|
+
if text_gen.should_stop(token, config.max_length) {
|
305
|
+
break;
|
306
|
+
}
|
307
|
+
}
|
308
|
+
|
309
|
+
// Verify generation worked
|
310
|
+
assert!(!all_tokens.is_empty(), "Should generate some tokens");
|
311
|
+
assert!(all_tokens.len() <= config.max_length, "Should respect max length");
|
312
|
+
|
313
|
+
// Verify tokens are being tracked
|
314
|
+
assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
|
315
|
+
}
|
316
|
+
}
|
data/ext/candle/src/llm/gemma.rs
CHANGED
@@ -16,6 +16,10 @@ pub struct Gemma {
|
|
16
16
|
}
|
17
17
|
|
18
18
|
impl Gemma {
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
20
|
+
self.eos_token_id
|
21
|
+
}
|
22
|
+
|
19
23
|
/// Clear the KV cache between generations
|
20
24
|
pub fn clear_kv_cache(&mut self) {
|
21
25
|
self.model.clear_kv_cache();
|
@@ -49,6 +53,9 @@ impl Gemma {
|
|
49
53
|
vec![single_file]
|
50
54
|
} else {
|
51
55
|
// Try to find sharded model files
|
56
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
57
|
+
// A better approach would be to read model.safetensors.index.json which
|
58
|
+
// contains the exact file list, but this works for most models (≤12 shards).
|
52
59
|
let mut sharded_files = Vec::new();
|
53
60
|
let mut index = 1;
|
54
61
|
loop {
|
@@ -139,7 +146,7 @@ impl Gemma {
|
|
139
146
|
config: &GenerationConfig,
|
140
147
|
mut callback: Option<impl FnMut(&str)>,
|
141
148
|
) -> CandleResult<Vec<u32>> {
|
142
|
-
let mut text_gen = TextGeneration::
|
149
|
+
let mut text_gen = TextGeneration::new(config);
|
143
150
|
text_gen.set_eos_token_id(self.eos_token_id);
|
144
151
|
text_gen.set_tokens(prompt_tokens.clone());
|
145
152
|
|
@@ -165,10 +172,7 @@ impl Gemma {
|
|
165
172
|
|
166
173
|
let logits = logits.to_dtype(DType::F32)?;
|
167
174
|
|
168
|
-
let next_token = text_gen.sample_next_token(
|
169
|
-
&logits,
|
170
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
171
|
-
)?;
|
175
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
172
176
|
|
173
177
|
all_tokens.push(next_token);
|
174
178
|
|
@@ -190,6 +194,18 @@ impl Gemma {
|
|
190
194
|
break;
|
191
195
|
}
|
192
196
|
|
197
|
+
// Check if constraint is satisfied (early stopping)
|
198
|
+
if config.stop_on_constraint_satisfaction {
|
199
|
+
let satisfied = if config.stop_on_match {
|
200
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
201
|
+
} else {
|
202
|
+
text_gen.is_constraint_satisfied()
|
203
|
+
};
|
204
|
+
if satisfied {
|
205
|
+
break;
|
206
|
+
}
|
207
|
+
}
|
208
|
+
|
193
209
|
// Check stop sequences
|
194
210
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
195
211
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
@@ -1,4 +1,6 @@
|
|
1
1
|
use std::time::{SystemTime, UNIX_EPOCH};
|
2
|
+
use std::sync::Arc;
|
3
|
+
use crate::structured::Index;
|
2
4
|
|
3
5
|
/// Configuration for text generation
|
4
6
|
#[derive(Debug, Clone)]
|
@@ -23,6 +25,12 @@ pub struct GenerationConfig {
|
|
23
25
|
pub include_prompt: bool,
|
24
26
|
/// Whether to show raw tokens during generation (for debugging)
|
25
27
|
pub debug_tokens: bool,
|
28
|
+
/// Optional constraint index for structured generation
|
29
|
+
pub constraint: Option<Arc<Index>>,
|
30
|
+
/// Stop immediately when constraint is satisfied
|
31
|
+
pub stop_on_constraint_satisfaction: bool,
|
32
|
+
/// Whether to stop immediately when pattern is matched (vs allowing continuation)
|
33
|
+
pub stop_on_match: bool,
|
26
34
|
}
|
27
35
|
|
28
36
|
/// Generate a random seed based on current time
|
@@ -46,6 +54,9 @@ impl Default for GenerationConfig {
|
|
46
54
|
stop_sequences: vec![],
|
47
55
|
include_prompt: false,
|
48
56
|
debug_tokens: false,
|
57
|
+
constraint: None,
|
58
|
+
stop_on_constraint_satisfaction: true,
|
59
|
+
stop_on_match: true,
|
49
60
|
}
|
50
61
|
}
|
51
62
|
}
|
data/ext/candle/src/llm/llama.rs
CHANGED
@@ -18,6 +18,10 @@ pub struct Llama {
|
|
18
18
|
}
|
19
19
|
|
20
20
|
impl Llama {
|
21
|
+
pub fn eos_token_id(&self) -> u32 {
|
22
|
+
self.eos_token_id
|
23
|
+
}
|
24
|
+
|
21
25
|
/// Clear the KV cache between generations
|
22
26
|
pub fn clear_kv_cache(&mut self) {
|
23
27
|
// Since Cache doesn't expose a reset method and kvs is private,
|
@@ -58,6 +62,9 @@ impl Llama {
|
|
58
62
|
vec![consolidated_file]
|
59
63
|
} else {
|
60
64
|
// Try to find sharded model files
|
65
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
66
|
+
// A better approach would be to read model.safetensors.index.json which
|
67
|
+
// contains the exact file list, but this works for most models (≤30 shards).
|
61
68
|
let mut sharded_files = Vec::new();
|
62
69
|
let mut index = 1;
|
63
70
|
loop {
|
@@ -175,7 +182,7 @@ impl Llama {
|
|
175
182
|
config: &GenerationConfig,
|
176
183
|
mut callback: Option<impl FnMut(&str)>,
|
177
184
|
) -> CandleResult<Vec<u32>> {
|
178
|
-
let mut text_gen = TextGeneration::
|
185
|
+
let mut text_gen = TextGeneration::new(config);
|
179
186
|
text_gen.set_eos_token_id(self.eos_token_id);
|
180
187
|
text_gen.set_tokens(prompt_tokens.clone());
|
181
188
|
|
@@ -201,10 +208,7 @@ impl Llama {
|
|
201
208
|
|
202
209
|
let logits = logits.to_dtype(DType::F32)?;
|
203
210
|
|
204
|
-
let next_token = text_gen.sample_next_token(
|
205
|
-
&logits,
|
206
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
207
|
-
)?;
|
211
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
208
212
|
|
209
213
|
all_tokens.push(next_token);
|
210
214
|
|
@@ -226,6 +230,18 @@ impl Llama {
|
|
226
230
|
break;
|
227
231
|
}
|
228
232
|
|
233
|
+
// Check if constraint is satisfied (early stopping)
|
234
|
+
if config.stop_on_constraint_satisfaction {
|
235
|
+
let satisfied = if config.stop_on_match {
|
236
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
237
|
+
} else {
|
238
|
+
text_gen.is_constraint_satisfied()
|
239
|
+
};
|
240
|
+
if satisfied {
|
241
|
+
break;
|
242
|
+
}
|
243
|
+
}
|
244
|
+
|
229
245
|
// Check stop sequences
|
230
246
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
231
247
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
@@ -16,6 +16,10 @@ pub struct Mistral {
|
|
16
16
|
}
|
17
17
|
|
18
18
|
impl Mistral {
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
20
|
+
self.eos_token_id
|
21
|
+
}
|
22
|
+
|
19
23
|
/// Clear the KV cache between generations
|
20
24
|
pub fn clear_kv_cache(&mut self) {
|
21
25
|
self.model.clear_kv_cache();
|
@@ -52,6 +56,9 @@ impl Mistral {
|
|
52
56
|
vec![consolidated_file]
|
53
57
|
} else {
|
54
58
|
// Try to find sharded model files
|
59
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
60
|
+
// A better approach would be to read model.safetensors.index.json which
|
61
|
+
// contains the exact file list, but this works for most models (≤8 shards).
|
55
62
|
let mut sharded_files = Vec::new();
|
56
63
|
let mut index = 1;
|
57
64
|
loop {
|
@@ -144,7 +151,7 @@ impl Mistral {
|
|
144
151
|
config: &GenerationConfig,
|
145
152
|
mut callback: Option<impl FnMut(&str)>,
|
146
153
|
) -> CandleResult<Vec<u32>> {
|
147
|
-
let mut text_gen = TextGeneration::
|
154
|
+
let mut text_gen = TextGeneration::new(config);
|
148
155
|
text_gen.set_eos_token_id(self.eos_token_id);
|
149
156
|
text_gen.set_tokens(prompt_tokens.clone());
|
150
157
|
|
@@ -176,10 +183,7 @@ impl Mistral {
|
|
176
183
|
// Convert to F32 for sampling if needed
|
177
184
|
let logits = logits.to_dtype(DType::F32)?;
|
178
185
|
|
179
|
-
let next_token = text_gen.sample_next_token(
|
180
|
-
&logits,
|
181
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
182
|
-
)?;
|
186
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
183
187
|
|
184
188
|
all_tokens.push(next_token);
|
185
189
|
|
@@ -201,6 +205,18 @@ impl Mistral {
|
|
201
205
|
break;
|
202
206
|
}
|
203
207
|
|
208
|
+
// Check if constraint is satisfied (early stopping)
|
209
|
+
if config.stop_on_constraint_satisfaction {
|
210
|
+
let satisfied = if config.stop_on_match {
|
211
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
212
|
+
} else {
|
213
|
+
text_gen.is_constraint_satisfied()
|
214
|
+
};
|
215
|
+
if satisfied {
|
216
|
+
break;
|
217
|
+
}
|
218
|
+
}
|
219
|
+
|
204
220
|
// Check stop sequences
|
205
221
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
206
222
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
data/ext/candle/src/llm/mod.rs
CHANGED
@@ -3,6 +3,8 @@ use candle_core::{Device, Result as CandleResult};
|
|
3
3
|
pub mod mistral;
|
4
4
|
pub mod llama;
|
5
5
|
pub mod gemma;
|
6
|
+
pub mod qwen;
|
7
|
+
pub mod phi;
|
6
8
|
pub mod generation_config;
|
7
9
|
pub mod text_generation;
|
8
10
|
pub mod quantized_gguf;
|
@@ -12,6 +14,9 @@ pub use text_generation::TextGeneration;
|
|
12
14
|
pub use quantized_gguf::QuantizedGGUF;
|
13
15
|
pub use crate::tokenizer::TokenizerWrapper;
|
14
16
|
|
17
|
+
#[cfg(test)]
|
18
|
+
mod constrained_generation_test;
|
19
|
+
|
15
20
|
/// Trait for text generation models
|
16
21
|
pub trait TextGenerator: Send + Sync {
|
17
22
|
/// Generate text from a prompt
|