red-candle 1.8.0-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
|
2
|
+
use candle_transformers::models::phi::{Config, Model as PhiModel};
|
|
3
|
+
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model};
|
|
4
|
+
use hf_hub::api::tokio::Api;
|
|
5
|
+
use tokenizers::Tokenizer;
|
|
6
|
+
|
|
7
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
|
8
|
+
|
|
9
|
+
/// Phi model wrapper for text generation
|
|
10
|
+
pub struct Phi {
|
|
11
|
+
model: PhiVariant,
|
|
12
|
+
tokenizer: TokenizerWrapper,
|
|
13
|
+
device: Device,
|
|
14
|
+
model_id: String,
|
|
15
|
+
eos_token_id: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
enum PhiVariant {
|
|
19
|
+
Phi2(PhiModel),
|
|
20
|
+
Phi3(Phi3Model),
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
impl Phi {
|
|
24
|
+
pub fn eos_token_id(&self) -> u32 {
|
|
25
|
+
self.eos_token_id
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/// Get the tokenizer
|
|
29
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
|
30
|
+
&self.tokenizer
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/// Clear the KV cache between generations
|
|
34
|
+
pub fn clear_kv_cache(&mut self) {
|
|
35
|
+
match &mut self.model {
|
|
36
|
+
PhiVariant::Phi2(model) => model.clear_kv_cache(),
|
|
37
|
+
PhiVariant::Phi3(model) => model.clear_kv_cache(),
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/// Load a Phi model from HuggingFace with optional custom tokenizer
|
|
42
|
+
pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
|
43
|
+
let api = Api::new()
|
|
44
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
|
45
|
+
|
|
46
|
+
let repo = api.model(model_id.to_string());
|
|
47
|
+
|
|
48
|
+
// Download configuration
|
|
49
|
+
let config_filename = repo.get("config.json").await
|
|
50
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
|
51
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
|
52
|
+
|
|
53
|
+
// Download tokenizer from custom source if provided, otherwise from model repo
|
|
54
|
+
let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
|
|
55
|
+
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
|
56
|
+
let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
|
|
57
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
|
|
58
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
60
|
+
} else {
|
|
61
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
|
62
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
|
63
|
+
Tokenizer::from_file(tokenizer_filename)
|
|
64
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
|
|
65
|
+
};
|
|
66
|
+
|
|
67
|
+
// Determine EOS token
|
|
68
|
+
let vocab = tokenizer.get_vocab(true);
|
|
69
|
+
let eos_token_id = vocab.get("<|endoftext|>")
|
|
70
|
+
.or_else(|| vocab.get("<|end|>"))
|
|
71
|
+
.or_else(|| vocab.get("</s>"))
|
|
72
|
+
.copied()
|
|
73
|
+
.unwrap_or(50256); // Default GPT-2 style EOS token
|
|
74
|
+
|
|
75
|
+
// Determine model variant based on model_id or config
|
|
76
|
+
let is_phi3 = model_id.contains("phi-3") || model_id.contains("Phi-3");
|
|
77
|
+
|
|
78
|
+
// Download model weights (handle both single and sharded files)
|
|
79
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
|
80
|
+
vec![single_file]
|
|
81
|
+
} else {
|
|
82
|
+
// Try to find sharded model files
|
|
83
|
+
// NOTE: This uses a brute-force approach, trying common shard counts.
|
|
84
|
+
// A better approach would be to read model.safetensors.index.json which
|
|
85
|
+
// contains the exact file list, but this works for most models (≤30 shards).
|
|
86
|
+
let mut sharded_files = Vec::new();
|
|
87
|
+
let mut index = 1;
|
|
88
|
+
loop {
|
|
89
|
+
// Try common shard counts
|
|
90
|
+
let mut found = false;
|
|
91
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
|
|
92
|
+
let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
|
|
93
|
+
if let Ok(file) = repo.get(&filename).await {
|
|
94
|
+
sharded_files.push(file);
|
|
95
|
+
found = true;
|
|
96
|
+
break;
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
if !found {
|
|
100
|
+
break;
|
|
101
|
+
}
|
|
102
|
+
index += 1;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
if sharded_files.is_empty() {
|
|
106
|
+
return Err(candle_core::Error::Msg(
|
|
107
|
+
"Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
|
|
108
|
+
));
|
|
109
|
+
}
|
|
110
|
+
sharded_files
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
let model = if is_phi3 {
|
|
114
|
+
// Load Phi3 model
|
|
115
|
+
// Handle config differences between Phi-3-small and Phi-3-mini
|
|
116
|
+
let mut config_str_fixed;
|
|
117
|
+
|
|
118
|
+
// Parse config as JSON for modifications
|
|
119
|
+
let mut config_json: serde_json::Value = serde_json::from_str(&config_str)
|
|
120
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
|
|
121
|
+
|
|
122
|
+
// Phi-3-small uses ff_intermediate_size instead of intermediate_size
|
|
123
|
+
if config_json.get("ff_intermediate_size").is_some() && config_json.get("intermediate_size").is_none() {
|
|
124
|
+
if let Some(ff_size) = config_json.get("ff_intermediate_size").cloned() {
|
|
125
|
+
config_json["intermediate_size"] = ff_size;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
// Phi-3-small uses layer_norm_epsilon instead of rms_norm_eps
|
|
130
|
+
if config_json.get("layer_norm_epsilon").is_some() && config_json.get("rms_norm_eps").is_none() {
|
|
131
|
+
if let Some(eps) = config_json.get("layer_norm_epsilon").cloned() {
|
|
132
|
+
config_json["rms_norm_eps"] = eps;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// Handle rope_scaling for long context models (Phi-3-mini-128k)
|
|
137
|
+
// Candle expects rope_scaling to be a string, but newer configs have it as an object
|
|
138
|
+
if let Some(rope_scaling) = config_json.get("rope_scaling") {
|
|
139
|
+
if rope_scaling.is_object() {
|
|
140
|
+
// For now, just convert to the type string - candle will use default scaling
|
|
141
|
+
if let Some(scaling_type) = rope_scaling.get("type").and_then(|v| v.as_str()) {
|
|
142
|
+
config_json["rope_scaling"] = serde_json::Value::String(scaling_type.to_string());
|
|
143
|
+
} else {
|
|
144
|
+
// Remove it if we can't determine the type
|
|
145
|
+
config_json.as_object_mut().unwrap().remove("rope_scaling");
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// Phi-3-small uses rope_embedding_base instead of rope_theta
|
|
151
|
+
if config_json.get("rope_embedding_base").is_some() && config_json.get("rope_theta").is_none() {
|
|
152
|
+
if let Some(rope_base) = config_json.get("rope_embedding_base").cloned() {
|
|
153
|
+
config_json["rope_theta"] = rope_base;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
config_str_fixed = serde_json::to_string(&config_json)
|
|
158
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to serialize config: {}", e)))?;
|
|
159
|
+
|
|
160
|
+
// Check for unsupported gegelu activation
|
|
161
|
+
if config_str_fixed.contains("\"gegelu\"") {
|
|
162
|
+
// For now, map gegelu to gelu_pytorch_tanh with a warning
|
|
163
|
+
// This is not ideal but allows the model to at least load
|
|
164
|
+
eprintln!("WARNING: This model uses 'gegelu' activation which is not fully supported.");
|
|
165
|
+
eprintln!(" Mapping to 'gelu_pytorch_tanh' - results may be degraded.");
|
|
166
|
+
eprintln!(" For best results, use Phi-3-mini models instead.");
|
|
167
|
+
config_str_fixed = config_str_fixed.replace("\"gegelu\"", "\"gelu_pytorch_tanh\"");
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
let config: Phi3Config = serde_json::from_str(&config_str_fixed)
|
|
171
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
|
|
172
|
+
|
|
173
|
+
let vb = unsafe {
|
|
174
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
|
175
|
+
};
|
|
176
|
+
|
|
177
|
+
let model = Phi3Model::new(&config, vb)?;
|
|
178
|
+
PhiVariant::Phi3(model)
|
|
179
|
+
} else {
|
|
180
|
+
// Load Phi2 model
|
|
181
|
+
let config: Config = serde_json::from_str(&config_str)
|
|
182
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi config: {}", e)))?;
|
|
183
|
+
|
|
184
|
+
let vb = unsafe {
|
|
185
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
|
186
|
+
};
|
|
187
|
+
|
|
188
|
+
let model = PhiModel::new(&config, vb)?;
|
|
189
|
+
PhiVariant::Phi2(model)
|
|
190
|
+
};
|
|
191
|
+
|
|
192
|
+
Ok(Self {
|
|
193
|
+
model,
|
|
194
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
|
195
|
+
device,
|
|
196
|
+
model_id: model_id.to_string(),
|
|
197
|
+
eos_token_id,
|
|
198
|
+
})
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
/// Load a Phi model from HuggingFace (backwards compatibility)
|
|
202
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
|
203
|
+
Self::from_pretrained_with_tokenizer(model_id, device, None).await
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
/// Apply Phi chat template to messages
|
|
207
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
208
|
+
let mut prompt = String::new();
|
|
209
|
+
|
|
210
|
+
// Phi-3 uses a specific format
|
|
211
|
+
if matches!(self.model, PhiVariant::Phi3(_)) {
|
|
212
|
+
for message in messages {
|
|
213
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
214
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
215
|
+
|
|
216
|
+
match role {
|
|
217
|
+
"system" => {
|
|
218
|
+
prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
|
|
219
|
+
}
|
|
220
|
+
"user" => {
|
|
221
|
+
prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
|
|
222
|
+
}
|
|
223
|
+
"assistant" => {
|
|
224
|
+
prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
|
|
225
|
+
}
|
|
226
|
+
_ => {}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
prompt.push_str("<|assistant|>\n");
|
|
230
|
+
} else {
|
|
231
|
+
// Phi-2 uses a simpler format
|
|
232
|
+
for message in messages {
|
|
233
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
234
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
235
|
+
|
|
236
|
+
match role {
|
|
237
|
+
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
238
|
+
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
239
|
+
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
240
|
+
_ => {}
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
prompt.push_str("Assistant: ");
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
Ok(prompt)
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
fn generate_tokens(
|
|
250
|
+
&mut self,
|
|
251
|
+
prompt_tokens: Vec<u32>,
|
|
252
|
+
config: &GenerationConfig,
|
|
253
|
+
mut callback: Option<impl FnMut(&str)>,
|
|
254
|
+
) -> CandleResult<Vec<u32>> {
|
|
255
|
+
let mut text_gen = TextGeneration::new(config);
|
|
256
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
|
257
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
|
258
|
+
|
|
259
|
+
let mut all_tokens = prompt_tokens.clone();
|
|
260
|
+
let start_gen = all_tokens.len();
|
|
261
|
+
|
|
262
|
+
for index in 0..config.max_length {
|
|
263
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
|
264
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
|
265
|
+
let ctxt = &all_tokens[start_pos..];
|
|
266
|
+
|
|
267
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
268
|
+
let logits = match &mut self.model {
|
|
269
|
+
PhiVariant::Phi2(model) => model.forward(&input)?,
|
|
270
|
+
PhiVariant::Phi3(model) => model.forward(&input, start_pos)?,
|
|
271
|
+
};
|
|
272
|
+
let logits = logits.squeeze(0)?;
|
|
273
|
+
|
|
274
|
+
// Handle different output shapes
|
|
275
|
+
let logits = if logits.dims().len() == 2 {
|
|
276
|
+
let seq_len = logits.dim(0)?;
|
|
277
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
|
278
|
+
} else {
|
|
279
|
+
logits
|
|
280
|
+
};
|
|
281
|
+
|
|
282
|
+
let logits = logits.to_dtype(DType::F32)?;
|
|
283
|
+
|
|
284
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
|
285
|
+
|
|
286
|
+
all_tokens.push(next_token);
|
|
287
|
+
|
|
288
|
+
// Stream callback
|
|
289
|
+
if let Some(ref mut cb) = callback {
|
|
290
|
+
if config.debug_tokens {
|
|
291
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
|
292
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
|
293
|
+
} else {
|
|
294
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
|
295
|
+
cb(&decoded_text);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
// Check stop conditions
|
|
300
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
|
301
|
+
break;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// Check if constraint is satisfied (early stopping)
|
|
305
|
+
if config.stop_on_constraint_satisfaction {
|
|
306
|
+
let satisfied = if config.stop_on_match {
|
|
307
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
|
308
|
+
} else {
|
|
309
|
+
text_gen.is_constraint_satisfied()
|
|
310
|
+
};
|
|
311
|
+
if satisfied {
|
|
312
|
+
break;
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
// Check stop sequences
|
|
317
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
|
318
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
|
319
|
+
break;
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
Ok(if config.include_prompt {
|
|
324
|
+
all_tokens
|
|
325
|
+
} else {
|
|
326
|
+
all_tokens[start_gen..].to_vec()
|
|
327
|
+
})
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
impl TextGenerator for Phi {
|
|
332
|
+
fn generate(
|
|
333
|
+
&mut self,
|
|
334
|
+
prompt: &str,
|
|
335
|
+
config: &GenerationConfig,
|
|
336
|
+
) -> CandleResult<String> {
|
|
337
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
338
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
|
339
|
+
|
|
340
|
+
if config.debug_tokens {
|
|
341
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
|
342
|
+
} else {
|
|
343
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
fn generate_stream(
|
|
348
|
+
&mut self,
|
|
349
|
+
prompt: &str,
|
|
350
|
+
config: &GenerationConfig,
|
|
351
|
+
mut callback: impl FnMut(&str),
|
|
352
|
+
) -> CandleResult<String> {
|
|
353
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
|
354
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
|
355
|
+
self.tokenizer.decode(&output_tokens, true)
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
fn model_name(&self) -> &str {
|
|
359
|
+
&self.model_id
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
fn device(&self) -> &Device {
|
|
363
|
+
&self.device
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
fn clear_cache(&mut self) {
|
|
367
|
+
self.clear_kv_cache();
|
|
368
|
+
}
|
|
369
|
+
}
|