red-candle 1.0.2 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,6 @@
1
1
  use std::time::{SystemTime, UNIX_EPOCH};
2
+ use std::sync::Arc;
3
+ use crate::structured::Index;
2
4
 
3
5
  /// Configuration for text generation
4
6
  #[derive(Debug, Clone)]
@@ -23,6 +25,8 @@ pub struct GenerationConfig {
23
25
  pub include_prompt: bool,
24
26
  /// Whether to show raw tokens during generation (for debugging)
25
27
  pub debug_tokens: bool,
28
+ /// Optional constraint index for structured generation
29
+ pub constraint: Option<Arc<Index>>,
26
30
  }
27
31
 
28
32
  /// Generate a random seed based on current time
@@ -46,6 +50,7 @@ impl Default for GenerationConfig {
46
50
  stop_sequences: vec![],
47
51
  include_prompt: false,
48
52
  debug_tokens: false,
53
+ constraint: None,
49
54
  }
50
55
  }
51
56
  }
@@ -3,6 +3,8 @@ use candle_core::{Device, Result as CandleResult};
3
3
  pub mod mistral;
4
4
  pub mod llama;
5
5
  pub mod gemma;
6
+ pub mod qwen;
7
+ pub mod phi;
6
8
  pub mod generation_config;
7
9
  pub mod text_generation;
8
10
  pub mod quantized_gguf;
@@ -12,6 +14,9 @@ pub use text_generation::TextGeneration;
12
14
  pub use quantized_gguf::QuantizedGGUF;
13
15
  pub use crate::tokenizer::TokenizerWrapper;
14
16
 
17
+ #[cfg(test)]
18
+ mod constrained_generation_test;
19
+
15
20
  /// Trait for text generation models
16
21
  pub trait TextGenerator: Send + Sync {
17
22
  /// Generate text from a prompt
@@ -0,0 +1,285 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_transformers::models::phi::{Config, Model as PhiModel};
3
+ use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model};
4
+ use hf_hub::api::tokio::Api;
5
+ use tokenizers::Tokenizer;
6
+
7
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ /// Phi model wrapper for text generation
10
+ pub struct Phi {
11
+ model: PhiVariant,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ enum PhiVariant {
19
+ Phi2(PhiModel),
20
+ Phi3(Phi3Model),
21
+ }
22
+
23
+ impl Phi {
24
+ /// Get the tokenizer
25
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
26
+ &self.tokenizer
27
+ }
28
+
29
+ /// Clear the KV cache between generations
30
+ pub fn clear_kv_cache(&mut self) {
31
+ match &mut self.model {
32
+ PhiVariant::Phi2(model) => model.clear_kv_cache(),
33
+ PhiVariant::Phi3(model) => model.clear_kv_cache(),
34
+ }
35
+ }
36
+
37
+ /// Load a Phi model from HuggingFace
38
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
39
+ let api = Api::new()
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
41
+
42
+ let repo = api.model(model_id.to_string());
43
+
44
+ // Download configuration
45
+ let config_filename = repo.get("config.json").await
46
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
47
+ let config_str = std::fs::read_to_string(config_filename)?;
48
+
49
+ // Download tokenizer
50
+ let tokenizer_filename = repo.get("tokenizer.json").await
51
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
52
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
53
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
54
+
55
+ // Determine EOS token
56
+ let vocab = tokenizer.get_vocab(true);
57
+ let eos_token_id = vocab.get("<|endoftext|>")
58
+ .or_else(|| vocab.get("<|end|>"))
59
+ .or_else(|| vocab.get("</s>"))
60
+ .copied()
61
+ .unwrap_or(50256); // Default GPT-2 style EOS token
62
+
63
+ // Determine model variant based on model_id or config
64
+ let is_phi3 = model_id.contains("phi-3") || model_id.contains("Phi-3");
65
+
66
+ // Download model weights (handle both single and sharded files)
67
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
68
+ vec![single_file]
69
+ } else {
70
+ // Try to find sharded model files
71
+ let mut sharded_files = Vec::new();
72
+ let mut index = 1;
73
+ loop {
74
+ // Try common shard counts
75
+ let mut found = false;
76
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
77
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
78
+ if let Ok(file) = repo.get(&filename).await {
79
+ sharded_files.push(file);
80
+ found = true;
81
+ break;
82
+ }
83
+ }
84
+ if !found {
85
+ break;
86
+ }
87
+ index += 1;
88
+ }
89
+
90
+ if sharded_files.is_empty() {
91
+ return Err(candle_core::Error::Msg(
92
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
93
+ ));
94
+ }
95
+ sharded_files
96
+ };
97
+
98
+ let model = if is_phi3 {
99
+ // Load Phi3 model
100
+ let config: Phi3Config = serde_json::from_str(&config_str)
101
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
102
+
103
+ let vb = unsafe {
104
+ candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
105
+ };
106
+
107
+ let model = Phi3Model::new(&config, vb)?;
108
+ PhiVariant::Phi3(model)
109
+ } else {
110
+ // Load Phi2 model
111
+ let config: Config = serde_json::from_str(&config_str)
112
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi config: {}", e)))?;
113
+
114
+ let vb = unsafe {
115
+ candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
116
+ };
117
+
118
+ let model = PhiModel::new(&config, vb)?;
119
+ PhiVariant::Phi2(model)
120
+ };
121
+
122
+ Ok(Self {
123
+ model,
124
+ tokenizer: TokenizerWrapper::new(tokenizer),
125
+ device,
126
+ model_id: model_id.to_string(),
127
+ eos_token_id,
128
+ })
129
+ }
130
+
131
+ /// Apply Phi chat template to messages
132
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
133
+ let mut prompt = String::new();
134
+
135
+ // Phi-3 uses a specific format
136
+ if matches!(self.model, PhiVariant::Phi3(_)) {
137
+ for message in messages {
138
+ let role = message["role"].as_str().unwrap_or("");
139
+ let content = message["content"].as_str().unwrap_or("");
140
+
141
+ match role {
142
+ "system" => {
143
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
144
+ }
145
+ "user" => {
146
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
147
+ }
148
+ "assistant" => {
149
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
150
+ }
151
+ _ => {}
152
+ }
153
+ }
154
+ prompt.push_str("<|assistant|>\n");
155
+ } else {
156
+ // Phi-2 uses a simpler format
157
+ for message in messages {
158
+ let role = message["role"].as_str().unwrap_or("");
159
+ let content = message["content"].as_str().unwrap_or("");
160
+
161
+ match role {
162
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
163
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
164
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
165
+ _ => {}
166
+ }
167
+ }
168
+ prompt.push_str("Assistant: ");
169
+ }
170
+
171
+ Ok(prompt)
172
+ }
173
+
174
+ fn generate_tokens(
175
+ &mut self,
176
+ prompt_tokens: Vec<u32>,
177
+ config: &GenerationConfig,
178
+ mut callback: Option<impl FnMut(&str)>,
179
+ ) -> CandleResult<Vec<u32>> {
180
+ let mut text_gen = TextGeneration::from_config(config);
181
+ text_gen.set_eos_token_id(self.eos_token_id);
182
+ text_gen.set_tokens(prompt_tokens.clone());
183
+
184
+ let mut all_tokens = prompt_tokens.clone();
185
+ let start_gen = all_tokens.len();
186
+
187
+ for index in 0..config.max_length {
188
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
189
+ let start_pos = all_tokens.len().saturating_sub(context_size);
190
+ let ctxt = &all_tokens[start_pos..];
191
+
192
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
193
+ let logits = match &mut self.model {
194
+ PhiVariant::Phi2(model) => model.forward(&input)?,
195
+ PhiVariant::Phi3(model) => model.forward(&input, start_pos)?,
196
+ };
197
+ let logits = logits.squeeze(0)?;
198
+
199
+ // Handle different output shapes
200
+ let logits = if logits.dims().len() == 2 {
201
+ let seq_len = logits.dim(0)?;
202
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
203
+ } else {
204
+ logits
205
+ };
206
+
207
+ let logits = logits.to_dtype(DType::F32)?;
208
+
209
+ let next_token = text_gen.sample_next_token(
210
+ &logits,
211
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
212
+ )?;
213
+
214
+ all_tokens.push(next_token);
215
+
216
+ // Stream callback
217
+ if let Some(ref mut cb) = callback {
218
+ if config.debug_tokens {
219
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
220
+ cb(&format!("[{}:{}]", next_token, token_piece));
221
+ } else {
222
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
223
+ cb(&decoded_text);
224
+ }
225
+ }
226
+
227
+ // Check stop conditions
228
+ if text_gen.should_stop(next_token, config.max_length) {
229
+ break;
230
+ }
231
+
232
+ // Check stop sequences
233
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
234
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
235
+ break;
236
+ }
237
+ }
238
+
239
+ Ok(if config.include_prompt {
240
+ all_tokens
241
+ } else {
242
+ all_tokens[start_gen..].to_vec()
243
+ })
244
+ }
245
+ }
246
+
247
+ impl TextGenerator for Phi {
248
+ fn generate(
249
+ &mut self,
250
+ prompt: &str,
251
+ config: &GenerationConfig,
252
+ ) -> CandleResult<String> {
253
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
254
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
255
+
256
+ if config.debug_tokens {
257
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
258
+ } else {
259
+ self.tokenizer.decode(&output_tokens, true)
260
+ }
261
+ }
262
+
263
+ fn generate_stream(
264
+ &mut self,
265
+ prompt: &str,
266
+ config: &GenerationConfig,
267
+ mut callback: impl FnMut(&str),
268
+ ) -> CandleResult<String> {
269
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
270
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
271
+ self.tokenizer.decode(&output_tokens, true)
272
+ }
273
+
274
+ fn model_name(&self) -> &str {
275
+ &self.model_id
276
+ }
277
+
278
+ fn device(&self) -> &Device {
279
+ &self.device
280
+ }
281
+
282
+ fn clear_cache(&mut self) {
283
+ self.clear_kv_cache();
284
+ }
285
+ }
@@ -2,6 +2,9 @@ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
2
  use candle_core::quantized::gguf_file;
3
3
  use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
4
4
  use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
5
+ use candle_transformers::models::quantized_qwen2::ModelWeights as QuantizedQwenModel;
6
+ use candle_transformers::models::quantized_phi::ModelWeights as QuantizedPhiModel;
7
+ use candle_transformers::models::quantized_phi3::ModelWeights as QuantizedPhi3Model;
5
8
  use hf_hub::api::tokio::{Api, ApiRepo};
6
9
  use tokenizers::Tokenizer;
7
10
  use std::io::Seek;
@@ -9,7 +12,6 @@ use std::io::Seek;
9
12
  use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
10
13
 
11
14
  /// Unified GGUF model that can load any GGUF file and detect the architecture
12
- #[derive(Debug)]
13
15
  pub struct QuantizedGGUF {
14
16
  model: ModelType,
15
17
  tokenizer: TokenizerWrapper,
@@ -20,10 +22,12 @@ pub struct QuantizedGGUF {
20
22
  _chat_template: Option<String>,
21
23
  }
22
24
 
23
- #[derive(Debug)]
24
25
  enum ModelType {
25
26
  Llama(QuantizedLlamaModel),
26
27
  Gemma(QuantizedGemmaModel),
28
+ Qwen(QuantizedQwenModel),
29
+ Phi(QuantizedPhiModel),
30
+ Phi3(QuantizedPhi3Model),
27
31
  // Mistral uses Llama loader due to tensor naming compatibility
28
32
  }
29
33
 
@@ -97,6 +101,34 @@ impl QuantizedGGUF {
97
101
  let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
98
102
  ModelType::Llama(model)
99
103
  }
104
+ "qwen" | "qwen2" | "qwen3" => {
105
+ // Try different loaders based on what metadata is available
106
+ if content.metadata.contains_key("llama.attention.head_count") {
107
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
108
+ ModelType::Llama(model)
109
+ } else if content.metadata.contains_key("qwen2.attention.head_count") {
110
+ let model = QuantizedQwenModel::from_gguf(content, &mut file, &device)?;
111
+ ModelType::Qwen(model)
112
+ } else if content.metadata.contains_key("qwen3.attention.head_count") {
113
+ // Qwen3 GGUF files use a different metadata format
114
+ // The quantized_qwen3 module is not yet in the released version of candle-transformers
115
+ return Err(candle_core::Error::Msg(format!(
116
+ "Qwen3 GGUF format detected but not yet fully supported.\n\n\
117
+ The file contains qwen3.* metadata keys which require candle-transformers > 0.9.1.\n\n\
118
+ Current alternatives:\n\
119
+ 1. Use Qwen2.5 GGUF models which work well:\n\
120
+ - Qwen/Qwen2.5-7B-Instruct-GGUF (recommended)\n\
121
+ - Qwen/Qwen2.5-32B-Instruct-GGUF\n\
122
+ 2. Use non-quantized Qwen models with safetensors\n\
123
+ 3. Wait for candle-transformers update with quantized_qwen3 support\n\n\
124
+ Note: Qwen2.5 models have similar capabilities to Qwen3."
125
+ )));
126
+ } else {
127
+ // Last resort: try llama loader anyway, as it's the most common
128
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
129
+ ModelType::Llama(model)
130
+ }
131
+ }
100
132
  "gemma" | "gemma2" | "gemma3" => {
101
133
  // Try Gemma-specific loader first, fall back to Llama if it fails
102
134
  match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
@@ -112,9 +144,20 @@ impl QuantizedGGUF {
112
144
  Err(e) => return Err(e),
113
145
  }
114
146
  }
147
+ "phi" | "phi2" => {
148
+ let model = QuantizedPhiModel::from_gguf(content, &mut file, &device)?;
149
+ ModelType::Phi(model)
150
+ }
151
+ "phi3" => {
152
+ // QuantizedPhi3Model requires an additional `approx` parameter
153
+ // Setting to false to avoid performance issues without flash-attn
154
+ let approx = false;
155
+ let model = QuantizedPhi3Model::from_gguf(approx, content, &mut file, &device)?;
156
+ ModelType::Phi3(model)
157
+ }
115
158
  _ => {
116
159
  return Err(candle_core::Error::Msg(format!(
117
- "Unsupported architecture: {}. Supported: llama, mistral, gemma",
160
+ "Unsupported architecture: {}. Supported: llama, mistral, gemma, qwen, qwen2, qwen3, phi, phi2, phi3",
118
161
  architecture
119
162
  )));
120
163
  }
@@ -149,6 +192,14 @@ impl QuantizedGGUF {
149
192
  Ok("mistral".to_string())
150
193
  } else if model_lower.contains("gemma") {
151
194
  Ok("gemma".to_string())
195
+ } else if model_lower.contains("qwen") {
196
+ Ok("qwen".to_string())
197
+ } else if model_lower.contains("phi-3") || model_lower.contains("phi3") {
198
+ Ok("phi3".to_string())
199
+ } else if model_lower.contains("phi-2") || model_lower.contains("phi2") {
200
+ Ok("phi2".to_string())
201
+ } else if model_lower.contains("phi") {
202
+ Ok("phi".to_string())
152
203
  } else {
153
204
  Err(candle_core::Error::Msg(
154
205
  "Could not determine model architecture from metadata or name".to_string()
@@ -235,6 +286,20 @@ impl QuantizedGGUF {
235
286
  .copied()
236
287
  .unwrap_or(1)
237
288
  }
289
+ "qwen" | "qwen2" | "qwen3" => {
290
+ vocab.get("<|endoftext|>")
291
+ .or_else(|| vocab.get("<|im_end|>"))
292
+ .or_else(|| vocab.get("</s>"))
293
+ .copied()
294
+ .unwrap_or(151643) // Default Qwen3 EOS token
295
+ }
296
+ "phi" | "phi2" | "phi3" => {
297
+ vocab.get("<|endoftext|>")
298
+ .or_else(|| vocab.get("<|end|>"))
299
+ .or_else(|| vocab.get("</s>"))
300
+ .copied()
301
+ .unwrap_or(50256) // Default GPT-2 style EOS token
302
+ }
238
303
  _ => 2, // Default
239
304
  }
240
305
  }
@@ -256,6 +321,10 @@ impl QuantizedGGUF {
256
321
  } else if model_lower.contains("gemma") {
257
322
  // Always use Gemma template for Gemma models, regardless of loader used
258
323
  self.apply_gemma_template(messages)
324
+ } else if model_lower.contains("qwen") {
325
+ self.apply_qwen_template(messages)
326
+ } else if model_lower.contains("phi") {
327
+ self.apply_phi_template(messages)
259
328
  } else {
260
329
  match self.architecture.as_str() {
261
330
  "llama" => {
@@ -268,6 +337,12 @@ impl QuantizedGGUF {
268
337
  "gemma" => {
269
338
  self.apply_gemma_template(messages)
270
339
  }
340
+ "qwen" | "qwen2" | "qwen3" => {
341
+ self.apply_qwen_template(messages)
342
+ }
343
+ "phi" | "phi2" | "phi3" => {
344
+ self.apply_phi_template(messages)
345
+ }
271
346
  _ => Ok(self.apply_generic_template(messages))
272
347
  }
273
348
  }
@@ -366,6 +441,77 @@ impl QuantizedGGUF {
366
441
  Ok(prompt)
367
442
  }
368
443
 
444
+ fn apply_qwen_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
445
+ let mut prompt = String::new();
446
+
447
+ for message in messages {
448
+ let role = message["role"].as_str().unwrap_or("");
449
+ let content = message["content"].as_str().unwrap_or("");
450
+
451
+ match role {
452
+ "system" => {
453
+ prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
454
+ }
455
+ "user" => {
456
+ prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
457
+ }
458
+ "assistant" => {
459
+ prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
460
+ }
461
+ _ => {}
462
+ }
463
+ }
464
+
465
+ // Add generation prompt
466
+ prompt.push_str("<|im_start|>assistant\n");
467
+ Ok(prompt)
468
+ }
469
+
470
+ fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
471
+ let mut prompt = String::new();
472
+
473
+ // Check if it's Phi-3 (newer format) or Phi-2/Phi (simpler format)
474
+ let is_phi3 = self.model_id.contains("phi-3") || self.model_id.contains("Phi-3") || self.architecture == "phi3";
475
+
476
+ if is_phi3 {
477
+ // Phi-3 format
478
+ for message in messages {
479
+ let role = message["role"].as_str().unwrap_or("");
480
+ let content = message["content"].as_str().unwrap_or("");
481
+
482
+ match role {
483
+ "system" => {
484
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
485
+ }
486
+ "user" => {
487
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
488
+ }
489
+ "assistant" => {
490
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
491
+ }
492
+ _ => {}
493
+ }
494
+ }
495
+ prompt.push_str("<|assistant|>\n");
496
+ } else {
497
+ // Phi-2 format
498
+ for message in messages {
499
+ let role = message["role"].as_str().unwrap_or("");
500
+ let content = message["content"].as_str().unwrap_or("");
501
+
502
+ match role {
503
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
504
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
505
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
506
+ _ => {}
507
+ }
508
+ }
509
+ prompt.push_str("Assistant: ");
510
+ }
511
+
512
+ Ok(prompt)
513
+ }
514
+
369
515
  fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
370
516
  let mut prompt = String::new();
371
517
 
@@ -381,7 +527,9 @@ impl QuantizedGGUF {
381
527
 
382
528
  /// Clear the KV cache between generations
383
529
  pub fn clear_kv_cache(&mut self) {
384
- // Quantized models manage cache internally
530
+ // Quantized models don't expose cache clearing methods
531
+ // Phi3 GGUF models have a known issue where the KV cache
532
+ // cannot be cleared, leading to errors on subsequent generations
385
533
  }
386
534
 
387
535
  fn generate_tokens(
@@ -408,6 +556,9 @@ impl QuantizedGGUF {
408
556
  let logits = match &mut self.model {
409
557
  ModelType::Llama(model) => model.forward(&input, start_pos)?,
410
558
  ModelType::Gemma(model) => model.forward(&input, start_pos)?,
559
+ ModelType::Qwen(model) => model.forward(&input, start_pos)?,
560
+ ModelType::Phi(model) => model.forward(&input, start_pos)?,
561
+ ModelType::Phi3(model) => model.forward(&input, start_pos)?,
411
562
  };
412
563
 
413
564
  let logits = logits.squeeze(0)?;