red-candle 1.2.3 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +460 -379
- data/README.md +1 -1
- data/ext/candle/Cargo.toml +3 -3
- data/ext/candle/src/llm/gemma.rs +24 -9
- data/ext/candle/src/llm/llama.rs +46 -10
- data/ext/candle/src/llm/mistral.rs +46 -10
- data/ext/candle/src/llm/phi.rs +76 -8
- data/ext/candle/src/llm/qwen.rs +23 -10
- data/ext/candle/src/ruby/llm.rs +62 -29
- data/lib/candle/version.rb +1 -1
- metadata +11 -13
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +0 -1
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +0 -355
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +0 -276
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +0 -49
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +0 -2748
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +0 -8902
data/README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
<img src="/docs/assets/logo-title.png" alt="red-candle" height="
|
1
|
+
<img src="/docs/assets/logo-title.png" alt="red-candle" height="160px">
|
2
2
|
|
3
3
|
[](https://github.com/scientist-labs/red-candle/actions/workflows/build.yml)
|
4
4
|
[](https://badge.fury.io/rb/red-candle)
|
data/ext/candle/Cargo.toml
CHANGED
@@ -12,8 +12,8 @@ crate-type = ["cdylib"]
|
|
12
12
|
candle-core = { version = "0.9.1" }
|
13
13
|
candle-nn = { version = "0.9.1" }
|
14
14
|
candle-transformers = { version = "0.9.1" }
|
15
|
-
tokenizers = { version = "0.
|
16
|
-
hf-hub = "0.4.
|
15
|
+
tokenizers = { version = "0.22.0", default-features = true, features = ["fancy-regex"] }
|
16
|
+
hf-hub = "0.4.1"
|
17
17
|
half = "2.6.0"
|
18
18
|
magnus = "0.7.1"
|
19
19
|
safetensors = "0.3"
|
@@ -21,7 +21,7 @@ serde_json = "1.0"
|
|
21
21
|
serde = { version = "1.0", features = ["derive"] }
|
22
22
|
tokio = { version = "1.45", features = ["rt", "macros"] }
|
23
23
|
rand = "0.8"
|
24
|
-
outlines-core = "0.2"
|
24
|
+
outlines-core = "0.2.11"
|
25
25
|
|
26
26
|
[features]
|
27
27
|
default = []
|
data/ext/candle/src/llm/gemma.rs
CHANGED
@@ -30,8 +30,8 @@ impl Gemma {
|
|
30
30
|
&self.tokenizer
|
31
31
|
}
|
32
32
|
|
33
|
-
/// Load a Gemma model from HuggingFace Hub
|
34
|
-
pub async fn
|
33
|
+
/// Load a Gemma model from HuggingFace Hub with optional custom tokenizer
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
35
35
|
let api = Api::new()
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
37
37
|
|
@@ -43,10 +43,23 @@ impl Gemma {
|
|
43
43
|
.await
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
45
45
|
|
46
|
-
|
47
|
-
|
48
|
-
.
|
49
|
-
|
46
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
47
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
48
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
49
|
+
let tokenizer_filename = tokenizer_repo
|
50
|
+
.get("tokenizer.json")
|
51
|
+
.await
|
52
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
53
|
+
Tokenizer::from_file(tokenizer_filename)
|
54
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
55
|
+
} else {
|
56
|
+
let tokenizer_filename = repo
|
57
|
+
.get("tokenizer.json")
|
58
|
+
.await
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
60
|
+
Tokenizer::from_file(tokenizer_filename)
|
61
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
62
|
+
};
|
50
63
|
|
51
64
|
// Try different file patterns for model weights
|
52
65
|
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
@@ -87,9 +100,6 @@ impl Gemma {
|
|
87
100
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
88
101
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
89
102
|
|
90
|
-
// Load tokenizer
|
91
|
-
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
92
|
-
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
93
103
|
|
94
104
|
// Gemma uses specific tokens
|
95
105
|
let eos_token_id = {
|
@@ -116,6 +126,11 @@ impl Gemma {
|
|
116
126
|
})
|
117
127
|
}
|
118
128
|
|
129
|
+
/// Load a Gemma model from HuggingFace Hub (backwards compatibility)
|
130
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
131
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
132
|
+
}
|
133
|
+
|
119
134
|
/// Create from existing components (useful for testing)
|
120
135
|
pub fn new(
|
121
136
|
model: GemmaModel,
|
data/ext/candle/src/llm/llama.rs
CHANGED
@@ -37,8 +37,8 @@ impl Llama {
|
|
37
37
|
&self.tokenizer
|
38
38
|
}
|
39
39
|
|
40
|
-
/// Load a Llama model from HuggingFace Hub
|
41
|
-
pub async fn
|
40
|
+
/// Load a Llama model from HuggingFace Hub with optional custom tokenizer
|
41
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
42
42
|
let api = Api::new()
|
43
43
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
44
44
|
|
@@ -50,10 +50,45 @@ impl Llama {
|
|
50
50
|
.await
|
51
51
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
52
52
|
|
53
|
-
|
54
|
-
|
55
|
-
.
|
56
|
-
|
53
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
54
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
55
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
56
|
+
let tokenizer_filename = tokenizer_repo
|
57
|
+
.get("tokenizer.json")
|
58
|
+
.await
|
59
|
+
.map_err(|e| {
|
60
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
61
|
+
format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
|
62
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
63
|
+
format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
|
64
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
65
|
+
format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
|
66
|
+
} else {
|
67
|
+
format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
|
68
|
+
};
|
69
|
+
candle_core::Error::Msg(error_msg)
|
70
|
+
})?;
|
71
|
+
Tokenizer::from_file(tokenizer_filename)
|
72
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
73
|
+
} else {
|
74
|
+
let tokenizer_filename = repo
|
75
|
+
.get("tokenizer.json")
|
76
|
+
.await
|
77
|
+
.map_err(|e| {
|
78
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
79
|
+
format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'appropriate-tokenizer-repo')", model_id, model_id)
|
80
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
81
|
+
format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
|
82
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
83
|
+
format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
|
84
|
+
} else {
|
85
|
+
format!("Failed to download tokenizer: {}", e)
|
86
|
+
};
|
87
|
+
candle_core::Error::Msg(error_msg)
|
88
|
+
})?;
|
89
|
+
Tokenizer::from_file(tokenizer_filename)
|
90
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
91
|
+
};
|
57
92
|
|
58
93
|
// Try different file patterns for model weights
|
59
94
|
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
@@ -97,10 +132,6 @@ impl Llama {
|
|
97
132
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
98
133
|
let config = llama_config.into_config(false); // Don't use flash attention for now
|
99
134
|
|
100
|
-
// Load tokenizer
|
101
|
-
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
102
|
-
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
103
|
-
|
104
135
|
// Determine EOS token ID based on model type
|
105
136
|
let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
|
106
137
|
// Llama 3 uses different special tokens
|
@@ -139,6 +170,11 @@ impl Llama {
|
|
139
170
|
})
|
140
171
|
}
|
141
172
|
|
173
|
+
/// Load a Llama model from HuggingFace Hub (backwards compatibility)
|
174
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
175
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
176
|
+
}
|
177
|
+
|
142
178
|
/// Create from existing components (useful for testing)
|
143
179
|
pub fn new(
|
144
180
|
model: LlamaModel,
|
@@ -30,8 +30,8 @@ impl Mistral {
|
|
30
30
|
&self.tokenizer
|
31
31
|
}
|
32
32
|
|
33
|
-
/// Load a Mistral model from HuggingFace Hub
|
34
|
-
pub async fn
|
33
|
+
/// Load a Mistral model from HuggingFace Hub with optional custom tokenizer
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
35
35
|
let api = Api::new()
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
37
37
|
|
@@ -43,10 +43,45 @@ impl Mistral {
|
|
43
43
|
.await
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
45
45
|
|
46
|
-
|
47
|
-
|
48
|
-
.
|
49
|
-
|
46
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
47
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
48
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
49
|
+
let tokenizer_filename = tokenizer_repo
|
50
|
+
.get("tokenizer.json")
|
51
|
+
.await
|
52
|
+
.map_err(|e| {
|
53
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
54
|
+
format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
|
55
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
56
|
+
format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
|
57
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
58
|
+
format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
|
59
|
+
} else {
|
60
|
+
format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
|
61
|
+
};
|
62
|
+
candle_core::Error::Msg(error_msg)
|
63
|
+
})?;
|
64
|
+
Tokenizer::from_file(tokenizer_filename)
|
65
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
66
|
+
} else {
|
67
|
+
let tokenizer_filename = repo
|
68
|
+
.get("tokenizer.json")
|
69
|
+
.await
|
70
|
+
.map_err(|e| {
|
71
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
72
|
+
format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'mistralai/Mistral-7B-Instruct-v0.2')", model_id, model_id)
|
73
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
74
|
+
format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
|
75
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
76
|
+
format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
|
77
|
+
} else {
|
78
|
+
format!("Failed to download tokenizer: {}", e)
|
79
|
+
};
|
80
|
+
candle_core::Error::Msg(error_msg)
|
81
|
+
})?;
|
82
|
+
Tokenizer::from_file(tokenizer_filename)
|
83
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
84
|
+
};
|
50
85
|
|
51
86
|
// Try different file patterns for model weights
|
52
87
|
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
@@ -97,10 +132,6 @@ impl Mistral {
|
|
97
132
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
98
133
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
99
134
|
|
100
|
-
// Load tokenizer
|
101
|
-
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
102
|
-
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
103
|
-
|
104
135
|
let eos_token_id = tokenizer
|
105
136
|
.get_vocab(true)
|
106
137
|
.get("</s>")
|
@@ -123,6 +154,11 @@ impl Mistral {
|
|
123
154
|
})
|
124
155
|
}
|
125
156
|
|
157
|
+
/// Load a Mistral model from HuggingFace Hub (backwards compatibility)
|
158
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
159
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
160
|
+
}
|
161
|
+
|
126
162
|
/// Create from existing components (useful for testing)
|
127
163
|
pub fn new(
|
128
164
|
model: MistralModel,
|
data/ext/candle/src/llm/phi.rs
CHANGED
@@ -38,8 +38,8 @@ impl Phi {
|
|
38
38
|
}
|
39
39
|
}
|
40
40
|
|
41
|
-
/// Load a Phi model from HuggingFace
|
42
|
-
pub async fn
|
41
|
+
/// Load a Phi model from HuggingFace with optional custom tokenizer
|
42
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
43
43
|
let api = Api::new()
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
45
45
|
|
@@ -50,11 +50,19 @@ impl Phi {
|
|
50
50
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
51
51
|
let config_str = std::fs::read_to_string(config_filename)?;
|
52
52
|
|
53
|
-
// Download tokenizer
|
54
|
-
let
|
55
|
-
.
|
56
|
-
|
57
|
-
|
53
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
54
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
55
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
56
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
57
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
58
|
+
Tokenizer::from_file(tokenizer_filename)
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
60
|
+
} else {
|
61
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
62
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
63
|
+
Tokenizer::from_file(tokenizer_filename)
|
64
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
65
|
+
};
|
58
66
|
|
59
67
|
// Determine EOS token
|
60
68
|
let vocab = tokenizer.get_vocab(true);
|
@@ -104,7 +112,62 @@ impl Phi {
|
|
104
112
|
|
105
113
|
let model = if is_phi3 {
|
106
114
|
// Load Phi3 model
|
107
|
-
|
115
|
+
// Handle config differences between Phi-3-small and Phi-3-mini
|
116
|
+
let mut config_str_fixed;
|
117
|
+
|
118
|
+
// Parse config as JSON for modifications
|
119
|
+
let mut config_json: serde_json::Value = serde_json::from_str(&config_str)
|
120
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
|
121
|
+
|
122
|
+
// Phi-3-small uses ff_intermediate_size instead of intermediate_size
|
123
|
+
if config_json.get("ff_intermediate_size").is_some() && config_json.get("intermediate_size").is_none() {
|
124
|
+
if let Some(ff_size) = config_json.get("ff_intermediate_size").cloned() {
|
125
|
+
config_json["intermediate_size"] = ff_size;
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
// Phi-3-small uses layer_norm_epsilon instead of rms_norm_eps
|
130
|
+
if config_json.get("layer_norm_epsilon").is_some() && config_json.get("rms_norm_eps").is_none() {
|
131
|
+
if let Some(eps) = config_json.get("layer_norm_epsilon").cloned() {
|
132
|
+
config_json["rms_norm_eps"] = eps;
|
133
|
+
}
|
134
|
+
}
|
135
|
+
|
136
|
+
// Handle rope_scaling for long context models (Phi-3-mini-128k)
|
137
|
+
// Candle expects rope_scaling to be a string, but newer configs have it as an object
|
138
|
+
if let Some(rope_scaling) = config_json.get("rope_scaling") {
|
139
|
+
if rope_scaling.is_object() {
|
140
|
+
// For now, just convert to the type string - candle will use default scaling
|
141
|
+
if let Some(scaling_type) = rope_scaling.get("type").and_then(|v| v.as_str()) {
|
142
|
+
config_json["rope_scaling"] = serde_json::Value::String(scaling_type.to_string());
|
143
|
+
} else {
|
144
|
+
// Remove it if we can't determine the type
|
145
|
+
config_json.as_object_mut().unwrap().remove("rope_scaling");
|
146
|
+
}
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
// Phi-3-small uses rope_embedding_base instead of rope_theta
|
151
|
+
if config_json.get("rope_embedding_base").is_some() && config_json.get("rope_theta").is_none() {
|
152
|
+
if let Some(rope_base) = config_json.get("rope_embedding_base").cloned() {
|
153
|
+
config_json["rope_theta"] = rope_base;
|
154
|
+
}
|
155
|
+
}
|
156
|
+
|
157
|
+
config_str_fixed = serde_json::to_string(&config_json)
|
158
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to serialize config: {}", e)))?;
|
159
|
+
|
160
|
+
// Check for unsupported gegelu activation
|
161
|
+
if config_str_fixed.contains("\"gegelu\"") {
|
162
|
+
// For now, map gegelu to gelu_pytorch_tanh with a warning
|
163
|
+
// This is not ideal but allows the model to at least load
|
164
|
+
eprintln!("WARNING: This model uses 'gegelu' activation which is not fully supported.");
|
165
|
+
eprintln!(" Mapping to 'gelu_pytorch_tanh' - results may be degraded.");
|
166
|
+
eprintln!(" For best results, use Phi-3-mini models instead.");
|
167
|
+
config_str_fixed = config_str_fixed.replace("\"gegelu\"", "\"gelu_pytorch_tanh\"");
|
168
|
+
}
|
169
|
+
|
170
|
+
let config: Phi3Config = serde_json::from_str(&config_str_fixed)
|
108
171
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
|
109
172
|
|
110
173
|
let vb = unsafe {
|
@@ -134,6 +197,11 @@ impl Phi {
|
|
134
197
|
eos_token_id,
|
135
198
|
})
|
136
199
|
}
|
200
|
+
|
201
|
+
/// Load a Phi model from HuggingFace (backwards compatibility)
|
202
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
203
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
204
|
+
}
|
137
205
|
|
138
206
|
/// Apply Phi chat template to messages
|
139
207
|
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
data/ext/candle/src/llm/qwen.rs
CHANGED
@@ -30,8 +30,8 @@ impl Qwen {
|
|
30
30
|
self.model.clear_kv_cache();
|
31
31
|
}
|
32
32
|
|
33
|
-
/// Load a Qwen model from HuggingFace
|
34
|
-
pub async fn
|
33
|
+
/// Load a Qwen model from HuggingFace with optional custom tokenizer
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
35
35
|
let api = Api::new()
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
37
37
|
|
@@ -44,19 +44,27 @@ impl Qwen {
|
|
44
44
|
let config: Config = serde_json::from_str(&config_str)
|
45
45
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
46
46
|
|
47
|
-
// Download tokenizer
|
48
|
-
let
|
49
|
-
.
|
50
|
-
|
51
|
-
|
47
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
48
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
49
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
50
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
51
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
52
|
+
Tokenizer::from_file(tokenizer_filename)
|
53
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
54
|
+
} else {
|
55
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
56
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
57
|
+
Tokenizer::from_file(tokenizer_filename)
|
58
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
59
|
+
};
|
52
60
|
|
53
61
|
// Determine EOS token
|
54
62
|
let vocab = tokenizer.get_vocab(true);
|
55
|
-
let eos_token_id = vocab.get("<|
|
56
|
-
.or_else(|| vocab.get("<|
|
63
|
+
let eos_token_id = vocab.get("<|im_end|>")
|
64
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
57
65
|
.or_else(|| vocab.get("</s>"))
|
58
66
|
.copied()
|
59
|
-
.unwrap_or(
|
67
|
+
.unwrap_or(151645); // Default Qwen2.5 EOS token
|
60
68
|
|
61
69
|
// Download model weights
|
62
70
|
// NOTE: Qwen uses hardcoded shard counts based on model size rather than
|
@@ -97,6 +105,11 @@ impl Qwen {
|
|
97
105
|
})
|
98
106
|
}
|
99
107
|
|
108
|
+
/// Load a Qwen model from HuggingFace (backwards compatibility)
|
109
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
110
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
111
|
+
}
|
112
|
+
|
100
113
|
/// Apply Qwen chat template to messages
|
101
114
|
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
102
115
|
let mut prompt = String::new();
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -257,14 +257,15 @@ impl LLM {
|
|
257
257
|
let model_lower = model_id.to_lowercase();
|
258
258
|
let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
|
259
259
|
|
260
|
+
// Extract tokenizer source if provided in model_id (for both GGUF and regular models)
|
261
|
+
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
262
|
+
let (id, _tok) = model_id.split_at(pos);
|
263
|
+
(id.to_string(), Some(&model_id[pos+2..]))
|
264
|
+
} else {
|
265
|
+
(model_id.clone(), None)
|
266
|
+
};
|
267
|
+
|
260
268
|
let model = if is_quantized {
|
261
|
-
// Extract tokenizer source if provided in model_id
|
262
|
-
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
263
|
-
let (id, _tok) = model_id.split_at(pos);
|
264
|
-
(id.to_string(), Some(&model_id[pos+2..]))
|
265
|
-
} else {
|
266
|
-
(model_id.clone(), None)
|
267
|
-
};
|
268
269
|
|
269
270
|
// Use unified GGUF loader for all quantized models
|
270
271
|
let gguf_model = rt.block_on(async {
|
@@ -273,41 +274,73 @@ impl LLM {
|
|
273
274
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
|
274
275
|
ModelType::QuantizedGGUF(gguf_model)
|
275
276
|
} else {
|
276
|
-
// Load non-quantized models
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
277
|
+
// Load non-quantized models based on type
|
278
|
+
let model_lower_clean = model_id_clean.to_lowercase();
|
279
|
+
|
280
|
+
if model_lower_clean.contains("mistral") {
|
281
|
+
let mistral = if tokenizer_source.is_some() {
|
282
|
+
rt.block_on(async {
|
283
|
+
RustMistral::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
284
|
+
})
|
285
|
+
} else {
|
286
|
+
rt.block_on(async {
|
287
|
+
RustMistral::from_pretrained(&model_id_clean, candle_device).await
|
288
|
+
})
|
289
|
+
}
|
281
290
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
282
291
|
ModelType::Mistral(mistral)
|
283
|
-
} else if
|
284
|
-
let llama =
|
285
|
-
|
286
|
-
|
292
|
+
} else if model_lower_clean.contains("llama") || model_lower_clean.contains("meta-llama") || model_lower_clean.contains("tinyllama") {
|
293
|
+
let llama = if tokenizer_source.is_some() {
|
294
|
+
rt.block_on(async {
|
295
|
+
RustLlama::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
296
|
+
})
|
297
|
+
} else {
|
298
|
+
rt.block_on(async {
|
299
|
+
RustLlama::from_pretrained(&model_id_clean, candle_device).await
|
300
|
+
})
|
301
|
+
}
|
287
302
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
288
303
|
ModelType::Llama(llama)
|
289
|
-
} else if
|
290
|
-
let gemma =
|
291
|
-
|
292
|
-
|
304
|
+
} else if model_lower_clean.contains("gemma") || model_lower_clean.contains("google/gemma") {
|
305
|
+
let gemma = if tokenizer_source.is_some() {
|
306
|
+
rt.block_on(async {
|
307
|
+
RustGemma::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
308
|
+
})
|
309
|
+
} else {
|
310
|
+
rt.block_on(async {
|
311
|
+
RustGemma::from_pretrained(&model_id_clean, candle_device).await
|
312
|
+
})
|
313
|
+
}
|
293
314
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
294
315
|
ModelType::Gemma(gemma)
|
295
|
-
} else if
|
296
|
-
let qwen =
|
297
|
-
|
298
|
-
|
316
|
+
} else if model_lower_clean.contains("qwen") {
|
317
|
+
let qwen = if tokenizer_source.is_some() {
|
318
|
+
rt.block_on(async {
|
319
|
+
RustQwen::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
320
|
+
})
|
321
|
+
} else {
|
322
|
+
rt.block_on(async {
|
323
|
+
RustQwen::from_pretrained(&model_id_clean, candle_device).await
|
324
|
+
})
|
325
|
+
}
|
299
326
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
300
327
|
ModelType::Qwen(qwen)
|
301
|
-
} else if
|
302
|
-
let phi =
|
303
|
-
|
304
|
-
|
328
|
+
} else if model_lower_clean.contains("phi") {
|
329
|
+
let phi = if tokenizer_source.is_some() {
|
330
|
+
rt.block_on(async {
|
331
|
+
RustPhi::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
332
|
+
})
|
333
|
+
} else {
|
334
|
+
rt.block_on(async {
|
335
|
+
RustPhi::from_pretrained(&model_id_clean, candle_device).await
|
336
|
+
})
|
337
|
+
}
|
305
338
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
306
339
|
ModelType::Phi(phi)
|
307
340
|
} else {
|
308
341
|
return Err(Error::new(
|
309
342
|
magnus::exception::runtime_error(),
|
310
|
-
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.",
|
343
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id_clean),
|
311
344
|
));
|
312
345
|
}
|
313
346
|
};
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.
|
4
|
+
version: 1.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -9,7 +9,7 @@ authors:
|
|
9
9
|
autorequire:
|
10
10
|
bindir: bin
|
11
11
|
cert_chain: []
|
12
|
-
date: 2025-09-
|
12
|
+
date: 2025-09-13 00:00:00.000000000 Z
|
13
13
|
dependencies:
|
14
14
|
- !ruby/object:Gem::Dependency
|
15
15
|
name: rb_sys
|
@@ -151,7 +151,9 @@ dependencies:
|
|
151
151
|
- - "~>"
|
152
152
|
- !ruby/object:Gem::Version
|
153
153
|
version: '3.13'
|
154
|
-
description:
|
154
|
+
description: Ruby gem for running state-of-the-art language models locally. Access
|
155
|
+
LLMs, embeddings, rerankers, and NER models directly from Ruby using Rust-powered
|
156
|
+
Candle with Metal/CUDA acceleration.
|
155
157
|
email:
|
156
158
|
- chris@petersen.io
|
157
159
|
- 2xijok@gmail.com
|
@@ -204,12 +206,6 @@ files:
|
|
204
206
|
- ext/candle/src/structured/vocabulary_adapter_simple_test.rs
|
205
207
|
- ext/candle/src/tokenizer/loader.rs
|
206
208
|
- ext/candle/src/tokenizer/mod.rs
|
207
|
-
- ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt
|
208
|
-
- ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs
|
209
|
-
- ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs
|
210
|
-
- ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs
|
211
|
-
- ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs
|
212
|
-
- ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs
|
213
209
|
- ext/candle/tests/device_tests.rs
|
214
210
|
- ext/candle/tests/tensor_tests.rs
|
215
211
|
- lib/candle.rb
|
@@ -237,16 +233,18 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
237
233
|
requirements:
|
238
234
|
- - ">="
|
239
235
|
- !ruby/object:Gem::Version
|
240
|
-
version: 3.
|
236
|
+
version: 3.1.0
|
241
237
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
242
238
|
requirements:
|
243
239
|
- - ">="
|
244
240
|
- !ruby/object:Gem::Version
|
245
|
-
version: 3.3
|
241
|
+
version: '3.3'
|
246
242
|
requirements:
|
247
243
|
- Rust >= 1.85
|
248
|
-
rubygems_version: 3.
|
244
|
+
rubygems_version: 3.3.3
|
249
245
|
signing_key:
|
250
246
|
specification_version: 4
|
251
|
-
summary:
|
247
|
+
summary: Ruby gem for running state-of-the-art language models locally. Access LLMs,
|
248
|
+
embeddings, rerankers, and NER models directly from Ruby using Rust-powered Candle
|
249
|
+
with Metal/CUDA acceleration.
|
252
250
|
test_files: []
|
@@ -1 +0,0 @@
|
|
1
|
-
aarch64-apple-darwin
|