red-candle 1.0.1 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +57 -4
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/build.rs +6 -5
- data/ext/candle/extconf.rb +5 -6
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +123 -0
- data/ext/candle/src/llm/generation_config.rs +5 -0
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +285 -0
- data/ext/candle/src/llm/quantized_gguf.rs +155 -4
- data/ext/candle/src/llm/qwen.rs +229 -0
- data/ext/candle/src/llm/text_generation.rs +66 -2
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +42 -4
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/build_info.rb +2 -2
- data/lib/candle/llm.rb +109 -3
- data/lib/candle/version.rb +1 -1
- data/lib/red-candle.rb +1 -0
- metadata +15 -4
@@ -2,6 +2,9 @@ use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
2
|
use candle_core::quantized::gguf_file;
|
3
3
|
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaModel;
|
4
4
|
use candle_transformers::models::quantized_gemma3::ModelWeights as QuantizedGemmaModel;
|
5
|
+
use candle_transformers::models::quantized_qwen2::ModelWeights as QuantizedQwenModel;
|
6
|
+
use candle_transformers::models::quantized_phi::ModelWeights as QuantizedPhiModel;
|
7
|
+
use candle_transformers::models::quantized_phi3::ModelWeights as QuantizedPhi3Model;
|
5
8
|
use hf_hub::api::tokio::{Api, ApiRepo};
|
6
9
|
use tokenizers::Tokenizer;
|
7
10
|
use std::io::Seek;
|
@@ -9,7 +12,6 @@ use std::io::Seek;
|
|
9
12
|
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
10
13
|
|
11
14
|
/// Unified GGUF model that can load any GGUF file and detect the architecture
|
12
|
-
#[derive(Debug)]
|
13
15
|
pub struct QuantizedGGUF {
|
14
16
|
model: ModelType,
|
15
17
|
tokenizer: TokenizerWrapper,
|
@@ -20,10 +22,12 @@ pub struct QuantizedGGUF {
|
|
20
22
|
_chat_template: Option<String>,
|
21
23
|
}
|
22
24
|
|
23
|
-
#[derive(Debug)]
|
24
25
|
enum ModelType {
|
25
26
|
Llama(QuantizedLlamaModel),
|
26
27
|
Gemma(QuantizedGemmaModel),
|
28
|
+
Qwen(QuantizedQwenModel),
|
29
|
+
Phi(QuantizedPhiModel),
|
30
|
+
Phi3(QuantizedPhi3Model),
|
27
31
|
// Mistral uses Llama loader due to tensor naming compatibility
|
28
32
|
}
|
29
33
|
|
@@ -97,6 +101,34 @@ impl QuantizedGGUF {
|
|
97
101
|
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
98
102
|
ModelType::Llama(model)
|
99
103
|
}
|
104
|
+
"qwen" | "qwen2" | "qwen3" => {
|
105
|
+
// Try different loaders based on what metadata is available
|
106
|
+
if content.metadata.contains_key("llama.attention.head_count") {
|
107
|
+
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
108
|
+
ModelType::Llama(model)
|
109
|
+
} else if content.metadata.contains_key("qwen2.attention.head_count") {
|
110
|
+
let model = QuantizedQwenModel::from_gguf(content, &mut file, &device)?;
|
111
|
+
ModelType::Qwen(model)
|
112
|
+
} else if content.metadata.contains_key("qwen3.attention.head_count") {
|
113
|
+
// Qwen3 GGUF files use a different metadata format
|
114
|
+
// The quantized_qwen3 module is not yet in the released version of candle-transformers
|
115
|
+
return Err(candle_core::Error::Msg(format!(
|
116
|
+
"Qwen3 GGUF format detected but not yet fully supported.\n\n\
|
117
|
+
The file contains qwen3.* metadata keys which require candle-transformers > 0.9.1.\n\n\
|
118
|
+
Current alternatives:\n\
|
119
|
+
1. Use Qwen2.5 GGUF models which work well:\n\
|
120
|
+
- Qwen/Qwen2.5-7B-Instruct-GGUF (recommended)\n\
|
121
|
+
- Qwen/Qwen2.5-32B-Instruct-GGUF\n\
|
122
|
+
2. Use non-quantized Qwen models with safetensors\n\
|
123
|
+
3. Wait for candle-transformers update with quantized_qwen3 support\n\n\
|
124
|
+
Note: Qwen2.5 models have similar capabilities to Qwen3."
|
125
|
+
)));
|
126
|
+
} else {
|
127
|
+
// Last resort: try llama loader anyway, as it's the most common
|
128
|
+
let model = QuantizedLlamaModel::from_gguf(content, &mut file, &device)?;
|
129
|
+
ModelType::Llama(model)
|
130
|
+
}
|
131
|
+
}
|
100
132
|
"gemma" | "gemma2" | "gemma3" => {
|
101
133
|
// Try Gemma-specific loader first, fall back to Llama if it fails
|
102
134
|
match QuantizedGemmaModel::from_gguf(content, &mut file, &device) {
|
@@ -112,9 +144,20 @@ impl QuantizedGGUF {
|
|
112
144
|
Err(e) => return Err(e),
|
113
145
|
}
|
114
146
|
}
|
147
|
+
"phi" | "phi2" => {
|
148
|
+
let model = QuantizedPhiModel::from_gguf(content, &mut file, &device)?;
|
149
|
+
ModelType::Phi(model)
|
150
|
+
}
|
151
|
+
"phi3" => {
|
152
|
+
// QuantizedPhi3Model requires an additional `approx` parameter
|
153
|
+
// Setting to false to avoid performance issues without flash-attn
|
154
|
+
let approx = false;
|
155
|
+
let model = QuantizedPhi3Model::from_gguf(approx, content, &mut file, &device)?;
|
156
|
+
ModelType::Phi3(model)
|
157
|
+
}
|
115
158
|
_ => {
|
116
159
|
return Err(candle_core::Error::Msg(format!(
|
117
|
-
"Unsupported architecture: {}. Supported: llama, mistral, gemma",
|
160
|
+
"Unsupported architecture: {}. Supported: llama, mistral, gemma, qwen, qwen2, qwen3, phi, phi2, phi3",
|
118
161
|
architecture
|
119
162
|
)));
|
120
163
|
}
|
@@ -149,6 +192,14 @@ impl QuantizedGGUF {
|
|
149
192
|
Ok("mistral".to_string())
|
150
193
|
} else if model_lower.contains("gemma") {
|
151
194
|
Ok("gemma".to_string())
|
195
|
+
} else if model_lower.contains("qwen") {
|
196
|
+
Ok("qwen".to_string())
|
197
|
+
} else if model_lower.contains("phi-3") || model_lower.contains("phi3") {
|
198
|
+
Ok("phi3".to_string())
|
199
|
+
} else if model_lower.contains("phi-2") || model_lower.contains("phi2") {
|
200
|
+
Ok("phi2".to_string())
|
201
|
+
} else if model_lower.contains("phi") {
|
202
|
+
Ok("phi".to_string())
|
152
203
|
} else {
|
153
204
|
Err(candle_core::Error::Msg(
|
154
205
|
"Could not determine model architecture from metadata or name".to_string()
|
@@ -235,6 +286,20 @@ impl QuantizedGGUF {
|
|
235
286
|
.copied()
|
236
287
|
.unwrap_or(1)
|
237
288
|
}
|
289
|
+
"qwen" | "qwen2" | "qwen3" => {
|
290
|
+
vocab.get("<|endoftext|>")
|
291
|
+
.or_else(|| vocab.get("<|im_end|>"))
|
292
|
+
.or_else(|| vocab.get("</s>"))
|
293
|
+
.copied()
|
294
|
+
.unwrap_or(151643) // Default Qwen3 EOS token
|
295
|
+
}
|
296
|
+
"phi" | "phi2" | "phi3" => {
|
297
|
+
vocab.get("<|endoftext|>")
|
298
|
+
.or_else(|| vocab.get("<|end|>"))
|
299
|
+
.or_else(|| vocab.get("</s>"))
|
300
|
+
.copied()
|
301
|
+
.unwrap_or(50256) // Default GPT-2 style EOS token
|
302
|
+
}
|
238
303
|
_ => 2, // Default
|
239
304
|
}
|
240
305
|
}
|
@@ -256,6 +321,10 @@ impl QuantizedGGUF {
|
|
256
321
|
} else if model_lower.contains("gemma") {
|
257
322
|
// Always use Gemma template for Gemma models, regardless of loader used
|
258
323
|
self.apply_gemma_template(messages)
|
324
|
+
} else if model_lower.contains("qwen") {
|
325
|
+
self.apply_qwen_template(messages)
|
326
|
+
} else if model_lower.contains("phi") {
|
327
|
+
self.apply_phi_template(messages)
|
259
328
|
} else {
|
260
329
|
match self.architecture.as_str() {
|
261
330
|
"llama" => {
|
@@ -268,6 +337,12 @@ impl QuantizedGGUF {
|
|
268
337
|
"gemma" => {
|
269
338
|
self.apply_gemma_template(messages)
|
270
339
|
}
|
340
|
+
"qwen" | "qwen2" | "qwen3" => {
|
341
|
+
self.apply_qwen_template(messages)
|
342
|
+
}
|
343
|
+
"phi" | "phi2" | "phi3" => {
|
344
|
+
self.apply_phi_template(messages)
|
345
|
+
}
|
271
346
|
_ => Ok(self.apply_generic_template(messages))
|
272
347
|
}
|
273
348
|
}
|
@@ -366,6 +441,77 @@ impl QuantizedGGUF {
|
|
366
441
|
Ok(prompt)
|
367
442
|
}
|
368
443
|
|
444
|
+
fn apply_qwen_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
445
|
+
let mut prompt = String::new();
|
446
|
+
|
447
|
+
for message in messages {
|
448
|
+
let role = message["role"].as_str().unwrap_or("");
|
449
|
+
let content = message["content"].as_str().unwrap_or("");
|
450
|
+
|
451
|
+
match role {
|
452
|
+
"system" => {
|
453
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
454
|
+
}
|
455
|
+
"user" => {
|
456
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
457
|
+
}
|
458
|
+
"assistant" => {
|
459
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
460
|
+
}
|
461
|
+
_ => {}
|
462
|
+
}
|
463
|
+
}
|
464
|
+
|
465
|
+
// Add generation prompt
|
466
|
+
prompt.push_str("<|im_start|>assistant\n");
|
467
|
+
Ok(prompt)
|
468
|
+
}
|
469
|
+
|
470
|
+
fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
471
|
+
let mut prompt = String::new();
|
472
|
+
|
473
|
+
// Check if it's Phi-3 (newer format) or Phi-2/Phi (simpler format)
|
474
|
+
let is_phi3 = self.model_id.contains("phi-3") || self.model_id.contains("Phi-3") || self.architecture == "phi3";
|
475
|
+
|
476
|
+
if is_phi3 {
|
477
|
+
// Phi-3 format
|
478
|
+
for message in messages {
|
479
|
+
let role = message["role"].as_str().unwrap_or("");
|
480
|
+
let content = message["content"].as_str().unwrap_or("");
|
481
|
+
|
482
|
+
match role {
|
483
|
+
"system" => {
|
484
|
+
prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
|
485
|
+
}
|
486
|
+
"user" => {
|
487
|
+
prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
|
488
|
+
}
|
489
|
+
"assistant" => {
|
490
|
+
prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
|
491
|
+
}
|
492
|
+
_ => {}
|
493
|
+
}
|
494
|
+
}
|
495
|
+
prompt.push_str("<|assistant|>\n");
|
496
|
+
} else {
|
497
|
+
// Phi-2 format
|
498
|
+
for message in messages {
|
499
|
+
let role = message["role"].as_str().unwrap_or("");
|
500
|
+
let content = message["content"].as_str().unwrap_or("");
|
501
|
+
|
502
|
+
match role {
|
503
|
+
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
504
|
+
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
505
|
+
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
506
|
+
_ => {}
|
507
|
+
}
|
508
|
+
}
|
509
|
+
prompt.push_str("Assistant: ");
|
510
|
+
}
|
511
|
+
|
512
|
+
Ok(prompt)
|
513
|
+
}
|
514
|
+
|
369
515
|
fn apply_generic_template(&self, messages: &[serde_json::Value]) -> String {
|
370
516
|
let mut prompt = String::new();
|
371
517
|
|
@@ -381,7 +527,9 @@ impl QuantizedGGUF {
|
|
381
527
|
|
382
528
|
/// Clear the KV cache between generations
|
383
529
|
pub fn clear_kv_cache(&mut self) {
|
384
|
-
// Quantized models
|
530
|
+
// Quantized models don't expose cache clearing methods
|
531
|
+
// Phi3 GGUF models have a known issue where the KV cache
|
532
|
+
// cannot be cleared, leading to errors on subsequent generations
|
385
533
|
}
|
386
534
|
|
387
535
|
fn generate_tokens(
|
@@ -408,6 +556,9 @@ impl QuantizedGGUF {
|
|
408
556
|
let logits = match &mut self.model {
|
409
557
|
ModelType::Llama(model) => model.forward(&input, start_pos)?,
|
410
558
|
ModelType::Gemma(model) => model.forward(&input, start_pos)?,
|
559
|
+
ModelType::Qwen(model) => model.forward(&input, start_pos)?,
|
560
|
+
ModelType::Phi(model) => model.forward(&input, start_pos)?,
|
561
|
+
ModelType::Phi3(model) => model.forward(&input, start_pos)?,
|
411
562
|
};
|
412
563
|
|
413
564
|
let logits = logits.squeeze(0)?;
|
@@ -0,0 +1,229 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_transformers::models::qwen2::{Config, Model as QwenModel};
|
3
|
+
use hf_hub::api::tokio::Api;
|
4
|
+
use tokenizers::Tokenizer;
|
5
|
+
|
6
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
7
|
+
|
8
|
+
/// Qwen model wrapper for text generation
|
9
|
+
#[derive(Debug)]
|
10
|
+
pub struct Qwen {
|
11
|
+
model: QwenModel,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
}
|
17
|
+
|
18
|
+
impl Qwen {
|
19
|
+
/// Get the tokenizer
|
20
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
21
|
+
&self.tokenizer
|
22
|
+
}
|
23
|
+
|
24
|
+
/// Clear the KV cache between generations
|
25
|
+
pub fn clear_kv_cache(&mut self) {
|
26
|
+
self.model.clear_kv_cache();
|
27
|
+
}
|
28
|
+
|
29
|
+
/// Load a Qwen model from HuggingFace
|
30
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
31
|
+
let api = Api::new()
|
32
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
33
|
+
|
34
|
+
let repo = api.model(model_id.to_string());
|
35
|
+
|
36
|
+
// Download configuration
|
37
|
+
let config_filename = repo.get("config.json").await
|
38
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
39
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
40
|
+
let config: Config = serde_json::from_str(&config_str)
|
41
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
42
|
+
|
43
|
+
// Download tokenizer
|
44
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
46
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
47
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
48
|
+
|
49
|
+
// Determine EOS token
|
50
|
+
let vocab = tokenizer.get_vocab(true);
|
51
|
+
let eos_token_id = vocab.get("<|endoftext|>")
|
52
|
+
.or_else(|| vocab.get("<|im_end|>"))
|
53
|
+
.or_else(|| vocab.get("</s>"))
|
54
|
+
.copied()
|
55
|
+
.unwrap_or(151643); // Default Qwen3 EOS token
|
56
|
+
|
57
|
+
// Download model weights
|
58
|
+
let mut filenames = vec![];
|
59
|
+
let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
|
60
|
+
else if model_id.contains("14b") || model_id.contains("14B") { 3 }
|
61
|
+
else { 1 };
|
62
|
+
|
63
|
+
if num_shards == 1 {
|
64
|
+
// Single file model
|
65
|
+
let filename = repo.get("model.safetensors").await
|
66
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
|
67
|
+
filenames.push(filename);
|
68
|
+
} else {
|
69
|
+
// Sharded model
|
70
|
+
for shard_idx in 1..=num_shards {
|
71
|
+
let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
|
72
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
|
73
|
+
filenames.push(filename);
|
74
|
+
}
|
75
|
+
}
|
76
|
+
|
77
|
+
// Load the model
|
78
|
+
let vb = unsafe {
|
79
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
|
80
|
+
};
|
81
|
+
|
82
|
+
let model = QwenModel::new(&config, vb)?;
|
83
|
+
|
84
|
+
Ok(Self {
|
85
|
+
model,
|
86
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
87
|
+
device,
|
88
|
+
model_id: model_id.to_string(),
|
89
|
+
eos_token_id,
|
90
|
+
})
|
91
|
+
}
|
92
|
+
|
93
|
+
/// Apply Qwen chat template to messages
|
94
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
95
|
+
let mut prompt = String::new();
|
96
|
+
|
97
|
+
for message in messages {
|
98
|
+
let role = message["role"].as_str().unwrap_or("");
|
99
|
+
let content = message["content"].as_str().unwrap_or("");
|
100
|
+
|
101
|
+
match role {
|
102
|
+
"system" => {
|
103
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
104
|
+
}
|
105
|
+
"user" => {
|
106
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
107
|
+
}
|
108
|
+
"assistant" => {
|
109
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
110
|
+
}
|
111
|
+
_ => {}
|
112
|
+
}
|
113
|
+
}
|
114
|
+
|
115
|
+
// Add generation prompt
|
116
|
+
prompt.push_str("<|im_start|>assistant\n");
|
117
|
+
|
118
|
+
Ok(prompt)
|
119
|
+
}
|
120
|
+
|
121
|
+
fn generate_tokens(
|
122
|
+
&mut self,
|
123
|
+
prompt_tokens: Vec<u32>,
|
124
|
+
config: &GenerationConfig,
|
125
|
+
mut callback: Option<impl FnMut(&str)>,
|
126
|
+
) -> CandleResult<Vec<u32>> {
|
127
|
+
let mut text_gen = TextGeneration::from_config(config);
|
128
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
129
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
130
|
+
|
131
|
+
let mut all_tokens = prompt_tokens.clone();
|
132
|
+
let start_gen = all_tokens.len();
|
133
|
+
|
134
|
+
for index in 0..config.max_length {
|
135
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
136
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
137
|
+
let ctxt = &all_tokens[start_pos..];
|
138
|
+
|
139
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
140
|
+
let logits = self.model.forward(&input, start_pos, None)?;
|
141
|
+
let logits = logits.squeeze(0)?;
|
142
|
+
|
143
|
+
// Handle different output shapes
|
144
|
+
let logits = if logits.dims().len() == 2 {
|
145
|
+
let seq_len = logits.dim(0)?;
|
146
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
147
|
+
} else {
|
148
|
+
logits
|
149
|
+
};
|
150
|
+
|
151
|
+
let logits = logits.to_dtype(DType::F32)?;
|
152
|
+
|
153
|
+
let next_token = text_gen.sample_next_token(
|
154
|
+
&logits,
|
155
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
156
|
+
)?;
|
157
|
+
|
158
|
+
all_tokens.push(next_token);
|
159
|
+
|
160
|
+
// Stream callback
|
161
|
+
if let Some(ref mut cb) = callback {
|
162
|
+
if config.debug_tokens {
|
163
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
164
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
165
|
+
} else {
|
166
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
167
|
+
cb(&decoded_text);
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
// Check stop conditions
|
172
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
173
|
+
break;
|
174
|
+
}
|
175
|
+
|
176
|
+
// Check stop sequences
|
177
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
178
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
179
|
+
break;
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
183
|
+
Ok(if config.include_prompt {
|
184
|
+
all_tokens
|
185
|
+
} else {
|
186
|
+
all_tokens[start_gen..].to_vec()
|
187
|
+
})
|
188
|
+
}
|
189
|
+
}
|
190
|
+
|
191
|
+
impl TextGenerator for Qwen {
|
192
|
+
fn generate(
|
193
|
+
&mut self,
|
194
|
+
prompt: &str,
|
195
|
+
config: &GenerationConfig,
|
196
|
+
) -> CandleResult<String> {
|
197
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
198
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
199
|
+
|
200
|
+
if config.debug_tokens {
|
201
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
202
|
+
} else {
|
203
|
+
self.tokenizer.decode(&output_tokens, true)
|
204
|
+
}
|
205
|
+
}
|
206
|
+
|
207
|
+
fn generate_stream(
|
208
|
+
&mut self,
|
209
|
+
prompt: &str,
|
210
|
+
config: &GenerationConfig,
|
211
|
+
mut callback: impl FnMut(&str),
|
212
|
+
) -> CandleResult<String> {
|
213
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
214
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
215
|
+
self.tokenizer.decode(&output_tokens, true)
|
216
|
+
}
|
217
|
+
|
218
|
+
fn model_name(&self) -> &str {
|
219
|
+
&self.model_id
|
220
|
+
}
|
221
|
+
|
222
|
+
fn device(&self) -> &Device {
|
223
|
+
&self.device
|
224
|
+
}
|
225
|
+
|
226
|
+
fn clear_cache(&mut self) {
|
227
|
+
self.clear_kv_cache();
|
228
|
+
}
|
229
|
+
}
|
@@ -1,13 +1,17 @@
|
|
1
1
|
use candle_core::{Result as CandleResult, Tensor};
|
2
2
|
use candle_transformers::generation::LogitsProcessor;
|
3
|
+
use std::sync::Arc;
|
3
4
|
|
4
5
|
use super::GenerationConfig;
|
6
|
+
use crate::structured::Index;
|
5
7
|
|
6
8
|
/// Helper struct for text generation process
|
7
9
|
pub struct TextGeneration {
|
8
10
|
logits_processor: LogitsProcessor,
|
9
11
|
tokens: Vec<u32>,
|
10
12
|
eos_token_id: Option<u32>,
|
13
|
+
constraint: Option<Arc<Index>>,
|
14
|
+
constraint_state: Option<u32>,
|
11
15
|
}
|
12
16
|
|
13
17
|
impl TextGeneration {
|
@@ -25,18 +29,27 @@ impl TextGeneration {
|
|
25
29
|
logits_processor,
|
26
30
|
tokens: Vec::new(),
|
27
31
|
eos_token_id: None,
|
32
|
+
constraint: None,
|
33
|
+
constraint_state: None,
|
28
34
|
}
|
29
35
|
}
|
30
36
|
|
31
37
|
pub fn from_config(config: &GenerationConfig) -> Self {
|
32
|
-
Self::new(
|
38
|
+
let mut text_gen = Self::new(
|
33
39
|
config.seed,
|
34
40
|
Some(config.temperature),
|
35
41
|
config.top_p,
|
36
42
|
config.top_k,
|
37
43
|
config.repetition_penalty,
|
38
44
|
config.repetition_penalty_last_n,
|
39
|
-
)
|
45
|
+
);
|
46
|
+
|
47
|
+
// Set constraint if provided
|
48
|
+
if let Some(ref constraint) = config.constraint {
|
49
|
+
text_gen.set_constraint(Arc::clone(constraint));
|
50
|
+
}
|
51
|
+
|
52
|
+
text_gen
|
40
53
|
}
|
41
54
|
|
42
55
|
pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
|
@@ -55,6 +68,36 @@ impl TextGeneration {
|
|
55
68
|
self.tokens.push(token);
|
56
69
|
}
|
57
70
|
|
71
|
+
pub fn set_constraint(&mut self, constraint: Arc<Index>) {
|
72
|
+
// Initialize with the first state
|
73
|
+
self.constraint_state = Some(constraint.initial_state());
|
74
|
+
self.constraint = Some(constraint);
|
75
|
+
}
|
76
|
+
|
77
|
+
/// Apply constraints to logits by masking disallowed tokens
|
78
|
+
fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
|
79
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
80
|
+
let device = logits.device();
|
81
|
+
let vocab_size = logits.dims1()?;
|
82
|
+
|
83
|
+
// Get allowed tokens from the constraint index for current state
|
84
|
+
if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
|
85
|
+
// Create a mask where allowed tokens have value 0 and others have -inf
|
86
|
+
let mut mask = vec![f32::NEG_INFINITY; vocab_size];
|
87
|
+
for &token_id in &allowed_tokens {
|
88
|
+
if (token_id as usize) < vocab_size {
|
89
|
+
mask[token_id as usize] = 0.0;
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
// Apply mask to logits
|
94
|
+
let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
|
95
|
+
*logits = logits.add(&mask_tensor)?;
|
96
|
+
}
|
97
|
+
}
|
98
|
+
Ok(())
|
99
|
+
}
|
100
|
+
|
58
101
|
/// Apply repetition penalty to logits
|
59
102
|
pub fn apply_repetition_penalty(
|
60
103
|
&self,
|
@@ -103,10 +146,18 @@ impl TextGeneration {
|
|
103
146
|
self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
|
104
147
|
}
|
105
148
|
|
149
|
+
// Apply constraints if active
|
150
|
+
self.apply_constraints(&mut logits)?;
|
151
|
+
|
106
152
|
// Sample token
|
107
153
|
let next_token = self.logits_processor.sample(&logits)?;
|
108
154
|
self.tokens.push(next_token);
|
109
155
|
|
156
|
+
// Update constraint state if active
|
157
|
+
if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
|
158
|
+
self.constraint_state = constraint_index.next_state(¤t_state, &next_token);
|
159
|
+
}
|
160
|
+
|
110
161
|
Ok(next_token)
|
111
162
|
}
|
112
163
|
|
@@ -122,6 +173,19 @@ impl TextGeneration {
|
|
122
173
|
}
|
123
174
|
}
|
124
175
|
|
176
|
+
// Check if we've reached a final state in constraint
|
177
|
+
// A state is considered final if it has no allowed tokens
|
178
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
179
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
180
|
+
if allowed.is_empty() {
|
181
|
+
return true;
|
182
|
+
}
|
183
|
+
} else {
|
184
|
+
// None means no tokens allowed - we're done
|
185
|
+
return true;
|
186
|
+
}
|
187
|
+
}
|
188
|
+
|
125
189
|
false
|
126
190
|
}
|
127
191
|
|
@@ -162,6 +162,10 @@ impl Device {
|
|
162
162
|
pub fn __str__(&self) -> String {
|
163
163
|
self.__repr__()
|
164
164
|
}
|
165
|
+
|
166
|
+
pub fn __eq__(&self, other: &Device) -> bool {
|
167
|
+
self == other
|
168
|
+
}
|
165
169
|
}
|
166
170
|
|
167
171
|
impl magnus::TryConvert for Device {
|
@@ -193,5 +197,6 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
193
197
|
rb_device.define_singleton_method("default", function!(default_device, 0))?;
|
194
198
|
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
195
199
|
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
200
|
+
rb_device.define_method("==", method!(Device::__eq__, 1))?;
|
196
201
|
Ok(())
|
197
202
|
}
|