red-candle 1.0.0.pre.1 → 1.0.0.pre.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +12 -0
- data/LICENSE +22 -0
- data/Rakefile +95 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/lib.rs +6 -96
- data/ext/candle/src/llm/generation_config.rs +49 -0
- data/ext/candle/src/llm/mistral.rs +325 -0
- data/ext/candle/src/llm/mod.rs +68 -0
- data/ext/candle/src/llm/text_generation.rs +141 -0
- data/ext/candle/src/reranker.rs +267 -0
- data/ext/candle/src/ruby/device.rs +197 -0
- data/ext/candle/src/ruby/dtype.rs +37 -0
- data/ext/candle/src/ruby/embedding_model.rs +410 -0
- data/ext/candle/src/ruby/errors.rs +13 -0
- data/ext/candle/src/ruby/llm.rs +295 -0
- data/ext/candle/src/ruby/mod.rs +21 -0
- data/ext/candle/src/ruby/qtensor.rs +69 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/tensor.rs +654 -0
- data/ext/candle/src/ruby/utils.rs +88 -0
- data/lib/candle/version.rb +1 -1
- metadata +22 -1
@@ -0,0 +1,325 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_nn::VarBuilder;
|
3
|
+
use candle_transformers::models::mistral::{Config, Model as MistralModel};
|
4
|
+
use hf_hub::{api::tokio::Api, Repo};
|
5
|
+
use tokenizers::Tokenizer;
|
6
|
+
|
7
|
+
use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
8
|
+
|
9
|
+
#[derive(Debug)]
|
10
|
+
pub struct Mistral {
|
11
|
+
model: MistralModel,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
}
|
17
|
+
|
18
|
+
impl Mistral {
|
19
|
+
/// Clear the KV cache between generations
|
20
|
+
pub fn clear_kv_cache(&mut self) {
|
21
|
+
self.model.clear_kv_cache();
|
22
|
+
}
|
23
|
+
|
24
|
+
/// Load a Mistral model from HuggingFace Hub
|
25
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
26
|
+
let api = Api::new()
|
27
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
28
|
+
|
29
|
+
let repo = api.repo(Repo::model(model_id.to_string()));
|
30
|
+
|
31
|
+
// Download model files
|
32
|
+
let config_filename = repo
|
33
|
+
.get("config.json")
|
34
|
+
.await
|
35
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
36
|
+
|
37
|
+
let tokenizer_filename = repo
|
38
|
+
.get("tokenizer.json")
|
39
|
+
.await
|
40
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
41
|
+
|
42
|
+
// Try different file patterns for model weights
|
43
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
44
|
+
vec![single_file]
|
45
|
+
} else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
|
46
|
+
// Some Mistral models use consolidated.safetensors
|
47
|
+
vec![consolidated_file]
|
48
|
+
} else {
|
49
|
+
// Try to find sharded model files
|
50
|
+
let mut sharded_files = Vec::new();
|
51
|
+
let mut index = 1;
|
52
|
+
loop {
|
53
|
+
// Try common shard counts
|
54
|
+
let mut found = false;
|
55
|
+
for total in [2, 3, 4, 5, 6, 7, 8] {
|
56
|
+
let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
|
57
|
+
if let Ok(file) = repo.get(&filename).await {
|
58
|
+
sharded_files.push(file);
|
59
|
+
found = true;
|
60
|
+
break;
|
61
|
+
}
|
62
|
+
}
|
63
|
+
if !found {
|
64
|
+
break;
|
65
|
+
}
|
66
|
+
index += 1;
|
67
|
+
}
|
68
|
+
|
69
|
+
if sharded_files.is_empty() {
|
70
|
+
// Try single pytorch_model.bin as last resort (though we prefer safetensors)
|
71
|
+
if let Ok(_pytorch_file) = repo.get("pytorch_model.bin").await {
|
72
|
+
return Err(candle_core::Error::Msg(
|
73
|
+
"Only safetensors format is supported. This model uses pytorch_model.bin format.".to_string()
|
74
|
+
));
|
75
|
+
} else {
|
76
|
+
return Err(candle_core::Error::Msg(
|
77
|
+
"Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
|
78
|
+
));
|
79
|
+
}
|
80
|
+
}
|
81
|
+
sharded_files
|
82
|
+
};
|
83
|
+
|
84
|
+
// Load config
|
85
|
+
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
86
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
87
|
+
|
88
|
+
// Load tokenizer
|
89
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
90
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
91
|
+
|
92
|
+
let eos_token_id = tokenizer
|
93
|
+
.get_vocab(true)
|
94
|
+
.get("</s>")
|
95
|
+
.copied()
|
96
|
+
.unwrap_or(2);
|
97
|
+
|
98
|
+
// Load model weights
|
99
|
+
let vb = unsafe {
|
100
|
+
VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
101
|
+
};
|
102
|
+
|
103
|
+
let model = MistralModel::new(&config, vb)?;
|
104
|
+
|
105
|
+
Ok(Self {
|
106
|
+
model,
|
107
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
108
|
+
device,
|
109
|
+
model_id: model_id.to_string(),
|
110
|
+
eos_token_id,
|
111
|
+
})
|
112
|
+
}
|
113
|
+
|
114
|
+
/// Create from existing components (useful for testing)
|
115
|
+
pub fn new(
|
116
|
+
model: MistralModel,
|
117
|
+
tokenizer: Tokenizer,
|
118
|
+
device: Device,
|
119
|
+
model_id: String,
|
120
|
+
) -> Self {
|
121
|
+
let eos_token_id = tokenizer
|
122
|
+
.get_vocab(true)
|
123
|
+
.get("</s>")
|
124
|
+
.copied()
|
125
|
+
.unwrap_or(2);
|
126
|
+
|
127
|
+
Self {
|
128
|
+
model,
|
129
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
130
|
+
device,
|
131
|
+
model_id,
|
132
|
+
eos_token_id,
|
133
|
+
}
|
134
|
+
}
|
135
|
+
|
136
|
+
fn generate_tokens(
|
137
|
+
&mut self,
|
138
|
+
prompt_tokens: Vec<u32>,
|
139
|
+
config: &GenerationConfig,
|
140
|
+
mut callback: Option<impl FnMut(&str)>,
|
141
|
+
) -> CandleResult<Vec<u32>> {
|
142
|
+
let mut text_gen = TextGeneration::from_config(config);
|
143
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
144
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
145
|
+
|
146
|
+
let mut all_tokens = prompt_tokens.clone();
|
147
|
+
let start_gen = all_tokens.len();
|
148
|
+
|
149
|
+
for index in 0..config.max_length {
|
150
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
151
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
152
|
+
let ctxt = &all_tokens[start_pos..];
|
153
|
+
|
154
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
155
|
+
// Ensure input tensor is contiguous for Metal backend
|
156
|
+
let input = input.contiguous()?;
|
157
|
+
let logits = self.model.forward(&input, start_pos)?;
|
158
|
+
|
159
|
+
// The model returns logits of shape [batch_size, seq_len, vocab_size]
|
160
|
+
// We need to get the logits for the last token only
|
161
|
+
let logits = logits.squeeze(0)?; // Remove batch dimension
|
162
|
+
let logits = if logits.dims().len() == 2 {
|
163
|
+
// If we still have [seq_len, vocab_size], take the last token
|
164
|
+
let seq_len = logits.dim(0)?;
|
165
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
166
|
+
} else {
|
167
|
+
// Already [vocab_size]
|
168
|
+
logits
|
169
|
+
};
|
170
|
+
|
171
|
+
// Convert to F32 for sampling if needed
|
172
|
+
let logits = logits.to_dtype(DType::F32)?;
|
173
|
+
|
174
|
+
let next_token = text_gen.sample_next_token(
|
175
|
+
&logits,
|
176
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
177
|
+
)?;
|
178
|
+
|
179
|
+
all_tokens.push(next_token);
|
180
|
+
|
181
|
+
// Stream callback
|
182
|
+
if let Some(ref mut cb) = callback {
|
183
|
+
let token_text = self.tokenizer.token_to_piece(next_token)?;
|
184
|
+
cb(&token_text);
|
185
|
+
}
|
186
|
+
|
187
|
+
// Check stop conditions
|
188
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
189
|
+
break;
|
190
|
+
}
|
191
|
+
|
192
|
+
// Check stop sequences
|
193
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
194
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
195
|
+
break;
|
196
|
+
}
|
197
|
+
}
|
198
|
+
|
199
|
+
Ok(if config.include_prompt {
|
200
|
+
all_tokens
|
201
|
+
} else {
|
202
|
+
all_tokens[start_gen..].to_vec()
|
203
|
+
})
|
204
|
+
}
|
205
|
+
|
206
|
+
fn generate_tokens_decoded(
|
207
|
+
&mut self,
|
208
|
+
prompt_tokens: Vec<u32>,
|
209
|
+
config: &GenerationConfig,
|
210
|
+
mut callback: Option<impl FnMut(&str)>,
|
211
|
+
) -> CandleResult<Vec<u32>> {
|
212
|
+
let mut text_gen = TextGeneration::from_config(config);
|
213
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
214
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
215
|
+
|
216
|
+
let mut all_tokens = prompt_tokens.clone();
|
217
|
+
let start_gen = all_tokens.len();
|
218
|
+
|
219
|
+
// For incremental decoding
|
220
|
+
let mut previously_decoded = String::new();
|
221
|
+
|
222
|
+
for index in 0..config.max_length {
|
223
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
224
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
225
|
+
let ctxt = &all_tokens[start_pos..];
|
226
|
+
|
227
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
228
|
+
// Ensure input tensor is contiguous for Metal backend
|
229
|
+
let input = input.contiguous()?;
|
230
|
+
let logits = self.model.forward(&input, start_pos)?;
|
231
|
+
|
232
|
+
// The model returns logits of shape [batch_size, seq_len, vocab_size]
|
233
|
+
// We need to get the logits for the last token only
|
234
|
+
let logits = logits.squeeze(0)?; // Remove batch dimension
|
235
|
+
let logits = if logits.dims().len() == 2 {
|
236
|
+
// If we still have [seq_len, vocab_size], take the last token
|
237
|
+
let seq_len = logits.dim(0)?;
|
238
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
239
|
+
} else {
|
240
|
+
// Already [vocab_size]
|
241
|
+
logits
|
242
|
+
};
|
243
|
+
|
244
|
+
// Convert to F32 for sampling if needed
|
245
|
+
let logits = logits.to_dtype(DType::F32)?;
|
246
|
+
|
247
|
+
let next_token = text_gen.sample_next_token(
|
248
|
+
&logits,
|
249
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
250
|
+
)?;
|
251
|
+
|
252
|
+
all_tokens.push(next_token);
|
253
|
+
|
254
|
+
// Stream callback with incremental decoding
|
255
|
+
if let Some(ref mut cb) = callback {
|
256
|
+
// Decode all generated tokens so far
|
257
|
+
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
258
|
+
|
259
|
+
// Only emit the new text since last callback
|
260
|
+
if current_decoded.len() > previously_decoded.len() {
|
261
|
+
let new_text = ¤t_decoded[previously_decoded.len()..];
|
262
|
+
cb(new_text);
|
263
|
+
previously_decoded = current_decoded;
|
264
|
+
}
|
265
|
+
}
|
266
|
+
|
267
|
+
// Check stop conditions
|
268
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
269
|
+
break;
|
270
|
+
}
|
271
|
+
|
272
|
+
// Check stop sequences
|
273
|
+
let generated_text = if callback.is_some() {
|
274
|
+
previously_decoded.clone()
|
275
|
+
} else {
|
276
|
+
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
277
|
+
};
|
278
|
+
|
279
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
280
|
+
break;
|
281
|
+
}
|
282
|
+
}
|
283
|
+
|
284
|
+
Ok(if config.include_prompt {
|
285
|
+
all_tokens
|
286
|
+
} else {
|
287
|
+
all_tokens[start_gen..].to_vec()
|
288
|
+
})
|
289
|
+
}
|
290
|
+
}
|
291
|
+
|
292
|
+
impl TextGenerator for Mistral {
|
293
|
+
fn generate(
|
294
|
+
&mut self,
|
295
|
+
prompt: &str,
|
296
|
+
config: &GenerationConfig,
|
297
|
+
) -> CandleResult<String> {
|
298
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
299
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
300
|
+
self.tokenizer.decode(&output_tokens, true)
|
301
|
+
}
|
302
|
+
|
303
|
+
fn generate_stream(
|
304
|
+
&mut self,
|
305
|
+
prompt: &str,
|
306
|
+
config: &GenerationConfig,
|
307
|
+
mut callback: impl FnMut(&str),
|
308
|
+
) -> CandleResult<String> {
|
309
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
310
|
+
let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
|
311
|
+
self.tokenizer.decode(&output_tokens, true)
|
312
|
+
}
|
313
|
+
|
314
|
+
fn model_name(&self) -> &str {
|
315
|
+
&self.model_id
|
316
|
+
}
|
317
|
+
|
318
|
+
fn device(&self) -> &Device {
|
319
|
+
&self.device
|
320
|
+
}
|
321
|
+
|
322
|
+
fn clear_cache(&mut self) {
|
323
|
+
self.clear_kv_cache();
|
324
|
+
}
|
325
|
+
}
|
@@ -0,0 +1,68 @@
|
|
1
|
+
use candle_core::{Device, Result as CandleResult};
|
2
|
+
use tokenizers::Tokenizer;
|
3
|
+
|
4
|
+
pub mod mistral;
|
5
|
+
pub mod generation_config;
|
6
|
+
pub mod text_generation;
|
7
|
+
|
8
|
+
pub use generation_config::GenerationConfig;
|
9
|
+
pub use text_generation::TextGeneration;
|
10
|
+
|
11
|
+
/// Trait for text generation models
|
12
|
+
pub trait TextGenerator: Send + Sync {
|
13
|
+
/// Generate text from a prompt
|
14
|
+
fn generate(
|
15
|
+
&mut self,
|
16
|
+
prompt: &str,
|
17
|
+
config: &GenerationConfig,
|
18
|
+
) -> CandleResult<String>;
|
19
|
+
|
20
|
+
/// Generate text with streaming callback
|
21
|
+
fn generate_stream(
|
22
|
+
&mut self,
|
23
|
+
prompt: &str,
|
24
|
+
config: &GenerationConfig,
|
25
|
+
callback: impl FnMut(&str),
|
26
|
+
) -> CandleResult<String>;
|
27
|
+
|
28
|
+
/// Get the model's name
|
29
|
+
fn model_name(&self) -> &str;
|
30
|
+
|
31
|
+
/// Get the device the model is running on
|
32
|
+
fn device(&self) -> &Device;
|
33
|
+
|
34
|
+
/// Clear any cached state (like KV cache)
|
35
|
+
fn clear_cache(&mut self);
|
36
|
+
}
|
37
|
+
|
38
|
+
/// Common structure for managing tokenizer
|
39
|
+
#[derive(Debug)]
|
40
|
+
pub struct TokenizerWrapper {
|
41
|
+
tokenizer: Tokenizer,
|
42
|
+
}
|
43
|
+
|
44
|
+
impl TokenizerWrapper {
|
45
|
+
pub fn new(tokenizer: Tokenizer) -> Self {
|
46
|
+
Self { tokenizer }
|
47
|
+
}
|
48
|
+
|
49
|
+
pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
|
50
|
+
let encoding = self.tokenizer
|
51
|
+
.encode(text, add_special_tokens)
|
52
|
+
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
|
53
|
+
Ok(encoding.get_ids().to_vec())
|
54
|
+
}
|
55
|
+
|
56
|
+
pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
|
57
|
+
self.tokenizer
|
58
|
+
.decode(tokens, skip_special_tokens)
|
59
|
+
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
|
60
|
+
}
|
61
|
+
|
62
|
+
pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
|
63
|
+
self.tokenizer
|
64
|
+
.id_to_token(token)
|
65
|
+
.map(|s| s.to_string())
|
66
|
+
.ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
|
67
|
+
}
|
68
|
+
}
|
@@ -0,0 +1,141 @@
|
|
1
|
+
use candle_core::{Result as CandleResult, Tensor};
|
2
|
+
use candle_transformers::generation::LogitsProcessor;
|
3
|
+
use rand::{rngs::StdRng, SeedableRng};
|
4
|
+
|
5
|
+
use super::GenerationConfig;
|
6
|
+
|
7
|
+
/// Helper struct for text generation process
|
8
|
+
pub struct TextGeneration {
|
9
|
+
#[allow(dead_code)]
|
10
|
+
rng: StdRng,
|
11
|
+
logits_processor: LogitsProcessor,
|
12
|
+
tokens: Vec<u32>,
|
13
|
+
eos_token_id: Option<u32>,
|
14
|
+
}
|
15
|
+
|
16
|
+
impl TextGeneration {
|
17
|
+
pub fn new(
|
18
|
+
seed: u64,
|
19
|
+
temperature: Option<f64>,
|
20
|
+
top_p: Option<f64>,
|
21
|
+
_top_k: Option<usize>,
|
22
|
+
_repetition_penalty: f32,
|
23
|
+
_repetition_penalty_last_n: usize,
|
24
|
+
) -> Self {
|
25
|
+
let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
|
26
|
+
|
27
|
+
Self {
|
28
|
+
rng: StdRng::seed_from_u64(seed),
|
29
|
+
logits_processor,
|
30
|
+
tokens: Vec::new(),
|
31
|
+
eos_token_id: None,
|
32
|
+
}
|
33
|
+
}
|
34
|
+
|
35
|
+
pub fn from_config(config: &GenerationConfig) -> Self {
|
36
|
+
Self::new(
|
37
|
+
config.seed,
|
38
|
+
Some(config.temperature),
|
39
|
+
config.top_p,
|
40
|
+
config.top_k,
|
41
|
+
config.repetition_penalty,
|
42
|
+
config.repetition_penalty_last_n,
|
43
|
+
)
|
44
|
+
}
|
45
|
+
|
46
|
+
pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
|
47
|
+
self.eos_token_id = Some(eos_token_id);
|
48
|
+
}
|
49
|
+
|
50
|
+
pub fn set_tokens(&mut self, tokens: Vec<u32>) {
|
51
|
+
self.tokens = tokens;
|
52
|
+
}
|
53
|
+
|
54
|
+
pub fn get_tokens(&self) -> &[u32] {
|
55
|
+
&self.tokens
|
56
|
+
}
|
57
|
+
|
58
|
+
pub fn push_token(&mut self, token: u32) {
|
59
|
+
self.tokens.push(token);
|
60
|
+
}
|
61
|
+
|
62
|
+
/// Apply repetition penalty to logits
|
63
|
+
pub fn apply_repetition_penalty(
|
64
|
+
&self,
|
65
|
+
logits: &mut Tensor,
|
66
|
+
penalty: f32,
|
67
|
+
context_size: usize,
|
68
|
+
) -> CandleResult<()> {
|
69
|
+
if penalty == 1.0 {
|
70
|
+
return Ok(());
|
71
|
+
}
|
72
|
+
|
73
|
+
let device = logits.device();
|
74
|
+
let vocab_size = logits.dims1()?;
|
75
|
+
|
76
|
+
// Get the context tokens to apply penalty to
|
77
|
+
let start = self.tokens.len().saturating_sub(context_size);
|
78
|
+
let context_tokens = &self.tokens[start..];
|
79
|
+
|
80
|
+
// Apply penalty to tokens that appear in the context
|
81
|
+
let mut logits_vec = logits.to_vec1::<f32>()?;
|
82
|
+
for &token in context_tokens {
|
83
|
+
if (token as usize) < vocab_size {
|
84
|
+
let idx = token as usize;
|
85
|
+
if logits_vec[idx] > 0.0 {
|
86
|
+
logits_vec[idx] /= penalty;
|
87
|
+
} else {
|
88
|
+
logits_vec[idx] *= penalty;
|
89
|
+
}
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
*logits = Tensor::from_vec(logits_vec, vocab_size, device)?;
|
94
|
+
Ok(())
|
95
|
+
}
|
96
|
+
|
97
|
+
/// Sample next token from logits
|
98
|
+
pub fn sample_next_token(
|
99
|
+
&mut self,
|
100
|
+
logits: &Tensor,
|
101
|
+
repetition_penalty: Option<(f32, usize)>,
|
102
|
+
) -> CandleResult<u32> {
|
103
|
+
let mut logits = logits.clone();
|
104
|
+
|
105
|
+
// Apply repetition penalty if specified
|
106
|
+
if let Some((penalty, last_n)) = repetition_penalty {
|
107
|
+
self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
|
108
|
+
}
|
109
|
+
|
110
|
+
// Sample token
|
111
|
+
let next_token = self.logits_processor.sample(&logits)?;
|
112
|
+
self.tokens.push(next_token);
|
113
|
+
|
114
|
+
Ok(next_token)
|
115
|
+
}
|
116
|
+
|
117
|
+
/// Check if we should stop generation
|
118
|
+
pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
|
119
|
+
if self.tokens.len() >= max_length {
|
120
|
+
return true;
|
121
|
+
}
|
122
|
+
|
123
|
+
if let Some(eos) = self.eos_token_id {
|
124
|
+
if token == eos {
|
125
|
+
return true;
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
false
|
130
|
+
}
|
131
|
+
|
132
|
+
/// Check if the generated text ends with any stop sequence
|
133
|
+
pub fn check_stop_sequences(&self, text: &str, stop_sequences: &[String]) -> bool {
|
134
|
+
for seq in stop_sequences {
|
135
|
+
if text.ends_with(seq) {
|
136
|
+
return true;
|
137
|
+
}
|
138
|
+
}
|
139
|
+
false
|
140
|
+
}
|
141
|
+
}
|