red-candle 1.8.0.pre2-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5193 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +33 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,396 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::llama::{Config, LlamaConfig, Llama as LlamaModel, Cache};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Llama {
11
+ model: LlamaModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ cache: Cache,
17
+ config: Config,
18
+ }
19
+
20
+ impl Llama {
21
+ pub fn eos_token_id(&self) -> u32 {
22
+ self.eos_token_id
23
+ }
24
+
25
+ /// Clear the KV cache between generations
26
+ pub fn clear_kv_cache(&mut self) {
27
+ // Since Cache doesn't expose a reset method and kvs is private,
28
+ // we'll recreate the cache to clear it
29
+ // This is a workaround until candle provides a proper reset method
30
+ if let Ok(new_cache) = Cache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device) {
31
+ self.cache = new_cache;
32
+ }
33
+ }
34
+
35
+ /// Get the tokenizer
36
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
37
+ &self.tokenizer
38
+ }
39
+
40
+ /// Load a Llama model from HuggingFace Hub with optional custom tokenizer
41
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
42
+ let api = Api::new()
43
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
44
+
45
+ let repo = api.repo(Repo::model(model_id.to_string()));
46
+
47
+ // Download model files
48
+ let config_filename = repo
49
+ .get("config.json")
50
+ .await
51
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
52
+
53
+ // Download tokenizer from custom source if provided, otherwise from model repo
54
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
55
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
56
+ let tokenizer_filename = tokenizer_repo
57
+ .get("tokenizer.json")
58
+ .await
59
+ .map_err(|e| {
60
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
61
+ format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
62
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
63
+ format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
64
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
65
+ format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
66
+ } else {
67
+ format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
68
+ };
69
+ candle_core::Error::Msg(error_msg)
70
+ })?;
71
+ Tokenizer::from_file(tokenizer_filename)
72
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
73
+ } else {
74
+ let tokenizer_filename = repo
75
+ .get("tokenizer.json")
76
+ .await
77
+ .map_err(|e| {
78
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
79
+ format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'appropriate-tokenizer-repo')", model_id, model_id)
80
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
81
+ format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
82
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
83
+ format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
84
+ } else {
85
+ format!("Failed to download tokenizer: {}", e)
86
+ };
87
+ candle_core::Error::Msg(error_msg)
88
+ })?;
89
+ Tokenizer::from_file(tokenizer_filename)
90
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
91
+ };
92
+
93
+ // Try different file patterns for model weights
94
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
95
+ vec![single_file]
96
+ } else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
97
+ vec![consolidated_file]
98
+ } else {
99
+ // Try to find sharded model files
100
+ // NOTE: This uses a brute-force approach, trying common shard counts.
101
+ // A better approach would be to read model.safetensors.index.json which
102
+ // contains the exact file list, but this works for most models (≤30 shards).
103
+ let mut sharded_files = Vec::new();
104
+ let mut index = 1;
105
+ loop {
106
+ // Try common shard counts for Llama models
107
+ let mut found = false;
108
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
109
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
110
+ if let Ok(file) = repo.get(&filename).await {
111
+ sharded_files.push(file);
112
+ found = true;
113
+ break;
114
+ }
115
+ }
116
+ if !found {
117
+ break;
118
+ }
119
+ index += 1;
120
+ }
121
+
122
+ if sharded_files.is_empty() {
123
+ return Err(candle_core::Error::Msg(
124
+ "Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
125
+ ));
126
+ }
127
+ sharded_files
128
+ };
129
+
130
+ // Load config
131
+ let llama_config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config_filename)?)
132
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
133
+ let config = llama_config.into_config(false); // Don't use flash attention for now
134
+
135
+ // Determine EOS token ID based on model type
136
+ let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
137
+ // Llama 3 uses different special tokens
138
+ {
139
+ let vocab = tokenizer.get_vocab(true);
140
+ vocab.get("<|eot_id|>")
141
+ .or_else(|| vocab.get("<|end_of_text|>"))
142
+ .copied()
143
+ .unwrap_or(128009) // Default Llama 3 EOS
144
+ }
145
+ } else {
146
+ // Llama 2 and earlier
147
+ tokenizer
148
+ .get_vocab(true)
149
+ .get("</s>")
150
+ .copied()
151
+ .unwrap_or(2)
152
+ };
153
+
154
+ // Load model weights
155
+ let vb = unsafe {
156
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
157
+ };
158
+
159
+ let model = LlamaModel::load(vb, &config)?;
160
+ let cache = Cache::new(true, DType::F32, &config, &device)?;
161
+
162
+ Ok(Self {
163
+ model,
164
+ tokenizer: TokenizerWrapper::new(tokenizer),
165
+ device,
166
+ model_id: model_id.to_string(),
167
+ eos_token_id,
168
+ cache,
169
+ config,
170
+ })
171
+ }
172
+
173
+ /// Load a Llama model from HuggingFace Hub (backwards compatibility)
174
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
175
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
176
+ }
177
+
178
+ /// Create from existing components (useful for testing)
179
+ pub fn new(
180
+ model: LlamaModel,
181
+ tokenizer: Tokenizer,
182
+ device: Device,
183
+ model_id: String,
184
+ config: &Config,
185
+ ) -> CandleResult<Self> {
186
+ let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
187
+ {
188
+ let vocab = tokenizer.get_vocab(true);
189
+ vocab.get("<|eot_id|>")
190
+ .or_else(|| vocab.get("<|end_of_text|>"))
191
+ .copied()
192
+ .unwrap_or(128009)
193
+ }
194
+ } else {
195
+ tokenizer
196
+ .get_vocab(true)
197
+ .get("</s>")
198
+ .copied()
199
+ .unwrap_or(2)
200
+ };
201
+
202
+ let cache = Cache::new(true, DType::F32, config, &device)?;
203
+
204
+ Ok(Self {
205
+ model,
206
+ tokenizer: TokenizerWrapper::new(tokenizer),
207
+ device,
208
+ model_id,
209
+ eos_token_id,
210
+ cache,
211
+ config: config.clone(),
212
+ })
213
+ }
214
+
215
+ fn generate_tokens(
216
+ &mut self,
217
+ prompt_tokens: Vec<u32>,
218
+ config: &GenerationConfig,
219
+ mut callback: Option<impl FnMut(&str)>,
220
+ ) -> CandleResult<Vec<u32>> {
221
+ let mut text_gen = TextGeneration::new(config);
222
+ text_gen.set_eos_token_id(self.eos_token_id);
223
+ text_gen.set_tokens(prompt_tokens.clone());
224
+
225
+ let mut all_tokens = prompt_tokens.clone();
226
+ let start_gen = all_tokens.len();
227
+
228
+ for index in 0..config.max_length {
229
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
230
+ let start_pos = all_tokens.len().saturating_sub(context_size);
231
+ let ctxt = &all_tokens[start_pos..];
232
+
233
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
234
+ let input = input.contiguous()?;
235
+ let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
236
+
237
+ let logits = logits.squeeze(0)?;
238
+ let logits = if logits.dims().len() == 2 {
239
+ let seq_len = logits.dim(0)?;
240
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
241
+ } else {
242
+ logits
243
+ };
244
+
245
+ let logits = logits.to_dtype(DType::F32)?;
246
+
247
+ let next_token = text_gen.sample_next_token(&logits)?;
248
+
249
+ all_tokens.push(next_token);
250
+
251
+ // Stream callback
252
+ if let Some(ref mut cb) = callback {
253
+ if config.debug_tokens {
254
+ // In debug mode, only show debug tokens
255
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
256
+ cb(&format!("[{}:{}]", next_token, token_piece));
257
+ } else {
258
+ // Normal mode: use incremental decoding for proper text
259
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
260
+ cb(&decoded_text);
261
+ }
262
+ }
263
+
264
+ // Check stop conditions
265
+ if text_gen.should_stop(next_token, config.max_length) {
266
+ break;
267
+ }
268
+
269
+ // Check if constraint is satisfied (early stopping)
270
+ if config.stop_on_constraint_satisfaction {
271
+ let satisfied = if config.stop_on_match {
272
+ text_gen.is_constraint_satisfied_stop_on_match()
273
+ } else {
274
+ text_gen.is_constraint_satisfied()
275
+ };
276
+ if satisfied {
277
+ break;
278
+ }
279
+ }
280
+
281
+ // Check stop sequences
282
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
283
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
284
+ break;
285
+ }
286
+ }
287
+
288
+ Ok(if config.include_prompt {
289
+ all_tokens
290
+ } else {
291
+ all_tokens[start_gen..].to_vec()
292
+ })
293
+ }
294
+
295
+ /// Apply chat template based on Llama version
296
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
297
+ let is_llama3 = self.model_id.contains("Llama-3") || self.model_id.contains("llama-3");
298
+
299
+ if is_llama3 {
300
+ self.apply_llama3_template(messages)
301
+ } else {
302
+ self.apply_llama2_template(messages)
303
+ }
304
+ }
305
+
306
+ fn apply_llama2_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
307
+ let mut prompt = String::new();
308
+ let mut system_message = String::new();
309
+
310
+ for (i, message) in messages.iter().enumerate() {
311
+ let role = message["role"].as_str().unwrap_or("");
312
+ let content = message["content"].as_str().unwrap_or("");
313
+
314
+ match role {
315
+ "system" => {
316
+ system_message = content.to_string();
317
+ }
318
+ "user" => {
319
+ if i == 1 || (i == 0 && system_message.is_empty()) {
320
+ // First user message
321
+ if !system_message.is_empty() {
322
+ prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
323
+ } else {
324
+ prompt.push_str(&format!("[INST] {} [/INST]", content));
325
+ }
326
+ } else {
327
+ prompt.push_str(&format!(" [INST] {} [/INST]", content));
328
+ }
329
+ }
330
+ "assistant" => {
331
+ prompt.push_str(&format!(" {} </s>", content));
332
+ }
333
+ _ => {}
334
+ }
335
+ }
336
+
337
+ Ok(prompt)
338
+ }
339
+
340
+ fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
341
+ let mut prompt = String::new();
342
+
343
+ // BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
344
+
345
+ for message in messages {
346
+ let role = message["role"].as_str().unwrap_or("");
347
+ let content = message["content"].as_str().unwrap_or("");
348
+
349
+ prompt.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, content));
350
+ }
351
+
352
+ prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
353
+
354
+ Ok(prompt)
355
+ }
356
+ }
357
+
358
+ impl TextGenerator for Llama {
359
+ fn generate(
360
+ &mut self,
361
+ prompt: &str,
362
+ config: &GenerationConfig,
363
+ ) -> CandleResult<String> {
364
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
365
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
366
+
367
+ if config.debug_tokens {
368
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
369
+ } else {
370
+ self.tokenizer.decode(&output_tokens, true)
371
+ }
372
+ }
373
+
374
+ fn generate_stream(
375
+ &mut self,
376
+ prompt: &str,
377
+ config: &GenerationConfig,
378
+ mut callback: impl FnMut(&str),
379
+ ) -> CandleResult<String> {
380
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
381
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
382
+ self.tokenizer.decode(&output_tokens, true)
383
+ }
384
+
385
+ fn model_name(&self) -> &str {
386
+ &self.model_id
387
+ }
388
+
389
+ fn device(&self) -> &Device {
390
+ &self.device
391
+ }
392
+
393
+ fn clear_cache(&mut self) {
394
+ self.clear_kv_cache();
395
+ }
396
+ }