red-candle 1.1.0 → 1.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +65 -1
- data/Rakefile +40 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +199 -6
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +6 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/phi.rs +21 -5
- data/ext/candle/src/llm/quantized_gguf.rs +35 -6
- data/ext/candle/src/llm/qwen.rs +21 -5
- data/ext/candle/src/llm/text_generation.rs +121 -28
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/device.rs +2 -1
- data/ext/candle/src/ruby/dtype.rs +1 -0
- data/ext/candle/src/ruby/errors.rs +1 -0
- data/ext/candle/src/ruby/llm.rs +81 -55
- data/ext/candle/src/ruby/tensor.rs +2 -1
- data/ext/candle/src/tokenizer/mod.rs +2 -1
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/llm.rb +129 -34
- data/lib/candle/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: cd566594df4f0d3ec8ecd592b1d71610ef0ba2091cd91ea53549405a8c5c9b18
|
4
|
+
data.tar.gz: cfaab403935e927371fbd2f605ea60da3f44effd64950bb132651a5e6437e30c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: b18f384766db5a3de2f794764a8a47d3d69261ae4bf5c93992df7228fe9a5817eb1862b516a9ee3e63b5d349936f46534d50f1378b7932c7ac482550270a8bea
|
7
|
+
data.tar.gz: ce74834cde884bf03e3d163f6d081e87f95db0fba8db24300b8c6aa5b0f1f0162542c836bebb414010534778e394ca36faaec807ff1cc652902c3b318503def5
|
data/README.md
CHANGED
@@ -58,7 +58,8 @@ end
|
|
58
58
|
- **EmbeddingModel**: Generate embeddings for text
|
59
59
|
- **Reranker**: Rerank documents based on relevance
|
60
60
|
- **NER**: Named Entity Recognition directly from Ruby
|
61
|
-
- **LLM**: Chat with Large Language Models (e.g., Llama, Mistral, Gemma)
|
61
|
+
- **LLM**: Chat with Large Language Models (e.g., Llama, Mistral, Gemma, Qwen, Phi)
|
62
|
+
- **Structured Generation**: Generate JSON from a schema or match a regular expression
|
62
63
|
|
63
64
|
## Model Storage
|
64
65
|
|
@@ -887,6 +888,69 @@ bundle exec rake compile
|
|
887
888
|
|
888
889
|
Pull requests are welcome.
|
889
890
|
|
891
|
+
## Testing
|
892
|
+
|
893
|
+
Red Candle has comprehensive tests at both the Ruby and Rust levels:
|
894
|
+
|
895
|
+
### Ruby Tests
|
896
|
+
```bash
|
897
|
+
# Run all Ruby tests
|
898
|
+
bundle exec rake test
|
899
|
+
|
900
|
+
# Run specific test suites
|
901
|
+
bundle exec rake test:device # Device compatibility tests
|
902
|
+
bundle exec rake test:benchmark # Benchmark tests
|
903
|
+
bundle exec rake test:llm:mistral # Model-specific tests
|
904
|
+
```
|
905
|
+
|
906
|
+
### Rust Tests
|
907
|
+
```bash
|
908
|
+
# Run Rust unit and integration tests
|
909
|
+
cd ext/candle && cargo test
|
910
|
+
|
911
|
+
# Or use the Rake task
|
912
|
+
bundle exec rake rust:test
|
913
|
+
```
|
914
|
+
|
915
|
+
The Rust tests include:
|
916
|
+
- Unit tests within source files (using `#[cfg(test)]` modules)
|
917
|
+
- Integration tests for external dependencies (candle_core operations)
|
918
|
+
- Tests for structured generation, tokenization, and text generation
|
919
|
+
|
920
|
+
### Code Coverage
|
921
|
+
|
922
|
+
#### Rust Code Coverage
|
923
|
+
Red Candle uses `cargo-llvm-cov` for Rust code coverage analysis:
|
924
|
+
|
925
|
+
```bash
|
926
|
+
# Generate HTML coverage report (opens in target/llvm-cov/html/index.html)
|
927
|
+
bundle exec rake rust:coverage:html
|
928
|
+
|
929
|
+
# Show coverage summary in terminal
|
930
|
+
bundle exec rake rust:coverage:summary
|
931
|
+
|
932
|
+
# Generate detailed coverage report
|
933
|
+
bundle exec rake rust:coverage:report
|
934
|
+
|
935
|
+
# Generate LCOV format for CI integration
|
936
|
+
bundle exec rake rust:coverage:lcov
|
937
|
+
|
938
|
+
# Clean coverage data
|
939
|
+
bundle exec rake rust:coverage:clean
|
940
|
+
```
|
941
|
+
|
942
|
+
**Note**: Overall Rust coverage shows ~17% because most code consists of Ruby FFI bindings that are tested through Ruby tests. The testable Rust components have high coverage:
|
943
|
+
- Constrained generation: 99.59%
|
944
|
+
- Schema processing: 90.99%
|
945
|
+
- Integration tests: 97.12%
|
946
|
+
|
947
|
+
#### Ruby Code Coverage
|
948
|
+
Ruby test coverage is generated automatically when running tests:
|
949
|
+
```bash
|
950
|
+
bundle exec rake test
|
951
|
+
# Coverage report generated in coverage/index.html
|
952
|
+
```
|
953
|
+
|
890
954
|
## Release
|
891
955
|
|
892
956
|
1. Update version number in `lib/candle/version.rb` and commit.
|
data/Rakefile
CHANGED
@@ -134,3 +134,43 @@ namespace :doc do
|
|
134
134
|
end
|
135
135
|
|
136
136
|
task doc: "doc:default"
|
137
|
+
|
138
|
+
namespace :rust do
|
139
|
+
desc "Run Rust tests with code coverage"
|
140
|
+
namespace :coverage do
|
141
|
+
desc "Generate HTML coverage report"
|
142
|
+
task :html do
|
143
|
+
sh "cd ext/candle && cargo llvm-cov --html"
|
144
|
+
puts "Coverage report generated in target/llvm-cov/html/index.html"
|
145
|
+
end
|
146
|
+
|
147
|
+
desc "Generate coverage report in terminal"
|
148
|
+
task :report do
|
149
|
+
sh "cd ext/candle && cargo llvm-cov"
|
150
|
+
end
|
151
|
+
|
152
|
+
desc "Show coverage summary"
|
153
|
+
task :summary do
|
154
|
+
sh "cd ext/candle && cargo llvm-cov --summary-only"
|
155
|
+
end
|
156
|
+
|
157
|
+
desc "Generate lcov format coverage report"
|
158
|
+
task :lcov do
|
159
|
+
sh "cd ext/candle && cargo llvm-cov --lcov --output-path ../../coverage/lcov.info"
|
160
|
+
puts "LCOV report generated in coverage/lcov.info"
|
161
|
+
end
|
162
|
+
|
163
|
+
desc "Clean coverage data"
|
164
|
+
task :clean do
|
165
|
+
sh "cd ext/candle && cargo llvm-cov clean"
|
166
|
+
end
|
167
|
+
end
|
168
|
+
|
169
|
+
desc "Run Rust tests"
|
170
|
+
task :test do
|
171
|
+
sh "cd ext/candle && cargo test"
|
172
|
+
end
|
173
|
+
end
|
174
|
+
|
175
|
+
desc "Run Rust tests with coverage (alias)"
|
176
|
+
task "coverage:rust" => "rust:coverage:html"
|
@@ -44,8 +44,8 @@ mod constrained_generation_tests {
|
|
44
44
|
let config_without_constraint = GenerationConfig::default();
|
45
45
|
|
46
46
|
// Create text generation instances
|
47
|
-
let mut gen_constrained = TextGeneration::
|
48
|
-
let mut gen_unconstrained = TextGeneration::
|
47
|
+
let mut gen_constrained = TextGeneration::new(&config_with_constraint);
|
48
|
+
let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
|
49
49
|
|
50
50
|
// Set EOS token
|
51
51
|
gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
|
@@ -60,9 +60,9 @@ mod constrained_generation_tests {
|
|
60
60
|
fn test_constraint_configuration() {
|
61
61
|
// Test that we can create a TextGeneration with constraints
|
62
62
|
let config = GenerationConfig::default();
|
63
|
-
let _text_gen = TextGeneration::
|
63
|
+
let _text_gen = TextGeneration::new(&config);
|
64
64
|
|
65
|
-
// Test that we can create a TextGeneration
|
65
|
+
// Test that we can create a TextGeneration with config
|
66
66
|
// Constraints are private implementation details
|
67
67
|
}
|
68
68
|
|
@@ -78,7 +78,10 @@ mod constrained_generation_tests {
|
|
78
78
|
let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
79
79
|
|
80
80
|
// Create text generation with some tokens
|
81
|
-
let mut
|
81
|
+
let mut config = GenerationConfig::default();
|
82
|
+
config.seed = 42;
|
83
|
+
config.temperature = 1.0;
|
84
|
+
let mut text_gen = TextGeneration::new(&config);
|
82
85
|
text_gen.push_token(0); // Token that had logit 1.0
|
83
86
|
text_gen.push_token(2); // Token that had logit 2.0
|
84
87
|
text_gen.push_token(5); // Token that had logit 3.0
|
@@ -100,7 +103,10 @@ mod constrained_generation_tests {
|
|
100
103
|
|
101
104
|
#[test]
|
102
105
|
fn test_stop_conditions() {
|
103
|
-
let mut
|
106
|
+
let mut config = GenerationConfig::default();
|
107
|
+
config.seed = 42;
|
108
|
+
config.temperature = 1.0;
|
109
|
+
let mut text_gen = TextGeneration::new(&config);
|
104
110
|
text_gen.set_eos_token_id(50256); // Common EOS token
|
105
111
|
|
106
112
|
// Test max length stop
|
@@ -120,4 +126,191 @@ mod constrained_generation_tests {
|
|
120
126
|
assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
|
121
127
|
assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
|
122
128
|
}
|
129
|
+
|
130
|
+
#[test]
|
131
|
+
fn test_sample_next_token_uses_repetition_penalty() {
|
132
|
+
use candle_core::{Tensor, Device};
|
133
|
+
|
134
|
+
let device = Device::Cpu;
|
135
|
+
let vocab_size = 10;
|
136
|
+
|
137
|
+
// Create initial logits
|
138
|
+
let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
139
|
+
let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
140
|
+
|
141
|
+
// Test 1: Create TextGeneration with repetition penalty and add some tokens to history
|
142
|
+
let mut config_with_penalty = GenerationConfig::default();
|
143
|
+
config_with_penalty.seed = 42;
|
144
|
+
config_with_penalty.temperature = 0.1;
|
145
|
+
config_with_penalty.repetition_penalty = 1.5;
|
146
|
+
config_with_penalty.repetition_penalty_last_n = 10;
|
147
|
+
let mut text_gen = TextGeneration::new(&config_with_penalty);
|
148
|
+
text_gen.push_token(2); // Token with logit 3.0
|
149
|
+
text_gen.push_token(5); // Token with logit 6.0
|
150
|
+
text_gen.push_token(9); // Token with logit 10.0
|
151
|
+
|
152
|
+
// Sample with repetition penalty (now uses stored penalty)
|
153
|
+
let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
|
154
|
+
|
155
|
+
// Test 2: Same setup but without penalty
|
156
|
+
let mut config_no_penalty = GenerationConfig::default();
|
157
|
+
config_no_penalty.seed = 42;
|
158
|
+
config_no_penalty.temperature = 0.1;
|
159
|
+
config_no_penalty.repetition_penalty = 1.0; // No penalty
|
160
|
+
let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
|
161
|
+
text_gen_no_penalty.push_token(2);
|
162
|
+
text_gen_no_penalty.push_token(5);
|
163
|
+
text_gen_no_penalty.push_token(9);
|
164
|
+
|
165
|
+
let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
|
166
|
+
|
167
|
+
// With low temperature and penalty, should avoid previously used high-logit tokens
|
168
|
+
// Without penalty, should prefer high-logit tokens
|
169
|
+
// This is probabilistic, but with temp=0.1 it should be fairly deterministic
|
170
|
+
|
171
|
+
// Test 3: Verify penalty is applied correctly by checking modified logits
|
172
|
+
let mut config_verify = GenerationConfig::default();
|
173
|
+
config_verify.seed = 42;
|
174
|
+
config_verify.temperature = 0.1;
|
175
|
+
config_verify.repetition_penalty = 2.0;
|
176
|
+
config_verify.repetition_penalty_last_n = 10;
|
177
|
+
let mut text_gen_verify = TextGeneration::new(&config_verify);
|
178
|
+
text_gen_verify.push_token(9); // Highest logit token
|
179
|
+
|
180
|
+
// Clone logits to check modification
|
181
|
+
let mut logits_for_penalty = logits.clone();
|
182
|
+
text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
|
183
|
+
|
184
|
+
let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
|
185
|
+
assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
|
186
|
+
assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
|
187
|
+
}
|
188
|
+
|
189
|
+
#[test]
|
190
|
+
fn test_text_generation_from_config_parameters() {
|
191
|
+
|
192
|
+
// Create a config with specific values
|
193
|
+
let mut config = GenerationConfig::default();
|
194
|
+
config.seed = 12345;
|
195
|
+
config.temperature = 0.5;
|
196
|
+
config.top_p = Some(0.9);
|
197
|
+
config.top_k = Some(40); // Currently unused but should be accepted
|
198
|
+
config.repetition_penalty = 1.2;
|
199
|
+
config.repetition_penalty_last_n = 50;
|
200
|
+
|
201
|
+
// Create TextGeneration from config
|
202
|
+
let text_gen = TextGeneration::new(&config);
|
203
|
+
|
204
|
+
// We can't directly inspect private fields, but we can test behavior
|
205
|
+
// Test that it creates successfully (no panic)
|
206
|
+
assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
|
207
|
+
|
208
|
+
// Test with constraint
|
209
|
+
let config_with_constraint = GenerationConfig::default();
|
210
|
+
// In real usage, this would be a real constraint
|
211
|
+
// For testing, we just verify it accepts the config
|
212
|
+
let text_gen_constrained = TextGeneration::new(&config_with_constraint);
|
213
|
+
assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
|
214
|
+
}
|
215
|
+
|
216
|
+
#[test]
|
217
|
+
fn test_generation_with_different_penalties() {
|
218
|
+
use candle_core::{Tensor, Device, DType};
|
219
|
+
|
220
|
+
let device = Device::Cpu;
|
221
|
+
let vocab_size = 50;
|
222
|
+
|
223
|
+
// Create logits with clear preferences
|
224
|
+
let mut logits_vec = vec![0.0; vocab_size];
|
225
|
+
logits_vec[10] = 10.0; // Strong preference
|
226
|
+
logits_vec[20] = 8.0; // Second preference
|
227
|
+
logits_vec[30] = 6.0; // Third preference
|
228
|
+
|
229
|
+
// Test different penalty configurations
|
230
|
+
let configs = vec![
|
231
|
+
(1.0, 64), // No penalty (1.0 = neutral)
|
232
|
+
(1.5, 64), // Moderate penalty
|
233
|
+
(2.0, 64), // Strong penalty
|
234
|
+
(1.2, 10), // Penalty with limited range
|
235
|
+
];
|
236
|
+
|
237
|
+
for (penalty, last_n) in configs {
|
238
|
+
let mut config = GenerationConfig::default();
|
239
|
+
config.seed = 42; // Fixed seed for reproducibility
|
240
|
+
config.temperature = 0.1; // Low temperature for more deterministic behavior
|
241
|
+
config.repetition_penalty = penalty;
|
242
|
+
config.repetition_penalty_last_n = last_n;
|
243
|
+
|
244
|
+
let mut text_gen = TextGeneration::new(&config);
|
245
|
+
|
246
|
+
// Generate a sequence of tokens
|
247
|
+
let mut generated = Vec::new();
|
248
|
+
for _i in 0..5 {
|
249
|
+
let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
|
250
|
+
|
251
|
+
let token = text_gen.sample_next_token(&logits).unwrap();
|
252
|
+
|
253
|
+
generated.push(token);
|
254
|
+
|
255
|
+
// Verify the token is in valid range
|
256
|
+
assert!(token < vocab_size as u32, "Token should be within vocabulary");
|
257
|
+
}
|
258
|
+
|
259
|
+
// With higher penalties, we should see more diversity (less repetition)
|
260
|
+
let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
|
261
|
+
if penalty > 1.5 {
|
262
|
+
assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
|
263
|
+
}
|
264
|
+
}
|
265
|
+
}
|
266
|
+
|
267
|
+
#[test]
|
268
|
+
fn test_sample_next_token_integration() {
|
269
|
+
use candle_core::{Tensor, Device, DType};
|
270
|
+
|
271
|
+
let device = Device::Cpu;
|
272
|
+
|
273
|
+
// Test the full integration of sample_next_token
|
274
|
+
let mut config = GenerationConfig::default();
|
275
|
+
config.seed = 999;
|
276
|
+
config.temperature = 0.7;
|
277
|
+
config.max_length = 10;
|
278
|
+
config.repetition_penalty = 1.3;
|
279
|
+
config.repetition_penalty_last_n = 5;
|
280
|
+
|
281
|
+
let mut text_gen = TextGeneration::new(&config);
|
282
|
+
text_gen.set_eos_token_id(50256);
|
283
|
+
|
284
|
+
// Simulate a generation loop
|
285
|
+
let vocab_size = 100;
|
286
|
+
let mut all_tokens = Vec::new();
|
287
|
+
|
288
|
+
for step in 0..8 {
|
289
|
+
// Create varying logits to simulate model output
|
290
|
+
let mut logits_vec = vec![0.0; vocab_size];
|
291
|
+
// Make different tokens attractive at different steps
|
292
|
+
let preferred_token = (step * 13) % vocab_size;
|
293
|
+
logits_vec[preferred_token] = 5.0;
|
294
|
+
logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
|
295
|
+
logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
|
296
|
+
|
297
|
+
let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
|
298
|
+
|
299
|
+
let token = text_gen.sample_next_token(&logits).unwrap();
|
300
|
+
|
301
|
+
all_tokens.push(token);
|
302
|
+
|
303
|
+
// Check if we should stop
|
304
|
+
if text_gen.should_stop(token, config.max_length) {
|
305
|
+
break;
|
306
|
+
}
|
307
|
+
}
|
308
|
+
|
309
|
+
// Verify generation worked
|
310
|
+
assert!(!all_tokens.is_empty(), "Should generate some tokens");
|
311
|
+
assert!(all_tokens.len() <= config.max_length, "Should respect max length");
|
312
|
+
|
313
|
+
// Verify tokens are being tracked
|
314
|
+
assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
|
315
|
+
}
|
123
316
|
}
|
data/ext/candle/src/llm/gemma.rs
CHANGED
@@ -16,6 +16,10 @@ pub struct Gemma {
|
|
16
16
|
}
|
17
17
|
|
18
18
|
impl Gemma {
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
20
|
+
self.eos_token_id
|
21
|
+
}
|
22
|
+
|
19
23
|
/// Clear the KV cache between generations
|
20
24
|
pub fn clear_kv_cache(&mut self) {
|
21
25
|
self.model.clear_kv_cache();
|
@@ -49,6 +53,9 @@ impl Gemma {
|
|
49
53
|
vec![single_file]
|
50
54
|
} else {
|
51
55
|
// Try to find sharded model files
|
56
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
57
|
+
// A better approach would be to read model.safetensors.index.json which
|
58
|
+
// contains the exact file list, but this works for most models (≤12 shards).
|
52
59
|
let mut sharded_files = Vec::new();
|
53
60
|
let mut index = 1;
|
54
61
|
loop {
|
@@ -139,7 +146,7 @@ impl Gemma {
|
|
139
146
|
config: &GenerationConfig,
|
140
147
|
mut callback: Option<impl FnMut(&str)>,
|
141
148
|
) -> CandleResult<Vec<u32>> {
|
142
|
-
let mut text_gen = TextGeneration::
|
149
|
+
let mut text_gen = TextGeneration::new(config);
|
143
150
|
text_gen.set_eos_token_id(self.eos_token_id);
|
144
151
|
text_gen.set_tokens(prompt_tokens.clone());
|
145
152
|
|
@@ -165,10 +172,7 @@ impl Gemma {
|
|
165
172
|
|
166
173
|
let logits = logits.to_dtype(DType::F32)?;
|
167
174
|
|
168
|
-
let next_token = text_gen.sample_next_token(
|
169
|
-
&logits,
|
170
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
171
|
-
)?;
|
175
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
172
176
|
|
173
177
|
all_tokens.push(next_token);
|
174
178
|
|
@@ -190,6 +194,18 @@ impl Gemma {
|
|
190
194
|
break;
|
191
195
|
}
|
192
196
|
|
197
|
+
// Check if constraint is satisfied (early stopping)
|
198
|
+
if config.stop_on_constraint_satisfaction {
|
199
|
+
let satisfied = if config.stop_on_match {
|
200
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
201
|
+
} else {
|
202
|
+
text_gen.is_constraint_satisfied()
|
203
|
+
};
|
204
|
+
if satisfied {
|
205
|
+
break;
|
206
|
+
}
|
207
|
+
}
|
208
|
+
|
193
209
|
// Check stop sequences
|
194
210
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
195
211
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
@@ -27,6 +27,10 @@ pub struct GenerationConfig {
|
|
27
27
|
pub debug_tokens: bool,
|
28
28
|
/// Optional constraint index for structured generation
|
29
29
|
pub constraint: Option<Arc<Index>>,
|
30
|
+
/// Stop immediately when constraint is satisfied
|
31
|
+
pub stop_on_constraint_satisfaction: bool,
|
32
|
+
/// Whether to stop immediately when pattern is matched (vs allowing continuation)
|
33
|
+
pub stop_on_match: bool,
|
30
34
|
}
|
31
35
|
|
32
36
|
/// Generate a random seed based on current time
|
@@ -51,6 +55,8 @@ impl Default for GenerationConfig {
|
|
51
55
|
include_prompt: false,
|
52
56
|
debug_tokens: false,
|
53
57
|
constraint: None,
|
58
|
+
stop_on_constraint_satisfaction: true,
|
59
|
+
stop_on_match: true,
|
54
60
|
}
|
55
61
|
}
|
56
62
|
}
|
data/ext/candle/src/llm/llama.rs
CHANGED
@@ -18,6 +18,10 @@ pub struct Llama {
|
|
18
18
|
}
|
19
19
|
|
20
20
|
impl Llama {
|
21
|
+
pub fn eos_token_id(&self) -> u32 {
|
22
|
+
self.eos_token_id
|
23
|
+
}
|
24
|
+
|
21
25
|
/// Clear the KV cache between generations
|
22
26
|
pub fn clear_kv_cache(&mut self) {
|
23
27
|
// Since Cache doesn't expose a reset method and kvs is private,
|
@@ -58,6 +62,9 @@ impl Llama {
|
|
58
62
|
vec![consolidated_file]
|
59
63
|
} else {
|
60
64
|
// Try to find sharded model files
|
65
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
66
|
+
// A better approach would be to read model.safetensors.index.json which
|
67
|
+
// contains the exact file list, but this works for most models (≤30 shards).
|
61
68
|
let mut sharded_files = Vec::new();
|
62
69
|
let mut index = 1;
|
63
70
|
loop {
|
@@ -175,7 +182,7 @@ impl Llama {
|
|
175
182
|
config: &GenerationConfig,
|
176
183
|
mut callback: Option<impl FnMut(&str)>,
|
177
184
|
) -> CandleResult<Vec<u32>> {
|
178
|
-
let mut text_gen = TextGeneration::
|
185
|
+
let mut text_gen = TextGeneration::new(config);
|
179
186
|
text_gen.set_eos_token_id(self.eos_token_id);
|
180
187
|
text_gen.set_tokens(prompt_tokens.clone());
|
181
188
|
|
@@ -201,10 +208,7 @@ impl Llama {
|
|
201
208
|
|
202
209
|
let logits = logits.to_dtype(DType::F32)?;
|
203
210
|
|
204
|
-
let next_token = text_gen.sample_next_token(
|
205
|
-
&logits,
|
206
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
207
|
-
)?;
|
211
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
208
212
|
|
209
213
|
all_tokens.push(next_token);
|
210
214
|
|
@@ -226,6 +230,18 @@ impl Llama {
|
|
226
230
|
break;
|
227
231
|
}
|
228
232
|
|
233
|
+
// Check if constraint is satisfied (early stopping)
|
234
|
+
if config.stop_on_constraint_satisfaction {
|
235
|
+
let satisfied = if config.stop_on_match {
|
236
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
237
|
+
} else {
|
238
|
+
text_gen.is_constraint_satisfied()
|
239
|
+
};
|
240
|
+
if satisfied {
|
241
|
+
break;
|
242
|
+
}
|
243
|
+
}
|
244
|
+
|
229
245
|
// Check stop sequences
|
230
246
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
231
247
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
@@ -16,6 +16,10 @@ pub struct Mistral {
|
|
16
16
|
}
|
17
17
|
|
18
18
|
impl Mistral {
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
20
|
+
self.eos_token_id
|
21
|
+
}
|
22
|
+
|
19
23
|
/// Clear the KV cache between generations
|
20
24
|
pub fn clear_kv_cache(&mut self) {
|
21
25
|
self.model.clear_kv_cache();
|
@@ -52,6 +56,9 @@ impl Mistral {
|
|
52
56
|
vec![consolidated_file]
|
53
57
|
} else {
|
54
58
|
// Try to find sharded model files
|
59
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
60
|
+
// A better approach would be to read model.safetensors.index.json which
|
61
|
+
// contains the exact file list, but this works for most models (≤8 shards).
|
55
62
|
let mut sharded_files = Vec::new();
|
56
63
|
let mut index = 1;
|
57
64
|
loop {
|
@@ -144,7 +151,7 @@ impl Mistral {
|
|
144
151
|
config: &GenerationConfig,
|
145
152
|
mut callback: Option<impl FnMut(&str)>,
|
146
153
|
) -> CandleResult<Vec<u32>> {
|
147
|
-
let mut text_gen = TextGeneration::
|
154
|
+
let mut text_gen = TextGeneration::new(config);
|
148
155
|
text_gen.set_eos_token_id(self.eos_token_id);
|
149
156
|
text_gen.set_tokens(prompt_tokens.clone());
|
150
157
|
|
@@ -176,10 +183,7 @@ impl Mistral {
|
|
176
183
|
// Convert to F32 for sampling if needed
|
177
184
|
let logits = logits.to_dtype(DType::F32)?;
|
178
185
|
|
179
|
-
let next_token = text_gen.sample_next_token(
|
180
|
-
&logits,
|
181
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
182
|
-
)?;
|
186
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
183
187
|
|
184
188
|
all_tokens.push(next_token);
|
185
189
|
|
@@ -201,6 +205,18 @@ impl Mistral {
|
|
201
205
|
break;
|
202
206
|
}
|
203
207
|
|
208
|
+
// Check if constraint is satisfied (early stopping)
|
209
|
+
if config.stop_on_constraint_satisfaction {
|
210
|
+
let satisfied = if config.stop_on_match {
|
211
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
212
|
+
} else {
|
213
|
+
text_gen.is_constraint_satisfied()
|
214
|
+
};
|
215
|
+
if satisfied {
|
216
|
+
break;
|
217
|
+
}
|
218
|
+
}
|
219
|
+
|
204
220
|
// Check stop sequences
|
205
221
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
206
222
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
data/ext/candle/src/llm/phi.rs
CHANGED
@@ -21,6 +21,10 @@ enum PhiVariant {
|
|
21
21
|
}
|
22
22
|
|
23
23
|
impl Phi {
|
24
|
+
pub fn eos_token_id(&self) -> u32 {
|
25
|
+
self.eos_token_id
|
26
|
+
}
|
27
|
+
|
24
28
|
/// Get the tokenizer
|
25
29
|
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
26
30
|
&self.tokenizer
|
@@ -68,6 +72,9 @@ impl Phi {
|
|
68
72
|
vec![single_file]
|
69
73
|
} else {
|
70
74
|
// Try to find sharded model files
|
75
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
76
|
+
// A better approach would be to read model.safetensors.index.json which
|
77
|
+
// contains the exact file list, but this works for most models (≤30 shards).
|
71
78
|
let mut sharded_files = Vec::new();
|
72
79
|
let mut index = 1;
|
73
80
|
loop {
|
@@ -177,7 +184,7 @@ impl Phi {
|
|
177
184
|
config: &GenerationConfig,
|
178
185
|
mut callback: Option<impl FnMut(&str)>,
|
179
186
|
) -> CandleResult<Vec<u32>> {
|
180
|
-
let mut text_gen = TextGeneration::
|
187
|
+
let mut text_gen = TextGeneration::new(config);
|
181
188
|
text_gen.set_eos_token_id(self.eos_token_id);
|
182
189
|
text_gen.set_tokens(prompt_tokens.clone());
|
183
190
|
|
@@ -206,10 +213,7 @@ impl Phi {
|
|
206
213
|
|
207
214
|
let logits = logits.to_dtype(DType::F32)?;
|
208
215
|
|
209
|
-
let next_token = text_gen.sample_next_token(
|
210
|
-
&logits,
|
211
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
212
|
-
)?;
|
216
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
213
217
|
|
214
218
|
all_tokens.push(next_token);
|
215
219
|
|
@@ -229,6 +233,18 @@ impl Phi {
|
|
229
233
|
break;
|
230
234
|
}
|
231
235
|
|
236
|
+
// Check if constraint is satisfied (early stopping)
|
237
|
+
if config.stop_on_constraint_satisfaction {
|
238
|
+
let satisfied = if config.stop_on_match {
|
239
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
240
|
+
} else {
|
241
|
+
text_gen.is_constraint_satisfied()
|
242
|
+
};
|
243
|
+
if satisfied {
|
244
|
+
break;
|
245
|
+
}
|
246
|
+
}
|
247
|
+
|
232
248
|
// Check stop sequences
|
233
249
|
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
234
250
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|