red-candle 1.8.0.pre2-x86_64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5193 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +33 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_nn::VarBuilder;
|
|
3
|
+
use candle_transformers::models::gemma::{Config, Model as GemmaModel};
|
|
4
|
+
use hf_hub::{api::tokio::Api, Repo};
|
|
5
|
+
use tokenizers::Tokenizer;
|
|
6
|
+
|
|
7
|
+
use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
8
|
+
|
|
9
|
+
#[derive(Debug)]
|
|
10
|
+
pub struct Gemma {
|
|
11
|
+
model: GemmaModel,
|
|
12
|
+
tokenizer: TokenizerWrapper,
|
|
13
|
+
device: Device,
|
|
14
|
+
model_id: String,
|
|
15
|
+
eos_token_id: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
impl Gemma {
|
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
20
|
+
self.eos_token_id
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
/// Clear the KV cache between generations
|
|
24
|
+
pub fn clear_kv_cache(&mut self) {
|
|
25
|
+
self.model.clear_kv_cache();
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/// Get the tokenizer
|
|
29
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
30
|
+
&self.tokenizer
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/// Load a Gemma model from HuggingFace Hub with optional custom tokenizer
|
|
34
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
35
|
+
let api = Api::new()
|
|
36
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
37
|
+
|
|
38
|
+
let repo = api.repo(Repo::model(model_id.to_string()));
|
|
39
|
+
|
|
40
|
+
// Download model files
|
|
41
|
+
let config_filename = repo
|
|
42
|
+
.get("config.json")
|
|
43
|
+
.await
|
|
44
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
45
|
+
|
|
46
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
47
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
48
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
|
49
|
+
let tokenizer_filename = tokenizer_repo
|
|
50
|
+
.get("tokenizer.json")
|
|
51
|
+
.await
|
|
52
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
53
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
54
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
55
|
+
} else {
|
|
56
|
+
let tokenizer_filename = repo
|
|
57
|
+
.get("tokenizer.json")
|
|
58
|
+
.await
|
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
60
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
61
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
// Try different file patterns for model weights
|
|
65
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
66
|
+
vec![single_file]
|
|
67
|
+
} else {
|
|
68
|
+
// Try to find sharded model files
|
|
69
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
|
70
|
+
// A better approach would be to read model.safetensors.index.json which
|
|
71
|
+
// contains the exact file list, but this works for most models (≤12 shards).
|
|
72
|
+
let mut sharded_files = Vec::new();
|
|
73
|
+
let mut index = 1;
|
|
74
|
+
loop {
|
|
75
|
+
// Try common shard counts for Gemma models
|
|
76
|
+
let mut found = false;
|
|
77
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 12] {
|
|
78
|
+
let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
|
|
79
|
+
if let Ok(file) = repo.get(&filename).await {
|
|
80
|
+
sharded_files.push(file);
|
|
81
|
+
found = true;
|
|
82
|
+
break;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
if !found {
|
|
86
|
+
break;
|
|
87
|
+
}
|
|
88
|
+
index += 1;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if sharded_files.is_empty() {
|
|
92
|
+
return Err(candle_core::Error::Msg(
|
|
93
|
+
"Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
|
|
94
|
+
));
|
|
95
|
+
}
|
|
96
|
+
sharded_files
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
// Load config
|
|
100
|
+
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
|
101
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
// Gemma uses specific tokens
|
|
105
|
+
let eos_token_id = {
|
|
106
|
+
let vocab = tokenizer.get_vocab(true);
|
|
107
|
+
vocab.get("<eos>")
|
|
108
|
+
.or_else(|| vocab.get("<end_of_turn>"))
|
|
109
|
+
.copied()
|
|
110
|
+
.unwrap_or(1) // Default Gemma EOS
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
// Load model weights
|
|
114
|
+
let vb = unsafe {
|
|
115
|
+
VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
let model = GemmaModel::new(false, &config, vb)?; // Don't use flash attention for now
|
|
119
|
+
|
|
120
|
+
Ok(Self {
|
|
121
|
+
model,
|
|
122
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
123
|
+
device,
|
|
124
|
+
model_id: model_id.to_string(),
|
|
125
|
+
eos_token_id,
|
|
126
|
+
})
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/// Load a Gemma model from HuggingFace Hub (backwards compatibility)
|
|
130
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
131
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
/// Create from existing components (useful for testing)
|
|
135
|
+
pub fn new(
|
|
136
|
+
model: GemmaModel,
|
|
137
|
+
tokenizer: Tokenizer,
|
|
138
|
+
device: Device,
|
|
139
|
+
model_id: String,
|
|
140
|
+
) -> Self {
|
|
141
|
+
let eos_token_id = {
|
|
142
|
+
let vocab = tokenizer.get_vocab(true);
|
|
143
|
+
vocab.get("<eos>")
|
|
144
|
+
.or_else(|| vocab.get("<end_of_turn>"))
|
|
145
|
+
.copied()
|
|
146
|
+
.unwrap_or(1)
|
|
147
|
+
};
|
|
148
|
+
|
|
149
|
+
Self {
|
|
150
|
+
model,
|
|
151
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
152
|
+
device,
|
|
153
|
+
model_id,
|
|
154
|
+
eos_token_id,
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
fn generate_tokens(
|
|
159
|
+
&mut self,
|
|
160
|
+
prompt_tokens: Vec<u32>,
|
|
161
|
+
config: &GenerationConfig,
|
|
162
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
163
|
+
) -> CandleResult<Vec<u32>> {
|
|
164
|
+
let mut text_gen = TextGeneration::new(config);
|
|
165
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
166
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
167
|
+
|
|
168
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
169
|
+
let start_gen = all_tokens.len();
|
|
170
|
+
|
|
171
|
+
for index in 0..config.max_length {
|
|
172
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
173
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
174
|
+
let ctxt = &all_tokens[start_pos..];
|
|
175
|
+
|
|
176
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
177
|
+
let input = input.contiguous()?;
|
|
178
|
+
let logits = self.model.forward(&input, start_pos)?;
|
|
179
|
+
|
|
180
|
+
let logits = logits.squeeze(0)?;
|
|
181
|
+
let logits = if logits.dims().len() == 2 {
|
|
182
|
+
let seq_len = logits.dim(0)?;
|
|
183
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
184
|
+
} else {
|
|
185
|
+
logits
|
|
186
|
+
};
|
|
187
|
+
|
|
188
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
189
|
+
|
|
190
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
191
|
+
|
|
192
|
+
all_tokens.push(next_token);
|
|
193
|
+
|
|
194
|
+
// Stream callback
|
|
195
|
+
if let Some(ref mut cb) = callback {
|
|
196
|
+
if config.debug_tokens {
|
|
197
|
+
// In debug mode, only show debug tokens
|
|
198
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
199
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
200
|
+
} else {
|
|
201
|
+
// Normal mode: use incremental decoding for proper text
|
|
202
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
203
|
+
cb(&decoded_text);
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
// Check stop conditions
|
|
208
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
209
|
+
break;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
// Check if constraint is satisfied (early stopping)
|
|
213
|
+
if config.stop_on_constraint_satisfaction {
|
|
214
|
+
let satisfied = if config.stop_on_match {
|
|
215
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
216
|
+
} else {
|
|
217
|
+
text_gen.is_constraint_satisfied()
|
|
218
|
+
};
|
|
219
|
+
if satisfied {
|
|
220
|
+
break;
|
|
221
|
+
}
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// Check stop sequences
|
|
225
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
226
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
227
|
+
break;
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
Ok(if config.include_prompt {
|
|
232
|
+
all_tokens
|
|
233
|
+
} else {
|
|
234
|
+
all_tokens[start_gen..].to_vec()
|
|
235
|
+
})
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
/// Apply Gemma chat template
|
|
239
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
240
|
+
let mut prompt = String::new();
|
|
241
|
+
|
|
242
|
+
// Gemma uses a specific format:
|
|
243
|
+
// <start_of_turn>user\n{user_message}<end_of_turn>
|
|
244
|
+
// <start_of_turn>model\n{model_message}<end_of_turn>
|
|
245
|
+
|
|
246
|
+
for message in messages {
|
|
247
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
248
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
249
|
+
|
|
250
|
+
match role {
|
|
251
|
+
"system" => {
|
|
252
|
+
// Gemma doesn't have explicit system messages, prepend to first user message
|
|
253
|
+
prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
|
|
254
|
+
}
|
|
255
|
+
"user" => {
|
|
256
|
+
if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
|
|
257
|
+
prompt.push_str("<start_of_turn>user\n");
|
|
258
|
+
}
|
|
259
|
+
prompt.push_str(&format!("{}<end_of_turn>\n", content));
|
|
260
|
+
}
|
|
261
|
+
"assistant" | "model" => {
|
|
262
|
+
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
|
|
263
|
+
}
|
|
264
|
+
_ => {}
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// Add the model prompt
|
|
269
|
+
prompt.push_str("<start_of_turn>model\n");
|
|
270
|
+
|
|
271
|
+
Ok(prompt)
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
impl TextGenerator for Gemma {
|
|
276
|
+
fn generate(
|
|
277
|
+
&mut self,
|
|
278
|
+
prompt: &str,
|
|
279
|
+
config: &GenerationConfig,
|
|
280
|
+
) -> CandleResult<String> {
|
|
281
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
282
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
283
|
+
|
|
284
|
+
if config.debug_tokens {
|
|
285
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
286
|
+
} else {
|
|
287
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
fn generate_stream(
|
|
292
|
+
&mut self,
|
|
293
|
+
prompt: &str,
|
|
294
|
+
config: &GenerationConfig,
|
|
295
|
+
mut callback: impl FnMut(&str),
|
|
296
|
+
) -> CandleResult<String> {
|
|
297
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
298
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
299
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
fn model_name(&self) -> &str {
|
|
303
|
+
&self.model_id
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
fn device(&self) -> &Device {
|
|
307
|
+
&self.device
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
fn clear_cache(&mut self) {
|
|
311
|
+
self.clear_kv_cache();
|
|
312
|
+
}
|
|
313
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
use std::time::{SystemTime, UNIX_EPOCH};
|
|
2
|
+
use std::sync::Arc;
|
|
3
|
+
use crate::structured::Index;
|
|
4
|
+
|
|
5
|
+
/// Configuration for text generation
|
|
6
|
+
#[derive(Debug, Clone)]
|
|
7
|
+
pub struct GenerationConfig {
|
|
8
|
+
/// The maximum number of tokens to generate
|
|
9
|
+
pub max_length: usize,
|
|
10
|
+
/// The temperature for sampling
|
|
11
|
+
pub temperature: f64,
|
|
12
|
+
/// The top-p value for nucleus sampling
|
|
13
|
+
pub top_p: Option<f64>,
|
|
14
|
+
/// The top-k value for top-k sampling
|
|
15
|
+
pub top_k: Option<usize>,
|
|
16
|
+
/// The repetition penalty
|
|
17
|
+
pub repetition_penalty: f32,
|
|
18
|
+
/// The repetition penalty range
|
|
19
|
+
pub repetition_penalty_last_n: usize,
|
|
20
|
+
/// Random seed for sampling
|
|
21
|
+
pub seed: u64,
|
|
22
|
+
/// Stop sequences
|
|
23
|
+
pub stop_sequences: Vec<String>,
|
|
24
|
+
/// Whether to return the prompt in the output
|
|
25
|
+
pub include_prompt: bool,
|
|
26
|
+
/// Whether to show raw tokens during generation (for debugging)
|
|
27
|
+
pub debug_tokens: bool,
|
|
28
|
+
/// Optional constraint index for structured generation
|
|
29
|
+
pub constraint: Option<Arc<Index>>,
|
|
30
|
+
/// Stop immediately when constraint is satisfied
|
|
31
|
+
pub stop_on_constraint_satisfaction: bool,
|
|
32
|
+
/// Whether to stop immediately when pattern is matched (vs allowing continuation)
|
|
33
|
+
pub stop_on_match: bool,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
/// Generate a random seed based on current time
|
|
37
|
+
fn random_seed() -> u64 {
|
|
38
|
+
SystemTime::now()
|
|
39
|
+
.duration_since(UNIX_EPOCH)
|
|
40
|
+
.map(|d| d.as_nanos() as u64)
|
|
41
|
+
.unwrap_or(42)
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
impl Default for GenerationConfig {
|
|
45
|
+
fn default() -> Self {
|
|
46
|
+
Self {
|
|
47
|
+
max_length: 512,
|
|
48
|
+
temperature: 0.7,
|
|
49
|
+
top_p: None,
|
|
50
|
+
top_k: None,
|
|
51
|
+
repetition_penalty: 1.1,
|
|
52
|
+
repetition_penalty_last_n: 64,
|
|
53
|
+
seed: random_seed(),
|
|
54
|
+
stop_sequences: vec![],
|
|
55
|
+
include_prompt: false,
|
|
56
|
+
debug_tokens: false,
|
|
57
|
+
constraint: None,
|
|
58
|
+
stop_on_constraint_satisfaction: true,
|
|
59
|
+
stop_on_match: true,
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_transformers::models::glm4_new::{Config, ModelForCausalLM as Glm4Model};
|
|
3
|
+
use hf_hub::api::tokio::Api;
|
|
4
|
+
use tokenizers::Tokenizer;
|
|
5
|
+
|
|
6
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
7
|
+
|
|
8
|
+
#[derive(Debug)]
|
|
9
|
+
pub struct Glm4 {
|
|
10
|
+
model: Glm4Model,
|
|
11
|
+
tokenizer: TokenizerWrapper,
|
|
12
|
+
device: Device,
|
|
13
|
+
model_id: String,
|
|
14
|
+
eos_token_id: u32,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
impl Glm4 {
|
|
18
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
19
|
+
self.eos_token_id
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
23
|
+
&self.tokenizer
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
pub fn clear_kv_cache(&mut self) {
|
|
27
|
+
self.model.clear_kv_cache();
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
31
|
+
let api = Api::new()
|
|
32
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
33
|
+
|
|
34
|
+
let repo = api.model(model_id.to_string());
|
|
35
|
+
|
|
36
|
+
let config_filename = repo.get("config.json").await
|
|
37
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
38
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
|
39
|
+
let config: Config = serde_json::from_str(&config_str)
|
|
40
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
41
|
+
|
|
42
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
43
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
|
44
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
46
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
47
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
48
|
+
} else {
|
|
49
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
|
50
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
51
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
52
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
let vocab = tokenizer.get_vocab(true);
|
|
56
|
+
let eos_token_id = vocab.get("<|endoftext|>")
|
|
57
|
+
.or_else(|| vocab.get("<|user|>"))
|
|
58
|
+
.or_else(|| vocab.get("</s>"))
|
|
59
|
+
.copied()
|
|
60
|
+
.unwrap_or(151329);
|
|
61
|
+
|
|
62
|
+
let mut filenames = vec![];
|
|
63
|
+
let num_shards = if model_id.contains("9b") || model_id.contains("9B") { 4 } else { 1 };
|
|
64
|
+
|
|
65
|
+
if num_shards == 1 {
|
|
66
|
+
let filename = repo.get("model.safetensors").await
|
|
67
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
|
|
68
|
+
filenames.push(filename);
|
|
69
|
+
} else {
|
|
70
|
+
for shard_idx in 1..=num_shards {
|
|
71
|
+
let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
|
|
72
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
|
|
73
|
+
filenames.push(filename);
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
let vb = unsafe {
|
|
78
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
let model = Glm4Model::new(&config, vb)?;
|
|
82
|
+
|
|
83
|
+
Ok(Self {
|
|
84
|
+
model,
|
|
85
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
86
|
+
device,
|
|
87
|
+
model_id: model_id.to_string(),
|
|
88
|
+
eos_token_id,
|
|
89
|
+
})
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
93
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
97
|
+
let mut prompt = String::new();
|
|
98
|
+
|
|
99
|
+
prompt.push_str("[gMASK]<sop>");
|
|
100
|
+
|
|
101
|
+
for message in messages {
|
|
102
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
103
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
104
|
+
|
|
105
|
+
match role {
|
|
106
|
+
"system" => {
|
|
107
|
+
prompt.push_str(&format!("<|system|>\n{}", content));
|
|
108
|
+
}
|
|
109
|
+
"user" => {
|
|
110
|
+
prompt.push_str(&format!("<|user|>\n{}", content));
|
|
111
|
+
}
|
|
112
|
+
"assistant" => {
|
|
113
|
+
prompt.push_str(&format!("<|assistant|>\n{}", content));
|
|
114
|
+
}
|
|
115
|
+
_ => {}
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
prompt.push_str("<|assistant|>\n");
|
|
120
|
+
|
|
121
|
+
Ok(prompt)
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
fn generate_tokens(
|
|
125
|
+
&mut self,
|
|
126
|
+
prompt_tokens: Vec<u32>,
|
|
127
|
+
config: &GenerationConfig,
|
|
128
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
129
|
+
) -> CandleResult<Vec<u32>> {
|
|
130
|
+
let mut text_gen = TextGeneration::new(config);
|
|
131
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
132
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
133
|
+
|
|
134
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
135
|
+
let start_gen = all_tokens.len();
|
|
136
|
+
|
|
137
|
+
for index in 0..config.max_length {
|
|
138
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
139
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
140
|
+
let ctxt = &all_tokens[start_pos..];
|
|
141
|
+
|
|
142
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
143
|
+
let logits = self.model.forward(&input, start_pos)?;
|
|
144
|
+
let logits = logits.squeeze(0)?;
|
|
145
|
+
|
|
146
|
+
let logits = if logits.dims().len() == 2 {
|
|
147
|
+
let seq_len = logits.dim(0)?;
|
|
148
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
149
|
+
} else {
|
|
150
|
+
logits
|
|
151
|
+
};
|
|
152
|
+
|
|
153
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
154
|
+
|
|
155
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
156
|
+
|
|
157
|
+
all_tokens.push(next_token);
|
|
158
|
+
|
|
159
|
+
if let Some(ref mut cb) = callback {
|
|
160
|
+
if config.debug_tokens {
|
|
161
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
162
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
163
|
+
} else {
|
|
164
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
165
|
+
cb(&decoded_text);
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
170
|
+
break;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if config.stop_on_constraint_satisfaction {
|
|
174
|
+
let satisfied = if config.stop_on_match {
|
|
175
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
176
|
+
} else {
|
|
177
|
+
text_gen.is_constraint_satisfied()
|
|
178
|
+
};
|
|
179
|
+
if satisfied {
|
|
180
|
+
break;
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
185
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
186
|
+
break;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
Ok(if config.include_prompt {
|
|
191
|
+
all_tokens
|
|
192
|
+
} else {
|
|
193
|
+
all_tokens[start_gen..].to_vec()
|
|
194
|
+
})
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
impl TextGenerator for Glm4 {
|
|
199
|
+
fn generate(
|
|
200
|
+
&mut self,
|
|
201
|
+
prompt: &str,
|
|
202
|
+
config: &GenerationConfig,
|
|
203
|
+
) -> CandleResult<String> {
|
|
204
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
205
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
206
|
+
|
|
207
|
+
if config.debug_tokens {
|
|
208
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
209
|
+
} else {
|
|
210
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
fn generate_stream(
|
|
215
|
+
&mut self,
|
|
216
|
+
prompt: &str,
|
|
217
|
+
config: &GenerationConfig,
|
|
218
|
+
mut callback: impl FnMut(&str),
|
|
219
|
+
) -> CandleResult<String> {
|
|
220
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
221
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
222
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
fn model_name(&self) -> &str {
|
|
226
|
+
&self.model_id
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
fn device(&self) -> &Device {
|
|
230
|
+
&self.device
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
fn clear_cache(&mut self) {
|
|
234
|
+
self.clear_kv_cache();
|
|
235
|
+
}
|
|
236
|
+
}
|