red-candle 1.8.0-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_transformers::models::qwen2::{Config, Model as QwenModel};
|
|
3
|
+
use hf_hub::api::tokio::Api;
|
|
4
|
+
use tokenizers::Tokenizer;
|
|
5
|
+
|
|
6
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
7
|
+
|
|
8
|
+
/// Qwen model wrapper for text generation
|
|
9
|
+
#[derive(Debug)]
|
|
10
|
+
pub struct Qwen {
|
|
11
|
+
model: QwenModel,
|
|
12
|
+
tokenizer: TokenizerWrapper,
|
|
13
|
+
device: Device,
|
|
14
|
+
model_id: String,
|
|
15
|
+
eos_token_id: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
impl Qwen {
|
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
20
|
+
self.eos_token_id
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
/// Get the tokenizer
|
|
24
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
25
|
+
&self.tokenizer
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/// Clear the KV cache between generations
|
|
29
|
+
pub fn clear_kv_cache(&mut self) {
|
|
30
|
+
self.model.clear_kv_cache();
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/// Load a Qwen model from HuggingFace with optional custom tokenizer
|
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
|
+
let api = Api::new()
|
|
36
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
|
+
|
|
38
|
+
let repo = api.model(model_id.to_string());
|
|
39
|
+
|
|
40
|
+
// Download configuration
|
|
41
|
+
let config_filename = repo.get("config.json").await
|
|
42
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
43
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
|
44
|
+
let config: Config = serde_json::from_str(&config_str)
|
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
46
|
+
|
|
47
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
48
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
49
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
|
50
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
|
51
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
52
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
53
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
54
|
+
} else {
|
|
55
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
|
56
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
57
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
58
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
// Determine EOS token
|
|
62
|
+
let vocab = tokenizer.get_vocab(true);
|
|
63
|
+
let eos_token_id = vocab.get("<|im_end|>")
|
|
64
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
|
65
|
+
.or_else(|| vocab.get("</s>"))
|
|
66
|
+
.copied()
|
|
67
|
+
.unwrap_or(151645); // Default Qwen2.5 EOS token
|
|
68
|
+
|
|
69
|
+
// Download model weights
|
|
70
|
+
// NOTE: Qwen uses hardcoded shard counts based on model size rather than
|
|
71
|
+
// reading model.safetensors.index.json. This works for official Qwen models
|
|
72
|
+
// but may fail for custom configurations with different shard counts.
|
|
73
|
+
let mut filenames = vec![];
|
|
74
|
+
let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
|
|
75
|
+
else if model_id.contains("14b") || model_id.contains("14B") { 3 }
|
|
76
|
+
else { 1 };
|
|
77
|
+
|
|
78
|
+
if num_shards == 1 {
|
|
79
|
+
// Single file model
|
|
80
|
+
let filename = repo.get("model.safetensors").await
|
|
81
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
|
|
82
|
+
filenames.push(filename);
|
|
83
|
+
} else {
|
|
84
|
+
// Sharded model
|
|
85
|
+
for shard_idx in 1..=num_shards {
|
|
86
|
+
let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
|
|
87
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
|
|
88
|
+
filenames.push(filename);
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Load the model
|
|
93
|
+
let vb = unsafe {
|
|
94
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
|
|
95
|
+
};
|
|
96
|
+
|
|
97
|
+
let model = QwenModel::new(&config, vb)?;
|
|
98
|
+
|
|
99
|
+
Ok(Self {
|
|
100
|
+
model,
|
|
101
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
102
|
+
device,
|
|
103
|
+
model_id: model_id.to_string(),
|
|
104
|
+
eos_token_id,
|
|
105
|
+
})
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/// Load a Qwen model from HuggingFace (backwards compatibility)
|
|
109
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
110
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/// Apply Qwen chat template to messages
|
|
114
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
115
|
+
let mut prompt = String::new();
|
|
116
|
+
|
|
117
|
+
for message in messages {
|
|
118
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
119
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
120
|
+
|
|
121
|
+
match role {
|
|
122
|
+
"system" => {
|
|
123
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
|
124
|
+
}
|
|
125
|
+
"user" => {
|
|
126
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
|
127
|
+
}
|
|
128
|
+
"assistant" => {
|
|
129
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
|
130
|
+
}
|
|
131
|
+
"tool" => {
|
|
132
|
+
prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
|
|
133
|
+
}
|
|
134
|
+
_ => {}
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
// Add generation prompt
|
|
139
|
+
prompt.push_str("<|im_start|>assistant\n");
|
|
140
|
+
|
|
141
|
+
Ok(prompt)
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
fn generate_tokens(
|
|
145
|
+
&mut self,
|
|
146
|
+
prompt_tokens: Vec<u32>,
|
|
147
|
+
config: &GenerationConfig,
|
|
148
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
149
|
+
) -> CandleResult<Vec<u32>> {
|
|
150
|
+
let mut text_gen = TextGeneration::new(config);
|
|
151
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
152
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
153
|
+
|
|
154
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
155
|
+
let start_gen = all_tokens.len();
|
|
156
|
+
|
|
157
|
+
for index in 0..config.max_length {
|
|
158
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
159
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
160
|
+
let ctxt = &all_tokens[start_pos..];
|
|
161
|
+
|
|
162
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
163
|
+
let logits = self.model.forward(&input, start_pos, None)?;
|
|
164
|
+
let logits = logits.squeeze(0)?;
|
|
165
|
+
|
|
166
|
+
// Handle different output shapes
|
|
167
|
+
let logits = if logits.dims().len() == 2 {
|
|
168
|
+
let seq_len = logits.dim(0)?;
|
|
169
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
170
|
+
} else {
|
|
171
|
+
logits
|
|
172
|
+
};
|
|
173
|
+
|
|
174
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
175
|
+
|
|
176
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
177
|
+
|
|
178
|
+
all_tokens.push(next_token);
|
|
179
|
+
|
|
180
|
+
// Stream callback
|
|
181
|
+
if let Some(ref mut cb) = callback {
|
|
182
|
+
if config.debug_tokens {
|
|
183
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
184
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
185
|
+
} else {
|
|
186
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
187
|
+
cb(&decoded_text);
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// Check stop conditions
|
|
192
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
193
|
+
break;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// Check if constraint is satisfied (early stopping)
|
|
197
|
+
if config.stop_on_constraint_satisfaction {
|
|
198
|
+
let satisfied = if config.stop_on_match {
|
|
199
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
200
|
+
} else {
|
|
201
|
+
text_gen.is_constraint_satisfied()
|
|
202
|
+
};
|
|
203
|
+
if satisfied {
|
|
204
|
+
break;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// Check stop sequences
|
|
209
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
210
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
211
|
+
break;
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
Ok(if config.include_prompt {
|
|
216
|
+
all_tokens
|
|
217
|
+
} else {
|
|
218
|
+
all_tokens[start_gen..].to_vec()
|
|
219
|
+
})
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
impl TextGenerator for Qwen {
|
|
224
|
+
fn generate(
|
|
225
|
+
&mut self,
|
|
226
|
+
prompt: &str,
|
|
227
|
+
config: &GenerationConfig,
|
|
228
|
+
) -> CandleResult<String> {
|
|
229
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
230
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
231
|
+
|
|
232
|
+
if config.debug_tokens {
|
|
233
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
234
|
+
} else {
|
|
235
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
fn generate_stream(
|
|
240
|
+
&mut self,
|
|
241
|
+
prompt: &str,
|
|
242
|
+
config: &GenerationConfig,
|
|
243
|
+
mut callback: impl FnMut(&str),
|
|
244
|
+
) -> CandleResult<String> {
|
|
245
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
246
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
247
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
fn model_name(&self) -> &str {
|
|
251
|
+
&self.model_id
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
fn device(&self) -> &Device {
|
|
255
|
+
&self.device
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
fn clear_cache(&mut self) {
|
|
259
|
+
self.clear_kv_cache();
|
|
260
|
+
}
|
|
261
|
+
}
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_transformers::models::qwen3::{Config, ModelForCausalLM as Qwen3Model};
|
|
3
|
+
use hf_hub::api::tokio::Api;
|
|
4
|
+
use tokenizers::Tokenizer;
|
|
5
|
+
|
|
6
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
7
|
+
|
|
8
|
+
/// Qwen3 model wrapper for text generation
|
|
9
|
+
#[derive(Debug)]
|
|
10
|
+
pub struct Qwen3 {
|
|
11
|
+
model: Qwen3Model,
|
|
12
|
+
tokenizer: TokenizerWrapper,
|
|
13
|
+
device: Device,
|
|
14
|
+
model_id: String,
|
|
15
|
+
eos_token_id: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
impl Qwen3 {
|
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
20
|
+
self.eos_token_id
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
/// Get the tokenizer
|
|
24
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
25
|
+
&self.tokenizer
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/// Clear the KV cache between generations
|
|
29
|
+
pub fn clear_kv_cache(&mut self) {
|
|
30
|
+
self.model.clear_kv_cache();
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/// Load a Qwen3 model from HuggingFace with optional custom tokenizer
|
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
|
+
let api = Api::new()
|
|
36
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
|
+
|
|
38
|
+
let repo = api.model(model_id.to_string());
|
|
39
|
+
|
|
40
|
+
// Download configuration
|
|
41
|
+
let config_filename = repo.get("config.json").await
|
|
42
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
43
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
|
44
|
+
let config: Config = serde_json::from_str(&config_str)
|
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
46
|
+
|
|
47
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
48
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
49
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
|
50
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
|
51
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
52
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
53
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
54
|
+
} else {
|
|
55
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
|
56
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
57
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
58
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
// Determine EOS token
|
|
62
|
+
let vocab = tokenizer.get_vocab(true);
|
|
63
|
+
let eos_token_id = vocab.get("<|im_end|>")
|
|
64
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
|
65
|
+
.or_else(|| vocab.get("</s>"))
|
|
66
|
+
.copied()
|
|
67
|
+
.unwrap_or(151645);
|
|
68
|
+
|
|
69
|
+
// Download model weights
|
|
70
|
+
let mut filenames = vec![];
|
|
71
|
+
let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
|
|
72
|
+
else if model_id.contains("14b") || model_id.contains("14B") { 3 }
|
|
73
|
+
else if model_id.contains("32b") || model_id.contains("32B") { 4 }
|
|
74
|
+
else { 1 };
|
|
75
|
+
|
|
76
|
+
if num_shards == 1 {
|
|
77
|
+
let filename = repo.get("model.safetensors").await
|
|
78
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
|
|
79
|
+
filenames.push(filename);
|
|
80
|
+
} else {
|
|
81
|
+
for shard_idx in 1..=num_shards {
|
|
82
|
+
let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
|
|
83
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
|
|
84
|
+
filenames.push(filename);
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Load the model
|
|
89
|
+
let vb = unsafe {
|
|
90
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
|
|
91
|
+
};
|
|
92
|
+
|
|
93
|
+
let model = Qwen3Model::new(&config, vb)?;
|
|
94
|
+
|
|
95
|
+
Ok(Self {
|
|
96
|
+
model,
|
|
97
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
98
|
+
device,
|
|
99
|
+
model_id: model_id.to_string(),
|
|
100
|
+
eos_token_id,
|
|
101
|
+
})
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/// Load a Qwen3 model from HuggingFace (backwards compatibility)
|
|
105
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
106
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
/// Apply Qwen3 chat template to messages (same format as Qwen2)
|
|
110
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
111
|
+
let mut prompt = String::new();
|
|
112
|
+
|
|
113
|
+
for message in messages {
|
|
114
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
115
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
116
|
+
|
|
117
|
+
match role {
|
|
118
|
+
"system" => {
|
|
119
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
|
120
|
+
}
|
|
121
|
+
"user" => {
|
|
122
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
|
123
|
+
}
|
|
124
|
+
"assistant" => {
|
|
125
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
|
126
|
+
}
|
|
127
|
+
"tool" => {
|
|
128
|
+
prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
|
|
129
|
+
}
|
|
130
|
+
_ => {}
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// Add generation prompt
|
|
135
|
+
prompt.push_str("<|im_start|>assistant\n");
|
|
136
|
+
|
|
137
|
+
Ok(prompt)
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
fn generate_tokens(
|
|
141
|
+
&mut self,
|
|
142
|
+
prompt_tokens: Vec<u32>,
|
|
143
|
+
config: &GenerationConfig,
|
|
144
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
145
|
+
) -> CandleResult<Vec<u32>> {
|
|
146
|
+
let mut text_gen = TextGeneration::new(config);
|
|
147
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
148
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
149
|
+
|
|
150
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
151
|
+
let start_gen = all_tokens.len();
|
|
152
|
+
|
|
153
|
+
for index in 0..config.max_length {
|
|
154
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
155
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
156
|
+
let ctxt = &all_tokens[start_pos..];
|
|
157
|
+
|
|
158
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
159
|
+
let logits = self.model.forward(&input, start_pos)?;
|
|
160
|
+
let logits = logits.squeeze(0)?;
|
|
161
|
+
|
|
162
|
+
// Handle different output shapes
|
|
163
|
+
let logits = if logits.dims().len() == 2 {
|
|
164
|
+
let seq_len = logits.dim(0)?;
|
|
165
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
166
|
+
} else {
|
|
167
|
+
logits
|
|
168
|
+
};
|
|
169
|
+
|
|
170
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
171
|
+
|
|
172
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
173
|
+
|
|
174
|
+
all_tokens.push(next_token);
|
|
175
|
+
|
|
176
|
+
// Stream callback
|
|
177
|
+
if let Some(ref mut cb) = callback {
|
|
178
|
+
if config.debug_tokens {
|
|
179
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
180
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
181
|
+
} else {
|
|
182
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
183
|
+
cb(&decoded_text);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
// Check stop conditions
|
|
188
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
189
|
+
break;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// Check if constraint is satisfied (early stopping)
|
|
193
|
+
if config.stop_on_constraint_satisfaction {
|
|
194
|
+
let satisfied = if config.stop_on_match {
|
|
195
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
196
|
+
} else {
|
|
197
|
+
text_gen.is_constraint_satisfied()
|
|
198
|
+
};
|
|
199
|
+
if satisfied {
|
|
200
|
+
break;
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
// Check stop sequences
|
|
205
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
206
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
207
|
+
break;
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
Ok(if config.include_prompt {
|
|
212
|
+
all_tokens
|
|
213
|
+
} else {
|
|
214
|
+
all_tokens[start_gen..].to_vec()
|
|
215
|
+
})
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
impl TextGenerator for Qwen3 {
|
|
220
|
+
fn generate(
|
|
221
|
+
&mut self,
|
|
222
|
+
prompt: &str,
|
|
223
|
+
config: &GenerationConfig,
|
|
224
|
+
) -> CandleResult<String> {
|
|
225
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
226
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
227
|
+
|
|
228
|
+
if config.debug_tokens {
|
|
229
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
230
|
+
} else {
|
|
231
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
fn generate_stream(
|
|
236
|
+
&mut self,
|
|
237
|
+
prompt: &str,
|
|
238
|
+
config: &GenerationConfig,
|
|
239
|
+
mut callback: impl FnMut(&str),
|
|
240
|
+
) -> CandleResult<String> {
|
|
241
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
242
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
243
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
fn model_name(&self) -> &str {
|
|
247
|
+
&self.model_id
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
fn device(&self) -> &Device {
|
|
251
|
+
&self.device
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
fn clear_cache(&mut self) {
|
|
255
|
+
self.clear_kv_cache();
|
|
256
|
+
}
|
|
257
|
+
}
|