red-candle 1.0.0.pre.1 → 1.0.0.pre.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,325 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::mistral::{Config, Model as MistralModel};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Mistral {
11
+ model: MistralModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Mistral {
19
+ /// Clear the KV cache between generations
20
+ pub fn clear_kv_cache(&mut self) {
21
+ self.model.clear_kv_cache();
22
+ }
23
+
24
+ /// Load a Mistral model from HuggingFace Hub
25
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
+ let api = Api::new()
27
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
28
+
29
+ let repo = api.repo(Repo::model(model_id.to_string()));
30
+
31
+ // Download model files
32
+ let config_filename = repo
33
+ .get("config.json")
34
+ .await
35
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
36
+
37
+ let tokenizer_filename = repo
38
+ .get("tokenizer.json")
39
+ .await
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
41
+
42
+ // Try different file patterns for model weights
43
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
44
+ vec![single_file]
45
+ } else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
46
+ // Some Mistral models use consolidated.safetensors
47
+ vec![consolidated_file]
48
+ } else {
49
+ // Try to find sharded model files
50
+ let mut sharded_files = Vec::new();
51
+ let mut index = 1;
52
+ loop {
53
+ // Try common shard counts
54
+ let mut found = false;
55
+ for total in [2, 3, 4, 5, 6, 7, 8] {
56
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
57
+ if let Ok(file) = repo.get(&filename).await {
58
+ sharded_files.push(file);
59
+ found = true;
60
+ break;
61
+ }
62
+ }
63
+ if !found {
64
+ break;
65
+ }
66
+ index += 1;
67
+ }
68
+
69
+ if sharded_files.is_empty() {
70
+ // Try single pytorch_model.bin as last resort (though we prefer safetensors)
71
+ if let Ok(_pytorch_file) = repo.get("pytorch_model.bin").await {
72
+ return Err(candle_core::Error::Msg(
73
+ "Only safetensors format is supported. This model uses pytorch_model.bin format.".to_string()
74
+ ));
75
+ } else {
76
+ return Err(candle_core::Error::Msg(
77
+ "Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
78
+ ));
79
+ }
80
+ }
81
+ sharded_files
82
+ };
83
+
84
+ // Load config
85
+ let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
86
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
87
+
88
+ // Load tokenizer
89
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
90
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
91
+
92
+ let eos_token_id = tokenizer
93
+ .get_vocab(true)
94
+ .get("</s>")
95
+ .copied()
96
+ .unwrap_or(2);
97
+
98
+ // Load model weights
99
+ let vb = unsafe {
100
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
101
+ };
102
+
103
+ let model = MistralModel::new(&config, vb)?;
104
+
105
+ Ok(Self {
106
+ model,
107
+ tokenizer: TokenizerWrapper::new(tokenizer),
108
+ device,
109
+ model_id: model_id.to_string(),
110
+ eos_token_id,
111
+ })
112
+ }
113
+
114
+ /// Create from existing components (useful for testing)
115
+ pub fn new(
116
+ model: MistralModel,
117
+ tokenizer: Tokenizer,
118
+ device: Device,
119
+ model_id: String,
120
+ ) -> Self {
121
+ let eos_token_id = tokenizer
122
+ .get_vocab(true)
123
+ .get("</s>")
124
+ .copied()
125
+ .unwrap_or(2);
126
+
127
+ Self {
128
+ model,
129
+ tokenizer: TokenizerWrapper::new(tokenizer),
130
+ device,
131
+ model_id,
132
+ eos_token_id,
133
+ }
134
+ }
135
+
136
+ fn generate_tokens(
137
+ &mut self,
138
+ prompt_tokens: Vec<u32>,
139
+ config: &GenerationConfig,
140
+ mut callback: Option<impl FnMut(&str)>,
141
+ ) -> CandleResult<Vec<u32>> {
142
+ let mut text_gen = TextGeneration::from_config(config);
143
+ text_gen.set_eos_token_id(self.eos_token_id);
144
+ text_gen.set_tokens(prompt_tokens.clone());
145
+
146
+ let mut all_tokens = prompt_tokens.clone();
147
+ let start_gen = all_tokens.len();
148
+
149
+ for index in 0..config.max_length {
150
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
151
+ let start_pos = all_tokens.len().saturating_sub(context_size);
152
+ let ctxt = &all_tokens[start_pos..];
153
+
154
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
155
+ // Ensure input tensor is contiguous for Metal backend
156
+ let input = input.contiguous()?;
157
+ let logits = self.model.forward(&input, start_pos)?;
158
+
159
+ // The model returns logits of shape [batch_size, seq_len, vocab_size]
160
+ // We need to get the logits for the last token only
161
+ let logits = logits.squeeze(0)?; // Remove batch dimension
162
+ let logits = if logits.dims().len() == 2 {
163
+ // If we still have [seq_len, vocab_size], take the last token
164
+ let seq_len = logits.dim(0)?;
165
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
166
+ } else {
167
+ // Already [vocab_size]
168
+ logits
169
+ };
170
+
171
+ // Convert to F32 for sampling if needed
172
+ let logits = logits.to_dtype(DType::F32)?;
173
+
174
+ let next_token = text_gen.sample_next_token(
175
+ &logits,
176
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
177
+ )?;
178
+
179
+ all_tokens.push(next_token);
180
+
181
+ // Stream callback
182
+ if let Some(ref mut cb) = callback {
183
+ let token_text = self.tokenizer.token_to_piece(next_token)?;
184
+ cb(&token_text);
185
+ }
186
+
187
+ // Check stop conditions
188
+ if text_gen.should_stop(next_token, config.max_length) {
189
+ break;
190
+ }
191
+
192
+ // Check stop sequences
193
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
194
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
195
+ break;
196
+ }
197
+ }
198
+
199
+ Ok(if config.include_prompt {
200
+ all_tokens
201
+ } else {
202
+ all_tokens[start_gen..].to_vec()
203
+ })
204
+ }
205
+
206
+ fn generate_tokens_decoded(
207
+ &mut self,
208
+ prompt_tokens: Vec<u32>,
209
+ config: &GenerationConfig,
210
+ mut callback: Option<impl FnMut(&str)>,
211
+ ) -> CandleResult<Vec<u32>> {
212
+ let mut text_gen = TextGeneration::from_config(config);
213
+ text_gen.set_eos_token_id(self.eos_token_id);
214
+ text_gen.set_tokens(prompt_tokens.clone());
215
+
216
+ let mut all_tokens = prompt_tokens.clone();
217
+ let start_gen = all_tokens.len();
218
+
219
+ // For incremental decoding
220
+ let mut previously_decoded = String::new();
221
+
222
+ for index in 0..config.max_length {
223
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
224
+ let start_pos = all_tokens.len().saturating_sub(context_size);
225
+ let ctxt = &all_tokens[start_pos..];
226
+
227
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
228
+ // Ensure input tensor is contiguous for Metal backend
229
+ let input = input.contiguous()?;
230
+ let logits = self.model.forward(&input, start_pos)?;
231
+
232
+ // The model returns logits of shape [batch_size, seq_len, vocab_size]
233
+ // We need to get the logits for the last token only
234
+ let logits = logits.squeeze(0)?; // Remove batch dimension
235
+ let logits = if logits.dims().len() == 2 {
236
+ // If we still have [seq_len, vocab_size], take the last token
237
+ let seq_len = logits.dim(0)?;
238
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
239
+ } else {
240
+ // Already [vocab_size]
241
+ logits
242
+ };
243
+
244
+ // Convert to F32 for sampling if needed
245
+ let logits = logits.to_dtype(DType::F32)?;
246
+
247
+ let next_token = text_gen.sample_next_token(
248
+ &logits,
249
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
250
+ )?;
251
+
252
+ all_tokens.push(next_token);
253
+
254
+ // Stream callback with incremental decoding
255
+ if let Some(ref mut cb) = callback {
256
+ // Decode all generated tokens so far
257
+ let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
258
+
259
+ // Only emit the new text since last callback
260
+ if current_decoded.len() > previously_decoded.len() {
261
+ let new_text = &current_decoded[previously_decoded.len()..];
262
+ cb(new_text);
263
+ previously_decoded = current_decoded;
264
+ }
265
+ }
266
+
267
+ // Check stop conditions
268
+ if text_gen.should_stop(next_token, config.max_length) {
269
+ break;
270
+ }
271
+
272
+ // Check stop sequences
273
+ let generated_text = if callback.is_some() {
274
+ previously_decoded.clone()
275
+ } else {
276
+ self.tokenizer.decode(&all_tokens[start_gen..], true)?
277
+ };
278
+
279
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
280
+ break;
281
+ }
282
+ }
283
+
284
+ Ok(if config.include_prompt {
285
+ all_tokens
286
+ } else {
287
+ all_tokens[start_gen..].to_vec()
288
+ })
289
+ }
290
+ }
291
+
292
+ impl TextGenerator for Mistral {
293
+ fn generate(
294
+ &mut self,
295
+ prompt: &str,
296
+ config: &GenerationConfig,
297
+ ) -> CandleResult<String> {
298
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
299
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
300
+ self.tokenizer.decode(&output_tokens, true)
301
+ }
302
+
303
+ fn generate_stream(
304
+ &mut self,
305
+ prompt: &str,
306
+ config: &GenerationConfig,
307
+ mut callback: impl FnMut(&str),
308
+ ) -> CandleResult<String> {
309
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
310
+ let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
311
+ self.tokenizer.decode(&output_tokens, true)
312
+ }
313
+
314
+ fn model_name(&self) -> &str {
315
+ &self.model_id
316
+ }
317
+
318
+ fn device(&self) -> &Device {
319
+ &self.device
320
+ }
321
+
322
+ fn clear_cache(&mut self) {
323
+ self.clear_kv_cache();
324
+ }
325
+ }
@@ -0,0 +1,68 @@
1
+ use candle_core::{Device, Result as CandleResult};
2
+ use tokenizers::Tokenizer;
3
+
4
+ pub mod mistral;
5
+ pub mod generation_config;
6
+ pub mod text_generation;
7
+
8
+ pub use generation_config::GenerationConfig;
9
+ pub use text_generation::TextGeneration;
10
+
11
+ /// Trait for text generation models
12
+ pub trait TextGenerator: Send + Sync {
13
+ /// Generate text from a prompt
14
+ fn generate(
15
+ &mut self,
16
+ prompt: &str,
17
+ config: &GenerationConfig,
18
+ ) -> CandleResult<String>;
19
+
20
+ /// Generate text with streaming callback
21
+ fn generate_stream(
22
+ &mut self,
23
+ prompt: &str,
24
+ config: &GenerationConfig,
25
+ callback: impl FnMut(&str),
26
+ ) -> CandleResult<String>;
27
+
28
+ /// Get the model's name
29
+ fn model_name(&self) -> &str;
30
+
31
+ /// Get the device the model is running on
32
+ fn device(&self) -> &Device;
33
+
34
+ /// Clear any cached state (like KV cache)
35
+ fn clear_cache(&mut self);
36
+ }
37
+
38
+ /// Common structure for managing tokenizer
39
+ #[derive(Debug)]
40
+ pub struct TokenizerWrapper {
41
+ tokenizer: Tokenizer,
42
+ }
43
+
44
+ impl TokenizerWrapper {
45
+ pub fn new(tokenizer: Tokenizer) -> Self {
46
+ Self { tokenizer }
47
+ }
48
+
49
+ pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
50
+ let encoding = self.tokenizer
51
+ .encode(text, add_special_tokens)
52
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
53
+ Ok(encoding.get_ids().to_vec())
54
+ }
55
+
56
+ pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
57
+ self.tokenizer
58
+ .decode(tokens, skip_special_tokens)
59
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
60
+ }
61
+
62
+ pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
63
+ self.tokenizer
64
+ .id_to_token(token)
65
+ .map(|s| s.to_string())
66
+ .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
67
+ }
68
+ }
@@ -0,0 +1,141 @@
1
+ use candle_core::{Result as CandleResult, Tensor};
2
+ use candle_transformers::generation::LogitsProcessor;
3
+ use rand::{rngs::StdRng, SeedableRng};
4
+
5
+ use super::GenerationConfig;
6
+
7
+ /// Helper struct for text generation process
8
+ pub struct TextGeneration {
9
+ #[allow(dead_code)]
10
+ rng: StdRng,
11
+ logits_processor: LogitsProcessor,
12
+ tokens: Vec<u32>,
13
+ eos_token_id: Option<u32>,
14
+ }
15
+
16
+ impl TextGeneration {
17
+ pub fn new(
18
+ seed: u64,
19
+ temperature: Option<f64>,
20
+ top_p: Option<f64>,
21
+ _top_k: Option<usize>,
22
+ _repetition_penalty: f32,
23
+ _repetition_penalty_last_n: usize,
24
+ ) -> Self {
25
+ let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
26
+
27
+ Self {
28
+ rng: StdRng::seed_from_u64(seed),
29
+ logits_processor,
30
+ tokens: Vec::new(),
31
+ eos_token_id: None,
32
+ }
33
+ }
34
+
35
+ pub fn from_config(config: &GenerationConfig) -> Self {
36
+ Self::new(
37
+ config.seed,
38
+ Some(config.temperature),
39
+ config.top_p,
40
+ config.top_k,
41
+ config.repetition_penalty,
42
+ config.repetition_penalty_last_n,
43
+ )
44
+ }
45
+
46
+ pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
47
+ self.eos_token_id = Some(eos_token_id);
48
+ }
49
+
50
+ pub fn set_tokens(&mut self, tokens: Vec<u32>) {
51
+ self.tokens = tokens;
52
+ }
53
+
54
+ pub fn get_tokens(&self) -> &[u32] {
55
+ &self.tokens
56
+ }
57
+
58
+ pub fn push_token(&mut self, token: u32) {
59
+ self.tokens.push(token);
60
+ }
61
+
62
+ /// Apply repetition penalty to logits
63
+ pub fn apply_repetition_penalty(
64
+ &self,
65
+ logits: &mut Tensor,
66
+ penalty: f32,
67
+ context_size: usize,
68
+ ) -> CandleResult<()> {
69
+ if penalty == 1.0 {
70
+ return Ok(());
71
+ }
72
+
73
+ let device = logits.device();
74
+ let vocab_size = logits.dims1()?;
75
+
76
+ // Get the context tokens to apply penalty to
77
+ let start = self.tokens.len().saturating_sub(context_size);
78
+ let context_tokens = &self.tokens[start..];
79
+
80
+ // Apply penalty to tokens that appear in the context
81
+ let mut logits_vec = logits.to_vec1::<f32>()?;
82
+ for &token in context_tokens {
83
+ if (token as usize) < vocab_size {
84
+ let idx = token as usize;
85
+ if logits_vec[idx] > 0.0 {
86
+ logits_vec[idx] /= penalty;
87
+ } else {
88
+ logits_vec[idx] *= penalty;
89
+ }
90
+ }
91
+ }
92
+
93
+ *logits = Tensor::from_vec(logits_vec, vocab_size, device)?;
94
+ Ok(())
95
+ }
96
+
97
+ /// Sample next token from logits
98
+ pub fn sample_next_token(
99
+ &mut self,
100
+ logits: &Tensor,
101
+ repetition_penalty: Option<(f32, usize)>,
102
+ ) -> CandleResult<u32> {
103
+ let mut logits = logits.clone();
104
+
105
+ // Apply repetition penalty if specified
106
+ if let Some((penalty, last_n)) = repetition_penalty {
107
+ self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
108
+ }
109
+
110
+ // Sample token
111
+ let next_token = self.logits_processor.sample(&logits)?;
112
+ self.tokens.push(next_token);
113
+
114
+ Ok(next_token)
115
+ }
116
+
117
+ /// Check if we should stop generation
118
+ pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
119
+ if self.tokens.len() >= max_length {
120
+ return true;
121
+ }
122
+
123
+ if let Some(eos) = self.eos_token_id {
124
+ if token == eos {
125
+ return true;
126
+ }
127
+ }
128
+
129
+ false
130
+ }
131
+
132
+ /// Check if the generated text ends with any stop sequence
133
+ pub fn check_stop_sequences(&self, text: &str, stop_sequences: &[String]) -> bool {
134
+ for seq in stop_sequences {
135
+ if text.ends_with(seq) {
136
+ return true;
137
+ }
138
+ }
139
+ false
140
+ }
141
+ }