red-candle 1.2.3 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/README.md CHANGED
@@ -1,4 +1,4 @@
1
- <img src="/docs/assets/logo-title.png" alt="red-candle" height="80px">
1
+ <img src="/docs/assets/logo-title.png" alt="red-candle" height="160px">
2
2
 
3
3
  [![build](https://github.com/scientist-labs/red-candle/actions/workflows/build.yml/badge.svg)](https://github.com/scientist-labs/red-candle/actions/workflows/build.yml)
4
4
  [![Gem Version](https://badge.fury.io/rb/red-candle.svg)](https://badge.fury.io/rb/red-candle)
@@ -12,8 +12,8 @@ crate-type = ["cdylib"]
12
12
  candle-core = { version = "0.9.1" }
13
13
  candle-nn = { version = "0.9.1" }
14
14
  candle-transformers = { version = "0.9.1" }
15
- tokenizers = { version = "0.21.1", default-features = true, features = ["fancy-regex"] }
16
- hf-hub = "0.4.3"
15
+ tokenizers = { version = "0.22.0", default-features = true, features = ["fancy-regex"] }
16
+ hf-hub = "0.4.1"
17
17
  half = "2.6.0"
18
18
  magnus = "0.7.1"
19
19
  safetensors = "0.3"
@@ -21,7 +21,7 @@ serde_json = "1.0"
21
21
  serde = { version = "1.0", features = ["derive"] }
22
22
  tokio = { version = "1.45", features = ["rt", "macros"] }
23
23
  rand = "0.8"
24
- outlines-core = "0.2"
24
+ outlines-core = "0.2.11"
25
25
 
26
26
  [features]
27
27
  default = []
@@ -313,4 +313,83 @@ mod constrained_generation_tests {
313
313
  // Verify tokens are being tracked
314
314
  assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
315
315
  }
316
+
317
+ #[test]
318
+ fn test_constraint_satisfied_not_triggered_by_large_allowed_set() {
319
+ // This test verifies the fix for the bug where is_constraint_satisfied_stop_on_match
320
+ // would incorrectly return true when many tokens are allowed (e.g., inside a JSON string).
321
+ // The old buggy code had: if allowed.len() > 1000 { return true; }
322
+ // This caused early termination when inside strings with many valid characters.
323
+
324
+ let config = GenerationConfig::default();
325
+ let mut text_gen = TextGeneration::new(&config);
326
+ text_gen.set_eos_token_id(50256);
327
+
328
+ // Without a constraint, should not be satisfied
329
+ assert!(!text_gen.is_constraint_satisfied(),
330
+ "Without constraint, should not be satisfied");
331
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
332
+ "Without constraint, stop_on_match should not be satisfied");
333
+ }
334
+
335
+ #[test]
336
+ fn test_constraint_satisfied_only_when_empty_or_eos_only() {
337
+ // Test that constraint satisfaction only triggers when:
338
+ // 1. No tokens are allowed (empty set)
339
+ // 2. Only EOS token is allowed
340
+ // NOT when many tokens are allowed (like inside a JSON string)
341
+
342
+ let config = GenerationConfig::default();
343
+ let mut text_gen = TextGeneration::new(&config);
344
+ text_gen.set_eos_token_id(100); // Set EOS token
345
+
346
+ // Without constraint, should not be satisfied
347
+ assert!(!text_gen.is_constraint_satisfied());
348
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match());
349
+
350
+ // The key insight: constraint satisfaction should NOT be triggered
351
+ // just because there are many allowed tokens. It should only trigger
352
+ // when the constraint is definitively complete (empty allowed set or only EOS).
353
+ }
354
+
355
+ #[tokio::test]
356
+ async fn test_constraint_with_json_schema_not_early_termination() {
357
+ // Integration test: Create a real JSON schema constraint and verify
358
+ // that being inside a string (many allowed tokens) doesn't trigger completion.
359
+
360
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
361
+ let wrapper = TokenizerWrapper::new(tokenizer);
362
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
363
+ .expect("Should create vocabulary");
364
+
365
+ let processor = SchemaProcessor::new();
366
+
367
+ // Schema with a string field - when generating content inside the string,
368
+ // many characters are valid, but the constraint is NOT complete
369
+ let schema = r#"{
370
+ "type": "object",
371
+ "properties": {
372
+ "name": { "type": "string" }
373
+ },
374
+ "required": ["name"]
375
+ }"#;
376
+
377
+ let index = processor.process_schema(schema, &vocabulary)
378
+ .expect("Should process schema");
379
+
380
+ let mut config = GenerationConfig::default();
381
+ config.constraint = Some(index);
382
+ config.max_length = 100;
383
+
384
+ let mut text_gen = TextGeneration::new(&config);
385
+ text_gen.set_eos_token_id(102); // BERT's [SEP]
386
+
387
+ // At the initial state, the constraint should NOT be satisfied
388
+ // (we haven't generated a complete JSON object yet)
389
+ assert!(!text_gen.is_constraint_satisfied(),
390
+ "Initial state should not be satisfied - JSON not yet generated");
391
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
392
+ "Initial state should not trigger stop_on_match");
393
+ }
394
+ }
316
395
  }
@@ -30,8 +30,8 @@ impl Gemma {
30
30
  &self.tokenizer
31
31
  }
32
32
 
33
- /// Load a Gemma model from HuggingFace Hub
34
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ /// Load a Gemma model from HuggingFace Hub with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
35
  let api = Api::new()
36
36
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
37
 
@@ -43,10 +43,23 @@ impl Gemma {
43
43
  .await
44
44
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
45
45
 
46
- let tokenizer_filename = repo
47
- .get("tokenizer.json")
48
- .await
49
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
46
+ // Download tokenizer from custom source if provided, otherwise from model repo
47
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
48
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
49
+ let tokenizer_filename = tokenizer_repo
50
+ .get("tokenizer.json")
51
+ .await
52
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
53
+ Tokenizer::from_file(tokenizer_filename)
54
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
55
+ } else {
56
+ let tokenizer_filename = repo
57
+ .get("tokenizer.json")
58
+ .await
59
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
60
+ Tokenizer::from_file(tokenizer_filename)
61
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
62
+ };
50
63
 
51
64
  // Try different file patterns for model weights
52
65
  let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
@@ -87,9 +100,6 @@ impl Gemma {
87
100
  let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
88
101
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
89
102
 
90
- // Load tokenizer
91
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
92
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
93
103
 
94
104
  // Gemma uses specific tokens
95
105
  let eos_token_id = {
@@ -116,6 +126,11 @@ impl Gemma {
116
126
  })
117
127
  }
118
128
 
129
+ /// Load a Gemma model from HuggingFace Hub (backwards compatibility)
130
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
131
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
132
+ }
133
+
119
134
  /// Create from existing components (useful for testing)
120
135
  pub fn new(
121
136
  model: GemmaModel,
@@ -37,8 +37,8 @@ impl Llama {
37
37
  &self.tokenizer
38
38
  }
39
39
 
40
- /// Load a Llama model from HuggingFace Hub
41
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
40
+ /// Load a Llama model from HuggingFace Hub with optional custom tokenizer
41
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
42
42
  let api = Api::new()
43
43
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
44
44
 
@@ -50,10 +50,45 @@ impl Llama {
50
50
  .await
51
51
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
52
52
 
53
- let tokenizer_filename = repo
54
- .get("tokenizer.json")
55
- .await
56
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
53
+ // Download tokenizer from custom source if provided, otherwise from model repo
54
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
55
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
56
+ let tokenizer_filename = tokenizer_repo
57
+ .get("tokenizer.json")
58
+ .await
59
+ .map_err(|e| {
60
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
61
+ format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
62
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
63
+ format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
64
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
65
+ format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
66
+ } else {
67
+ format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
68
+ };
69
+ candle_core::Error::Msg(error_msg)
70
+ })?;
71
+ Tokenizer::from_file(tokenizer_filename)
72
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
73
+ } else {
74
+ let tokenizer_filename = repo
75
+ .get("tokenizer.json")
76
+ .await
77
+ .map_err(|e| {
78
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
79
+ format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'appropriate-tokenizer-repo')", model_id, model_id)
80
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
81
+ format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
82
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
83
+ format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
84
+ } else {
85
+ format!("Failed to download tokenizer: {}", e)
86
+ };
87
+ candle_core::Error::Msg(error_msg)
88
+ })?;
89
+ Tokenizer::from_file(tokenizer_filename)
90
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
91
+ };
57
92
 
58
93
  // Try different file patterns for model weights
59
94
  let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
@@ -97,10 +132,6 @@ impl Llama {
97
132
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
98
133
  let config = llama_config.into_config(false); // Don't use flash attention for now
99
134
 
100
- // Load tokenizer
101
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
102
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
103
-
104
135
  // Determine EOS token ID based on model type
105
136
  let eos_token_id = if model_id.contains("Llama-3") || model_id.contains("llama-3") {
106
137
  // Llama 3 uses different special tokens
@@ -139,6 +170,11 @@ impl Llama {
139
170
  })
140
171
  }
141
172
 
173
+ /// Load a Llama model from HuggingFace Hub (backwards compatibility)
174
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
175
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
176
+ }
177
+
142
178
  /// Create from existing components (useful for testing)
143
179
  pub fn new(
144
180
  model: LlamaModel,
@@ -30,8 +30,8 @@ impl Mistral {
30
30
  &self.tokenizer
31
31
  }
32
32
 
33
- /// Load a Mistral model from HuggingFace Hub
34
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ /// Load a Mistral model from HuggingFace Hub with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
35
  let api = Api::new()
36
36
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
37
 
@@ -43,10 +43,45 @@ impl Mistral {
43
43
  .await
44
44
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
45
45
 
46
- let tokenizer_filename = repo
47
- .get("tokenizer.json")
48
- .await
49
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
46
+ // Download tokenizer from custom source if provided, otherwise from model repo
47
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
48
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
49
+ let tokenizer_filename = tokenizer_repo
50
+ .get("tokenizer.json")
51
+ .await
52
+ .map_err(|e| {
53
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
54
+ format!("Tokenizer file 'tokenizer.json' not found in repository '{}'. The repository may not have a tokenizer.json file or may use a different format (e.g., tokenizer.model for SentencePiece).", tokenizer_id)
55
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
56
+ format!("Authentication required to access tokenizer '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", tokenizer_id)
57
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
58
+ format!("Network error downloading tokenizer from '{}': {}. Please check your internet connection.", tokenizer_id, e)
59
+ } else {
60
+ format!("Failed to download tokenizer from '{}': {}", tokenizer_id, e)
61
+ };
62
+ candle_core::Error::Msg(error_msg)
63
+ })?;
64
+ Tokenizer::from_file(tokenizer_filename)
65
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
66
+ } else {
67
+ let tokenizer_filename = repo
68
+ .get("tokenizer.json")
69
+ .await
70
+ .map_err(|e| {
71
+ let error_msg = if e.to_string().contains("404") || e.to_string().contains("Not Found") {
72
+ format!("No tokenizer found in model repository '{}'. The model may not include a tokenizer. Try specifying a tokenizer explicitly using the 'tokenizer' parameter, e.g.: from_pretrained('{}', tokenizer: 'mistralai/Mistral-7B-Instruct-v0.2')", model_id, model_id)
73
+ } else if e.to_string().contains("401") || e.to_string().contains("Unauthorized") {
74
+ format!("Authentication required to access model '{}'. You may need to set HF_TOKEN environment variable with a valid Hugging Face token.", model_id)
75
+ } else if e.to_string().contains("timed out") || e.to_string().contains("connection") {
76
+ format!("Network error downloading tokenizer: {}. Please check your internet connection.", e)
77
+ } else {
78
+ format!("Failed to download tokenizer: {}", e)
79
+ };
80
+ candle_core::Error::Msg(error_msg)
81
+ })?;
82
+ Tokenizer::from_file(tokenizer_filename)
83
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer file: {}", e)))?
84
+ };
50
85
 
51
86
  // Try different file patterns for model weights
52
87
  let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
@@ -97,10 +132,6 @@ impl Mistral {
97
132
  let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
98
133
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
99
134
 
100
- // Load tokenizer
101
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
102
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
103
-
104
135
  let eos_token_id = tokenizer
105
136
  .get_vocab(true)
106
137
  .get("</s>")
@@ -123,6 +154,11 @@ impl Mistral {
123
154
  })
124
155
  }
125
156
 
157
+ /// Load a Mistral model from HuggingFace Hub (backwards compatibility)
158
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
159
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
160
+ }
161
+
126
162
  /// Create from existing components (useful for testing)
127
163
  pub fn new(
128
164
  model: MistralModel,
@@ -38,8 +38,8 @@ impl Phi {
38
38
  }
39
39
  }
40
40
 
41
- /// Load a Phi model from HuggingFace
42
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
41
+ /// Load a Phi model from HuggingFace with optional custom tokenizer
42
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
43
43
  let api = Api::new()
44
44
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
45
45
 
@@ -50,11 +50,19 @@ impl Phi {
50
50
  .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
51
51
  let config_str = std::fs::read_to_string(config_filename)?;
52
52
 
53
- // Download tokenizer
54
- let tokenizer_filename = repo.get("tokenizer.json").await
55
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
56
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
57
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
53
+ // Download tokenizer from custom source if provided, otherwise from model repo
54
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
55
+ let tokenizer_repo = api.model(tokenizer_id.to_string());
56
+ let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
57
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
58
+ Tokenizer::from_file(tokenizer_filename)
59
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
60
+ } else {
61
+ let tokenizer_filename = repo.get("tokenizer.json").await
62
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
63
+ Tokenizer::from_file(tokenizer_filename)
64
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
65
+ };
58
66
 
59
67
  // Determine EOS token
60
68
  let vocab = tokenizer.get_vocab(true);
@@ -104,7 +112,62 @@ impl Phi {
104
112
 
105
113
  let model = if is_phi3 {
106
114
  // Load Phi3 model
107
- let config: Phi3Config = serde_json::from_str(&config_str)
115
+ // Handle config differences between Phi-3-small and Phi-3-mini
116
+ let mut config_str_fixed;
117
+
118
+ // Parse config as JSON for modifications
119
+ let mut config_json: serde_json::Value = serde_json::from_str(&config_str)
120
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
121
+
122
+ // Phi-3-small uses ff_intermediate_size instead of intermediate_size
123
+ if config_json.get("ff_intermediate_size").is_some() && config_json.get("intermediate_size").is_none() {
124
+ if let Some(ff_size) = config_json.get("ff_intermediate_size").cloned() {
125
+ config_json["intermediate_size"] = ff_size;
126
+ }
127
+ }
128
+
129
+ // Phi-3-small uses layer_norm_epsilon instead of rms_norm_eps
130
+ if config_json.get("layer_norm_epsilon").is_some() && config_json.get("rms_norm_eps").is_none() {
131
+ if let Some(eps) = config_json.get("layer_norm_epsilon").cloned() {
132
+ config_json["rms_norm_eps"] = eps;
133
+ }
134
+ }
135
+
136
+ // Handle rope_scaling for long context models (Phi-3-mini-128k)
137
+ // Candle expects rope_scaling to be a string, but newer configs have it as an object
138
+ if let Some(rope_scaling) = config_json.get("rope_scaling") {
139
+ if rope_scaling.is_object() {
140
+ // For now, just convert to the type string - candle will use default scaling
141
+ if let Some(scaling_type) = rope_scaling.get("type").and_then(|v| v.as_str()) {
142
+ config_json["rope_scaling"] = serde_json::Value::String(scaling_type.to_string());
143
+ } else {
144
+ // Remove it if we can't determine the type
145
+ config_json.as_object_mut().unwrap().remove("rope_scaling");
146
+ }
147
+ }
148
+ }
149
+
150
+ // Phi-3-small uses rope_embedding_base instead of rope_theta
151
+ if config_json.get("rope_embedding_base").is_some() && config_json.get("rope_theta").is_none() {
152
+ if let Some(rope_base) = config_json.get("rope_embedding_base").cloned() {
153
+ config_json["rope_theta"] = rope_base;
154
+ }
155
+ }
156
+
157
+ config_str_fixed = serde_json::to_string(&config_json)
158
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to serialize config: {}", e)))?;
159
+
160
+ // Check for unsupported gegelu activation
161
+ if config_str_fixed.contains("\"gegelu\"") {
162
+ // For now, map gegelu to gelu_pytorch_tanh with a warning
163
+ // This is not ideal but allows the model to at least load
164
+ eprintln!("WARNING: This model uses 'gegelu' activation which is not fully supported.");
165
+ eprintln!(" Mapping to 'gelu_pytorch_tanh' - results may be degraded.");
166
+ eprintln!(" For best results, use Phi-3-mini models instead.");
167
+ config_str_fixed = config_str_fixed.replace("\"gegelu\"", "\"gelu_pytorch_tanh\"");
168
+ }
169
+
170
+ let config: Phi3Config = serde_json::from_str(&config_str_fixed)
108
171
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
109
172
 
110
173
  let vb = unsafe {
@@ -134,6 +197,11 @@ impl Phi {
134
197
  eos_token_id,
135
198
  })
136
199
  }
200
+
201
+ /// Load a Phi model from HuggingFace (backwards compatibility)
202
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
203
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
204
+ }
137
205
 
138
206
  /// Apply Phi chat template to messages
139
207
  pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
@@ -30,8 +30,8 @@ impl Qwen {
30
30
  self.model.clear_kv_cache();
31
31
  }
32
32
 
33
- /// Load a Qwen model from HuggingFace
34
- pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
+ /// Load a Qwen model from HuggingFace with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
35
  let api = Api::new()
36
36
  .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
37
 
@@ -44,19 +44,27 @@ impl Qwen {
44
44
  let config: Config = serde_json::from_str(&config_str)
45
45
  .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
46
46
 
47
- // Download tokenizer
48
- let tokenizer_filename = repo.get("tokenizer.json").await
49
- .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
50
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
51
- .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
47
+ // Download tokenizer from custom source if provided, otherwise from model repo
48
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
49
+ let tokenizer_repo = api.model(tokenizer_id.to_string());
50
+ let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
51
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
52
+ Tokenizer::from_file(tokenizer_filename)
53
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
54
+ } else {
55
+ let tokenizer_filename = repo.get("tokenizer.json").await
56
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
57
+ Tokenizer::from_file(tokenizer_filename)
58
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
59
+ };
52
60
 
53
61
  // Determine EOS token
54
62
  let vocab = tokenizer.get_vocab(true);
55
- let eos_token_id = vocab.get("<|endoftext|>")
56
- .or_else(|| vocab.get("<|im_end|>"))
63
+ let eos_token_id = vocab.get("<|im_end|>")
64
+ .or_else(|| vocab.get("<|endoftext|>"))
57
65
  .or_else(|| vocab.get("</s>"))
58
66
  .copied()
59
- .unwrap_or(151643); // Default Qwen3 EOS token
67
+ .unwrap_or(151645); // Default Qwen2.5 EOS token
60
68
 
61
69
  // Download model weights
62
70
  // NOTE: Qwen uses hardcoded shard counts based on model size rather than
@@ -97,6 +105,11 @@ impl Qwen {
97
105
  })
98
106
  }
99
107
 
108
+ /// Load a Qwen model from HuggingFace (backwards compatibility)
109
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
110
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
111
+ }
112
+
100
113
  /// Apply Qwen chat template to messages
101
114
  pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
102
115
  let mut prompt = String::new();