red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_nn::VarBuilder;
|
|
3
|
+
use candle_transformers::models::granite::{
|
|
4
|
+
Cache, Config, Granite as GraniteModel, GraniteConfig,
|
|
5
|
+
};
|
|
6
|
+
use hf_hub::{api::tokio::Api, Repo};
|
|
7
|
+
use tokenizers::Tokenizer;
|
|
8
|
+
|
|
9
|
+
use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
10
|
+
|
|
11
|
+
#[derive(Debug)]
|
|
12
|
+
pub struct Granite {
|
|
13
|
+
model: GraniteModel,
|
|
14
|
+
tokenizer: TokenizerWrapper,
|
|
15
|
+
device: Device,
|
|
16
|
+
model_id: String,
|
|
17
|
+
eos_token_id: u32,
|
|
18
|
+
cache: Cache,
|
|
19
|
+
config: Config,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
impl Granite {
|
|
23
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
24
|
+
self.eos_token_id
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
pub fn clear_kv_cache(&mut self) {
|
|
28
|
+
if let Ok(new_cache) =
|
|
29
|
+
Cache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device)
|
|
30
|
+
{
|
|
31
|
+
self.cache = new_cache;
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
36
|
+
&self.tokenizer
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
pub async fn from_pretrained_with_tokenizer(
|
|
40
|
+
model_id: &str,
|
|
41
|
+
device: Device,
|
|
42
|
+
tokenizer_source: Option<&str>,
|
|
43
|
+
) -> CandleResult<Self> {
|
|
44
|
+
let api = Api::new()
|
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
46
|
+
|
|
47
|
+
let repo = api.repo(Repo::model(model_id.to_string()));
|
|
48
|
+
|
|
49
|
+
let config_filename = repo
|
|
50
|
+
.get("config.json")
|
|
51
|
+
.await
|
|
52
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
53
|
+
|
|
54
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
|
55
|
+
let granite_config: GraniteConfig = serde_json::from_str(&config_str)
|
|
56
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
57
|
+
let config = granite_config.into_config(false);
|
|
58
|
+
|
|
59
|
+
let config_json: serde_json::Value = serde_json::from_str(&config_str)
|
|
60
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
|
|
61
|
+
let tie_word_embeddings = config_json
|
|
62
|
+
.get("tie_word_embeddings")
|
|
63
|
+
.and_then(|v| v.as_bool())
|
|
64
|
+
.unwrap_or(false);
|
|
65
|
+
|
|
66
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
67
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
|
68
|
+
let tokenizer_filename = tokenizer_repo
|
|
69
|
+
.get("tokenizer.json")
|
|
70
|
+
.await
|
|
71
|
+
.map_err(|e| {
|
|
72
|
+
candle_core::Error::Msg(format!(
|
|
73
|
+
"Failed to download tokenizer from {}: {}",
|
|
74
|
+
tokenizer_id, e
|
|
75
|
+
))
|
|
76
|
+
})?;
|
|
77
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
78
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
79
|
+
} else {
|
|
80
|
+
let tokenizer_filename = repo.get("tokenizer.json").await.map_err(|e| {
|
|
81
|
+
candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e))
|
|
82
|
+
})?;
|
|
83
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
84
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
let vocab = tokenizer.get_vocab(true);
|
|
88
|
+
let eos_token_id = vocab
|
|
89
|
+
.get("<|end_of_text|>")
|
|
90
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
|
91
|
+
.or_else(|| vocab.get("</s>"))
|
|
92
|
+
.copied()
|
|
93
|
+
.unwrap_or(0);
|
|
94
|
+
|
|
95
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
96
|
+
vec![single_file]
|
|
97
|
+
} else {
|
|
98
|
+
let mut sharded_files = Vec::new();
|
|
99
|
+
let mut index = 1;
|
|
100
|
+
loop {
|
|
101
|
+
let mut found = false;
|
|
102
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
|
|
103
|
+
let filename =
|
|
104
|
+
format!("model-{:05}-of-{:05}.safetensors", index, total);
|
|
105
|
+
if let Ok(file) = repo.get(&filename).await {
|
|
106
|
+
sharded_files.push(file);
|
|
107
|
+
found = true;
|
|
108
|
+
break;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
if !found {
|
|
112
|
+
break;
|
|
113
|
+
}
|
|
114
|
+
index += 1;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if sharded_files.is_empty() {
|
|
118
|
+
return Err(candle_core::Error::Msg(
|
|
119
|
+
"Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string(),
|
|
120
|
+
));
|
|
121
|
+
}
|
|
122
|
+
sharded_files
|
|
123
|
+
};
|
|
124
|
+
|
|
125
|
+
let vb = unsafe {
|
|
126
|
+
VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
|
127
|
+
};
|
|
128
|
+
|
|
129
|
+
let vb = if tie_word_embeddings {
|
|
130
|
+
vb.rename_f(|name: &str| {
|
|
131
|
+
if name == "lm_head.weight" {
|
|
132
|
+
"model.embed_tokens.weight".to_string()
|
|
133
|
+
} else {
|
|
134
|
+
name.to_string()
|
|
135
|
+
}
|
|
136
|
+
})
|
|
137
|
+
} else {
|
|
138
|
+
vb
|
|
139
|
+
};
|
|
140
|
+
|
|
141
|
+
let model = GraniteModel::load(vb, &config)?;
|
|
142
|
+
let cache = Cache::new(true, DType::F32, &config, &device)?;
|
|
143
|
+
|
|
144
|
+
Ok(Self {
|
|
145
|
+
model,
|
|
146
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
147
|
+
device,
|
|
148
|
+
model_id: model_id.to_string(),
|
|
149
|
+
eos_token_id,
|
|
150
|
+
cache,
|
|
151
|
+
config,
|
|
152
|
+
})
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
156
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
pub fn apply_chat_template(
|
|
160
|
+
&self,
|
|
161
|
+
messages: &[serde_json::Value],
|
|
162
|
+
) -> CandleResult<String> {
|
|
163
|
+
let mut prompt = String::new();
|
|
164
|
+
|
|
165
|
+
for message in messages {
|
|
166
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
167
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
168
|
+
|
|
169
|
+
match role {
|
|
170
|
+
"system" => {
|
|
171
|
+
prompt.push_str(&format!(
|
|
172
|
+
"<|start_of_role|>system<|end_of_role|>{}<|end_of_text|>\n",
|
|
173
|
+
content
|
|
174
|
+
));
|
|
175
|
+
}
|
|
176
|
+
"user" => {
|
|
177
|
+
prompt.push_str(&format!(
|
|
178
|
+
"<|start_of_role|>user<|end_of_role|>{}<|end_of_text|>\n",
|
|
179
|
+
content
|
|
180
|
+
));
|
|
181
|
+
}
|
|
182
|
+
"assistant" => {
|
|
183
|
+
prompt.push_str(&format!(
|
|
184
|
+
"<|start_of_role|>assistant<|end_of_role|>{}<|end_of_text|>\n",
|
|
185
|
+
content
|
|
186
|
+
));
|
|
187
|
+
}
|
|
188
|
+
_ => {}
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
prompt.push_str("<|start_of_role|>assistant<|end_of_role|>");
|
|
193
|
+
|
|
194
|
+
Ok(prompt)
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
fn generate_tokens(
|
|
198
|
+
&mut self,
|
|
199
|
+
prompt_tokens: Vec<u32>,
|
|
200
|
+
config: &GenerationConfig,
|
|
201
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
202
|
+
) -> CandleResult<Vec<u32>> {
|
|
203
|
+
let mut text_gen = TextGeneration::new(config);
|
|
204
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
205
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
206
|
+
|
|
207
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
208
|
+
let start_gen = all_tokens.len();
|
|
209
|
+
|
|
210
|
+
for index in 0..config.max_length {
|
|
211
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
212
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
213
|
+
let ctxt = &all_tokens[start_pos..];
|
|
214
|
+
|
|
215
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
216
|
+
let input = input.contiguous()?;
|
|
217
|
+
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
|
218
|
+
|
|
219
|
+
let logits = logits.squeeze(0)?;
|
|
220
|
+
let logits = if logits.dims().len() == 2 {
|
|
221
|
+
let seq_len = logits.dim(0)?;
|
|
222
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
223
|
+
} else {
|
|
224
|
+
logits
|
|
225
|
+
};
|
|
226
|
+
|
|
227
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
228
|
+
|
|
229
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
230
|
+
|
|
231
|
+
all_tokens.push(next_token);
|
|
232
|
+
|
|
233
|
+
if let Some(ref mut cb) = callback {
|
|
234
|
+
if config.debug_tokens {
|
|
235
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
236
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
237
|
+
} else {
|
|
238
|
+
let decoded_text =
|
|
239
|
+
self.tokenizer
|
|
240
|
+
.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
241
|
+
cb(&decoded_text);
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
246
|
+
break;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
if config.stop_on_constraint_satisfaction {
|
|
250
|
+
let satisfied = if config.stop_on_match {
|
|
251
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
252
|
+
} else {
|
|
253
|
+
text_gen.is_constraint_satisfied()
|
|
254
|
+
};
|
|
255
|
+
if satisfied {
|
|
256
|
+
break;
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
261
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
262
|
+
break;
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
Ok(if config.include_prompt {
|
|
267
|
+
all_tokens
|
|
268
|
+
} else {
|
|
269
|
+
all_tokens[start_gen..].to_vec()
|
|
270
|
+
})
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
impl TextGenerator for Granite {
|
|
275
|
+
fn generate(&mut self, prompt: &str, config: &GenerationConfig) -> CandleResult<String> {
|
|
276
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
277
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
278
|
+
|
|
279
|
+
if config.debug_tokens {
|
|
280
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
281
|
+
} else {
|
|
282
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
fn generate_stream(
|
|
287
|
+
&mut self,
|
|
288
|
+
prompt: &str,
|
|
289
|
+
config: &GenerationConfig,
|
|
290
|
+
mut callback: impl FnMut(&str),
|
|
291
|
+
) -> CandleResult<String> {
|
|
292
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
293
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
294
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
fn model_name(&self) -> &str {
|
|
298
|
+
&self.model_id
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
fn device(&self) -> &Device {
|
|
302
|
+
&self.device
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
fn clear_cache(&mut self) {
|
|
306
|
+
self.clear_kv_cache();
|
|
307
|
+
}
|
|
308
|
+
}
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_nn::VarBuilder;
|
|
3
|
+
use candle_transformers::models::granitemoehybrid::{
|
|
4
|
+
GraniteMoeHybrid as GraniteMoeHybridModel, GraniteMoeHybridCache,
|
|
5
|
+
GraniteMoeHybridConfig, GraniteMoeHybridInternalConfig,
|
|
6
|
+
};
|
|
7
|
+
use hf_hub::{api::tokio::Api, Repo};
|
|
8
|
+
use tokenizers::Tokenizer;
|
|
9
|
+
|
|
10
|
+
use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
11
|
+
|
|
12
|
+
#[derive(Debug)]
|
|
13
|
+
pub struct GraniteMoeHybrid {
|
|
14
|
+
model: GraniteMoeHybridModel,
|
|
15
|
+
tokenizer: TokenizerWrapper,
|
|
16
|
+
device: Device,
|
|
17
|
+
model_id: String,
|
|
18
|
+
eos_token_id: u32,
|
|
19
|
+
cache: GraniteMoeHybridCache,
|
|
20
|
+
config: GraniteMoeHybridInternalConfig,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
impl GraniteMoeHybrid {
|
|
24
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
25
|
+
self.eos_token_id
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
pub fn clear_kv_cache(&mut self) {
|
|
29
|
+
if let Ok(new_cache) =
|
|
30
|
+
GraniteMoeHybridCache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device)
|
|
31
|
+
{
|
|
32
|
+
self.cache = new_cache;
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
37
|
+
&self.tokenizer
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
pub async fn from_pretrained_with_tokenizer(
|
|
41
|
+
model_id: &str,
|
|
42
|
+
device: Device,
|
|
43
|
+
tokenizer_source: Option<&str>,
|
|
44
|
+
) -> CandleResult<Self> {
|
|
45
|
+
let api = Api::new()
|
|
46
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
47
|
+
|
|
48
|
+
let repo = api.repo(Repo::model(model_id.to_string()));
|
|
49
|
+
|
|
50
|
+
let config_filename = repo
|
|
51
|
+
.get("config.json")
|
|
52
|
+
.await
|
|
53
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
54
|
+
|
|
55
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
|
56
|
+
let granite_config: GraniteMoeHybridConfig = serde_json::from_str(&config_str)
|
|
57
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
|
58
|
+
let config = granite_config.into_config(false);
|
|
59
|
+
|
|
60
|
+
let config_json: serde_json::Value = serde_json::from_str(&config_str)
|
|
61
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
|
|
62
|
+
let tie_word_embeddings = config_json
|
|
63
|
+
.get("tie_word_embeddings")
|
|
64
|
+
.and_then(|v| v.as_bool())
|
|
65
|
+
.unwrap_or(false);
|
|
66
|
+
|
|
67
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
68
|
+
let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
|
|
69
|
+
let tokenizer_filename = tokenizer_repo
|
|
70
|
+
.get("tokenizer.json")
|
|
71
|
+
.await
|
|
72
|
+
.map_err(|e| {
|
|
73
|
+
candle_core::Error::Msg(format!(
|
|
74
|
+
"Failed to download tokenizer from {}: {}",
|
|
75
|
+
tokenizer_id, e
|
|
76
|
+
))
|
|
77
|
+
})?;
|
|
78
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
79
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
80
|
+
} else {
|
|
81
|
+
let tokenizer_filename = repo.get("tokenizer.json").await.map_err(|e| {
|
|
82
|
+
candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e))
|
|
83
|
+
})?;
|
|
84
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
85
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
86
|
+
};
|
|
87
|
+
|
|
88
|
+
let vocab = tokenizer.get_vocab(true);
|
|
89
|
+
let eos_token_id = vocab
|
|
90
|
+
.get("<|end_of_text|>")
|
|
91
|
+
.or_else(|| vocab.get("<|endoftext|>"))
|
|
92
|
+
.or_else(|| vocab.get("</s>"))
|
|
93
|
+
.copied()
|
|
94
|
+
.unwrap_or(0);
|
|
95
|
+
|
|
96
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
97
|
+
vec![single_file]
|
|
98
|
+
} else {
|
|
99
|
+
let mut sharded_files = Vec::new();
|
|
100
|
+
let mut index = 1;
|
|
101
|
+
loop {
|
|
102
|
+
let mut found = false;
|
|
103
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
|
|
104
|
+
let filename =
|
|
105
|
+
format!("model-{:05}-of-{:05}.safetensors", index, total);
|
|
106
|
+
if let Ok(file) = repo.get(&filename).await {
|
|
107
|
+
sharded_files.push(file);
|
|
108
|
+
found = true;
|
|
109
|
+
break;
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
if !found {
|
|
113
|
+
break;
|
|
114
|
+
}
|
|
115
|
+
index += 1;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if sharded_files.is_empty() {
|
|
119
|
+
return Err(candle_core::Error::Msg(
|
|
120
|
+
"Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string(),
|
|
121
|
+
));
|
|
122
|
+
}
|
|
123
|
+
sharded_files
|
|
124
|
+
};
|
|
125
|
+
|
|
126
|
+
let vb = unsafe {
|
|
127
|
+
VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
|
128
|
+
};
|
|
129
|
+
|
|
130
|
+
let vb = if tie_word_embeddings {
|
|
131
|
+
vb.rename_f(|name: &str| {
|
|
132
|
+
if name == "lm_head.weight" {
|
|
133
|
+
"model.embed_tokens.weight".to_string()
|
|
134
|
+
} else {
|
|
135
|
+
name.to_string()
|
|
136
|
+
}
|
|
137
|
+
})
|
|
138
|
+
} else {
|
|
139
|
+
vb
|
|
140
|
+
};
|
|
141
|
+
|
|
142
|
+
let model = GraniteMoeHybridModel::load(vb, &config)?;
|
|
143
|
+
let cache = GraniteMoeHybridCache::new(true, DType::F32, &config, &device)?;
|
|
144
|
+
|
|
145
|
+
Ok(Self {
|
|
146
|
+
model,
|
|
147
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
148
|
+
device,
|
|
149
|
+
model_id: model_id.to_string(),
|
|
150
|
+
eos_token_id,
|
|
151
|
+
cache,
|
|
152
|
+
config,
|
|
153
|
+
})
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
157
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
pub fn apply_chat_template(
|
|
161
|
+
&self,
|
|
162
|
+
messages: &[serde_json::Value],
|
|
163
|
+
) -> CandleResult<String> {
|
|
164
|
+
let mut prompt = String::new();
|
|
165
|
+
|
|
166
|
+
for message in messages {
|
|
167
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
168
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
169
|
+
|
|
170
|
+
match role {
|
|
171
|
+
"system" => {
|
|
172
|
+
prompt.push_str(&format!(
|
|
173
|
+
"<|start_of_role|>system<|end_of_role|>{}<|end_of_text|>\n",
|
|
174
|
+
content
|
|
175
|
+
));
|
|
176
|
+
}
|
|
177
|
+
"user" => {
|
|
178
|
+
prompt.push_str(&format!(
|
|
179
|
+
"<|start_of_role|>user<|end_of_role|>{}<|end_of_text|>\n",
|
|
180
|
+
content
|
|
181
|
+
));
|
|
182
|
+
}
|
|
183
|
+
"assistant" => {
|
|
184
|
+
prompt.push_str(&format!(
|
|
185
|
+
"<|start_of_role|>assistant<|end_of_role|>{}<|end_of_text|>\n",
|
|
186
|
+
content
|
|
187
|
+
));
|
|
188
|
+
}
|
|
189
|
+
"tool" => {
|
|
190
|
+
prompt.push_str(&format!(
|
|
191
|
+
"<|start_of_role|>tool<|end_of_role|>{}<|end_of_text|>\n",
|
|
192
|
+
content
|
|
193
|
+
));
|
|
194
|
+
}
|
|
195
|
+
_ => {}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
prompt.push_str("<|start_of_role|>assistant<|end_of_role|>");
|
|
200
|
+
|
|
201
|
+
Ok(prompt)
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
fn generate_tokens(
|
|
205
|
+
&mut self,
|
|
206
|
+
prompt_tokens: Vec<u32>,
|
|
207
|
+
config: &GenerationConfig,
|
|
208
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
209
|
+
) -> CandleResult<Vec<u32>> {
|
|
210
|
+
let mut text_gen = TextGeneration::new(config);
|
|
211
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
212
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
213
|
+
|
|
214
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
215
|
+
let start_gen = all_tokens.len();
|
|
216
|
+
|
|
217
|
+
for index in 0..config.max_length {
|
|
218
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
219
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
220
|
+
let ctxt = &all_tokens[start_pos..];
|
|
221
|
+
|
|
222
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
223
|
+
let input = input.contiguous()?;
|
|
224
|
+
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
|
225
|
+
|
|
226
|
+
let logits = logits.squeeze(0)?;
|
|
227
|
+
let logits = if logits.dims().len() == 2 {
|
|
228
|
+
let seq_len = logits.dim(0)?;
|
|
229
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
230
|
+
} else {
|
|
231
|
+
logits
|
|
232
|
+
};
|
|
233
|
+
|
|
234
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
235
|
+
|
|
236
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
237
|
+
|
|
238
|
+
all_tokens.push(next_token);
|
|
239
|
+
|
|
240
|
+
if let Some(ref mut cb) = callback {
|
|
241
|
+
if config.debug_tokens {
|
|
242
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
243
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
244
|
+
} else {
|
|
245
|
+
let decoded_text =
|
|
246
|
+
self.tokenizer
|
|
247
|
+
.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
248
|
+
cb(&decoded_text);
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
253
|
+
break;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
if config.stop_on_constraint_satisfaction {
|
|
257
|
+
let satisfied = if config.stop_on_match {
|
|
258
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
259
|
+
} else {
|
|
260
|
+
text_gen.is_constraint_satisfied()
|
|
261
|
+
};
|
|
262
|
+
if satisfied {
|
|
263
|
+
break;
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
268
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
269
|
+
break;
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
Ok(if config.include_prompt {
|
|
274
|
+
all_tokens
|
|
275
|
+
} else {
|
|
276
|
+
all_tokens[start_gen..].to_vec()
|
|
277
|
+
})
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
impl TextGenerator for GraniteMoeHybrid {
|
|
282
|
+
fn generate(&mut self, prompt: &str, config: &GenerationConfig) -> CandleResult<String> {
|
|
283
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
284
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
285
|
+
|
|
286
|
+
if config.debug_tokens {
|
|
287
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
288
|
+
} else {
|
|
289
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
fn generate_stream(
|
|
294
|
+
&mut self,
|
|
295
|
+
prompt: &str,
|
|
296
|
+
config: &GenerationConfig,
|
|
297
|
+
mut callback: impl FnMut(&str),
|
|
298
|
+
) -> CandleResult<String> {
|
|
299
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
300
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
301
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
fn model_name(&self) -> &str {
|
|
305
|
+
&self.model_id
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
fn device(&self) -> &Device {
|
|
309
|
+
&self.device
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
fn clear_cache(&mut self) {
|
|
313
|
+
self.clear_kv_cache();
|
|
314
|
+
}
|
|
315
|
+
}
|