red-candle 1.2.3 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/README.md CHANGED
@@ -1,4 +1,4 @@
1
- <img src="/docs/assets/logo-title.png" alt="red-candle" height="80px">
1
+ <img src="/docs/assets/logo-title.png" alt="red-candle" height="160px">
2
2
 
3
3
  [![build](https://github.com/scientist-labs/red-candle/actions/workflows/build.yml/badge.svg)](https://github.com/scientist-labs/red-candle/actions/workflows/build.yml)
4
4
  [![Gem Version](https://badge.fury.io/rb/red-candle.svg)](https://badge.fury.io/rb/red-candle)
@@ -12,8 +12,8 @@ crate-type = ["cdylib"]
12
12
  candle-core = { version = "0.9.1" }
13
13
  candle-nn = { version = "0.9.1" }
14
14
  candle-transformers = { version = "0.9.1" }
15
- tokenizers = { version = "0.21.1", default-features = true, features = ["fancy-regex"] }
16
- hf-hub = "0.4.3"
15
+ tokenizers = { version = "0.22.0", default-features = true, features = ["fancy-regex"] }
16
+ hf-hub = "0.4.1"
17
17
  half = "2.6.0"
18
18
  magnus = "0.7.1"
19
19
  safetensors = "0.3"
@@ -21,7 +21,7 @@ serde_json = "1.0"
21
21
  serde = { version = "1.0", features = ["derive"] }
22
22
  tokio = { version = "1.45", features = ["rt", "macros"] }
23
23
  rand = "0.8"
24
- outlines-core = "0.2"
24
+ outlines-core = "0.2.11"
25
25
 
26
26
  [features]
27
27
  default = []
@@ -30,8 +30,8 @@ impl Gemma {
30
30
  &self.tokenizer
31
31
  }
32
32
 
33
- /// Load a Gemma model from HuggingFace Hub
34
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ /// Load a Gemma model from HuggingFace Hub with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
35
  let api = Api::new()
36
36
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
37
 
@@ -43,10 +43,23 @@ impl Gemma {
43
43
  .await
44
44
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
45
45
 
46
- let tokenizer_filename = repo
47
- .get("tokenizer.json")
48
- .await
49
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
46
+ // Download tokenizer from custom source if provided, otherwise from model repo
47
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
48
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
49
+ let tokenizer_filename = tokenizer_repo
50
+ .get("tokenizer.json")
51
+ .await
52
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
53
+ Tokenizer::from_file(tokenizer_filename)
54
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
55
+ } else {
56
+ let tokenizer_filename = repo
57
+ .get("tokenizer.json")
58
+ .await
59
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
60
+ Tokenizer::from_file(tokenizer_filename)
61
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
62
+ };
50
63
 
51
64
  // Try different file patterns for model weights
52
65
  let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
@@ -87,9 +100,6 @@ impl Gemma {
87
100
  let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
88
101
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
89
102
 
90
- // Load tokenizer
91
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
92
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
93
103
 
94
104
  // Gemma uses specific tokens
95
105
  let eos_token_id = {
@@ -116,6 +126,11 @@ impl Gemma {
116
126
  })
117
127
  }
118
128
 
129
+ /// Load a Gemma model from HuggingFace Hub (backwards compatibility)
130
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
131
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
132
+ }
133
+
119
134
  /// Create from existing components (useful for testing)
120
135
  pub fn new(
121
136
  model: GemmaModel,
@@ -37,8 +37,8 @@ impl Llama {
37
37
  &self.tokenizer
38
38
  }
39
39
 
40
- /// Load a Llama model from HuggingFace Hub
41
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
40
+ /// Load a Llama model from HuggingFace Hub with optional custom tokenizer
41
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
42
42
  let api = Api::new()
43
43
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
44
44
 
@@ -50,10 +50,45 @@ impl Llama {
50
50
  .await
51
51
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
52
52
 
53
- let tokenizer_filename = repo
54
- .get("tokenizer.json")
55
- .await
56
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
53
+ // Download tokenizer from custom source if provided, otherwise from model repo
54
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
55
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
56
+ let tokenizer_filename = tokenizer_repo
57
+ .get("tokenizer.json")
58
+ .await
59
+ .map_err(|e| {
60
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
61
+ format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
62
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
63
+ format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
64
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
65
+ format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
66
+ } else {
67
+ format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
68
+ };
69
+ candle_core::Error::Msg(error_msg)
70
+ })?;
71
+ Tokenizer::from_file(tokenizer_filename)
72
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
73
+ } else {
74
+ let tokenizer_filename = repo
75
+ .get("tokenizer.json")
76
+ .await
77
+ .map_err(|e| {
78
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
79
+ format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'appropriate-tokenizer-repo')", model_id, model_id)
80
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
81
+ format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
82
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
83
+ format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
84
+ } else {
85
+ format!("Failed to download tokenizer: {}", e)
86
+ };
87
+ candle_core::Error::Msg(error_msg)
88
+ })?;
89
+ Tokenizer::from_file(tokenizer_filename)
90
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
91
+ };
57
92
 
58
93
  // Try different file patterns for model weights
59
94
  let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
@@ -97,10 +132,6 @@ impl Llama {
97
132
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
98
133
  let config = llama_config.into_config(false); // Don't use flash attention for now
99
134
 
100
- // Load tokenizer
101
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
102
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
103
-
104
135
  // Determine EOS token ID based on model type
105
136
  let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
106
137
  // Llama 3 uses different special tokens
@@ -139,6 +170,11 @@ impl Llama {
139
170
  })
140
171
  }
141
172
 
173
+ /// Load a Llama model from HuggingFace Hub (backwards compatibility)
174
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
175
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
176
+ }
177
+
142
178
  /// Create from existing components (useful for testing)
143
179
  pub fn new(
144
180
  model: LlamaModel,
@@ -30,8 +30,8 @@ impl Mistral {
30
30
  &self.tokenizer
31
31
  }
32
32
 
33
- /// Load a Mistral model from HuggingFace Hub
34
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ /// Load a Mistral model from HuggingFace Hub with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
35
  let api = Api::new()
36
36
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
37
 
@@ -43,10 +43,45 @@ impl Mistral {
43
43
  .await
44
44
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
45
45
 
46
- let tokenizer_filename = repo
47
- .get("tokenizer.json")
48
- .await
49
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
46
+ // Download tokenizer from custom source if provided, otherwise from model repo
47
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
48
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
49
+ let tokenizer_filename = tokenizer_repo
50
+ .get("tokenizer.json")
51
+ .await
52
+ .map_err(|e| {
53
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
54
+ format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
55
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
56
+ format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
57
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
58
+ format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
59
+ } else {
60
+ format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
61
+ };
62
+ candle_core::Error::Msg(error_msg)
63
+ })?;
64
+ Tokenizer::from_file(tokenizer_filename)
65
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
66
+ } else {
67
+ let tokenizer_filename = repo
68
+ .get("tokenizer.json")
69
+ .await
70
+ .map_err(|e| {
71
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
72
+ format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'mistralai/Mistral-7B-Instruct-v0.2')", model_id, model_id)
73
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
74
+ format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
75
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
76
+ format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
77
+ } else {
78
+ format!("Failed to download tokenizer: {}", e)
79
+ };
80
+ candle_core::Error::Msg(error_msg)
81
+ })?;
82
+ Tokenizer::from_file(tokenizer_filename)
83
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
84
+ };
50
85
 
51
86
  // Try different file patterns for model weights
52
87
  let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
@@ -97,10 +132,6 @@ impl Mistral {
97
132
  let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
98
133
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
99
134
 
100
- // Load tokenizer
101
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
102
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
103
-
104
135
  let eos_token_id = tokenizer
105
136
  .get_vocab(true)
106
137
  .get("</s>")
@@ -123,6 +154,11 @@ impl Mistral {
123
154
  })
124
155
  }
125
156
 
157
+ /// Load a Mistral model from HuggingFace Hub (backwards compatibility)
158
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
159
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
160
+ }
161
+
126
162
  /// Create from existing components (useful for testing)
127
163
  pub fn new(
128
164
  model: MistralModel,
@@ -38,8 +38,8 @@ impl Phi {
38
38
  }
39
39
  }
40
40
 
41
- /// Load a Phi model from HuggingFace
42
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
41
+ /// Load a Phi model from HuggingFace with optional custom tokenizer
42
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
43
43
  let api = Api::new()
44
44
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
45
45
 
@@ -50,11 +50,19 @@ impl Phi {
50
50
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
51
51
  let config_str = std::fs::read_to_string(config_filename)?;
52
52
 
53
- // Download tokenizer
54
- let tokenizer_filename = repo.get("tokenizer.json").await
55
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
56
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
57
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
53
+ // Download tokenizer from custom source if provided, otherwise from model repo
54
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
55
+ let tokenizer_repo = api.model(tokenizer_id.to_string());
56
+ let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
57
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
58
+ Tokenizer::from_file(tokenizer_filename)
59
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
60
+ } else {
61
+ let tokenizer_filename = repo.get("tokenizer.json").await
62
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
63
+ Tokenizer::from_file(tokenizer_filename)
64
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
65
+ };
58
66
 
59
67
  // Determine EOS token
60
68
  let vocab = tokenizer.get_vocab(true);
@@ -104,7 +112,62 @@ impl Phi {
104
112
 
105
113
  let model = if is_phi3 {
106
114
  // Load Phi3 model
107
- let config: Phi3Config = serde_json::from_str(&config_str)
115
+ // Handle config differences between Phi-3-small and Phi-3-mini
116
+ let mut config_str_fixed;
117
+
118
+ // Parse config as JSON for modifications
119
+ let mut config_json: serde_json::Value = serde_json::from_str(&config_str)
120
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
121
+
122
+ // Phi-3-small uses ff_intermediate_size instead of intermediate_size
123
+ if config_json.get("ff_intermediate_size").is_some() && config_json.get("intermediate_size").is_none() {
124
+ if let Some(ff_size) = config_json.get("ff_intermediate_size").cloned() {
125
+ config_json["intermediate_size"] = ff_size;
126
+ }
127
+ }
128
+
129
+ // Phi-3-small uses layer_norm_epsilon instead of rms_norm_eps
130
+ if config_json.get("layer_norm_epsilon").is_some() && config_json.get("rms_norm_eps").is_none() {
131
+ if let Some(eps) = config_json.get("layer_norm_epsilon").cloned() {
132
+ config_json["rms_norm_eps"] = eps;
133
+ }
134
+ }
135
+
136
+ // Handle rope_scaling for long context models (Phi-3-mini-128k)
137
+ // Candle expects rope_scaling to be a string, but newer configs have it as an object
138
+ if let Some(rope_scaling) = config_json.get("rope_scaling") {
139
+ if rope_scaling.is_object() {
140
+ // For now, just convert to the type string - candle will use default scaling
141
+ if let Some(scaling_type) = rope_scaling.get("type").and_then(|v| v.as_str()) {
142
+ config_json["rope_scaling"] = serde_json::Value::String(scaling_type.to_string());
143
+ } else {
144
+ // Remove it if we can't determine the type
145
+ config_json.as_object_mut().unwrap().remove("rope_scaling");
146
+ }
147
+ }
148
+ }
149
+
150
+ // Phi-3-small uses rope_embedding_base instead of rope_theta
151
+ if config_json.get("rope_embedding_base").is_some() && config_json.get("rope_theta").is_none() {
152
+ if let Some(rope_base) = config_json.get("rope_embedding_base").cloned() {
153
+ config_json["rope_theta"] = rope_base;
154
+ }
155
+ }
156
+
157
+ config_str_fixed = serde_json::to_string(&config_json)
158
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to serialize config: {}", e)))?;
159
+
160
+ // Check for unsupported gegelu activation
161
+ if config_str_fixed.contains("\"gegelu\"") {
162
+ // For now, map gegelu to gelu_pytorch_tanh with a warning
163
+ // This is not ideal but allows the model to at least load
164
+ eprintln!("WARNING: This model uses 'gegelu' activation which is not fully supported.");
165
+ eprintln!(" Mapping to 'gelu_pytorch_tanh' - results may be degraded.");
166
+ eprintln!(" For best results, use Phi-3-mini models instead.");
167
+ config_str_fixed = config_str_fixed.replace("\"gegelu\"", "\"gelu_pytorch_tanh\"");
168
+ }
169
+
170
+ let config: Phi3Config = serde_json::from_str(&config_str_fixed)
108
171
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
109
172
 
110
173
  let vb = unsafe {
@@ -134,6 +197,11 @@ impl Phi {
134
197
  eos_token_id,
135
198
  })
136
199
  }
200
+
201
+ /// Load a Phi model from HuggingFace (backwards compatibility)
202
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
203
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
204
+ }
137
205
 
138
206
  /// Apply Phi chat template to messages
139
207
  pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
@@ -30,8 +30,8 @@ impl Qwen {
30
30
  self.model.clear_kv_cache();
31
31
  }
32
32
 
33
- /// Load a Qwen model from HuggingFace
34
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ /// Load a Qwen model from HuggingFace with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
35
  let api = Api::new()
36
36
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
37
 
@@ -44,19 +44,27 @@ impl Qwen {
44
44
  let config: Config = serde_json::from_str(&config_str)
45
45
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
46
46
 
47
- // Download tokenizer
48
- let tokenizer_filename = repo.get("tokenizer.json").await
49
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
50
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
51
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
47
+ // Download tokenizer from custom source if provided, otherwise from model repo
48
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
49
+ let tokenizer_repo = api.model(tokenizer_id.to_string());
50
+ let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
51
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
52
+ Tokenizer::from_file(tokenizer_filename)
53
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
54
+ } else {
55
+ let tokenizer_filename = repo.get("tokenizer.json").await
56
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
57
+ Tokenizer::from_file(tokenizer_filename)
58
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
59
+ };
52
60
 
53
61
  // Determine EOS token
54
62
  let vocab = tokenizer.get_vocab(true);
55
- let eos_token_id = vocab.get("<|endoftext|>")
56
- .or_else(|| vocab.get("<|im_end|>"))
63
+ let eos_token_id = vocab.get("<|im_end|>")
64
+ .or_else(|| vocab.get("<|endoftext|>"))
57
65
  .or_else(|| vocab.get("</s>"))
58
66
  .copied()
59
- .unwrap_or(151643); // Default Qwen3 EOS token
67
+ .unwrap_or(151645); // Default Qwen2.5 EOS token
60
68
 
61
69
  // Download model weights
62
70
  // NOTE: Qwen uses hardcoded shard counts based on model size rather than
@@ -97,6 +105,11 @@ impl Qwen {
97
105
  })
98
106
  }
99
107
 
108
+ /// Load a Qwen model from HuggingFace (backwards compatibility)
109
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
110
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
111
+ }
112
+
100
113
  /// Apply Qwen chat template to messages
101
114
  pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
102
115
  let mut prompt = String::new();
@@ -257,14 +257,15 @@ impl LLM {
257
257
  let model_lower = model_id.to_lowercase();
258
258
  let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
259
259
 
260
+ // Extract tokenizer source if provided in model_id (for both GGUF and regular models)
261
+ let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
262
+ let (id, _tok) = model_id.split_at(pos);
263
+ (id.to_string(), Some(&model_id[pos+2..]))
264
+ } else {
265
+ (model_id.clone(), None)
266
+ };
267
+
260
268
  let model = if is_quantized {
261
- // Extract tokenizer source if provided in model_id
262
- let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
263
- let (id, _tok) = model_id.split_at(pos);
264
- (id.to_string(), Some(&model_id[pos+2..]))
265
- } else {
266
- (model_id.clone(), None)
267
- };
268
269
 
269
270
  // Use unified GGUF loader for all quantized models
270
271
  let gguf_model = rt.block_on(async {
@@ -273,41 +274,73 @@ impl LLM {
273
274
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
274
275
  ModelType::QuantizedGGUF(gguf_model)
275
276
  } else {
276
- // Load non-quantized models
277
- if model_lower.contains("mistral") {
278
- let mistral = rt.block_on(async {
279
- RustMistral::from_pretrained(&model_id, candle_device).await
280
- })
277
+ // Load non-quantized models based on type
278
+ let model_lower_clean = model_id_clean.to_lowercase();
279
+
280
+ if model_lower_clean.contains("mistral") {
281
+ let mistral = if tokenizer_source.is_some() {
282
+ rt.block_on(async {
283
+ RustMistral::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
284
+ })
285
+ } else {
286
+ rt.block_on(async {
287
+ RustMistral::from_pretrained(&model_id_clean, candle_device).await
288
+ })
289
+ }
281
290
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
282
291
  ModelType::Mistral(mistral)
283
- } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
284
- let llama = rt.block_on(async {
285
- RustLlama::from_pretrained(&model_id, candle_device).await
286
- })
292
+ } else if model_lower_clean.contains("llama") || model_lower_clean.contains("meta-llama") || model_lower_clean.contains("tinyllama") {
293
+ let llama = if tokenizer_source.is_some() {
294
+ rt.block_on(async {
295
+ RustLlama::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
296
+ })
297
+ } else {
298
+ rt.block_on(async {
299
+ RustLlama::from_pretrained(&model_id_clean, candle_device).await
300
+ })
301
+ }
287
302
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
288
303
  ModelType::Llama(llama)
289
- } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
290
- let gemma = rt.block_on(async {
291
- RustGemma::from_pretrained(&model_id, candle_device).await
292
- })
304
+ } else if model_lower_clean.contains("gemma") || model_lower_clean.contains("google/gemma") {
305
+ let gemma = if tokenizer_source.is_some() {
306
+ rt.block_on(async {
307
+ RustGemma::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
308
+ })
309
+ } else {
310
+ rt.block_on(async {
311
+ RustGemma::from_pretrained(&model_id_clean, candle_device).await
312
+ })
313
+ }
293
314
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
294
315
  ModelType::Gemma(gemma)
295
- } else if model_lower.contains("qwen") {
296
- let qwen = rt.block_on(async {
297
- RustQwen::from_pretrained(&model_id, candle_device).await
298
- })
316
+ } else if model_lower_clean.contains("qwen") {
317
+ let qwen = if tokenizer_source.is_some() {
318
+ rt.block_on(async {
319
+ RustQwen::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
320
+ })
321
+ } else {
322
+ rt.block_on(async {
323
+ RustQwen::from_pretrained(&model_id_clean, candle_device).await
324
+ })
325
+ }
299
326
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
300
327
  ModelType::Qwen(qwen)
301
- } else if model_lower.contains("phi") {
302
- let phi = rt.block_on(async {
303
- RustPhi::from_pretrained(&model_id, candle_device).await
304
- })
328
+ } else if model_lower_clean.contains("phi") {
329
+ let phi = if tokenizer_source.is_some() {
330
+ rt.block_on(async {
331
+ RustPhi::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
332
+ })
333
+ } else {
334
+ rt.block_on(async {
335
+ RustPhi::from_pretrained(&model_id_clean, candle_device).await
336
+ })
337
+ }
305
338
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
306
339
  ModelType::Phi(phi)
307
340
  } else {
308
341
  return Err(Error::new(
309
342
  magnus::exception::runtime_error(),
310
- format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id),
343
+ format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id_clean),
311
344
  ));
312
345
  }
313
346
  };
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.2.3"
3
+ VERSION = "1.3.0"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.3
4
+ version: 1.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-09-07 00:00:00.000000000 Z
12
+ date: 2025-09-13 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -151,7 +151,9 @@ dependencies:
151
151
  - - "~>"
152
152
  - !ruby/object:Gem::Version
153
153
  version: '3.13'
154
- description: huggingface/candle for Ruby
154
+ description: Ruby gem for running state-of-the-art language models locally. Access
155
+ LLMs, embeddings, rerankers, and NER models directly from Ruby using Rust-powered
156
+ Candle with Metal/CUDA acceleration.
155
157
  email:
156
158
  - chris@petersen.io
157
159
  - 2xijok@gmail.com
@@ -204,12 +206,6 @@ files:
204
206
  - ext/candle/src/structured/vocabulary_adapter_simple_test.rs
205
207
  - ext/candle/src/tokenizer/loader.rs
206
208
  - ext/candle/src/tokenizer/mod.rs
207
- - ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt
208
- - ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs
209
- - ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs
210
- - ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs
211
- - ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs
212
- - ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs
213
209
  - ext/candle/tests/device_tests.rs
214
210
  - ext/candle/tests/tensor_tests.rs
215
211
  - lib/candle.rb
@@ -237,16 +233,18 @@ required_ruby_version: !ruby/object:Gem::Requirement
237
233
  requirements:
238
234
  - - ">="
239
235
  - !ruby/object:Gem::Version
240
- version: 3.2.0
236
+ version: 3.1.0
241
237
  required_rubygems_version: !ruby/object:Gem::Requirement
242
238
  requirements:
243
239
  - - ">="
244
240
  - !ruby/object:Gem::Version
245
- version: 3.3.26
241
+ version: '3.3'
246
242
  requirements:
247
243
  - Rust >= 1.85
248
- rubygems_version: 3.5.3
244
+ rubygems_version: 3.3.3
249
245
  signing_key:
250
246
  specification_version: 4
251
- summary: huggingface/candle for Ruby
247
+ summary: Ruby gem for running state-of-the-art language models locally. Access LLMs,
248
+ embeddings, rerankers, and NER models directly from Ruby using Rust-powered Candle
249
+ with Metal/CUDA acceleration.
252
250
  test_files: []