red-candle 1.0.1 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,123 @@
1
+ #[cfg(test)]
2
+ mod constrained_generation_tests {
3
+ use super::super::*;
4
+ use crate::structured::{VocabularyAdapter, SchemaProcessor};
5
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
6
+
7
+ #[tokio::test]
8
+ async fn test_constrained_vs_unconstrained_generation() {
9
+ // This test demonstrates the difference between constrained and unconstrained generation
10
+
11
+ // Load a tokenizer for testing
12
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
13
+ let wrapper = TokenizerWrapper::new(tokenizer);
14
+
15
+ // Create vocabulary adapter
16
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
17
+ .expect("Should create vocabulary");
18
+
19
+ // Create schema processor
20
+ let processor = SchemaProcessor::new();
21
+
22
+ // Define a simple JSON schema for a yes/no response
23
+ let schema = r#"{
24
+ "type": "object",
25
+ "properties": {
26
+ "answer": {
27
+ "type": "string",
28
+ "enum": ["yes", "no"]
29
+ }
30
+ },
31
+ "required": ["answer"]
32
+ }"#;
33
+
34
+ // Process schema into Index
35
+ let index = processor.process_schema(schema, &vocabulary)
36
+ .expect("Should process schema");
37
+
38
+ // Test configuration with constraint
39
+ let mut config_with_constraint = GenerationConfig::default();
40
+ config_with_constraint.constraint = Some(index.clone());
41
+ config_with_constraint.max_length = 50;
42
+
43
+ // Test configuration without constraint
44
+ let config_without_constraint = GenerationConfig::default();
45
+
46
+ // Create text generation instances
47
+ let mut gen_constrained = TextGeneration::from_config(&config_with_constraint);
48
+ let mut gen_unconstrained = TextGeneration::from_config(&config_without_constraint);
49
+
50
+ // Set EOS token
51
+ gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
52
+ gen_unconstrained.set_eos_token_id(102);
53
+
54
+ // Constraints are set internally - we can't directly verify them
55
+ // but we can test their effects in actual generation
56
+ }
57
+ }
58
+
59
+ #[test]
60
+ fn test_constraint_configuration() {
61
+ // Test that we can create a TextGeneration with constraints
62
+ let config = GenerationConfig::default();
63
+ let _text_gen = TextGeneration::from_config(&config);
64
+
65
+ // Test that we can create a TextGeneration from config
66
+ // Constraints are private implementation details
67
+ }
68
+
69
+ #[test]
70
+ fn test_repetition_penalty() {
71
+ use candle_core::{Tensor, Device};
72
+
73
+ let device = Device::Cpu;
74
+ let vocab_size = 10;
75
+
76
+ // Create logits with some positive and negative values
77
+ let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
78
+ let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
79
+
80
+ // Create text generation with some tokens
81
+ let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
82
+ text_gen.push_token(0); // Token that had logit 1.0
83
+ text_gen.push_token(2); // Token that had logit 2.0
84
+ text_gen.push_token(5); // Token that had logit 3.0
85
+
86
+ // Apply repetition penalty
87
+ text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
88
+
89
+ let penalized = logits.to_vec1::<f32>().unwrap();
90
+
91
+ // Check that tokens in context were penalized
92
+ assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
93
+ assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
94
+ assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
95
+
96
+ // Check that other tokens remain unchanged
97
+ assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
98
+ assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
99
+ }
100
+
101
+ #[test]
102
+ fn test_stop_conditions() {
103
+ let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
104
+ text_gen.set_eos_token_id(50256); // Common EOS token
105
+
106
+ // Test max length stop
107
+ for i in 0..10 {
108
+ text_gen.push_token(i);
109
+ }
110
+ assert!(text_gen.should_stop(100, 10), "Should stop at max length");
111
+ assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
112
+
113
+ // Test EOS token stop
114
+ assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
115
+ assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
116
+
117
+ // Test stop sequences
118
+ let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
119
+ assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
120
+ assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
121
+ assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
122
+ }
123
+ }
@@ -1,4 +1,6 @@
1
1
  use std::time::{SystemTime, UNIX_EPOCH};
2
+ use std::sync::Arc;
3
+ use crate::structured::Index;
2
4
 
3
5
  /// Configuration for text generation
4
6
  #[derive(Debug, Clone)]
@@ -23,6 +25,8 @@ pub struct GenerationConfig {
23
25
  pub include_prompt: bool,
24
26
  /// Whether to show raw tokens during generation (for debugging)
25
27
  pub debug_tokens: bool,
28
+ /// Optional constraint index for structured generation
29
+ pub constraint: Option<Arc<Index>>,
26
30
  }
27
31
 
28
32
  /// Generate a random seed based on current time
@@ -46,6 +50,7 @@ impl Default for GenerationConfig {
46
50
  stop_sequences: vec![],
47
51
  include_prompt: false,
48
52
  debug_tokens: false,
53
+ constraint: None,
49
54
  }
50
55
  }
51
56
  }
@@ -3,6 +3,8 @@ use candle_core::{Device, Result as CandleResult};
3
3
  pub mod mistral;
4
4
  pub mod llama;
5
5
  pub mod gemma;
6
+ pub mod qwen;
7
+ pub mod phi;
6
8
  pub mod generation_config;
7
9
  pub mod text_generation;
8
10
  pub mod quantized_gguf;
@@ -12,6 +14,9 @@ pub use text_generation::TextGeneration;
12
14
  pub use quantized_gguf::QuantizedGGUF;
13
15
  pub use crate::tokenizer::TokenizerWrapper;
14
16
 
17
+ #[cfg(test)]
18
+ mod constrained_generation_test;
19
+
15
20
  /// Trait for text generation models
16
21
  pub trait TextGenerator: Send + Sync {
17
22
  /// Generate text from a prompt
@@ -0,0 +1,285 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_transformers::models::phi::{Config, Model as PhiModel};
3
+ use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model};
4
+ use hf_hub::api::tokio::Api;
5
+ use tokenizers::Tokenizer;
6
+
7
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ /// Phi model wrapper for text generation
10
+ pub struct Phi {
11
+ model: PhiVariant,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ enum PhiVariant {
19
+ Phi2(PhiModel),
20
+ Phi3(Phi3Model),
21
+ }
22
+
23
+ impl Phi {
24
+ /// Get the tokenizer
25
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
26
+ &self.tokenizer
27
+ }
28
+
29
+ /// Clear the KV cache between generations
30
+ pub fn clear_kv_cache(&mut self) {
31
+ match &mut self.model {
32
+ PhiVariant::Phi2(model) => model.clear_kv_cache(),
33
+ PhiVariant::Phi3(model) => model.clear_kv_cache(),
34
+ }
35
+ }
36
+
37
+ /// Load a Phi model from HuggingFace
38
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
39
+ let api = Api::new()
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
41
+
42
+ let repo = api.model(model_id.to_string());
43
+
44
+ // Download configuration
45
+ let config_filename = repo.get("config.json").await
46
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
47
+ let config_str = std::fs::read_to_string(config_filename)?;
48
+
49
+ // Download tokenizer
50
+ let tokenizer_filename = repo.get("tokenizer.json").await
51
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
52
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
53
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
54
+
55
+ // Determine EOS token
56
+ let vocab = tokenizer.get_vocab(true);
57
+ let eos_token_id = vocab.get("<|endoftext|>")
58
+ .or_else(|| vocab.get("<|end|>"))
59
+ .or_else(|| vocab.get("</s>"))
60
+ .copied()
61
+ .unwrap_or(50256); // Default GPT-2 style EOS token
62
+
63
+ // Determine model variant based on model_id or config
64
+ let is_phi3 = model_id.contains("phi-3") || model_id.contains("Phi-3");
65
+
66
+ // Download model weights (handle both single and sharded files)
67
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
68
+ vec![single_file]
69
+ } else {
70
+ // Try to find sharded model files
71
+ let mut sharded_files = Vec::new();
72
+ let mut index = 1;
73
+ loop {
74
+ // Try common shard counts
75
+ let mut found = false;
76
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
77
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
78
+ if let Ok(file) = repo.get(&filename).await {
79
+ sharded_files.push(file);
80
+ found = true;
81
+ break;
82
+ }
83
+ }
84
+ if !found {
85
+ break;
86
+ }
87
+ index += 1;
88
+ }
89
+
90
+ if sharded_files.is_empty() {
91
+ return Err(candle_core::Error::Msg(
92
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
93
+ ));
94
+ }
95
+ sharded_files
96
+ };
97
+
98
+ let model = if is_phi3 {
99
+ // Load Phi3 model
100
+ let config: Phi3Config = serde_json::from_str(&config_str)
101
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
102
+
103
+ let vb = unsafe {
104
+ candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
105
+ };
106
+
107
+ let model = Phi3Model::new(&config, vb)?;
108
+ PhiVariant::Phi3(model)
109
+ } else {
110
+ // Load Phi2 model
111
+ let config: Config = serde_json::from_str(&config_str)
112
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi config: {}", e)))?;
113
+
114
+ let vb = unsafe {
115
+ candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
116
+ };
117
+
118
+ let model = PhiModel::new(&config, vb)?;
119
+ PhiVariant::Phi2(model)
120
+ };
121
+
122
+ Ok(Self {
123
+ model,
124
+ tokenizer: TokenizerWrapper::new(tokenizer),
125
+ device,
126
+ model_id: model_id.to_string(),
127
+ eos_token_id,
128
+ })
129
+ }
130
+
131
+ /// Apply Phi chat template to messages
132
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
133
+ let mut prompt = String::new();
134
+
135
+ // Phi-3 uses a specific format
136
+ if matches!(self.model, PhiVariant::Phi3(_)) {
137
+ for message in messages {
138
+ let role = message["role"].as_str().unwrap_or("");
139
+ let content = message["content"].as_str().unwrap_or("");
140
+
141
+ match role {
142
+ "system" => {
143
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
144
+ }
145
+ "user" => {
146
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
147
+ }
148
+ "assistant" => {
149
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
150
+ }
151
+ _ => {}
152
+ }
153
+ }
154
+ prompt.push_str("<|assistant|>\n");
155
+ } else {
156
+ // Phi-2 uses a simpler format
157
+ for message in messages {
158
+ let role = message["role"].as_str().unwrap_or("");
159
+ let content = message["content"].as_str().unwrap_or("");
160
+
161
+ match role {
162
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
163
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
164
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
165
+ _ => {}
166
+ }
167
+ }
168
+ prompt.push_str("Assistant: ");
169
+ }
170
+
171
+ Ok(prompt)
172
+ }
173
+
174
+ fn generate_tokens(
175
+ &mut self,
176
+ prompt_tokens: Vec<u32>,
177
+ config: &GenerationConfig,
178
+ mut callback: Option<impl FnMut(&str)>,
179
+ ) -> CandleResult<Vec<u32>> {
180
+ let mut text_gen = TextGeneration::from_config(config);
181
+ text_gen.set_eos_token_id(self.eos_token_id);
182
+ text_gen.set_tokens(prompt_tokens.clone());
183
+
184
+ let mut all_tokens = prompt_tokens.clone();
185
+ let start_gen = all_tokens.len();
186
+
187
+ for index in 0..config.max_length {
188
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
189
+ let start_pos = all_tokens.len().saturating_sub(context_size);
190
+ let ctxt = &all_tokens[start_pos..];
191
+
192
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
193
+ let logits = match &mut self.model {
194
+ PhiVariant::Phi2(model) => model.forward(&input)?,
195
+ PhiVariant::Phi3(model) => model.forward(&input, start_pos)?,
196
+ };
197
+ let logits = logits.squeeze(0)?;
198
+
199
+ // Handle different output shapes
200
+ let logits = if logits.dims().len() == 2 {
201
+ let seq_len = logits.dim(0)?;
202
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
203
+ } else {
204
+ logits
205
+ };
206
+
207
+ let logits = logits.to_dtype(DType::F32)?;
208
+
209
+ let next_token = text_gen.sample_next_token(
210
+ &logits,
211
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
212
+ )?;
213
+
214
+ all_tokens.push(next_token);
215
+
216
+ // Stream callback
217
+ if let Some(ref mut cb) = callback {
218
+ if config.debug_tokens {
219
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
220
+ cb(&format!("[{}:{}]", next_token, token_piece));
221
+ } else {
222
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
223
+ cb(&decoded_text);
224
+ }
225
+ }
226
+
227
+ // Check stop conditions
228
+ if text_gen.should_stop(next_token, config.max_length) {
229
+ break;
230
+ }
231
+
232
+ // Check stop sequences
233
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
234
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
235
+ break;
236
+ }
237
+ }
238
+
239
+ Ok(if config.include_prompt {
240
+ all_tokens
241
+ } else {
242
+ all_tokens[start_gen..].to_vec()
243
+ })
244
+ }
245
+ }
246
+
247
+ impl TextGenerator for Phi {
248
+ fn generate(
249
+ &mut self,
250
+ prompt: &str,
251
+ config: &GenerationConfig,
252
+ ) -> CandleResult<String> {
253
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
254
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
255
+
256
+ if config.debug_tokens {
257
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
258
+ } else {
259
+ self.tokenizer.decode(&output_tokens, true)
260
+ }
261
+ }
262
+
263
+ fn generate_stream(
264
+ &mut self,
265
+ prompt: &str,
266
+ config: &GenerationConfig,
267
+ mut callback: impl FnMut(&str),
268
+ ) -> CandleResult<String> {
269
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
270
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
271
+ self.tokenizer.decode(&output_tokens, true)
272
+ }
273
+
274
+ fn model_name(&self) -> &str {
275
+ &self.model_id
276
+ }
277
+
278
+ fn device(&self) -> &Device {
279
+ &self.device
280
+ }
281
+
282
+ fn clear_cache(&mut self) {
283
+ self.clear_kv_cache();
284
+ }
285
+ }