red-candle 1.8.0.pre3-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,395 @@
1
+ #[cfg(test)]
2
+ mod constrained_generation_tests {
3
+ use super::super::*;
4
+ use crate::structured::{VocabularyAdapter, SchemaProcessor};
5
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
6
+
7
+ #[tokio::test]
8
+ async fn test_constrained_vs_unconstrained_generation() {
9
+ // This test demonstrates the difference between constrained and unconstrained generation
10
+
11
+ // Load a tokenizer for testing
12
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
13
+ let wrapper = TokenizerWrapper::new(tokenizer);
14
+
15
+ // Create vocabulary adapter
16
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
17
+ .expect("Should create vocabulary");
18
+
19
+ // Create schema processor
20
+ let processor = SchemaProcessor::new();
21
+
22
+ // Define a simple JSON schema for a yes/no response
23
+ let schema = r#"{
24
+ "type": "object",
25
+ "properties": {
26
+ "answer": {
27
+ "type": "string",
28
+ "enum": ["yes", "no"]
29
+ }
30
+ },
31
+ "required": ["answer"]
32
+ }"#;
33
+
34
+ // Process schema into Index
35
+ let index = processor.process_schema(schema, &vocabulary)
36
+ .expect("Should process schema");
37
+
38
+ // Test configuration with constraint
39
+ let mut config_with_constraint = GenerationConfig::default();
40
+ config_with_constraint.constraint = Some(index.clone());
41
+ config_with_constraint.max_length = 50;
42
+
43
+ // Test configuration without constraint
44
+ let config_without_constraint = GenerationConfig::default();
45
+
46
+ // Create text generation instances
47
+ let mut gen_constrained = TextGeneration::new(&config_with_constraint);
48
+ let mut gen_unconstrained = TextGeneration::new(&config_without_constraint);
49
+
50
+ // Set EOS token
51
+ gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
52
+ gen_unconstrained.set_eos_token_id(102);
53
+
54
+ // Constraints are set internally - we can't directly verify them
55
+ // but we can test their effects in actual generation
56
+ }
57
+ }
58
+
59
+ #[test]
60
+ fn test_constraint_configuration() {
61
+ // Test that we can create a TextGeneration with constraints
62
+ let config = GenerationConfig::default();
63
+ let _text_gen = TextGeneration::new(&config);
64
+
65
+ // Test that we can create a TextGeneration with config
66
+ // Constraints are private implementation details
67
+ }
68
+
69
+ #[test]
70
+ fn test_repetition_penalty() {
71
+ use candle_core::{Tensor, Device};
72
+
73
+ let device = Device::Cpu;
74
+ let vocab_size = 10;
75
+
76
+ // Create logits with some positive and negative values
77
+ let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
78
+ let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
79
+
80
+ // Create text generation with some tokens
81
+ let mut config = GenerationConfig::default();
82
+ config.seed = 42;
83
+ config.temperature = 1.0;
84
+ let mut text_gen = TextGeneration::new(&config);
85
+ text_gen.push_token(0); // Token that had logit 1.0
86
+ text_gen.push_token(2); // Token that had logit 2.0
87
+ text_gen.push_token(5); // Token that had logit 3.0
88
+
89
+ // Apply repetition penalty
90
+ text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
91
+
92
+ let penalized = logits.to_vec1::<f32>().unwrap();
93
+
94
+ // Check that tokens in context were penalized
95
+ assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
96
+ assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
97
+ assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
98
+
99
+ // Check that other tokens remain unchanged
100
+ assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
101
+ assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
102
+ }
103
+
104
+ #[test]
105
+ fn test_stop_conditions() {
106
+ let mut config = GenerationConfig::default();
107
+ config.seed = 42;
108
+ config.temperature = 1.0;
109
+ let mut text_gen = TextGeneration::new(&config);
110
+ text_gen.set_eos_token_id(50256); // Common EOS token
111
+
112
+ // Test max length stop
113
+ for i in 0..10 {
114
+ text_gen.push_token(i);
115
+ }
116
+ assert!(text_gen.should_stop(100, 10), "Should stop at max length");
117
+ assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
118
+
119
+ // Test EOS token stop
120
+ assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
121
+ assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
122
+
123
+ // Test stop sequences
124
+ let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
125
+ assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
126
+ assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
127
+ assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
128
+ }
129
+
130
+ #[test]
131
+ fn test_sample_next_token_uses_repetition_penalty() {
132
+ use candle_core::{Tensor, Device};
133
+
134
+ let device = Device::Cpu;
135
+ let vocab_size = 10;
136
+
137
+ // Create initial logits
138
+ let logits_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
139
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
140
+
141
+ // Test 1: Create TextGeneration with repetition penalty and add some tokens to history
142
+ let mut config_with_penalty = GenerationConfig::default();
143
+ config_with_penalty.seed = 42;
144
+ config_with_penalty.temperature = 0.1;
145
+ config_with_penalty.repetition_penalty = 1.5;
146
+ config_with_penalty.repetition_penalty_last_n = 10;
147
+ let mut text_gen = TextGeneration::new(&config_with_penalty);
148
+ text_gen.push_token(2); // Token with logit 3.0
149
+ text_gen.push_token(5); // Token with logit 6.0
150
+ text_gen.push_token(9); // Token with logit 10.0
151
+
152
+ // Sample with repetition penalty (now uses stored penalty)
153
+ let _token_with_penalty = text_gen.sample_next_token(&logits).unwrap();
154
+
155
+ // Test 2: Same setup but without penalty
156
+ let mut config_no_penalty = GenerationConfig::default();
157
+ config_no_penalty.seed = 42;
158
+ config_no_penalty.temperature = 0.1;
159
+ config_no_penalty.repetition_penalty = 1.0; // No penalty
160
+ let mut text_gen_no_penalty = TextGeneration::new(&config_no_penalty);
161
+ text_gen_no_penalty.push_token(2);
162
+ text_gen_no_penalty.push_token(5);
163
+ text_gen_no_penalty.push_token(9);
164
+
165
+ let _token_without_penalty = text_gen_no_penalty.sample_next_token(&logits).unwrap();
166
+
167
+ // With low temperature and penalty, should avoid previously used high-logit tokens
168
+ // Without penalty, should prefer high-logit tokens
169
+ // This is probabilistic, but with temp=0.1 it should be fairly deterministic
170
+
171
+ // Test 3: Verify penalty is applied correctly by checking modified logits
172
+ let mut config_verify = GenerationConfig::default();
173
+ config_verify.seed = 42;
174
+ config_verify.temperature = 0.1;
175
+ config_verify.repetition_penalty = 2.0;
176
+ config_verify.repetition_penalty_last_n = 10;
177
+ let mut text_gen_verify = TextGeneration::new(&config_verify);
178
+ text_gen_verify.push_token(9); // Highest logit token
179
+
180
+ // Clone logits to check modification
181
+ let mut logits_for_penalty = logits.clone();
182
+ text_gen_verify.apply_repetition_penalty(&mut logits_for_penalty, 2.0, 10).unwrap();
183
+
184
+ let penalized = logits_for_penalty.to_vec1::<f32>().unwrap();
185
+ assert!(penalized[9] < logits_vec[9], "Token 9 should be penalized");
186
+ assert_eq!(penalized[0], logits_vec[0], "Token 0 should not be penalized");
187
+ }
188
+
189
+ #[test]
190
+ fn test_text_generation_from_config_parameters() {
191
+
192
+ // Create a config with specific values
193
+ let mut config = GenerationConfig::default();
194
+ config.seed = 12345;
195
+ config.temperature = 0.5;
196
+ config.top_p = Some(0.9);
197
+ config.top_k = Some(40); // Currently unused but should be accepted
198
+ config.repetition_penalty = 1.2;
199
+ config.repetition_penalty_last_n = 50;
200
+
201
+ // Create TextGeneration from config
202
+ let text_gen = TextGeneration::new(&config);
203
+
204
+ // We can't directly inspect private fields, but we can test behavior
205
+ // Test that it creates successfully (no panic)
206
+ assert!(text_gen.get_tokens().is_empty(), "Should start with no tokens");
207
+
208
+ // Test with constraint
209
+ let config_with_constraint = GenerationConfig::default();
210
+ // In real usage, this would be a real constraint
211
+ // For testing, we just verify it accepts the config
212
+ let text_gen_constrained = TextGeneration::new(&config_with_constraint);
213
+ assert!(text_gen_constrained.get_tokens().is_empty(), "Should start with no tokens");
214
+ }
215
+
216
+ #[test]
217
+ fn test_generation_with_different_penalties() {
218
+ use candle_core::{Tensor, Device, DType};
219
+
220
+ let device = Device::Cpu;
221
+ let vocab_size = 50;
222
+
223
+ // Create logits with clear preferences
224
+ let mut logits_vec = vec![0.0; vocab_size];
225
+ logits_vec[10] = 10.0; // Strong preference
226
+ logits_vec[20] = 8.0; // Second preference
227
+ logits_vec[30] = 6.0; // Third preference
228
+
229
+ // Test different penalty configurations
230
+ let configs = vec![
231
+ (1.0, 64), // No penalty (1.0 = neutral)
232
+ (1.5, 64), // Moderate penalty
233
+ (2.0, 64), // Strong penalty
234
+ (1.2, 10), // Penalty with limited range
235
+ ];
236
+
237
+ for (penalty, last_n) in configs {
238
+ let mut config = GenerationConfig::default();
239
+ config.seed = 42; // Fixed seed for reproducibility
240
+ config.temperature = 0.1; // Low temperature for more deterministic behavior
241
+ config.repetition_penalty = penalty;
242
+ config.repetition_penalty_last_n = last_n;
243
+
244
+ let mut text_gen = TextGeneration::new(&config);
245
+
246
+ // Generate a sequence of tokens
247
+ let mut generated = Vec::new();
248
+ for _i in 0..5 {
249
+ let logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
250
+
251
+ let token = text_gen.sample_next_token(&logits).unwrap();
252
+
253
+ generated.push(token);
254
+
255
+ // Verify the token is in valid range
256
+ assert!(token < vocab_size as u32, "Token should be within vocabulary");
257
+ }
258
+
259
+ // With higher penalties, we should see more diversity (less repetition)
260
+ let unique_tokens = generated.iter().collect::<std::collections::HashSet<_>>().len();
261
+ if penalty > 1.5 {
262
+ assert!(unique_tokens >= 3, "High penalty should produce diverse tokens");
263
+ }
264
+ }
265
+ }
266
+
267
+ #[test]
268
+ fn test_sample_next_token_integration() {
269
+ use candle_core::{Tensor, Device, DType};
270
+
271
+ let device = Device::Cpu;
272
+
273
+ // Test the full integration of sample_next_token
274
+ let mut config = GenerationConfig::default();
275
+ config.seed = 999;
276
+ config.temperature = 0.7;
277
+ config.max_length = 10;
278
+ config.repetition_penalty = 1.3;
279
+ config.repetition_penalty_last_n = 5;
280
+
281
+ let mut text_gen = TextGeneration::new(&config);
282
+ text_gen.set_eos_token_id(50256);
283
+
284
+ // Simulate a generation loop
285
+ let vocab_size = 100;
286
+ let mut all_tokens = Vec::new();
287
+
288
+ for step in 0..8 {
289
+ // Create varying logits to simulate model output
290
+ let mut logits_vec = vec![0.0; vocab_size];
291
+ // Make different tokens attractive at different steps
292
+ let preferred_token = (step * 13) % vocab_size;
293
+ logits_vec[preferred_token] = 5.0;
294
+ logits_vec[(preferred_token + 10) % vocab_size] = 4.0;
295
+ logits_vec[(preferred_token + 20) % vocab_size] = 3.0;
296
+
297
+ let logits = Tensor::from_vec(logits_vec, vocab_size, &device).unwrap().to_dtype(DType::F32).unwrap();
298
+
299
+ let token = text_gen.sample_next_token(&logits).unwrap();
300
+
301
+ all_tokens.push(token);
302
+
303
+ // Check if we should stop
304
+ if text_gen.should_stop(token, config.max_length) {
305
+ break;
306
+ }
307
+ }
308
+
309
+ // Verify generation worked
310
+ assert!(!all_tokens.is_empty(), "Should generate some tokens");
311
+ assert!(all_tokens.len() <= config.max_length, "Should respect max length");
312
+
313
+ // Verify tokens are being tracked
314
+ assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
315
+ }
316
+
317
+ #[test]
318
+ fn test_constraint_satisfied_not_triggered_by_large_allowed_set() {
319
+ // This test verifies the fix for the bug where is_constraint_satisfied_stop_on_match
320
+ // would incorrectly return true when many tokens are allowed (e.g., inside a JSON string).
321
+ // The old buggy code had: if allowed.len() > 1000 { return true; }
322
+ // This caused early termination when inside strings with many valid characters.
323
+
324
+ let config = GenerationConfig::default();
325
+ let mut text_gen = TextGeneration::new(&config);
326
+ text_gen.set_eos_token_id(50256);
327
+
328
+ // Without a constraint, should not be satisfied
329
+ assert!(!text_gen.is_constraint_satisfied(),
330
+ "Without constraint, should not be satisfied");
331
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
332
+ "Without constraint, stop_on_match should not be satisfied");
333
+ }
334
+
335
+ #[test]
336
+ fn test_constraint_satisfied_only_when_empty_or_eos_only() {
337
+ // Test that constraint satisfaction only triggers when:
338
+ // 1. No tokens are allowed (empty set)
339
+ // 2. Only EOS token is allowed
340
+ // NOT when many tokens are allowed (like inside a JSON string)
341
+
342
+ let config = GenerationConfig::default();
343
+ let mut text_gen = TextGeneration::new(&config);
344
+ text_gen.set_eos_token_id(100); // Set EOS token
345
+
346
+ // Without constraint, should not be satisfied
347
+ assert!(!text_gen.is_constraint_satisfied());
348
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match());
349
+
350
+ // The key insight: constraint satisfaction should NOT be triggered
351
+ // just because there are many allowed tokens. It should only trigger
352
+ // when the constraint is definitively complete (empty allowed set or only EOS).
353
+ }
354
+
355
+ #[tokio::test]
356
+ async fn test_constraint_with_json_schema_not_early_termination() {
357
+ // Integration test: Create a real JSON schema constraint and verify
358
+ // that being inside a string (many allowed tokens) doesn't trigger completion.
359
+
360
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
361
+ let wrapper = TokenizerWrapper::new(tokenizer);
362
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
363
+ .expect("Should create vocabulary");
364
+
365
+ let processor = SchemaProcessor::new();
366
+
367
+ // Schema with a string field - when generating content inside the string,
368
+ // many characters are valid, but the constraint is NOT complete
369
+ let schema = r#"{
370
+ "type": "object",
371
+ "properties": {
372
+ "name": { "type": "string" }
373
+ },
374
+ "required": ["name"]
375
+ }"#;
376
+
377
+ let index = processor.process_schema(schema, &vocabulary)
378
+ .expect("Should process schema");
379
+
380
+ let mut config = GenerationConfig::default();
381
+ config.constraint = Some(index);
382
+ config.max_length = 100;
383
+
384
+ let mut text_gen = TextGeneration::new(&config);
385
+ text_gen.set_eos_token_id(102); // BERT's [SEP]
386
+
387
+ // At the initial state, the constraint should NOT be satisfied
388
+ // (we haven't generated a complete JSON object yet)
389
+ assert!(!text_gen.is_constraint_satisfied(),
390
+ "Initial state should not be satisfied - JSON not yet generated");
391
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
392
+ "Initial state should not trigger stop_on_match");
393
+ }
394
+ }
395
+ }