red-candle 1.2.3 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +460 -379
- data/README.md +1 -1
- data/ext/candle/Cargo.toml +3 -3
- data/ext/candle/src/llm/constrained_generation_test.rs +79 -0
- data/ext/candle/src/llm/gemma.rs +24 -9
- data/ext/candle/src/llm/llama.rs +46 -10
- data/ext/candle/src/llm/mistral.rs +46 -10
- data/ext/candle/src/llm/phi.rs +76 -8
- data/ext/candle/src/llm/qwen.rs +23 -10
- data/ext/candle/src/llm/text_generation.rs +40 -50
- data/ext/candle/src/ruby/llm.rs +62 -29
- data/ext/candle/src/ruby/structured.rs +54 -10
- data/lib/candle/llm.rb +77 -3
- data/lib/candle/version.rb +1 -1
- metadata +11 -13
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +0 -1
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +0 -355
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +0 -276
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +0 -49
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +0 -2748
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +0 -8902
data/README.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
<img src="/docs/assets/logo-title.png" alt="red-candle" height="
|
|
1
|
+
<img src="/docs/assets/logo-title.png" alt="red-candle" height="160px">
|
|
2
2
|
|
|
3
3
|
[](https://github.com/scientist-labs/red-candle/actions/workflows/build.yml)
|
|
4
4
|
[](https://badge.fury.io/rb/red-candle)
|
data/ext/candle/Cargo.toml
CHANGED
|
@@ -12,8 +12,8 @@ crate-type = ["cdylib"]
|
|
|
12
12
|
candle-core = { version = "0.9.1" }
|
|
13
13
|
candle-nn = { version = "0.9.1" }
|
|
14
14
|
candle-transformers = { version = "0.9.1" }
|
|
15
|
-
tokenizers = { version = "0.
|
|
16
|
-
hf-hub = "0.4.
|
|
15
|
+
tokenizers = { version = "0.22.0", default-features = true, features = ["fancy-regex"] }
|
|
16
|
+
hf-hub = "0.4.1"
|
|
17
17
|
half = "2.6.0"
|
|
18
18
|
magnus = "0.7.1"
|
|
19
19
|
safetensors = "0.3"
|
|
@@ -21,7 +21,7 @@ serde_json = "1.0"
|
|
|
21
21
|
serde = { version = "1.0", features = ["derive"] }
|
|
22
22
|
tokio = { version = "1.45", features = ["rt", "macros"] }
|
|
23
23
|
rand = "0.8"
|
|
24
|
-
outlines-core = "0.2"
|
|
24
|
+
outlines-core = "0.2.11"
|
|
25
25
|
|
|
26
26
|
[features]
|
|
27
27
|
default = []
|
|
@@ -313,4 +313,83 @@ mod constrained_generation_tests {
|
|
|
313
313
|
// Verify tokens are being tracked
|
|
314
314
|
assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
|
|
315
315
|
}
|
|
316
|
+
|
|
317
|
+
#[test]
|
|
318
|
+
fn test_constraint_satisfied_not_triggered_by_large_allowed_set() {
|
|
319
|
+
// This test verifies the fix for the bug where is_constraint_satisfied_stop_on_match
|
|
320
|
+
// would incorrectly return true when many tokens are allowed (e.g., inside a JSON string).
|
|
321
|
+
// The old buggy code had: if allowed.len() > 1000 { return true; }
|
|
322
|
+
// This caused early termination when inside strings with many valid characters.
|
|
323
|
+
|
|
324
|
+
let config = GenerationConfig::default();
|
|
325
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
326
|
+
text_gen.set_eos_token_id(50256);
|
|
327
|
+
|
|
328
|
+
// Without a constraint, should not be satisfied
|
|
329
|
+
assert!(!text_gen.is_constraint_satisfied(),
|
|
330
|
+
"Without constraint, should not be satisfied");
|
|
331
|
+
assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
|
|
332
|
+
"Without constraint, stop_on_match should not be satisfied");
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
#[test]
|
|
336
|
+
fn test_constraint_satisfied_only_when_empty_or_eos_only() {
|
|
337
|
+
// Test that constraint satisfaction only triggers when:
|
|
338
|
+
// 1. No tokens are allowed (empty set)
|
|
339
|
+
// 2. Only EOS token is allowed
|
|
340
|
+
// NOT when many tokens are allowed (like inside a JSON string)
|
|
341
|
+
|
|
342
|
+
let config = GenerationConfig::default();
|
|
343
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
344
|
+
text_gen.set_eos_token_id(100); // Set EOS token
|
|
345
|
+
|
|
346
|
+
// Without constraint, should not be satisfied
|
|
347
|
+
assert!(!text_gen.is_constraint_satisfied());
|
|
348
|
+
assert!(!text_gen.is_constraint_satisfied_stop_on_match());
|
|
349
|
+
|
|
350
|
+
// The key insight: constraint satisfaction should NOT be triggered
|
|
351
|
+
// just because there are many allowed tokens. It should only trigger
|
|
352
|
+
// when the constraint is definitively complete (empty allowed set or only EOS).
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
#[tokio::test]
|
|
356
|
+
async fn test_constraint_with_json_schema_not_early_termination() {
|
|
357
|
+
// Integration test: Create a real JSON schema constraint and verify
|
|
358
|
+
// that being inside a string (many allowed tokens) doesn't trigger completion.
|
|
359
|
+
|
|
360
|
+
if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
|
|
361
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
|
362
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
|
363
|
+
.expect("Should create vocabulary");
|
|
364
|
+
|
|
365
|
+
let processor = SchemaProcessor::new();
|
|
366
|
+
|
|
367
|
+
// Schema with a string field - when generating content inside the string,
|
|
368
|
+
// many characters are valid, but the constraint is NOT complete
|
|
369
|
+
let schema = r#"{
|
|
370
|
+
"type": "object",
|
|
371
|
+
"properties": {
|
|
372
|
+
"name": { "type": "string" }
|
|
373
|
+
},
|
|
374
|
+
"required": ["name"]
|
|
375
|
+
}"#;
|
|
376
|
+
|
|
377
|
+
let index = processor.process_schema(schema, &vocabulary)
|
|
378
|
+
.expect("Should process schema");
|
|
379
|
+
|
|
380
|
+
let mut config = GenerationConfig::default();
|
|
381
|
+
config.constraint = Some(index);
|
|
382
|
+
config.max_length = 100;
|
|
383
|
+
|
|
384
|
+
let mut text_gen = TextGeneration::new(&config);
|
|
385
|
+
text_gen.set_eos_token_id(102); // BERT's [SEP]
|
|
386
|
+
|
|
387
|
+
// At the initial state, the constraint should NOT be satisfied
|
|
388
|
+
// (we haven't generated a complete JSON object yet)
|
|
389
|
+
assert!(!text_gen.is_constraint_satisfied(),
|
|
390
|
+
"Initial state should not be satisfied - JSON not yet generated");
|
|
391
|
+
assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
|
|
392
|
+
"Initial state should not trigger stop_on_match");
|
|
393
|
+
}
|
|
394
|
+
}
|
|
316
395
|
}
|
data/ext/candle/src/llm/gemma.rs
CHANGED
|
@@ -30,8 +30,8 @@ impl Gemma {
|
|
|
30
30
|
&self.tokenizer
|
|
31
31
|
}
|
|
32
32
|
|
|
33
|
-
/// Load a Gemma model from HuggingFace Hub
|
|
34
|
-
pub async fn
|
|
33
|
+
/// Load a Gemma model from HuggingFace Hub with optional custom tokenizer
|
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
35
|
let api = Api::new()
|
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
37
|
|
|
@@ -43,10 +43,23 @@ impl Gemma {
|
|
|
43
43
|
.await
|
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
.
|
|
49
|
-
|
|
46
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
47
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
48
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
|
49
|
+
let tokenizer_filename = tokenizer_repo
|
|
50
|
+
.get("tokenizer.json")
|
|
51
|
+
.await
|
|
52
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
53
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
54
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
55
|
+
} else {
|
|
56
|
+
let tokenizer_filename = repo
|
|
57
|
+
.get("tokenizer.json")
|
|
58
|
+
.await
|
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
60
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
61
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
62
|
+
};
|
|
50
63
|
|
|
51
64
|
// Try different file patterns for model weights
|
|
52
65
|
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
@@ -87,9 +100,6 @@ impl Gemma {
|
|
|
87
100
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
|
88
101
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
89
102
|
|
|
90
|
-
// Load tokenizer
|
|
91
|
-
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
|
92
|
-
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
|
93
103
|
|
|
94
104
|
// Gemma uses specific tokens
|
|
95
105
|
let eos_token_id = {
|
|
@@ -116,6 +126,11 @@ impl Gemma {
|
|
|
116
126
|
})
|
|
117
127
|
}
|
|
118
128
|
|
|
129
|
+
/// Load a Gemma model from HuggingFace Hub (backwards compatibility)
|
|
130
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
131
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
132
|
+
}
|
|
133
|
+
|
|
119
134
|
/// Create from existing components (useful for testing)
|
|
120
135
|
pub fn new(
|
|
121
136
|
model: GemmaModel,
|
data/ext/candle/src/llm/llama.rs
CHANGED
|
@@ -37,8 +37,8 @@ impl Llama {
|
|
|
37
37
|
&self.tokenizer
|
|
38
38
|
}
|
|
39
39
|
|
|
40
|
-
/// Load a Llama model from HuggingFace Hub
|
|
41
|
-
pub async fn
|
|
40
|
+
/// Load a Llama model from HuggingFace Hub with optional custom tokenizer
|
|
41
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
42
42
|
let api = Api::new()
|
|
43
43
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
44
44
|
|
|
@@ -50,10 +50,45 @@ impl Llama {
|
|
|
50
50
|
.await
|
|
51
51
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
52
52
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
.
|
|
56
|
-
|
|
53
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
54
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
55
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
|
56
|
+
let tokenizer_filename = tokenizer_repo
|
|
57
|
+
.get("tokenizer.json")
|
|
58
|
+
.await
|
|
59
|
+
.map_err(|e| {
|
|
60
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
|
61
|
+
format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
|
|
62
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
|
63
|
+
format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
|
|
64
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
|
65
|
+
format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
|
|
66
|
+
} else {
|
|
67
|
+
format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
|
|
68
|
+
};
|
|
69
|
+
candle_core::Error::Msg(error_msg)
|
|
70
|
+
})?;
|
|
71
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
72
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
|
73
|
+
} else {
|
|
74
|
+
let tokenizer_filename = repo
|
|
75
|
+
.get("tokenizer.json")
|
|
76
|
+
.await
|
|
77
|
+
.map_err(|e| {
|
|
78
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
|
79
|
+
format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'appropriate-tokenizer-repo')", model_id, model_id)
|
|
80
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
|
81
|
+
format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
|
|
82
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
|
83
|
+
format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
|
|
84
|
+
} else {
|
|
85
|
+
format!("Failed to download tokenizer: {}", e)
|
|
86
|
+
};
|
|
87
|
+
candle_core::Error::Msg(error_msg)
|
|
88
|
+
})?;
|
|
89
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
90
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
|
91
|
+
};
|
|
57
92
|
|
|
58
93
|
// Try different file patterns for model weights
|
|
59
94
|
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
@@ -97,10 +132,6 @@ impl Llama {
|
|
|
97
132
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
98
133
|
let config = llama_config.into_config(false); // Don't use flash attention for now
|
|
99
134
|
|
|
100
|
-
// Load tokenizer
|
|
101
|
-
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
|
102
|
-
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
|
103
|
-
|
|
104
135
|
// Determine EOS token ID based on model type
|
|
105
136
|
let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
|
|
106
137
|
// Llama 3 uses different special tokens
|
|
@@ -139,6 +170,11 @@ impl Llama {
|
|
|
139
170
|
})
|
|
140
171
|
}
|
|
141
172
|
|
|
173
|
+
/// Load a Llama model from HuggingFace Hub (backwards compatibility)
|
|
174
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
175
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
176
|
+
}
|
|
177
|
+
|
|
142
178
|
/// Create from existing components (useful for testing)
|
|
143
179
|
pub fn new(
|
|
144
180
|
model: LlamaModel,
|
|
@@ -30,8 +30,8 @@ impl Mistral {
|
|
|
30
30
|
&self.tokenizer
|
|
31
31
|
}
|
|
32
32
|
|
|
33
|
-
/// Load a Mistral model from HuggingFace Hub
|
|
34
|
-
pub async fn
|
|
33
|
+
/// Load a Mistral model from HuggingFace Hub with optional custom tokenizer
|
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
35
|
let api = Api::new()
|
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
37
|
|
|
@@ -43,10 +43,45 @@ impl Mistral {
|
|
|
43
43
|
.await
|
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
.
|
|
49
|
-
|
|
46
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
47
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
48
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
|
49
|
+
let tokenizer_filename = tokenizer_repo
|
|
50
|
+
.get("tokenizer.json")
|
|
51
|
+
.await
|
|
52
|
+
.map_err(|e| {
|
|
53
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
|
54
|
+
format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
|
|
55
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
|
56
|
+
format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
|
|
57
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
|
58
|
+
format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
|
|
59
|
+
} else {
|
|
60
|
+
format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
|
|
61
|
+
};
|
|
62
|
+
candle_core::Error::Msg(error_msg)
|
|
63
|
+
})?;
|
|
64
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
65
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
|
66
|
+
} else {
|
|
67
|
+
let tokenizer_filename = repo
|
|
68
|
+
.get("tokenizer.json")
|
|
69
|
+
.await
|
|
70
|
+
.map_err(|e| {
|
|
71
|
+
let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
|
|
72
|
+
format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'mistralai/Mistral-7B-Instruct-v0.2')", model_id, model_id)
|
|
73
|
+
} else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
|
|
74
|
+
format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
|
|
75
|
+
} else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
|
|
76
|
+
format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
|
|
77
|
+
} else {
|
|
78
|
+
format!("Failed to download tokenizer: {}", e)
|
|
79
|
+
};
|
|
80
|
+
candle_core::Error::Msg(error_msg)
|
|
81
|
+
})?;
|
|
82
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
83
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
|
|
84
|
+
};
|
|
50
85
|
|
|
51
86
|
// Try different file patterns for model weights
|
|
52
87
|
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
@@ -97,10 +132,6 @@ impl Mistral {
|
|
|
97
132
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
|
98
133
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
99
134
|
|
|
100
|
-
// Load tokenizer
|
|
101
|
-
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
|
102
|
-
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
|
103
|
-
|
|
104
135
|
let eos_token_id = tokenizer
|
|
105
136
|
.get_vocab(true)
|
|
106
137
|
.get("</s>")
|
|
@@ -123,6 +154,11 @@ impl Mistral {
|
|
|
123
154
|
})
|
|
124
155
|
}
|
|
125
156
|
|
|
157
|
+
/// Load a Mistral model from HuggingFace Hub (backwards compatibility)
|
|
158
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
159
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
160
|
+
}
|
|
161
|
+
|
|
126
162
|
/// Create from existing components (useful for testing)
|
|
127
163
|
pub fn new(
|
|
128
164
|
model: MistralModel,
|
data/ext/candle/src/llm/phi.rs
CHANGED
|
@@ -38,8 +38,8 @@ impl Phi {
|
|
|
38
38
|
}
|
|
39
39
|
}
|
|
40
40
|
|
|
41
|
-
/// Load a Phi model from HuggingFace
|
|
42
|
-
pub async fn
|
|
41
|
+
/// Load a Phi model from HuggingFace with optional custom tokenizer
|
|
42
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
43
43
|
let api = Api::new()
|
|
44
44
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
45
45
|
|
|
@@ -50,11 +50,19 @@ impl Phi {
|
|
|
50
50
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
51
51
|
let config_str = std::fs::read_to_string(config_filename)?;
|
|
52
52
|
|
|
53
|
-
// Download tokenizer
|
|
54
|
-
let
|
|
55
|
-
.
|
|
56
|
-
|
|
57
|
-
|
|
53
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
54
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
55
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
|
56
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
|
57
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
58
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
60
|
+
} else {
|
|
61
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
|
62
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
63
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
64
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
65
|
+
};
|
|
58
66
|
|
|
59
67
|
// Determine EOS token
|
|
60
68
|
let vocab = tokenizer.get_vocab(true);
|
|
@@ -104,7 +112,62 @@ impl Phi {
|
|
|
104
112
|
|
|
105
113
|
let model = if is_phi3 {
|
|
106
114
|
// Load Phi3 model
|
|
107
|
-
|
|
115
|
+
// Handle config differences between Phi-3-small and Phi-3-mini
|
|
116
|
+
let mut config_str_fixed;
|
|
117
|
+
|
|
118
|
+
// Parse config as JSON for modifications
|
|
119
|
+
let mut config_json: serde_json::Value = serde_json::from_str(&config_str)
|
|
120
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
|
|
121
|
+
|
|
122
|
+
// Phi-3-small uses ff_intermediate_size instead of intermediate_size
|
|
123
|
+
if config_json.get("ff_intermediate_size").is_some() && config_json.get("intermediate_size").is_none() {
|
|
124
|
+
if let Some(ff_size) = config_json.get("ff_intermediate_size").cloned() {
|
|
125
|
+
config_json["intermediate_size"] = ff_size;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
// Phi-3-small uses layer_norm_epsilon instead of rms_norm_eps
|
|
130
|
+
if config_json.get("layer_norm_epsilon").is_some() && config_json.get("rms_norm_eps").is_none() {
|
|
131
|
+
if let Some(eps) = config_json.get("layer_norm_epsilon").cloned() {
|
|
132
|
+
config_json["rms_norm_eps"] = eps;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// Handle rope_scaling for long context models (Phi-3-mini-128k)
|
|
137
|
+
// Candle expects rope_scaling to be a string, but newer configs have it as an object
|
|
138
|
+
if let Some(rope_scaling) = config_json.get("rope_scaling") {
|
|
139
|
+
if rope_scaling.is_object() {
|
|
140
|
+
// For now, just convert to the type string - candle will use default scaling
|
|
141
|
+
if let Some(scaling_type) = rope_scaling.get("type").and_then(|v| v.as_str()) {
|
|
142
|
+
config_json["rope_scaling"] = serde_json::Value::String(scaling_type.to_string());
|
|
143
|
+
} else {
|
|
144
|
+
// Remove it if we can't determine the type
|
|
145
|
+
config_json.as_object_mut().unwrap().remove("rope_scaling");
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// Phi-3-small uses rope_embedding_base instead of rope_theta
|
|
151
|
+
if config_json.get("rope_embedding_base").is_some() && config_json.get("rope_theta").is_none() {
|
|
152
|
+
if let Some(rope_base) = config_json.get("rope_embedding_base").cloned() {
|
|
153
|
+
config_json["rope_theta"] = rope_base;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
config_str_fixed = serde_json::to_string(&config_json)
|
|
158
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to serialize config: {}", e)))?;
|
|
159
|
+
|
|
160
|
+
// Check for unsupported gegelu activation
|
|
161
|
+
if config_str_fixed.contains("\"gegelu\"") {
|
|
162
|
+
// For now, map gegelu to gelu_pytorch_tanh with a warning
|
|
163
|
+
// This is not ideal but allows the model to at least load
|
|
164
|
+
eprintln!("WARNING: This model uses 'gegelu' activation which is not fully supported.");
|
|
165
|
+
eprintln!(" Mapping to 'gelu_pytorch_tanh' - results may be degraded.");
|
|
166
|
+
eprintln!(" For best results, use Phi-3-mini models instead.");
|
|
167
|
+
config_str_fixed = config_str_fixed.replace("\"gegelu\"", "\"gelu_pytorch_tanh\"");
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
let config: Phi3Config = serde_json::from_str(&config_str_fixed)
|
|
108
171
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
|
|
109
172
|
|
|
110
173
|
let vb = unsafe {
|
|
@@ -134,6 +197,11 @@ impl Phi {
|
|
|
134
197
|
eos_token_id,
|
|
135
198
|
})
|
|
136
199
|
}
|
|
200
|
+
|
|
201
|
+
/// Load a Phi model from HuggingFace (backwards compatibility)
|
|
202
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
203
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
204
|
+
}
|
|
137
205
|
|
|
138
206
|
/// Apply Phi chat template to messages
|
|
139
207
|
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
data/ext/candle/src/llm/qwen.rs
CHANGED
|
@@ -30,8 +30,8 @@ impl Qwen {
|
|
|
30
30
|
self.model.clear_kv_cache();
|
|
31
31
|
}
|
|
32
32
|
|
|
33
|
-
/// Load a Qwen model from HuggingFace
|
|
34
|
-
pub async fn
|
|
33
|
+
/// Load a Qwen model from HuggingFace with optional custom tokenizer
|
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
35
|
let api = Api::new()
|
|
36
36
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
37
|
|
|
@@ -44,19 +44,27 @@ impl Qwen {
|
|
|
44
44
|
let config: Config = serde_json::from_str(&config_str)
|
|
45
45
|
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
46
46
|
|
|
47
|
-
// Download tokenizer
|
|
48
|
-
let
|
|
49
|
-
.
|
|
50
|
-
|
|
51
|
-
|
|
47
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
48
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
49
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
|
50
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
|
51
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
52
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
53
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
54
|
+
} else {
|
|
55
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
|
56
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
57
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
58
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
59
|
+
};
|
|
52
60
|
|
|
53
61
|
// Determine EOS token
|
|
54
62
|
let vocab = tokenizer.get_vocab(true);
|
|
55
|
-
let eos_token_id = vocab.get("<|
|
|
56
|
-
.or_else(|| vocab.get("<|
|
|
63
|
+
let eos_token_id = vocab.get("<|im_end|>")
|
|
64
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
|
57
65
|
.or_else(|| vocab.get("</s>"))
|
|
58
66
|
.copied()
|
|
59
|
-
.unwrap_or(
|
|
67
|
+
.unwrap_or(151645); // Default Qwen2.5 EOS token
|
|
60
68
|
|
|
61
69
|
// Download model weights
|
|
62
70
|
// NOTE: Qwen uses hardcoded shard counts based on model size rather than
|
|
@@ -97,6 +105,11 @@ impl Qwen {
|
|
|
97
105
|
})
|
|
98
106
|
}
|
|
99
107
|
|
|
108
|
+
/// Load a Qwen model from HuggingFace (backwards compatibility)
|
|
109
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
110
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
111
|
+
}
|
|
112
|
+
|
|
100
113
|
/// Apply Qwen chat template to messages
|
|
101
114
|
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
102
115
|
let mut prompt = String::new();
|