red-candle 1.8.0.pre2-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5193 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +33 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,309 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::mistral::{Config, Model as MistralModel};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Mistral {
11
+ model: MistralModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Mistral {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
23
+ /// Clear the KV cache between generations
24
+ pub fn clear_kv_cache(&mut self) {
25
+ self.model.clear_kv_cache();
26
+ }
27
+
28
+ /// Get the tokenizer
29
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
30
+ &self.tokenizer
31
+ }
32
+
33
+ /// Load a Mistral model from HuggingFace Hub with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
+ let api = Api::new()
36
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
+
38
+ let repo = api.repo(Repo::model(model_id.to_string()));
39
+
40
+ // Download model files
41
+ let config_filename = repo
42
+ .get("config.json")
43
+ .await
44
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
45
+
46
+ // Download tokenizer from custom source if provided, otherwise from model repo
47
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
48
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
49
+ let tokenizer_filename = tokenizer_repo
50
+ .get("tokenizer.json")
51
+ .await
52
+ .map_err(|e| {
53
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
54
+ format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
55
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
56
+ format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
57
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
58
+ format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
59
+ } else {
60
+ format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
61
+ };
62
+ candle_core::Error::Msg(error_msg)
63
+ })?;
64
+ Tokenizer::from_file(tokenizer_filename)
65
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
66
+ } else {
67
+ let tokenizer_filename = repo
68
+ .get("tokenizer.json")
69
+ .await
70
+ .map_err(|e| {
71
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
72
+ format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'mistralai/Mistral-7B-Instruct-v0.2')", model_id, model_id)
73
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
74
+ format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
75
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
76
+ format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
77
+ } else {
78
+ format!("Failed to download tokenizer: {}", e)
79
+ };
80
+ candle_core::Error::Msg(error_msg)
81
+ })?;
82
+ Tokenizer::from_file(tokenizer_filename)
83
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
84
+ };
85
+
86
+ // Try different file patterns for model weights
87
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
88
+ vec![single_file]
89
+ } else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
90
+ // Some Mistral models use consolidated.safetensors
91
+ vec![consolidated_file]
92
+ } else {
93
+ // Try to find sharded model files
94
+ // NOTE: This uses a brute-force approach, trying common shard counts.
95
+ // A better approach would be to read model.safetensors.index.json which
96
+ // contains the exact file list, but this works for most models (≤8 shards).
97
+ let mut sharded_files = Vec::new();
98
+ let mut index = 1;
99
+ loop {
100
+ // Try common shard counts
101
+ let mut found = false;
102
+ for total in [2, 3, 4, 5, 6, 7, 8] {
103
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
104
+ if let Ok(file) = repo.get(&filename).await {
105
+ sharded_files.push(file);
106
+ found = true;
107
+ break;
108
+ }
109
+ }
110
+ if !found {
111
+ break;
112
+ }
113
+ index += 1;
114
+ }
115
+
116
+ if sharded_files.is_empty() {
117
+ // Try single pytorch_model.bin as last resort (though we prefer safetensors)
118
+ if let Ok(_pytorch_file) = repo.get("pytorch_model.bin").await {
119
+ return Err(candle_core::Error::Msg(
120
+ "Only safetensors format is supported. This model uses pytorch_model.bin format.".to_string()
121
+ ));
122
+ } else {
123
+ return Err(candle_core::Error::Msg(
124
+ "Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
125
+ ));
126
+ }
127
+ }
128
+ sharded_files
129
+ };
130
+
131
+ // Load config
132
+ let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
133
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
134
+
135
+ let eos_token_id = tokenizer
136
+ .get_vocab(true)
137
+ .get("</s>")
138
+ .copied()
139
+ .unwrap_or(2);
140
+
141
+ // Load model weights
142
+ let vb = unsafe {
143
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
144
+ };
145
+
146
+ let model = MistralModel::new(&config, vb)?;
147
+
148
+ Ok(Self {
149
+ model,
150
+ tokenizer: TokenizerWrapper::new(tokenizer),
151
+ device,
152
+ model_id: model_id.to_string(),
153
+ eos_token_id,
154
+ })
155
+ }
156
+
157
+ /// Load a Mistral model from HuggingFace Hub (backwards compatibility)
158
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
159
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
160
+ }
161
+
162
+ /// Create from existing components (useful for testing)
163
+ pub fn new(
164
+ model: MistralModel,
165
+ tokenizer: Tokenizer,
166
+ device: Device,
167
+ model_id: String,
168
+ ) -> Self {
169
+ let eos_token_id = tokenizer
170
+ .get_vocab(true)
171
+ .get("</s>")
172
+ .copied()
173
+ .unwrap_or(2);
174
+
175
+ Self {
176
+ model,
177
+ tokenizer: TokenizerWrapper::new(tokenizer),
178
+ device,
179
+ model_id,
180
+ eos_token_id,
181
+ }
182
+ }
183
+
184
+ fn generate_tokens(
185
+ &mut self,
186
+ prompt_tokens: Vec<u32>,
187
+ config: &GenerationConfig,
188
+ mut callback: Option<impl FnMut(&str)>,
189
+ ) -> CandleResult<Vec<u32>> {
190
+ let mut text_gen = TextGeneration::new(config);
191
+ text_gen.set_eos_token_id(self.eos_token_id);
192
+ text_gen.set_tokens(prompt_tokens.clone());
193
+
194
+ let mut all_tokens = prompt_tokens.clone();
195
+ let start_gen = all_tokens.len();
196
+
197
+ for index in 0..config.max_length {
198
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
199
+ let start_pos = all_tokens.len().saturating_sub(context_size);
200
+ let ctxt = &all_tokens[start_pos..];
201
+
202
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
203
+ // Ensure input tensor is contiguous for Metal backend
204
+ let input = input.contiguous()?;
205
+ let logits = self.model.forward(&input, start_pos)?;
206
+
207
+ // The model returns logits of shape [batch_size, seq_len, vocab_size]
208
+ // We need to get the logits for the last token only
209
+ let logits = logits.squeeze(0)?; // Remove batch dimension
210
+ let logits = if logits.dims().len() == 2 {
211
+ // If we still have [seq_len, vocab_size], take the last token
212
+ let seq_len = logits.dim(0)?;
213
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
214
+ } else {
215
+ // Already [vocab_size]
216
+ logits
217
+ };
218
+
219
+ // Convert to F32 for sampling if needed
220
+ let logits = logits.to_dtype(DType::F32)?;
221
+
222
+ let next_token = text_gen.sample_next_token(&logits)?;
223
+
224
+ all_tokens.push(next_token);
225
+
226
+ // Stream callback
227
+ if let Some(ref mut cb) = callback {
228
+ if config.debug_tokens {
229
+ // In debug mode, only show debug tokens
230
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
231
+ cb(&format!("[{}:{}]", next_token, token_piece));
232
+ } else {
233
+ // Normal mode: use incremental decoding for proper text
234
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
235
+ cb(&decoded_text);
236
+ }
237
+ }
238
+
239
+ // Check stop conditions
240
+ if text_gen.should_stop(next_token, config.max_length) {
241
+ break;
242
+ }
243
+
244
+ // Check if constraint is satisfied (early stopping)
245
+ if config.stop_on_constraint_satisfaction {
246
+ let satisfied = if config.stop_on_match {
247
+ text_gen.is_constraint_satisfied_stop_on_match()
248
+ } else {
249
+ text_gen.is_constraint_satisfied()
250
+ };
251
+ if satisfied {
252
+ break;
253
+ }
254
+ }
255
+
256
+ // Check stop sequences
257
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
258
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
259
+ break;
260
+ }
261
+ }
262
+
263
+ Ok(if config.include_prompt {
264
+ all_tokens
265
+ } else {
266
+ all_tokens[start_gen..].to_vec()
267
+ })
268
+ }
269
+ }
270
+
271
+ impl TextGenerator for Mistral {
272
+ fn generate(
273
+ &mut self,
274
+ prompt: &str,
275
+ config: &GenerationConfig,
276
+ ) -> CandleResult<String> {
277
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
278
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
279
+
280
+ if config.debug_tokens {
281
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
282
+ } else {
283
+ self.tokenizer.decode(&output_tokens, true)
284
+ }
285
+ }
286
+
287
+ fn generate_stream(
288
+ &mut self,
289
+ prompt: &str,
290
+ config: &GenerationConfig,
291
+ mut callback: impl FnMut(&str),
292
+ ) -> CandleResult<String> {
293
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
294
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
295
+ self.tokenizer.decode(&output_tokens, true)
296
+ }
297
+
298
+ fn model_name(&self) -> &str {
299
+ &self.model_id
300
+ }
301
+
302
+ fn device(&self) -> &Device {
303
+ &self.device
304
+ }
305
+
306
+ fn clear_cache(&mut self) {
307
+ self.clear_kv_cache();
308
+ }
309
+ }
@@ -0,0 +1,49 @@
1
+ use candle_core::{Device, Result as CandleResult};
2
+
3
+ pub mod mistral;
4
+ pub mod llama;
5
+ pub mod gemma;
6
+ pub mod qwen;
7
+ pub mod qwen3;
8
+ pub mod phi;
9
+ pub mod granite;
10
+ pub mod granitemoehybrid;
11
+ pub mod glm4;
12
+ pub mod generation_config;
13
+ pub mod text_generation;
14
+ pub mod quantized_gguf;
15
+
16
+ pub use generation_config::GenerationConfig;
17
+ pub use text_generation::TextGeneration;
18
+ pub use quantized_gguf::QuantizedGGUF;
19
+ pub use crate::tokenizer::TokenizerWrapper;
20
+
21
+ #[cfg(test)]
22
+ mod constrained_generation_test;
23
+
24
+ /// Trait for text generation models
25
+ pub trait TextGenerator: Send + Sync {
26
+ /// Generate text from a prompt
27
+ fn generate(
28
+ &mut self,
29
+ prompt: &str,
30
+ config: &GenerationConfig,
31
+ ) -> CandleResult<String>;
32
+
33
+ /// Generate text with streaming callback
34
+ fn generate_stream(
35
+ &mut self,
36
+ prompt: &str,
37
+ config: &GenerationConfig,
38
+ callback: impl FnMut(&str),
39
+ ) -> CandleResult<String>;
40
+
41
+ /// Get the model's name
42
+ fn model_name(&self) -> &str;
43
+
44
+ /// Get the device the model is running on
45
+ fn device(&self) -> &Device;
46
+
47
+ /// Clear any cached state (like KV cache)
48
+ fn clear_cache(&mut self);
49
+ }