red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,734 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_core::quantized::gguf_file;
|
|
3
|
+
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
|
|
4
|
+
use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
|
|
5
|
+
use candle_transformers::models::quantized_qwen2::ModelWeights as QuantizedQwenModel;
|
|
6
|
+
use candle_transformers::models::quantized_phi::ModelWeights as QuantizedPhiModel;
|
|
7
|
+
use candle_transformers::models::quantized_phi3::ModelWeights as QuantizedPhi3Model;
|
|
8
|
+
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Model;
|
|
9
|
+
use candle_transformers::models::quantized_glm4::ModelWeights as QuantizedGlm4Model;
|
|
10
|
+
use hf_hub::api::tokio::{Api, ApiRepo};
|
|
11
|
+
use tokenizers::Tokenizer;
|
|
12
|
+
use std::io::Seek;
|
|
13
|
+
|
|
14
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
15
|
+
|
|
16
|
+
/// Unified GGUF model that can load any GGUF file and detect the architecture
|
|
17
|
+
pub struct QuantizedGGUF {
|
|
18
|
+
model: ModelType,
|
|
19
|
+
tokenizer: TokenizerWrapper,
|
|
20
|
+
device: Device,
|
|
21
|
+
model_id: String,
|
|
22
|
+
eos_token_id: u32,
|
|
23
|
+
pub architecture: String,
|
|
24
|
+
_chat_template: Option<String>,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
enum ModelType {
|
|
28
|
+
Llama(QuantizedLlamaModel),
|
|
29
|
+
Gemma(QuantizedGemmaModel),
|
|
30
|
+
Qwen(QuantizedQwenModel),
|
|
31
|
+
Qwen3(QuantizedQwen3Model),
|
|
32
|
+
Phi(QuantizedPhiModel),
|
|
33
|
+
Phi3(QuantizedPhi3Model),
|
|
34
|
+
Glm4(QuantizedGlm4Model),
|
|
35
|
+
// Mistral uses Llama loader due to tensor naming compatibility
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
impl QuantizedGGUF {
|
|
39
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
40
|
+
self.eos_token_id
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
/// Get the tokenizer
|
|
44
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
45
|
+
&self.tokenizer
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/// Load a quantized model from a GGUF file
|
|
49
|
+
pub async fn from_pretrained(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
50
|
+
// Check if user specified an exact GGUF filename
|
|
51
|
+
let (actual_model_id, gguf_file) = if let Some(pos) = model_id.find('@') {
|
|
52
|
+
let (id, filename) = model_id.split_at(pos);
|
|
53
|
+
(id, Some(&filename[1..]))
|
|
54
|
+
} else {
|
|
55
|
+
(model_id, None)
|
|
56
|
+
};
|
|
57
|
+
|
|
58
|
+
let api = Api::new()
|
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
60
|
+
|
|
61
|
+
let repo = api.model(actual_model_id.to_string());
|
|
62
|
+
|
|
63
|
+
// Download GGUF file
|
|
64
|
+
let gguf_filename = if let Some(filename) = gguf_file {
|
|
65
|
+
// User specified exact filename
|
|
66
|
+
repo.get(filename).await
|
|
67
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download GGUF file '{}': {}", filename, e)))?
|
|
68
|
+
.to_string_lossy().to_string()
|
|
69
|
+
} else {
|
|
70
|
+
// Let Ruby handle the search, for now just try a common name
|
|
71
|
+
return Err(candle_core::Error::Msg(
|
|
72
|
+
"Please specify a GGUF filename using gguf_file parameter".to_string()
|
|
73
|
+
));
|
|
74
|
+
};
|
|
75
|
+
|
|
76
|
+
// Read GGUF metadata to determine architecture
|
|
77
|
+
let mut file = std::fs::File::open(&gguf_filename)?;
|
|
78
|
+
let content = gguf_file::Content::read(&mut file)?;
|
|
79
|
+
|
|
80
|
+
// Detect architecture from metadata
|
|
81
|
+
let architecture = Self::detect_architecture(&content, actual_model_id)?;
|
|
82
|
+
|
|
83
|
+
// For Gemma 3 models, we might need to adjust the architecture
|
|
84
|
+
let architecture = if actual_model_id.contains("gemma-3") || actual_model_id.contains("gemma3") {
|
|
85
|
+
"gemma3".to_string()
|
|
86
|
+
} else {
|
|
87
|
+
architecture
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
// Download tokenizer - either from specified source or with fallback
|
|
91
|
+
let tokenizer_filename = if let Some(source) = tokenizer_source {
|
|
92
|
+
Self::download_tokenizer_from_source(&api, source).await?
|
|
93
|
+
} else {
|
|
94
|
+
Self::download_tokenizer(&api, &repo, actual_model_id, &architecture).await?
|
|
95
|
+
};
|
|
96
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
|
97
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
|
98
|
+
|
|
99
|
+
// Determine EOS token based on architecture and model
|
|
100
|
+
let eos_token_id = Self::determine_eos_token(&tokenizer, &architecture, actual_model_id);
|
|
101
|
+
|
|
102
|
+
// Load the appropriate model based on architecture
|
|
103
|
+
file.seek(std::io::SeekFrom::Start(0))?;
|
|
104
|
+
let content = gguf_file::Content::read(&mut file)?;
|
|
105
|
+
|
|
106
|
+
let model = match architecture.as_str() {
|
|
107
|
+
"llama" | "mistral" => {
|
|
108
|
+
// Both use the same GGUF format with llama.cpp tensor names
|
|
109
|
+
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
|
110
|
+
ModelType::Llama(model)
|
|
111
|
+
}
|
|
112
|
+
"qwen" | "qwen2" | "qwen3" => {
|
|
113
|
+
// Try different loaders based on what metadata is available
|
|
114
|
+
if content.metadata.contains_key("llama.attention.head_count") {
|
|
115
|
+
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
|
116
|
+
ModelType::Llama(model)
|
|
117
|
+
} else if content.metadata.contains_key("qwen2.attention.head_count") {
|
|
118
|
+
let model = QuantizedQwenModel::from_gguf(content, &mut file, &device)?;
|
|
119
|
+
ModelType::Qwen(model)
|
|
120
|
+
} else if content.metadata.contains_key("qwen3.attention.head_count") {
|
|
121
|
+
let model = QuantizedQwen3Model::from_gguf(content, &mut file, &device)?;
|
|
122
|
+
ModelType::Qwen3(model)
|
|
123
|
+
} else {
|
|
124
|
+
// Last resort: try llama loader anyway, as it's the most common
|
|
125
|
+
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
|
126
|
+
ModelType::Llama(model)
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
"gemma" | "gemma2" | "gemma3" => {
|
|
130
|
+
// Try Gemma-specific loader first, fall back to Llama if it fails
|
|
131
|
+
match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
|
|
132
|
+
Ok(model) => ModelType::Gemma(model),
|
|
133
|
+
Err(e) if e.to_string().contains("gemma3.attention.head_count") => {
|
|
134
|
+
// This might be an older Gemma GGUF that uses llama format
|
|
135
|
+
// Note: Some Gemma GGUF files may not be compatible
|
|
136
|
+
file.seek(std::io::SeekFrom::Start(0))?;
|
|
137
|
+
let content = gguf_file::Content::read(&mut file)?;
|
|
138
|
+
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
|
139
|
+
ModelType::Llama(model)
|
|
140
|
+
}
|
|
141
|
+
Err(e) => return Err(e),
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
"phi" | "phi2" => {
|
|
145
|
+
let model = QuantizedPhiModel::from_gguf(content, &mut file, &device)?;
|
|
146
|
+
ModelType::Phi(model)
|
|
147
|
+
}
|
|
148
|
+
"phi3" => {
|
|
149
|
+
// QuantizedPhi3Model requires an additional `approx` parameter
|
|
150
|
+
// Setting to false to avoid performance issues without flash-attn
|
|
151
|
+
let approx = false;
|
|
152
|
+
let model = QuantizedPhi3Model::from_gguf(approx, content, &mut file, &device)?;
|
|
153
|
+
ModelType::Phi3(model)
|
|
154
|
+
}
|
|
155
|
+
"glm4" => {
|
|
156
|
+
if content.metadata.contains_key("glm4.attention.head_count") {
|
|
157
|
+
let model = QuantizedGlm4Model::from_gguf(content, &mut file, &device, DType::F32)?;
|
|
158
|
+
ModelType::Glm4(model)
|
|
159
|
+
} else {
|
|
160
|
+
return Err(candle_core::Error::Msg(
|
|
161
|
+
"GLM-4 GGUF file does not contain expected glm4.* metadata keys".to_string()
|
|
162
|
+
));
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
_ => {
|
|
166
|
+
return Err(candle_core::Error::Msg(format!(
|
|
167
|
+
"Unsupported architecture: {}. Supported: llama, mistral, gemma, qwen, qwen2, qwen3, phi, phi2, phi3, glm4",
|
|
168
|
+
architecture
|
|
169
|
+
)));
|
|
170
|
+
}
|
|
171
|
+
};
|
|
172
|
+
|
|
173
|
+
// Detect chat template (for now, use defaults based on architecture)
|
|
174
|
+
let chat_template = Self::detect_chat_template(&tokenizer, &architecture, actual_model_id);
|
|
175
|
+
|
|
176
|
+
Ok(Self {
|
|
177
|
+
model,
|
|
178
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
179
|
+
device,
|
|
180
|
+
model_id: actual_model_id.to_string(),
|
|
181
|
+
eos_token_id,
|
|
182
|
+
architecture: architecture.clone(),
|
|
183
|
+
_chat_template: chat_template,
|
|
184
|
+
})
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
/// Detect architecture from GGUF metadata or model name
|
|
188
|
+
fn detect_architecture(content: &gguf_file::Content, model_id: &str) -> CandleResult<String> {
|
|
189
|
+
// First try to get from metadata
|
|
190
|
+
if let Some(gguf_file::Value::String(arch)) = content.metadata.get("general.architecture") {
|
|
191
|
+
return Ok(arch.clone());
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
// Fallback to model name detection
|
|
195
|
+
let model_lower = model_id.to_lowercase();
|
|
196
|
+
if model_lower.contains("llama") || model_lower.contains("tinyllama") {
|
|
197
|
+
Ok("llama".to_string())
|
|
198
|
+
} else if model_lower.contains("mistral") {
|
|
199
|
+
Ok("mistral".to_string())
|
|
200
|
+
} else if model_lower.contains("gemma") {
|
|
201
|
+
Ok("gemma".to_string())
|
|
202
|
+
} else if model_lower.contains("qwen") {
|
|
203
|
+
Ok("qwen".to_string())
|
|
204
|
+
} else if model_lower.contains("glm") {
|
|
205
|
+
Ok("glm4".to_string())
|
|
206
|
+
} else if model_lower.contains("phi-3") || model_lower.contains("phi3") {
|
|
207
|
+
Ok("phi3".to_string())
|
|
208
|
+
} else if model_lower.contains("phi-2") || model_lower.contains("phi2") {
|
|
209
|
+
Ok("phi2".to_string())
|
|
210
|
+
} else if model_lower.contains("phi") {
|
|
211
|
+
Ok("phi".to_string())
|
|
212
|
+
} else {
|
|
213
|
+
Err(candle_core::Error::Msg(
|
|
214
|
+
"Could not determine model architecture from metadata or name".to_string()
|
|
215
|
+
))
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/// Download tokenizer from a specific source
|
|
220
|
+
async fn download_tokenizer_from_source(
|
|
221
|
+
api: &Api,
|
|
222
|
+
source: &str
|
|
223
|
+
) -> CandleResult<std::path::PathBuf> {
|
|
224
|
+
// Check if it's a local file path
|
|
225
|
+
if source.ends_with(".json") && std::path::Path::new(source).exists() {
|
|
226
|
+
return Ok(std::path::PathBuf::from(source));
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
// Otherwise treat it as a HuggingFace repo
|
|
230
|
+
let repo = api.model(source.to_string());
|
|
231
|
+
|
|
232
|
+
// Try tokenizer.json first
|
|
233
|
+
if let Ok(path) = repo.get("tokenizer.json").await {
|
|
234
|
+
return Ok(path);
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// Try tokenizer.model (for models that use sentencepiece)
|
|
238
|
+
if let Ok(path) = repo.get("tokenizer.model").await {
|
|
239
|
+
return Ok(path);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
Err(candle_core::Error::Msg(format!(
|
|
243
|
+
"Failed to find tokenizer in specified source: {}",
|
|
244
|
+
source
|
|
245
|
+
)))
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
/// Download tokenizer with architecture-specific fallbacks
|
|
249
|
+
async fn download_tokenizer(
|
|
250
|
+
_api: &Api,
|
|
251
|
+
repo: &ApiRepo,
|
|
252
|
+
model_id: &str,
|
|
253
|
+
_architecture: &str
|
|
254
|
+
) -> CandleResult<std::path::PathBuf> {
|
|
255
|
+
// First try to get tokenizer.json from the GGUF repo
|
|
256
|
+
if let Ok(path) = repo.get("tokenizer.json").await {
|
|
257
|
+
return Ok(path);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
// Try tokenizer.model (for models that use sentencepiece)
|
|
261
|
+
if let Ok(path) = repo.get("tokenizer.model").await {
|
|
262
|
+
return Ok(path);
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
// If no tokenizer found in GGUF repo, return error
|
|
266
|
+
// Ruby will handle the fallback logic
|
|
267
|
+
Err(candle_core::Error::Msg(format!(
|
|
268
|
+
"No tokenizer found in GGUF repository {}. Please specify a tokenizer source.",
|
|
269
|
+
model_id
|
|
270
|
+
)))
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
/// Determine EOS token based on architecture and model
|
|
274
|
+
fn determine_eos_token(tokenizer: &Tokenizer, architecture: &str, model_id: &str) -> u32 {
|
|
275
|
+
let vocab = tokenizer.get_vocab(true);
|
|
276
|
+
|
|
277
|
+
match architecture {
|
|
278
|
+
"llama" | "mistral" => {
|
|
279
|
+
// Check if it's Llama 3
|
|
280
|
+
if model_id.contains("Llama-3") || model_id.contains("llama-3") {
|
|
281
|
+
vocab.get("<|eot_id|>")
|
|
282
|
+
.or_else(|| vocab.get("<|end_of_text|>"))
|
|
283
|
+
.copied()
|
|
284
|
+
.unwrap_or(128009)
|
|
285
|
+
} else {
|
|
286
|
+
// Llama 2 and Mistral
|
|
287
|
+
vocab.get("</s>")
|
|
288
|
+
.copied()
|
|
289
|
+
.unwrap_or(2)
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
"gemma" => {
|
|
293
|
+
vocab.get("<eos>")
|
|
294
|
+
.or_else(|| vocab.get("<end_of_turn>"))
|
|
295
|
+
.copied()
|
|
296
|
+
.unwrap_or(1)
|
|
297
|
+
}
|
|
298
|
+
"qwen" | "qwen2" | "qwen3" => {
|
|
299
|
+
vocab.get("<|endoftext|>")
|
|
300
|
+
.or_else(|| vocab.get("<|im_end|>"))
|
|
301
|
+
.or_else(|| vocab.get("</s>"))
|
|
302
|
+
.copied()
|
|
303
|
+
.unwrap_or(151643) // Default Qwen3 EOS token
|
|
304
|
+
}
|
|
305
|
+
"phi" | "phi2" | "phi3" => {
|
|
306
|
+
vocab.get("<|endoftext|>")
|
|
307
|
+
.or_else(|| vocab.get("<|end|>"))
|
|
308
|
+
.or_else(|| vocab.get("</s>"))
|
|
309
|
+
.copied()
|
|
310
|
+
.unwrap_or(50256) // Default GPT-2 style EOS token
|
|
311
|
+
}
|
|
312
|
+
"glm4" => {
|
|
313
|
+
vocab.get("<|endoftext|>")
|
|
314
|
+
.or_else(|| vocab.get("<|user|>"))
|
|
315
|
+
.or_else(|| vocab.get("</s>"))
|
|
316
|
+
.copied()
|
|
317
|
+
.unwrap_or(151329)
|
|
318
|
+
}
|
|
319
|
+
_ => 2, // Default
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
/// Detect chat template based on model
|
|
324
|
+
fn detect_chat_template(_tokenizer: &Tokenizer, _architecture: &str, _model_id: &str) -> Option<String> {
|
|
325
|
+
// For now, return None and handle templates in apply_chat_template
|
|
326
|
+
// In the future, this could read from tokenizer config
|
|
327
|
+
None
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
/// Apply chat template based on detected architecture
|
|
331
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
332
|
+
// Check model name since Mistral GGUF reports as llama architecture
|
|
333
|
+
let model_lower = self.model_id.to_lowercase();
|
|
334
|
+
|
|
335
|
+
if model_lower.contains("tinyllama") {
|
|
336
|
+
self.apply_chatml_template(messages)
|
|
337
|
+
} else if model_lower.contains("mistral") {
|
|
338
|
+
self.apply_mistral_template(messages)
|
|
339
|
+
} else if model_lower.contains("gemma") {
|
|
340
|
+
// Always use Gemma template for Gemma models, regardless of loader used
|
|
341
|
+
self.apply_gemma_template(messages)
|
|
342
|
+
} else if model_lower.contains("qwen") {
|
|
343
|
+
self.apply_qwen_template(messages)
|
|
344
|
+
} else if model_lower.contains("glm") {
|
|
345
|
+
self.apply_glm4_template(messages)
|
|
346
|
+
} else if model_lower.contains("phi") {
|
|
347
|
+
self.apply_phi_template(messages)
|
|
348
|
+
} else {
|
|
349
|
+
match self.architecture.as_str() {
|
|
350
|
+
"llama" => {
|
|
351
|
+
if self.model_id.contains("Llama-3") || self.model_id.contains("llama-3") {
|
|
352
|
+
self.apply_llama3_template(messages)
|
|
353
|
+
} else {
|
|
354
|
+
self.apply_llama2_template(messages)
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
"gemma" => {
|
|
358
|
+
self.apply_gemma_template(messages)
|
|
359
|
+
}
|
|
360
|
+
"qwen" | "qwen2" | "qwen3" => {
|
|
361
|
+
self.apply_qwen_template(messages)
|
|
362
|
+
}
|
|
363
|
+
"phi" | "phi2" | "phi3" => {
|
|
364
|
+
self.apply_phi_template(messages)
|
|
365
|
+
}
|
|
366
|
+
"glm4" => {
|
|
367
|
+
self.apply_glm4_template(messages)
|
|
368
|
+
}
|
|
369
|
+
_ => Ok(self.apply_generic_template(messages))
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
fn apply_llama2_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
375
|
+
let mut prompt = String::new();
|
|
376
|
+
let mut system_message = String::new();
|
|
377
|
+
|
|
378
|
+
for (i, message) in messages.iter().enumerate() {
|
|
379
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
380
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
381
|
+
|
|
382
|
+
match role {
|
|
383
|
+
"system" => {
|
|
384
|
+
system_message = content.to_string();
|
|
385
|
+
}
|
|
386
|
+
"user" => {
|
|
387
|
+
if i == 1 || (i == 0 && system_message.is_empty()) {
|
|
388
|
+
if !system_message.is_empty() {
|
|
389
|
+
prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
|
|
390
|
+
} else {
|
|
391
|
+
prompt.push_str(&format!("[INST] {} [/INST]", content));
|
|
392
|
+
}
|
|
393
|
+
} else {
|
|
394
|
+
prompt.push_str(&format!(" [INST] {} [/INST]", content));
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
"assistant" => {
|
|
398
|
+
prompt.push_str(&format!(" {} </s>", content));
|
|
399
|
+
}
|
|
400
|
+
_ => {}
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
Ok(prompt)
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
408
|
+
let mut prompt = String::new();
|
|
409
|
+
// BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
|
|
410
|
+
|
|
411
|
+
for message in messages {
|
|
412
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
413
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
414
|
+
prompt.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, content));
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
|
|
418
|
+
Ok(prompt)
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
fn apply_mistral_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
422
|
+
let mut prompt = String::new();
|
|
423
|
+
|
|
424
|
+
for message in messages {
|
|
425
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
426
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
427
|
+
|
|
428
|
+
match role {
|
|
429
|
+
"user" => prompt.push_str(&format!("[INST] {} [/INST]", content)),
|
|
430
|
+
"assistant" => prompt.push_str(&format!(" {}</s>", content)),
|
|
431
|
+
"system" => prompt.push_str(&format!("[INST] {} [/INST]\n", content)),
|
|
432
|
+
_ => {}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
Ok(prompt)
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
fn apply_gemma_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
440
|
+
let mut prompt = String::new();
|
|
441
|
+
|
|
442
|
+
for message in messages {
|
|
443
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
444
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
445
|
+
|
|
446
|
+
match role {
|
|
447
|
+
"system" => {
|
|
448
|
+
prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
|
|
449
|
+
}
|
|
450
|
+
"user" => {
|
|
451
|
+
if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
|
|
452
|
+
prompt.push_str("<start_of_turn>user\n");
|
|
453
|
+
}
|
|
454
|
+
prompt.push_str(&format!("{}<end_of_turn>\n", content));
|
|
455
|
+
}
|
|
456
|
+
"assistant" | "model" => {
|
|
457
|
+
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
|
|
458
|
+
}
|
|
459
|
+
_ => {}
|
|
460
|
+
}
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
prompt.push_str("<start_of_turn>model\n");
|
|
464
|
+
Ok(prompt)
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
fn apply_qwen_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
468
|
+
let mut prompt = String::new();
|
|
469
|
+
|
|
470
|
+
for message in messages {
|
|
471
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
472
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
473
|
+
|
|
474
|
+
match role {
|
|
475
|
+
"system" => {
|
|
476
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
|
477
|
+
}
|
|
478
|
+
"user" => {
|
|
479
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
|
480
|
+
}
|
|
481
|
+
"assistant" => {
|
|
482
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
|
483
|
+
}
|
|
484
|
+
"tool" => {
|
|
485
|
+
prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
|
|
486
|
+
}
|
|
487
|
+
_ => {}
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
// Add generation prompt
|
|
492
|
+
prompt.push_str("<|im_start|>assistant\n");
|
|
493
|
+
Ok(prompt)
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
497
|
+
let mut prompt = String::new();
|
|
498
|
+
|
|
499
|
+
// Check if it's Phi-3 (newer format) or Phi-2/Phi (simpler format)
|
|
500
|
+
let is_phi3 = self.model_id.contains("phi-3") || self.model_id.contains("Phi-3") || self.architecture == "phi3";
|
|
501
|
+
|
|
502
|
+
if is_phi3 {
|
|
503
|
+
// Phi-3 format
|
|
504
|
+
for message in messages {
|
|
505
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
506
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
507
|
+
|
|
508
|
+
match role {
|
|
509
|
+
"system" => {
|
|
510
|
+
prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
|
|
511
|
+
}
|
|
512
|
+
"user" => {
|
|
513
|
+
prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
|
|
514
|
+
}
|
|
515
|
+
"assistant" => {
|
|
516
|
+
prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
|
|
517
|
+
}
|
|
518
|
+
_ => {}
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
prompt.push_str("<|assistant|>\n");
|
|
522
|
+
} else {
|
|
523
|
+
// Phi-2 format
|
|
524
|
+
for message in messages {
|
|
525
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
526
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
527
|
+
|
|
528
|
+
match role {
|
|
529
|
+
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
530
|
+
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
531
|
+
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
532
|
+
_ => {}
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
prompt.push_str("Assistant: ");
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
Ok(prompt)
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
fn apply_glm4_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
542
|
+
let mut prompt = String::new();
|
|
543
|
+
|
|
544
|
+
prompt.push_str("[gMASK]<sop>");
|
|
545
|
+
|
|
546
|
+
for message in messages {
|
|
547
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
548
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
549
|
+
|
|
550
|
+
match role {
|
|
551
|
+
"system" => {
|
|
552
|
+
prompt.push_str(&format!("<|system|>\n{}", content));
|
|
553
|
+
}
|
|
554
|
+
"user" => {
|
|
555
|
+
prompt.push_str(&format!("<|user|>\n{}", content));
|
|
556
|
+
}
|
|
557
|
+
"assistant" => {
|
|
558
|
+
prompt.push_str(&format!("<|assistant|>\n{}", content));
|
|
559
|
+
}
|
|
560
|
+
_ => {}
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
prompt.push_str("<|assistant|>\n");
|
|
565
|
+
Ok(prompt)
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
fn apply_chatml_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
569
|
+
let mut prompt = String::new();
|
|
570
|
+
|
|
571
|
+
for message in messages {
|
|
572
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
573
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
574
|
+
|
|
575
|
+
prompt.push_str(&format!("<|{}|>\n{}</s>\n", role, content));
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
prompt.push_str("<|assistant|>");
|
|
579
|
+
Ok(prompt)
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
|
|
583
|
+
let mut prompt = String::new();
|
|
584
|
+
|
|
585
|
+
for message in messages {
|
|
586
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
587
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
588
|
+
prompt.push_str(&format!("{}: {}\n", role, content));
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
prompt.push_str("assistant: ");
|
|
592
|
+
prompt
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
/// Clear the KV cache between generations
|
|
596
|
+
pub fn clear_kv_cache(&mut self) {
|
|
597
|
+
// Only some quantized models expose clear_kv_cache
|
|
598
|
+
if let ModelType::Qwen3(model) = &mut self.model {
|
|
599
|
+
model.clear_kv_cache();
|
|
600
|
+
}
|
|
601
|
+
// Other quantized models (Llama, Gemma, Qwen2, Phi, Phi3) don't expose
|
|
602
|
+
// clear_kv_cache in candle-transformers. For these, the generate() method
|
|
603
|
+
// in Ruby calls clear_cache which recreates the model if needed.
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
fn generate_tokens(
|
|
607
|
+
&mut self,
|
|
608
|
+
prompt_tokens: Vec<u32>,
|
|
609
|
+
config: &GenerationConfig,
|
|
610
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
611
|
+
) -> CandleResult<Vec<u32>> {
|
|
612
|
+
let mut text_gen = TextGeneration::new(config);
|
|
613
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
614
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
615
|
+
|
|
616
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
617
|
+
let start_gen = all_tokens.len();
|
|
618
|
+
|
|
619
|
+
for index in 0..config.max_length {
|
|
620
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
621
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
622
|
+
let ctxt = &all_tokens[start_pos..];
|
|
623
|
+
|
|
624
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
625
|
+
let input = input.contiguous()?;
|
|
626
|
+
|
|
627
|
+
let logits = match &mut self.model {
|
|
628
|
+
ModelType::Llama(model) => model.forward(&input, start_pos)?,
|
|
629
|
+
ModelType::Gemma(model) => model.forward(&input, start_pos)?,
|
|
630
|
+
ModelType::Qwen(model) => model.forward(&input, start_pos)?,
|
|
631
|
+
ModelType::Qwen3(model) => model.forward(&input, start_pos)?,
|
|
632
|
+
ModelType::Phi(model) => model.forward(&input, start_pos)?,
|
|
633
|
+
ModelType::Phi3(model) => model.forward(&input, start_pos)?,
|
|
634
|
+
ModelType::Glm4(model) => model.forward(&input, start_pos)?,
|
|
635
|
+
};
|
|
636
|
+
|
|
637
|
+
let logits = logits.squeeze(0)?;
|
|
638
|
+
let logits = if logits.dims().len() == 2 {
|
|
639
|
+
let seq_len = logits.dim(0)?;
|
|
640
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
641
|
+
} else {
|
|
642
|
+
logits
|
|
643
|
+
};
|
|
644
|
+
|
|
645
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
646
|
+
|
|
647
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
648
|
+
|
|
649
|
+
all_tokens.push(next_token);
|
|
650
|
+
|
|
651
|
+
// Stream callback
|
|
652
|
+
if let Some(ref mut cb) = callback {
|
|
653
|
+
if config.debug_tokens {
|
|
654
|
+
// In debug mode, only show debug tokens
|
|
655
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
656
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
657
|
+
} else {
|
|
658
|
+
// Normal mode: use incremental decoding for proper text
|
|
659
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
660
|
+
cb(&decoded_text);
|
|
661
|
+
}
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
// Check stop conditions
|
|
665
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
666
|
+
break;
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
// Check if constraint is satisfied (early stopping)
|
|
670
|
+
if config.stop_on_constraint_satisfaction {
|
|
671
|
+
let satisfied = if config.stop_on_match {
|
|
672
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
673
|
+
} else {
|
|
674
|
+
text_gen.is_constraint_satisfied()
|
|
675
|
+
};
|
|
676
|
+
if satisfied {
|
|
677
|
+
break;
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
// Check stop sequences
|
|
682
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
683
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
684
|
+
break;
|
|
685
|
+
}
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
Ok(if config.include_prompt {
|
|
689
|
+
all_tokens
|
|
690
|
+
} else {
|
|
691
|
+
all_tokens[start_gen..].to_vec()
|
|
692
|
+
})
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
impl TextGenerator for QuantizedGGUF {
|
|
697
|
+
fn generate(
|
|
698
|
+
&mut self,
|
|
699
|
+
prompt: &str,
|
|
700
|
+
config: &GenerationConfig,
|
|
701
|
+
) -> CandleResult<String> {
|
|
702
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
703
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
704
|
+
|
|
705
|
+
if config.debug_tokens {
|
|
706
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
707
|
+
} else {
|
|
708
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
709
|
+
}
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
fn generate_stream(
|
|
713
|
+
&mut self,
|
|
714
|
+
prompt: &str,
|
|
715
|
+
config: &GenerationConfig,
|
|
716
|
+
mut callback: impl FnMut(&str),
|
|
717
|
+
) -> CandleResult<String> {
|
|
718
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
719
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
720
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
fn model_name(&self) -> &str {
|
|
724
|
+
&self.model_id
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
fn device(&self) -> &Device {
|
|
728
|
+
&self.device
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
fn clear_cache(&mut self) {
|
|
732
|
+
self.clear_kv_cache();
|
|
733
|
+
}
|
|
734
|
+
}
|