red-candle 1.0.0.pre.5 → 1.0.0.pre.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 91a4c43a1a12d6d8960f1a1d190c9bfe8ea60db75f687233012d09b8c90b5020
4
- data.tar.gz: 3f6ce143cd38856365231baebe25a188e7d5824d930ecdae79c1660b3ad6c787
3
+ metadata.gz: 07ca4e6eb0b65eac5b62f4b3622ed3189f203279265b7174936ccfd5ff3e5099
4
+ data.tar.gz: f4970f5c4376453cde1ee18b93155f69ca634ccc3e4a359a45b49d7f20379f64
5
5
  SHA512:
6
- metadata.gz: 273a01c438b085509a433602097b5ad4bcdb3420fc19ebe84c4fd37bf43ae5e6e1040701ecaad9e2864c0cda31d6128d4b2df6c395290994c17f135620282be6
7
- data.tar.gz: 9193d01d8bfc704b982c839f9aebac23217ae5d56b5afae6228daae73b60e219b049fe8fc8d5ad13657ceb110476489c68bbb9aa6397688bf70a976a8d62d41e
6
+ metadata.gz: 10ed0881ec2f67ab1e798401e857eac638049b254b20460bcb5565cee822b24ce2abe23d0ce00275dcb1d1ddebfd926d47eac7e6d54924937da4356a36211224
7
+ data.tar.gz: d24fa67f74cd62c87ea1666e9488f12e8773d15e2d62b806bd38ca7cb20215d819b0502d352e4d310d1377b9ab64debbdd664148c70fc1ac70f1f2e23e9b516c
data/README.md CHANGED
@@ -47,6 +47,7 @@ Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
47
47
 
48
48
  ### Supported Models
49
49
 
50
+ - **Gemma**: Google's Gemma models (e.g., `google/gemma-2b`, `google/gemma-7b`, `google/gemma-2b-it`)
50
51
  - **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
51
52
  - **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
52
53
 
@@ -67,10 +68,10 @@ device = Candle::Device.cpu # CPU (default)
67
68
  device = Candle::Device.metal # Apple GPU (Metal)
68
69
  device = Candle::Device.cuda # NVIDIA GPU (CUDA)
69
70
 
70
- # Load a Llama model
71
- llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
72
- # Or a Mistral model
73
- llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
71
+ # Load a model
72
+ llm = Candle::LLM.from_pretrained("google/gemma-2b-it", device: device) # Gemma
73
+ # llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device) # Llama
74
+ # llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device) # Mistral
74
75
 
75
76
  # Generate text
76
77
  response = llm.generate("What is Ruby?", config: Candle::GenerationConfig.balanced)
@@ -0,0 +1,340 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::gemma::{Config, Model as GemmaModel};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Gemma {
11
+ model: GemmaModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Gemma {
19
+ /// Clear the KV cache between generations
20
+ pub fn clear_kv_cache(&mut self) {
21
+ self.model.clear_kv_cache();
22
+ }
23
+
24
+ /// Load a Gemma model from HuggingFace Hub
25
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
+ let api = Api::new()
27
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
28
+
29
+ let repo = api.repo(Repo::model(model_id.to_string()));
30
+
31
+ // Download model files
32
+ let config_filename = repo
33
+ .get("config.json")
34
+ .await
35
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
36
+
37
+ let tokenizer_filename = repo
38
+ .get("tokenizer.json")
39
+ .await
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
41
+
42
+ // Try different file patterns for model weights
43
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
44
+ vec![single_file]
45
+ } else {
46
+ // Try to find sharded model files
47
+ let mut sharded_files = Vec::new();
48
+ let mut index = 1;
49
+ loop {
50
+ // Try common shard counts for Gemma models
51
+ let mut found = false;
52
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 12] {
53
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
54
+ if let Ok(file) = repo.get(&filename).await {
55
+ sharded_files.push(file);
56
+ found = true;
57
+ break;
58
+ }
59
+ }
60
+ if !found {
61
+ break;
62
+ }
63
+ index += 1;
64
+ }
65
+
66
+ if sharded_files.is_empty() {
67
+ return Err(candle_core::Error::Msg(
68
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
69
+ ));
70
+ }
71
+ sharded_files
72
+ };
73
+
74
+ // Load config
75
+ let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
76
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
77
+
78
+ // Load tokenizer
79
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
80
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
81
+
82
+ // Gemma uses specific tokens
83
+ let eos_token_id = {
84
+ let vocab = tokenizer.get_vocab(true);
85
+ vocab.get("<eos>")
86
+ .or_else(|| vocab.get("<end_of_turn>"))
87
+ .copied()
88
+ .unwrap_or(1) // Default Gemma EOS
89
+ };
90
+
91
+ // Load model weights
92
+ let vb = unsafe {
93
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
94
+ };
95
+
96
+ let model = GemmaModel::new(false, &config, vb)?; // Don't use flash attention for now
97
+
98
+ Ok(Self {
99
+ model,
100
+ tokenizer: TokenizerWrapper::new(tokenizer),
101
+ device,
102
+ model_id: model_id.to_string(),
103
+ eos_token_id,
104
+ })
105
+ }
106
+
107
+ /// Create from existing components (useful for testing)
108
+ pub fn new(
109
+ model: GemmaModel,
110
+ tokenizer: Tokenizer,
111
+ device: Device,
112
+ model_id: String,
113
+ ) -> Self {
114
+ let eos_token_id = {
115
+ let vocab = tokenizer.get_vocab(true);
116
+ vocab.get("<eos>")
117
+ .or_else(|| vocab.get("<end_of_turn>"))
118
+ .copied()
119
+ .unwrap_or(1)
120
+ };
121
+
122
+ Self {
123
+ model,
124
+ tokenizer: TokenizerWrapper::new(tokenizer),
125
+ device,
126
+ model_id,
127
+ eos_token_id,
128
+ }
129
+ }
130
+
131
+ fn generate_tokens(
132
+ &mut self,
133
+ prompt_tokens: Vec<u32>,
134
+ config: &GenerationConfig,
135
+ mut callback: Option<impl FnMut(&str)>,
136
+ ) -> CandleResult<Vec<u32>> {
137
+ let mut text_gen = TextGeneration::from_config(config);
138
+ text_gen.set_eos_token_id(self.eos_token_id);
139
+ text_gen.set_tokens(prompt_tokens.clone());
140
+
141
+ let mut all_tokens = prompt_tokens.clone();
142
+ let start_gen = all_tokens.len();
143
+
144
+ for index in 0..config.max_length {
145
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
146
+ let start_pos = all_tokens.len().saturating_sub(context_size);
147
+ let ctxt = &all_tokens[start_pos..];
148
+
149
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
150
+ let input = input.contiguous()?;
151
+ let logits = self.model.forward(&input, start_pos)?;
152
+
153
+ let logits = logits.squeeze(0)?;
154
+ let logits = if logits.dims().len() == 2 {
155
+ let seq_len = logits.dim(0)?;
156
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
157
+ } else {
158
+ logits
159
+ };
160
+
161
+ let logits = logits.to_dtype(DType::F32)?;
162
+
163
+ let next_token = text_gen.sample_next_token(
164
+ &logits,
165
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
166
+ )?;
167
+
168
+ all_tokens.push(next_token);
169
+
170
+ // Stream callback
171
+ if let Some(ref mut cb) = callback {
172
+ let token_text = self.tokenizer.token_to_piece(next_token)?;
173
+ cb(&token_text);
174
+ }
175
+
176
+ // Check stop conditions
177
+ if text_gen.should_stop(next_token, config.max_length) {
178
+ break;
179
+ }
180
+
181
+ // Check stop sequences
182
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
183
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
184
+ break;
185
+ }
186
+ }
187
+
188
+ Ok(if config.include_prompt {
189
+ all_tokens
190
+ } else {
191
+ all_tokens[start_gen..].to_vec()
192
+ })
193
+ }
194
+
195
+ fn generate_tokens_decoded(
196
+ &mut self,
197
+ prompt_tokens: Vec<u32>,
198
+ config: &GenerationConfig,
199
+ mut callback: Option<impl FnMut(&str)>,
200
+ ) -> CandleResult<Vec<u32>> {
201
+ let mut text_gen = TextGeneration::from_config(config);
202
+ text_gen.set_eos_token_id(self.eos_token_id);
203
+ text_gen.set_tokens(prompt_tokens.clone());
204
+
205
+ let mut all_tokens = prompt_tokens.clone();
206
+ let start_gen = all_tokens.len();
207
+ let mut previously_decoded = String::new();
208
+
209
+ for index in 0..config.max_length {
210
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
211
+ let start_pos = all_tokens.len().saturating_sub(context_size);
212
+ let ctxt = &all_tokens[start_pos..];
213
+
214
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
215
+ let input = input.contiguous()?;
216
+ let logits = self.model.forward(&input, start_pos)?;
217
+
218
+ let logits = logits.squeeze(0)?;
219
+ let logits = if logits.dims().len() == 2 {
220
+ let seq_len = logits.dim(0)?;
221
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
222
+ } else {
223
+ logits
224
+ };
225
+
226
+ let logits = logits.to_dtype(DType::F32)?;
227
+
228
+ let next_token = text_gen.sample_next_token(
229
+ &logits,
230
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
231
+ )?;
232
+
233
+ all_tokens.push(next_token);
234
+
235
+ // Stream callback with incremental decoding
236
+ if let Some(ref mut cb) = callback {
237
+ let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
238
+
239
+ if current_decoded.len() > previously_decoded.len() {
240
+ let new_text = &current_decoded[previously_decoded.len()..];
241
+ cb(new_text);
242
+ previously_decoded = current_decoded;
243
+ }
244
+ }
245
+
246
+ // Check stop conditions
247
+ if text_gen.should_stop(next_token, config.max_length) {
248
+ break;
249
+ }
250
+
251
+ // Check stop sequences
252
+ let generated_text = if callback.is_some() {
253
+ previously_decoded.clone()
254
+ } else {
255
+ self.tokenizer.decode(&all_tokens[start_gen..], true)?
256
+ };
257
+
258
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
259
+ break;
260
+ }
261
+ }
262
+
263
+ Ok(if config.include_prompt {
264
+ all_tokens
265
+ } else {
266
+ all_tokens[start_gen..].to_vec()
267
+ })
268
+ }
269
+
270
+ /// Apply Gemma chat template
271
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
272
+ let mut prompt = String::new();
273
+
274
+ // Gemma uses a specific format:
275
+ // <start_of_turn>user\n{user_message}<end_of_turn>
276
+ // <start_of_turn>model\n{model_message}<end_of_turn>
277
+
278
+ for message in messages {
279
+ let role = message["role"].as_str().unwrap_or("");
280
+ let content = message["content"].as_str().unwrap_or("");
281
+
282
+ match role {
283
+ "system" => {
284
+ // Gemma doesn't have explicit system messages, prepend to first user message
285
+ prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
286
+ }
287
+ "user" => {
288
+ if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
289
+ prompt.push_str("<start_of_turn>user\n");
290
+ }
291
+ prompt.push_str(&format!("{}<end_of_turn>\n", content));
292
+ }
293
+ "assistant" | "model" => {
294
+ prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
295
+ }
296
+ _ => {}
297
+ }
298
+ }
299
+
300
+ // Add the model prompt
301
+ prompt.push_str("<start_of_turn>model\n");
302
+
303
+ Ok(prompt)
304
+ }
305
+ }
306
+
307
+ impl TextGenerator for Gemma {
308
+ fn generate(
309
+ &mut self,
310
+ prompt: &str,
311
+ config: &GenerationConfig,
312
+ ) -> CandleResult<String> {
313
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
314
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
315
+ self.tokenizer.decode(&output_tokens, true)
316
+ }
317
+
318
+ fn generate_stream(
319
+ &mut self,
320
+ prompt: &str,
321
+ config: &GenerationConfig,
322
+ mut callback: impl FnMut(&str),
323
+ ) -> CandleResult<String> {
324
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
325
+ let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
326
+ self.tokenizer.decode(&output_tokens, true)
327
+ }
328
+
329
+ fn model_name(&self) -> &str {
330
+ &self.model_id
331
+ }
332
+
333
+ fn device(&self) -> &Device {
334
+ &self.device
335
+ }
336
+
337
+ fn clear_cache(&mut self) {
338
+ self.clear_kv_cache();
339
+ }
340
+ }
@@ -3,6 +3,7 @@ use tokenizers::Tokenizer;
3
3
 
4
4
  pub mod mistral;
5
5
  pub mod llama;
6
+ pub mod gemma;
6
7
  pub mod generation_config;
7
8
  pub mod text_generation;
8
9
 
@@ -1,7 +1,7 @@
1
1
  use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
2
  use std::cell::RefCell;
3
3
 
4
- use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama};
4
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma};
5
5
  use crate::ruby::{Result as RbResult, Device as RbDevice};
6
6
 
7
7
  // Use an enum to handle different model types instead of trait objects
@@ -9,6 +9,7 @@ use crate::ruby::{Result as RbResult, Device as RbDevice};
9
9
  enum ModelType {
10
10
  Mistral(RustMistral),
11
11
  Llama(RustLlama),
12
+ Gemma(RustGemma),
12
13
  }
13
14
 
14
15
  impl ModelType {
@@ -16,6 +17,7 @@ impl ModelType {
16
17
  match self {
17
18
  ModelType::Mistral(m) => m.generate(prompt, config),
18
19
  ModelType::Llama(m) => m.generate(prompt, config),
20
+ ModelType::Gemma(m) => m.generate(prompt, config),
19
21
  }
20
22
  }
21
23
 
@@ -28,6 +30,7 @@ impl ModelType {
28
30
  match self {
29
31
  ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
30
32
  ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
33
+ ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
31
34
  }
32
35
  }
33
36
 
@@ -36,6 +39,7 @@ impl ModelType {
36
39
  match self {
37
40
  ModelType::Mistral(m) => m.model_name(),
38
41
  ModelType::Llama(m) => m.model_name(),
42
+ ModelType::Gemma(m) => m.model_name(),
39
43
  }
40
44
  }
41
45
 
@@ -43,6 +47,7 @@ impl ModelType {
43
47
  match self {
44
48
  ModelType::Mistral(m) => m.clear_cache(),
45
49
  ModelType::Llama(m) => m.clear_cache(),
50
+ ModelType::Gemma(m) => m.clear_cache(),
46
51
  }
47
52
  }
48
53
 
@@ -66,6 +71,7 @@ impl ModelType {
66
71
  Ok(prompt)
67
72
  },
68
73
  ModelType::Llama(m) => m.apply_chat_template(messages),
74
+ ModelType::Gemma(m) => m.apply_chat_template(messages),
69
75
  }
70
76
  }
71
77
  }
@@ -208,16 +214,22 @@ impl LLM {
208
214
  })
209
215
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
210
216
  ModelType::Mistral(mistral)
211
- } else if model_lower.contains("llama") || model_lower.contains("meta-llama") {
217
+ } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
212
218
  let llama = rt.block_on(async {
213
219
  RustLlama::from_pretrained(&model_id, candle_device).await
214
220
  })
215
221
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
216
222
  ModelType::Llama(llama)
223
+ } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
224
+ let gemma = rt.block_on(async {
225
+ RustGemma::from_pretrained(&model_id, candle_device).await
226
+ })
227
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
228
+ ModelType::Gemma(gemma)
217
229
  } else {
218
230
  return Err(Error::new(
219
231
  magnus::exception::runtime_error(),
220
- format!("Unsupported model type: {}. Currently only Mistral and Llama models are supported.", model_id),
232
+ format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
221
233
  ));
222
234
  };
223
235
 
@@ -1,3 +1,3 @@
1
1
  module Candle
2
- VERSION = "1.0.0.pre.5"
2
+ VERSION = "1.0.0.pre.6"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.0.pre.5
4
+ version: 1.0.0.pre.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-07-09 00:00:00.000000000 Z
12
+ date: 2025-07-10 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -47,6 +47,7 @@ files:
47
47
  - ext/candle/extconf.rb
48
48
  - ext/candle/rustfmt.toml
49
49
  - ext/candle/src/lib.rs
50
+ - ext/candle/src/llm/gemma.rs
50
51
  - ext/candle/src/llm/generation_config.rs
51
52
  - ext/candle/src/llm/llama.rs
52
53
  - ext/candle/src/llm/mistral.rs