red-candle 1.0.1 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +57 -4
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/build.rs +6 -5
- data/ext/candle/extconf.rb +5 -6
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +123 -0
- data/ext/candle/src/llm/generation_config.rs +5 -0
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +285 -0
- data/ext/candle/src/llm/quantized_gguf.rs +155 -4
- data/ext/candle/src/llm/qwen.rs +229 -0
- data/ext/candle/src/llm/text_generation.rs +66 -2
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +42 -4
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/build_info.rb +2 -2
- data/lib/candle/llm.rb +109 -3
- data/lib/candle/version.rb +1 -1
- data/lib/red-candle.rb +1 -0
- metadata +15 -4
@@ -0,0 +1,123 @@
|
|
1
|
+
#[cfg(test)]
|
2
|
+
mod constrained_generation_tests {
|
3
|
+
use super::super::*;
|
4
|
+
use crate::structured::{VocabularyAdapter, SchemaProcessor};
|
5
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
6
|
+
|
7
|
+
#[tokio::test]
|
8
|
+
async fn test_constrained_vs_unconstrained_generation() {
|
9
|
+
// This test demonstrates the difference between constrained and unconstrained generation
|
10
|
+
|
11
|
+
// Load a tokenizer for testing
|
12
|
+
if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
|
13
|
+
let wrapper = TokenizerWrapper::new(tokenizer);
|
14
|
+
|
15
|
+
// Create vocabulary adapter
|
16
|
+
let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
|
17
|
+
.expect("Should create vocabulary");
|
18
|
+
|
19
|
+
// Create schema processor
|
20
|
+
let processor = SchemaProcessor::new();
|
21
|
+
|
22
|
+
// Define a simple JSON schema for a yes/no response
|
23
|
+
let schema = r#"{
|
24
|
+
"type": "object",
|
25
|
+
"properties": {
|
26
|
+
"answer": {
|
27
|
+
"type": "string",
|
28
|
+
"enum": ["yes", "no"]
|
29
|
+
}
|
30
|
+
},
|
31
|
+
"required": ["answer"]
|
32
|
+
}"#;
|
33
|
+
|
34
|
+
// Process schema into Index
|
35
|
+
let index = processor.process_schema(schema, &vocabulary)
|
36
|
+
.expect("Should process schema");
|
37
|
+
|
38
|
+
// Test configuration with constraint
|
39
|
+
let mut config_with_constraint = GenerationConfig::default();
|
40
|
+
config_with_constraint.constraint = Some(index.clone());
|
41
|
+
config_with_constraint.max_length = 50;
|
42
|
+
|
43
|
+
// Test configuration without constraint
|
44
|
+
let config_without_constraint = GenerationConfig::default();
|
45
|
+
|
46
|
+
// Create text generation instances
|
47
|
+
let mut gen_constrained = TextGeneration::from_config(&config_with_constraint);
|
48
|
+
let mut gen_unconstrained = TextGeneration::from_config(&config_without_constraint);
|
49
|
+
|
50
|
+
// Set EOS token
|
51
|
+
gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
|
52
|
+
gen_unconstrained.set_eos_token_id(102);
|
53
|
+
|
54
|
+
// Constraints are set internally - we can't directly verify them
|
55
|
+
// but we can test their effects in actual generation
|
56
|
+
}
|
57
|
+
}
|
58
|
+
|
59
|
+
#[test]
|
60
|
+
fn test_constraint_configuration() {
|
61
|
+
// Test that we can create a TextGeneration with constraints
|
62
|
+
let config = GenerationConfig::default();
|
63
|
+
let _text_gen = TextGeneration::from_config(&config);
|
64
|
+
|
65
|
+
// Test that we can create a TextGeneration from config
|
66
|
+
// Constraints are private implementation details
|
67
|
+
}
|
68
|
+
|
69
|
+
#[test]
|
70
|
+
fn test_repetition_penalty() {
|
71
|
+
use candle_core::{Tensor, Device};
|
72
|
+
|
73
|
+
let device = Device::Cpu;
|
74
|
+
let vocab_size = 10;
|
75
|
+
|
76
|
+
// Create logits with some positive and negative values
|
77
|
+
let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
|
78
|
+
let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
|
79
|
+
|
80
|
+
// Create text generation with some tokens
|
81
|
+
let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
|
82
|
+
text_gen.push_token(0); // Token that had logit 1.0
|
83
|
+
text_gen.push_token(2); // Token that had logit 2.0
|
84
|
+
text_gen.push_token(5); // Token that had logit 3.0
|
85
|
+
|
86
|
+
// Apply repetition penalty
|
87
|
+
text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
|
88
|
+
|
89
|
+
let penalized = logits.to_vec1::<f32>().unwrap();
|
90
|
+
|
91
|
+
// Check that tokens in context were penalized
|
92
|
+
assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
|
93
|
+
assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
|
94
|
+
assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
|
95
|
+
|
96
|
+
// Check that other tokens remain unchanged
|
97
|
+
assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
|
98
|
+
assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
|
99
|
+
}
|
100
|
+
|
101
|
+
#[test]
|
102
|
+
fn test_stop_conditions() {
|
103
|
+
let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
|
104
|
+
text_gen.set_eos_token_id(50256); // Common EOS token
|
105
|
+
|
106
|
+
// Test max length stop
|
107
|
+
for i in 0..10 {
|
108
|
+
text_gen.push_token(i);
|
109
|
+
}
|
110
|
+
assert!(text_gen.should_stop(100, 10), "Should stop at max length");
|
111
|
+
assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
|
112
|
+
|
113
|
+
// Test EOS token stop
|
114
|
+
assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
|
115
|
+
assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
|
116
|
+
|
117
|
+
// Test stop sequences
|
118
|
+
let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
|
119
|
+
assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
|
120
|
+
assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
|
121
|
+
assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
|
122
|
+
}
|
123
|
+
}
|
@@ -1,4 +1,6 @@
|
|
1
1
|
use std::time::{SystemTime, UNIX_EPOCH};
|
2
|
+
use std::sync::Arc;
|
3
|
+
use crate::structured::Index;
|
2
4
|
|
3
5
|
/// Configuration for text generation
|
4
6
|
#[derive(Debug, Clone)]
|
@@ -23,6 +25,8 @@ pub struct GenerationConfig {
|
|
23
25
|
pub include_prompt: bool,
|
24
26
|
/// Whether to show raw tokens during generation (for debugging)
|
25
27
|
pub debug_tokens: bool,
|
28
|
+
/// Optional constraint index for structured generation
|
29
|
+
pub constraint: Option<Arc<Index>>,
|
26
30
|
}
|
27
31
|
|
28
32
|
/// Generate a random seed based on current time
|
@@ -46,6 +50,7 @@ impl Default for GenerationConfig {
|
|
46
50
|
stop_sequences: vec![],
|
47
51
|
include_prompt: false,
|
48
52
|
debug_tokens: false,
|
53
|
+
constraint: None,
|
49
54
|
}
|
50
55
|
}
|
51
56
|
}
|
data/ext/candle/src/llm/mod.rs
CHANGED
@@ -3,6 +3,8 @@ use candle_core::{Device, Result as CandleResult};
|
|
3
3
|
pub mod mistral;
|
4
4
|
pub mod llama;
|
5
5
|
pub mod gemma;
|
6
|
+
pub mod qwen;
|
7
|
+
pub mod phi;
|
6
8
|
pub mod generation_config;
|
7
9
|
pub mod text_generation;
|
8
10
|
pub mod quantized_gguf;
|
@@ -12,6 +14,9 @@ pub use text_generation::TextGeneration;
|
|
12
14
|
pub use quantized_gguf::QuantizedGGUF;
|
13
15
|
pub use crate::tokenizer::TokenizerWrapper;
|
14
16
|
|
17
|
+
#[cfg(test)]
|
18
|
+
mod constrained_generation_test;
|
19
|
+
|
15
20
|
/// Trait for text generation models
|
16
21
|
pub trait TextGenerator: Send + Sync {
|
17
22
|
/// Generate text from a prompt
|
@@ -0,0 +1,285 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_transformers::models::phi::{Config, Model as PhiModel};
|
3
|
+
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model};
|
4
|
+
use hf_hub::api::tokio::Api;
|
5
|
+
use tokenizers::Tokenizer;
|
6
|
+
|
7
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
8
|
+
|
9
|
+
/// Phi model wrapper for text generation
|
10
|
+
pub struct Phi {
|
11
|
+
model: PhiVariant,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
}
|
17
|
+
|
18
|
+
enum PhiVariant {
|
19
|
+
Phi2(PhiModel),
|
20
|
+
Phi3(Phi3Model),
|
21
|
+
}
|
22
|
+
|
23
|
+
impl Phi {
|
24
|
+
/// Get the tokenizer
|
25
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
26
|
+
&self.tokenizer
|
27
|
+
}
|
28
|
+
|
29
|
+
/// Clear the KV cache between generations
|
30
|
+
pub fn clear_kv_cache(&mut self) {
|
31
|
+
match &mut self.model {
|
32
|
+
PhiVariant::Phi2(model) => model.clear_kv_cache(),
|
33
|
+
PhiVariant::Phi3(model) => model.clear_kv_cache(),
|
34
|
+
}
|
35
|
+
}
|
36
|
+
|
37
|
+
/// Load a Phi model from HuggingFace
|
38
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
39
|
+
let api = Api::new()
|
40
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
41
|
+
|
42
|
+
let repo = api.model(model_id.to_string());
|
43
|
+
|
44
|
+
// Download configuration
|
45
|
+
let config_filename = repo.get("config.json").await
|
46
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
47
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
48
|
+
|
49
|
+
// Download tokenizer
|
50
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
51
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
52
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
53
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
54
|
+
|
55
|
+
// Determine EOS token
|
56
|
+
let vocab = tokenizer.get_vocab(true);
|
57
|
+
let eos_token_id = vocab.get("<|endoftext|>")
|
58
|
+
.or_else(|| vocab.get("<|end|>"))
|
59
|
+
.or_else(|| vocab.get("</s>"))
|
60
|
+
.copied()
|
61
|
+
.unwrap_or(50256); // Default GPT-2 style EOS token
|
62
|
+
|
63
|
+
// Determine model variant based on model_id or config
|
64
|
+
let is_phi3 = model_id.contains("phi-3") || model_id.contains("Phi-3");
|
65
|
+
|
66
|
+
// Download model weights (handle both single and sharded files)
|
67
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
68
|
+
vec![single_file]
|
69
|
+
} else {
|
70
|
+
// Try to find sharded model files
|
71
|
+
let mut sharded_files = Vec::new();
|
72
|
+
let mut index = 1;
|
73
|
+
loop {
|
74
|
+
// Try common shard counts
|
75
|
+
let mut found = false;
|
76
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
|
77
|
+
let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
|
78
|
+
if let Ok(file) = repo.get(&filename).await {
|
79
|
+
sharded_files.push(file);
|
80
|
+
found = true;
|
81
|
+
break;
|
82
|
+
}
|
83
|
+
}
|
84
|
+
if !found {
|
85
|
+
break;
|
86
|
+
}
|
87
|
+
index += 1;
|
88
|
+
}
|
89
|
+
|
90
|
+
if sharded_files.is_empty() {
|
91
|
+
return Err(candle_core::Error::Msg(
|
92
|
+
"Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
|
93
|
+
));
|
94
|
+
}
|
95
|
+
sharded_files
|
96
|
+
};
|
97
|
+
|
98
|
+
let model = if is_phi3 {
|
99
|
+
// Load Phi3 model
|
100
|
+
let config: Phi3Config = serde_json::from_str(&config_str)
|
101
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi3 config: {}", e)))?;
|
102
|
+
|
103
|
+
let vb = unsafe {
|
104
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
105
|
+
};
|
106
|
+
|
107
|
+
let model = Phi3Model::new(&config, vb)?;
|
108
|
+
PhiVariant::Phi3(model)
|
109
|
+
} else {
|
110
|
+
// Load Phi2 model
|
111
|
+
let config: Config = serde_json::from_str(&config_str)
|
112
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse Phi config: {}", e)))?;
|
113
|
+
|
114
|
+
let vb = unsafe {
|
115
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
116
|
+
};
|
117
|
+
|
118
|
+
let model = PhiModel::new(&config, vb)?;
|
119
|
+
PhiVariant::Phi2(model)
|
120
|
+
};
|
121
|
+
|
122
|
+
Ok(Self {
|
123
|
+
model,
|
124
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
125
|
+
device,
|
126
|
+
model_id: model_id.to_string(),
|
127
|
+
eos_token_id,
|
128
|
+
})
|
129
|
+
}
|
130
|
+
|
131
|
+
/// Apply Phi chat template to messages
|
132
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
133
|
+
let mut prompt = String::new();
|
134
|
+
|
135
|
+
// Phi-3 uses a specific format
|
136
|
+
if matches!(self.model, PhiVariant::Phi3(_)) {
|
137
|
+
for message in messages {
|
138
|
+
let role = message["role"].as_str().unwrap_or("");
|
139
|
+
let content = message["content"].as_str().unwrap_or("");
|
140
|
+
|
141
|
+
match role {
|
142
|
+
"system" => {
|
143
|
+
prompt.push_str(&format!("<|system|>\n{}<|end|>\n", content));
|
144
|
+
}
|
145
|
+
"user" => {
|
146
|
+
prompt.push_str(&format!("<|user|>\n{}<|end|>\n", content));
|
147
|
+
}
|
148
|
+
"assistant" => {
|
149
|
+
prompt.push_str(&format!("<|assistant|>\n{}<|end|>\n", content));
|
150
|
+
}
|
151
|
+
_ => {}
|
152
|
+
}
|
153
|
+
}
|
154
|
+
prompt.push_str("<|assistant|>\n");
|
155
|
+
} else {
|
156
|
+
// Phi-2 uses a simpler format
|
157
|
+
for message in messages {
|
158
|
+
let role = message["role"].as_str().unwrap_or("");
|
159
|
+
let content = message["content"].as_str().unwrap_or("");
|
160
|
+
|
161
|
+
match role {
|
162
|
+
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
163
|
+
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
164
|
+
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
165
|
+
_ => {}
|
166
|
+
}
|
167
|
+
}
|
168
|
+
prompt.push_str("Assistant: ");
|
169
|
+
}
|
170
|
+
|
171
|
+
Ok(prompt)
|
172
|
+
}
|
173
|
+
|
174
|
+
fn generate_tokens(
|
175
|
+
&mut self,
|
176
|
+
prompt_tokens: Vec<u32>,
|
177
|
+
config: &GenerationConfig,
|
178
|
+
mut callback: Option<impl FnMut(&str)>,
|
179
|
+
) -> CandleResult<Vec<u32>> {
|
180
|
+
let mut text_gen = TextGeneration::from_config(config);
|
181
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
182
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
183
|
+
|
184
|
+
let mut all_tokens = prompt_tokens.clone();
|
185
|
+
let start_gen = all_tokens.len();
|
186
|
+
|
187
|
+
for index in 0..config.max_length {
|
188
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
189
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
190
|
+
let ctxt = &all_tokens[start_pos..];
|
191
|
+
|
192
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
193
|
+
let logits = match &mut self.model {
|
194
|
+
PhiVariant::Phi2(model) => model.forward(&input)?,
|
195
|
+
PhiVariant::Phi3(model) => model.forward(&input, start_pos)?,
|
196
|
+
};
|
197
|
+
let logits = logits.squeeze(0)?;
|
198
|
+
|
199
|
+
// Handle different output shapes
|
200
|
+
let logits = if logits.dims().len() == 2 {
|
201
|
+
let seq_len = logits.dim(0)?;
|
202
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
203
|
+
} else {
|
204
|
+
logits
|
205
|
+
};
|
206
|
+
|
207
|
+
let logits = logits.to_dtype(DType::F32)?;
|
208
|
+
|
209
|
+
let next_token = text_gen.sample_next_token(
|
210
|
+
&logits,
|
211
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
212
|
+
)?;
|
213
|
+
|
214
|
+
all_tokens.push(next_token);
|
215
|
+
|
216
|
+
// Stream callback
|
217
|
+
if let Some(ref mut cb) = callback {
|
218
|
+
if config.debug_tokens {
|
219
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
220
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
221
|
+
} else {
|
222
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
223
|
+
cb(&decoded_text);
|
224
|
+
}
|
225
|
+
}
|
226
|
+
|
227
|
+
// Check stop conditions
|
228
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
229
|
+
break;
|
230
|
+
}
|
231
|
+
|
232
|
+
// Check stop sequences
|
233
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
234
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
235
|
+
break;
|
236
|
+
}
|
237
|
+
}
|
238
|
+
|
239
|
+
Ok(if config.include_prompt {
|
240
|
+
all_tokens
|
241
|
+
} else {
|
242
|
+
all_tokens[start_gen..].to_vec()
|
243
|
+
})
|
244
|
+
}
|
245
|
+
}
|
246
|
+
|
247
|
+
impl TextGenerator for Phi {
|
248
|
+
fn generate(
|
249
|
+
&mut self,
|
250
|
+
prompt: &str,
|
251
|
+
config: &GenerationConfig,
|
252
|
+
) -> CandleResult<String> {
|
253
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
254
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
255
|
+
|
256
|
+
if config.debug_tokens {
|
257
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
258
|
+
} else {
|
259
|
+
self.tokenizer.decode(&output_tokens, true)
|
260
|
+
}
|
261
|
+
}
|
262
|
+
|
263
|
+
fn generate_stream(
|
264
|
+
&mut self,
|
265
|
+
prompt: &str,
|
266
|
+
config: &GenerationConfig,
|
267
|
+
mut callback: impl FnMut(&str),
|
268
|
+
) -> CandleResult<String> {
|
269
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
270
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
271
|
+
self.tokenizer.decode(&output_tokens, true)
|
272
|
+
}
|
273
|
+
|
274
|
+
fn model_name(&self) -> &str {
|
275
|
+
&self.model_id
|
276
|
+
}
|
277
|
+
|
278
|
+
fn device(&self) -> &Device {
|
279
|
+
&self.device
|
280
|
+
}
|
281
|
+
|
282
|
+
fn clear_cache(&mut self) {
|
283
|
+
self.clear_kv_cache();
|
284
|
+
}
|
285
|
+
}
|