red-candle 1.8.0.pre3-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,734 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_core::quantized::gguf_file;
3
+ use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
4
+ use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
5
+ use candle_transformers::models::quantized_qwen2::ModelWeights as QuantizedQwenModel;
6
+ use candle_transformers::models::quantized_phi::ModelWeights as QuantizedPhiModel;
7
+ use candle_transformers::models::quantized_phi3::ModelWeights as QuantizedPhi3Model;
8
+ use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Model;
9
+ use candle_transformers::models::quantized_glm4::ModelWeights as QuantizedGlm4Model;
10
+ use hf_hub::api::tokio::{Api, ApiRepo};
11
+ use tokenizers::Tokenizer;
12
+ use std::io::Seek;
13
+
14
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
15
+
16
+ /// Unified GGUF model that can load any GGUF file and detect the architecture
17
+ pub struct QuantizedGGUF {
18
+ model: ModelType,
19
+ tokenizer: TokenizerWrapper,
20
+ device: Device,
21
+ model_id: String,
22
+ eos_token_id: u32,
23
+ pub architecture: String,
24
+ _chat_template: Option<String>,
25
+ }
26
+
27
+ enum ModelType {
28
+ Llama(QuantizedLlamaModel),
29
+ Gemma(QuantizedGemmaModel),
30
+ Qwen(QuantizedQwenModel),
31
+ Qwen3(QuantizedQwen3Model),
32
+ Phi(QuantizedPhiModel),
33
+ Phi3(QuantizedPhi3Model),
34
+ Glm4(QuantizedGlm4Model),
35
+ // Mistral uses Llama loader due to tensor naming compatibility
36
+ }
37
+
38
+ impl QuantizedGGUF {
39
+ pub fn eos_token_id(&self) -> u32 {
40
+ self.eos_token_id
41
+ }
42
+
43
+ /// Get the tokenizer
44
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
45
+ &self.tokenizer
46
+ }
47
+
48
+ /// Load a quantized model from a GGUF file
49
+ pub async fn from_pretrained(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
50
+ // Check if user specified an exact GGUF filename
51
+ let (actual_model_id, gguf_file) = if let Some(pos) = model_id.find('@') {
52
+ let (id, filename) = model_id.split_at(pos);
53
+ (id, Some(&filename[1..]))
54
+ } else {
55
+ (model_id, None)
56
+ };
57
+
58
+ let api = Api::new()
59
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
60
+
61
+ let repo = api.model(actual_model_id.to_string());
62
+
63
+ // Download GGUF file
64
+ let gguf_filename = if let Some(filename) = gguf_file {
65
+ // User specified exact filename
66
+ repo.get(filename).await
67
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download GGUF file '{}': {}", filename, e)))?
68
+ .to_string_lossy().to_string()
69
+ } else {
70
+ // Let Ruby handle the search, for now just try a common name
71
+ return Err(candle_core::Error::Msg(
72
+ "Please specify a GGUF filename using gguf_file parameter".to_string()
73
+ ));
74
+ };
75
+
76
+ // Read GGUF metadata to determine architecture
77
+ let mut file = std::fs::File::open(&gguf_filename)?;
78
+ let content = gguf_file::Content::read(&mut file)?;
79
+
80
+ // Detect architecture from metadata
81
+ let architecture = Self::detect_architecture(&content, actual_model_id)?;
82
+
83
+ // For Gemma 3 models, we might need to adjust the architecture
84
+ let architecture = if actual_model_id.contains("gemma-3") || actual_model_id.contains("gemma3") {
85
+ "gemma3".to_string()
86
+ } else {
87
+ architecture
88
+ };
89
+
90
+ // Download tokenizer - either from specified source or with fallback
91
+ let tokenizer_filename = if let Some(source) = tokenizer_source {
92
+ Self::download_tokenizer_from_source(&api, source).await?
93
+ } else {
94
+ Self::download_tokenizer(&api, &repo, actual_model_id, &architecture).await?
95
+ };
96
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
97
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
98
+
99
+ // Determine EOS token based on architecture and model
100
+ let eos_token_id = Self::determine_eos_token(&tokenizer, &architecture, actual_model_id);
101
+
102
+ // Load the appropriate model based on architecture
103
+ file.seek(std::io::SeekFrom::Start(0))?;
104
+ let content = gguf_file::Content::read(&mut file)?;
105
+
106
+ let model = match architecture.as_str() {
107
+ "llama" | "mistral" => {
108
+ // Both use the same GGUF format with llama.cpp tensor names
109
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
110
+ ModelType::Llama(model)
111
+ }
112
+ "qwen" | "qwen2" | "qwen3" => {
113
+ // Try different loaders based on what metadata is available
114
+ if content.metadata.contains_key("llama.attention.head_count") {
115
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
116
+ ModelType::Llama(model)
117
+ } else if content.metadata.contains_key("qwen2.attention.head_count") {
118
+ let model = QuantizedQwenModel::from_gguf(content, &mut file, &device)?;
119
+ ModelType::Qwen(model)
120
+ } else if content.metadata.contains_key("qwen3.attention.head_count") {
121
+ let model = QuantizedQwen3Model::from_gguf(content, &mut file, &device)?;
122
+ ModelType::Qwen3(model)
123
+ } else {
124
+ // Last resort: try llama loader anyway, as it's the most common
125
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
126
+ ModelType::Llama(model)
127
+ }
128
+ }
129
+ "gemma" | "gemma2" | "gemma3" => {
130
+ // Try Gemma-specific loader first, fall back to Llama if it fails
131
+ match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
132
+ Ok(model) => ModelType::Gemma(model),
133
+ Err(e) if e.to_string().contains("gemma3.attention.head_count") => {
134
+ // This might be an older Gemma GGUF that uses llama format
135
+ // Note: Some Gemma GGUF files may not be compatible
136
+ file.seek(std::io::SeekFrom::Start(0))?;
137
+ let content = gguf_file::Content::read(&mut file)?;
138
+ let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
139
+ ModelType::Llama(model)
140
+ }
141
+ Err(e) => return Err(e),
142
+ }
143
+ }
144
+ "phi" | "phi2" => {
145
+ let model = QuantizedPhiModel::from_gguf(content, &mut file, &device)?;
146
+ ModelType::Phi(model)
147
+ }
148
+ "phi3" => {
149
+ // QuantizedPhi3Model requires an additional `approx` parameter
150
+ // Setting to false to avoid performance issues without flash-attn
151
+ let approx = false;
152
+ let model = QuantizedPhi3Model::from_gguf(approx, content, &mut file, &device)?;
153
+ ModelType::Phi3(model)
154
+ }
155
+ "glm4" => {
156
+ if content.metadata.contains_key("glm4.attention.head_count") {
157
+ let model = QuantizedGlm4Model::from_gguf(content, &mut file, &device, DType::F32)?;
158
+ ModelType::Glm4(model)
159
+ } else {
160
+ return Err(candle_core::Error::Msg(
161
+ "GLM-4 GGUF file does not contain expected glm4.* metadata keys".to_string()
162
+ ));
163
+ }
164
+ }
165
+ _ => {
166
+ return Err(candle_core::Error::Msg(format!(
167
+ "Unsupported architecture: {}. Supported: llama, mistral, gemma, qwen, qwen2, qwen3, phi, phi2, phi3, glm4",
168
+ architecture
169
+ )));
170
+ }
171
+ };
172
+
173
+ // Detect chat template (for now, use defaults based on architecture)
174
+ let chat_template = Self::detect_chat_template(&tokenizer, &architecture, actual_model_id);
175
+
176
+ Ok(Self {
177
+ model,
178
+ tokenizer: TokenizerWrapper::new(tokenizer),
179
+ device,
180
+ model_id: actual_model_id.to_string(),
181
+ eos_token_id,
182
+ architecture: architecture.clone(),
183
+ _chat_template: chat_template,
184
+ })
185
+ }
186
+
187
+ /// Detect architecture from GGUF metadata or model name
188
+ fn detect_architecture(content: &gguf_file::Content, model_id: &str) -> CandleResult<String> {
189
+ // First try to get from metadata
190
+ if let Some(gguf_file::Value::String(arch)) = content.metadata.get("general.architecture") {
191
+ return Ok(arch.clone());
192
+ }
193
+
194
+ // Fallback to model name detection
195
+ let model_lower = model_id.to_lowercase();
196
+ if model_lower.contains("llama") || model_lower.contains("tinyllama") {
197
+ Ok("llama".to_string())
198
+ } else if model_lower.contains("mistral") {
199
+ Ok("mistral".to_string())
200
+ } else if model_lower.contains("gemma") {
201
+ Ok("gemma".to_string())
202
+ } else if model_lower.contains("qwen") {
203
+ Ok("qwen".to_string())
204
+ } else if model_lower.contains("glm") {
205
+ Ok("glm4".to_string())
206
+ } else if model_lower.contains("phi-3") || model_lower.contains("phi3") {
207
+ Ok("phi3".to_string())
208
+ } else if model_lower.contains("phi-2") || model_lower.contains("phi2") {
209
+ Ok("phi2".to_string())
210
+ } else if model_lower.contains("phi") {
211
+ Ok("phi".to_string())
212
+ } else {
213
+ Err(candle_core::Error::Msg(
214
+ "Could not determine model architecture from metadata or name".to_string()
215
+ ))
216
+ }
217
+ }
218
+
219
+ /// Download tokenizer from a specific source
220
+ async fn download_tokenizer_from_source(
221
+ api: &Api,
222
+ source: &str
223
+ ) -> CandleResult<std::path::PathBuf> {
224
+ // Check if it's a local file path
225
+ if source.ends_with(".json") && std::path::Path::new(source).exists() {
226
+ return Ok(std::path::PathBuf::from(source));
227
+ }
228
+
229
+ // Otherwise treat it as a HuggingFace repo
230
+ let repo = api.model(source.to_string());
231
+
232
+ // Try tokenizer.json first
233
+ if let Ok(path) = repo.get("tokenizer.json").await {
234
+ return Ok(path);
235
+ }
236
+
237
+ // Try tokenizer.model (for models that use sentencepiece)
238
+ if let Ok(path) = repo.get("tokenizer.model").await {
239
+ return Ok(path);
240
+ }
241
+
242
+ Err(candle_core::Error::Msg(format!(
243
+ "Failed to find tokenizer in specified source: {}",
244
+ source
245
+ )))
246
+ }
247
+
248
+ /// Download tokenizer with architecture-specific fallbacks
249
+ async fn download_tokenizer(
250
+ _api: &Api,
251
+ repo: &ApiRepo,
252
+ model_id: &str,
253
+ _architecture: &str
254
+ ) -> CandleResult<std::path::PathBuf> {
255
+ // First try to get tokenizer.json from the GGUF repo
256
+ if let Ok(path) = repo.get("tokenizer.json").await {
257
+ return Ok(path);
258
+ }
259
+
260
+ // Try tokenizer.model (for models that use sentencepiece)
261
+ if let Ok(path) = repo.get("tokenizer.model").await {
262
+ return Ok(path);
263
+ }
264
+
265
+ // If no tokenizer found in GGUF repo, return error
266
+ // Ruby will handle the fallback logic
267
+ Err(candle_core::Error::Msg(format!(
268
+ "No tokenizer found in GGUF repository {}. Please specify a tokenizer source.",
269
+ model_id
270
+ )))
271
+ }
272
+
273
+ /// Determine EOS token based on architecture and model
274
+ fn determine_eos_token(tokenizer: &Tokenizer, architecture: &str, model_id: &str) -> u32 {
275
+ let vocab = tokenizer.get_vocab(true);
276
+
277
+ match architecture {
278
+ "llama" | "mistral" => {
279
+ // Check if it's Llama 3
280
+ if model_id.contains("Llama-3") || model_id.contains("llama-3") {
281
+ vocab.get("<|eot_id|>")
282
+ .or_else(|| vocab.get("<|end_of_text|>"))
283
+ .copied()
284
+ .unwrap_or(128009)
285
+ } else {
286
+ // Llama 2 and Mistral
287
+ vocab.get("</s>")
288
+ .copied()
289
+ .unwrap_or(2)
290
+ }
291
+ }
292
+ "gemma" => {
293
+ vocab.get("<eos>")
294
+ .or_else(|| vocab.get("<end_of_turn>"))
295
+ .copied()
296
+ .unwrap_or(1)
297
+ }
298
+ "qwen" | "qwen2" | "qwen3" => {
299
+ vocab.get("<|endoftext|>")
300
+ .or_else(|| vocab.get("<|im_end|>"))
301
+ .or_else(|| vocab.get("</s>"))
302
+ .copied()
303
+ .unwrap_or(151643) // Default Qwen3 EOS token
304
+ }
305
+ "phi" | "phi2" | "phi3" => {
306
+ vocab.get("<|endoftext|>")
307
+ .or_else(|| vocab.get("<|end|>"))
308
+ .or_else(|| vocab.get("</s>"))
309
+ .copied()
310
+ .unwrap_or(50256) // Default GPT-2 style EOS token
311
+ }
312
+ "glm4" => {
313
+ vocab.get("<|endoftext|>")
314
+ .or_else(|| vocab.get("<|user|>"))
315
+ .or_else(|| vocab.get("</s>"))
316
+ .copied()
317
+ .unwrap_or(151329)
318
+ }
319
+ _ => 2, // Default
320
+ }
321
+ }
322
+
323
+ /// Detect chat template based on model
324
+ fn detect_chat_template(_tokenizer: &Tokenizer, _architecture: &str, _model_id: &str) -> Option<String> {
325
+ // For now, return None and handle templates in apply_chat_template
326
+ // In the future, this could read from tokenizer config
327
+ None
328
+ }
329
+
330
+ /// Apply chat template based on detected architecture
331
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
332
+ // Check model name since Mistral GGUF reports as llama architecture
333
+ let model_lower = self.model_id.to_lowercase();
334
+
335
+ if model_lower.contains("tinyllama") {
336
+ self.apply_chatml_template(messages)
337
+ } else if model_lower.contains("mistral") {
338
+ self.apply_mistral_template(messages)
339
+ } else if model_lower.contains("gemma") {
340
+ // Always use Gemma template for Gemma models, regardless of loader used
341
+ self.apply_gemma_template(messages)
342
+ } else if model_lower.contains("qwen") {
343
+ self.apply_qwen_template(messages)
344
+ } else if model_lower.contains("glm") {
345
+ self.apply_glm4_template(messages)
346
+ } else if model_lower.contains("phi") {
347
+ self.apply_phi_template(messages)
348
+ } else {
349
+ match self.architecture.as_str() {
350
+ "llama" => {
351
+ if self.model_id.contains("Llama-3") || self.model_id.contains("llama-3") {
352
+ self.apply_llama3_template(messages)
353
+ } else {
354
+ self.apply_llama2_template(messages)
355
+ }
356
+ }
357
+ "gemma" => {
358
+ self.apply_gemma_template(messages)
359
+ }
360
+ "qwen" | "qwen2" | "qwen3" => {
361
+ self.apply_qwen_template(messages)
362
+ }
363
+ "phi" | "phi2" | "phi3" => {
364
+ self.apply_phi_template(messages)
365
+ }
366
+ "glm4" => {
367
+ self.apply_glm4_template(messages)
368
+ }
369
+ _ => Ok(self.apply_generic_template(messages))
370
+ }
371
+ }
372
+ }
373
+
374
+ fn apply_llama2_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
375
+ let mut prompt = String::new();
376
+ let mut system_message = String::new();
377
+
378
+ for (i, message) in messages.iter().enumerate() {
379
+ let role = message["role"].as_str().unwrap_or("");
380
+ let content = message["content"].as_str().unwrap_or("");
381
+
382
+ match role {
383
+ "system" => {
384
+ system_message = content.to_string();
385
+ }
386
+ "user" => {
387
+ if i == 1 || (i == 0 && system_message.is_empty()) {
388
+ if !system_message.is_empty() {
389
+ prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
390
+ } else {
391
+ prompt.push_str(&format!("[INST] {} [/INST]", content));
392
+ }
393
+ } else {
394
+ prompt.push_str(&format!(" [INST] {} [/INST]", content));
395
+ }
396
+ }
397
+ "assistant" => {
398
+ prompt.push_str(&format!(" {} </s>", content));
399
+ }
400
+ _ => {}
401
+ }
402
+ }
403
+
404
+ Ok(prompt)
405
+ }
406
+
407
+ fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
408
+ let mut prompt = String::new();
409
+ // BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
410
+
411
+ for message in messages {
412
+ let role = message["role"].as_str().unwrap_or("");
413
+ let content = message["content"].as_str().unwrap_or("");
414
+ prompt.push_str(&format!("<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, content));
415
+ }
416
+
417
+ prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
418
+ Ok(prompt)
419
+ }
420
+
421
+ fn apply_mistral_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
422
+ let mut prompt = String::new();
423
+
424
+ for message in messages {
425
+ let role = message["role"].as_str().unwrap_or("");
426
+ let content = message["content"].as_str().unwrap_or("");
427
+
428
+ match role {
429
+ "user" => prompt.push_str(&format!("[INST] {} [/INST]", content)),
430
+ "assistant" => prompt.push_str(&format!(" {}</s>", content)),
431
+ "system" => prompt.push_str(&format!("[INST] {} [/INST]\n", content)),
432
+ _ => {}
433
+ }
434
+ }
435
+
436
+ Ok(prompt)
437
+ }
438
+
439
+ fn apply_gemma_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
440
+ let mut prompt = String::new();
441
+
442
+ for message in messages {
443
+ let role = message["role"].as_str().unwrap_or("");
444
+ let content = message["content"].as_str().unwrap_or("");
445
+
446
+ match role {
447
+ "system" => {
448
+ prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
449
+ }
450
+ "user" => {
451
+ if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
452
+ prompt.push_str("<start_of_turn>user\n");
453
+ }
454
+ prompt.push_str(&format!("{}<end_of_turn>\n", content));
455
+ }
456
+ "assistant" | "model" => {
457
+ prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
458
+ }
459
+ _ => {}
460
+ }
461
+ }
462
+
463
+ prompt.push_str("<start_of_turn>model\n");
464
+ Ok(prompt)
465
+ }
466
+
467
+ fn apply_qwen_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
468
+ let mut prompt = String::new();
469
+
470
+ for message in messages {
471
+ let role = message["role"].as_str().unwrap_or("");
472
+ let content = message["content"].as_str().unwrap_or("");
473
+
474
+ match role {
475
+ "system" => {
476
+ prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
477
+ }
478
+ "user" => {
479
+ prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
480
+ }
481
+ "assistant" => {
482
+ prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
483
+ }
484
+ "tool" => {
485
+ prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
486
+ }
487
+ _ => {}
488
+ }
489
+ }
490
+
491
+ // Add generation prompt
492
+ prompt.push_str("<|im_start|>assistant\n");
493
+ Ok(prompt)
494
+ }
495
+
496
+ fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
497
+ let mut prompt = String::new();
498
+
499
+ // Check if it's Phi-3 (newer format) or Phi-2/Phi (simpler format)
500
+ let is_phi3 = self.model_id.contains("phi-3") || self.model_id.contains("Phi-3") || self.architecture == "phi3";
501
+
502
+ if is_phi3 {
503
+ // Phi-3 format
504
+ for message in messages {
505
+ let role = message["role"].as_str().unwrap_or("");
506
+ let content = message["content"].as_str().unwrap_or("");
507
+
508
+ match role {
509
+ "system" => {
510
+ prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
511
+ }
512
+ "user" => {
513
+ prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
514
+ }
515
+ "assistant" => {
516
+ prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
517
+ }
518
+ _ => {}
519
+ }
520
+ }
521
+ prompt.push_str("<|assistant|>\n");
522
+ } else {
523
+ // Phi-2 format
524
+ for message in messages {
525
+ let role = message["role"].as_str().unwrap_or("");
526
+ let content = message["content"].as_str().unwrap_or("");
527
+
528
+ match role {
529
+ "system" => prompt.push_str(&format!("System: {}\n", content)),
530
+ "user" => prompt.push_str(&format!("User: {}\n", content)),
531
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
532
+ _ => {}
533
+ }
534
+ }
535
+ prompt.push_str("Assistant: ");
536
+ }
537
+
538
+ Ok(prompt)
539
+ }
540
+
541
+ fn apply_glm4_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
542
+ let mut prompt = String::new();
543
+
544
+ prompt.push_str("[gMASK]<sop>");
545
+
546
+ for message in messages {
547
+ let role = message["role"].as_str().unwrap_or("");
548
+ let content = message["content"].as_str().unwrap_or("");
549
+
550
+ match role {
551
+ "system" => {
552
+ prompt.push_str(&format!("<|system|>\n{}", content));
553
+ }
554
+ "user" => {
555
+ prompt.push_str(&format!("<|user|>\n{}", content));
556
+ }
557
+ "assistant" => {
558
+ prompt.push_str(&format!("<|assistant|>\n{}", content));
559
+ }
560
+ _ => {}
561
+ }
562
+ }
563
+
564
+ prompt.push_str("<|assistant|>\n");
565
+ Ok(prompt)
566
+ }
567
+
568
+ fn apply_chatml_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
569
+ let mut prompt = String::new();
570
+
571
+ for message in messages {
572
+ let role = message["role"].as_str().unwrap_or("");
573
+ let content = message["content"].as_str().unwrap_or("");
574
+
575
+ prompt.push_str(&format!("<|{}|>\n{}</s>\n", role, content));
576
+ }
577
+
578
+ prompt.push_str("<|assistant|>");
579
+ Ok(prompt)
580
+ }
581
+
582
+ fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
583
+ let mut prompt = String::new();
584
+
585
+ for message in messages {
586
+ let role = message["role"].as_str().unwrap_or("");
587
+ let content = message["content"].as_str().unwrap_or("");
588
+ prompt.push_str(&format!("{}: {}\n", role, content));
589
+ }
590
+
591
+ prompt.push_str("assistant: ");
592
+ prompt
593
+ }
594
+
595
+ /// Clear the KV cache between generations
596
+ pub fn clear_kv_cache(&mut self) {
597
+ // Only some quantized models expose clear_kv_cache
598
+ if let ModelType::Qwen3(model) = &mut self.model {
599
+ model.clear_kv_cache();
600
+ }
601
+ // Other quantized models (Llama, Gemma, Qwen2, Phi, Phi3) don't expose
602
+ // clear_kv_cache in candle-transformers. For these, the generate() method
603
+ // in Ruby calls clear_cache which recreates the model if needed.
604
+ }
605
+
606
+ fn generate_tokens(
607
+ &mut self,
608
+ prompt_tokens: Vec<u32>,
609
+ config: &GenerationConfig,
610
+ mut callback: Option<impl FnMut(&str)>,
611
+ ) -> CandleResult<Vec<u32>> {
612
+ let mut text_gen = TextGeneration::new(config);
613
+ text_gen.set_eos_token_id(self.eos_token_id);
614
+ text_gen.set_tokens(prompt_tokens.clone());
615
+
616
+ let mut all_tokens = prompt_tokens.clone();
617
+ let start_gen = all_tokens.len();
618
+
619
+ for index in 0..config.max_length {
620
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
621
+ let start_pos = all_tokens.len().saturating_sub(context_size);
622
+ let ctxt = &all_tokens[start_pos..];
623
+
624
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
625
+ let input = input.contiguous()?;
626
+
627
+ let logits = match &mut self.model {
628
+ ModelType::Llama(model) => model.forward(&input, start_pos)?,
629
+ ModelType::Gemma(model) => model.forward(&input, start_pos)?,
630
+ ModelType::Qwen(model) => model.forward(&input, start_pos)?,
631
+ ModelType::Qwen3(model) => model.forward(&input, start_pos)?,
632
+ ModelType::Phi(model) => model.forward(&input, start_pos)?,
633
+ ModelType::Phi3(model) => model.forward(&input, start_pos)?,
634
+ ModelType::Glm4(model) => model.forward(&input, start_pos)?,
635
+ };
636
+
637
+ let logits = logits.squeeze(0)?;
638
+ let logits = if logits.dims().len() == 2 {
639
+ let seq_len = logits.dim(0)?;
640
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
641
+ } else {
642
+ logits
643
+ };
644
+
645
+ let logits = logits.to_dtype(DType::F32)?;
646
+
647
+ let next_token = text_gen.sample_next_token(&logits)?;
648
+
649
+ all_tokens.push(next_token);
650
+
651
+ // Stream callback
652
+ if let Some(ref mut cb) = callback {
653
+ if config.debug_tokens {
654
+ // In debug mode, only show debug tokens
655
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
656
+ cb(&format!("[{}:{}]", next_token, token_piece));
657
+ } else {
658
+ // Normal mode: use incremental decoding for proper text
659
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
660
+ cb(&decoded_text);
661
+ }
662
+ }
663
+
664
+ // Check stop conditions
665
+ if text_gen.should_stop(next_token, config.max_length) {
666
+ break;
667
+ }
668
+
669
+ // Check if constraint is satisfied (early stopping)
670
+ if config.stop_on_constraint_satisfaction {
671
+ let satisfied = if config.stop_on_match {
672
+ text_gen.is_constraint_satisfied_stop_on_match()
673
+ } else {
674
+ text_gen.is_constraint_satisfied()
675
+ };
676
+ if satisfied {
677
+ break;
678
+ }
679
+ }
680
+
681
+ // Check stop sequences
682
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
683
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
684
+ break;
685
+ }
686
+ }
687
+
688
+ Ok(if config.include_prompt {
689
+ all_tokens
690
+ } else {
691
+ all_tokens[start_gen..].to_vec()
692
+ })
693
+ }
694
+ }
695
+
696
+ impl TextGenerator for QuantizedGGUF {
697
+ fn generate(
698
+ &mut self,
699
+ prompt: &str,
700
+ config: &GenerationConfig,
701
+ ) -> CandleResult<String> {
702
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
703
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
704
+
705
+ if config.debug_tokens {
706
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
707
+ } else {
708
+ self.tokenizer.decode(&output_tokens, true)
709
+ }
710
+ }
711
+
712
+ fn generate_stream(
713
+ &mut self,
714
+ prompt: &str,
715
+ config: &GenerationConfig,
716
+ mut callback: impl FnMut(&str),
717
+ ) -> CandleResult<String> {
718
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
719
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
720
+ self.tokenizer.decode(&output_tokens, true)
721
+ }
722
+
723
+ fn model_name(&self) -> &str {
724
+ &self.model_id
725
+ }
726
+
727
+ fn device(&self) -> &Device {
728
+ &self.device
729
+ }
730
+
731
+ fn clear_cache(&mut self) {
732
+ self.clear_kv_cache();
733
+ }
734
+ }