red-candle 1.0.0.pre.6 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +481 -4
- data/Rakefile +1 -3
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +21 -79
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +21 -79
- data/ext/candle/src/llm/mistral.rs +21 -89
- data/ext/candle/src/llm/mod.rs +3 -33
- data/ext/candle/src/llm/quantized_gguf.rs +501 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -34
- data/ext/candle/src/ruby/llm.rs +110 -49
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +127 -3
- data/ext/candle/src/ruby/qtensor.rs +0 -69
data/ext/candle/src/llm/gemma.rs
CHANGED
@@ -21,6 +21,11 @@ impl Gemma {
|
|
21
21
|
self.model.clear_kv_cache();
|
22
22
|
}
|
23
23
|
|
24
|
+
/// Get the tokenizer
|
25
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
26
|
+
&self.tokenizer
|
27
|
+
}
|
28
|
+
|
24
29
|
/// Load a Gemma model from HuggingFace Hub
|
25
30
|
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
26
31
|
let api = Api::new()
|
@@ -169,77 +174,14 @@ impl Gemma {
|
|
169
174
|
|
170
175
|
// Stream callback
|
171
176
|
if let Some(ref mut cb) = callback {
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
// Check stop sequences
|
182
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
183
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
184
|
-
break;
|
185
|
-
}
|
186
|
-
}
|
187
|
-
|
188
|
-
Ok(if config.include_prompt {
|
189
|
-
all_tokens
|
190
|
-
} else {
|
191
|
-
all_tokens[start_gen..].to_vec()
|
192
|
-
})
|
193
|
-
}
|
194
|
-
|
195
|
-
fn generate_tokens_decoded(
|
196
|
-
&mut self,
|
197
|
-
prompt_tokens: Vec<u32>,
|
198
|
-
config: &GenerationConfig,
|
199
|
-
mut callback: Option<impl FnMut(&str)>,
|
200
|
-
) -> CandleResult<Vec<u32>> {
|
201
|
-
let mut text_gen = TextGeneration::from_config(config);
|
202
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
203
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
204
|
-
|
205
|
-
let mut all_tokens = prompt_tokens.clone();
|
206
|
-
let start_gen = all_tokens.len();
|
207
|
-
let mut previously_decoded = String::new();
|
208
|
-
|
209
|
-
for index in 0..config.max_length {
|
210
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
211
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
212
|
-
let ctxt = &all_tokens[start_pos..];
|
213
|
-
|
214
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
215
|
-
let input = input.contiguous()?;
|
216
|
-
let logits = self.model.forward(&input, start_pos)?;
|
217
|
-
|
218
|
-
let logits = logits.squeeze(0)?;
|
219
|
-
let logits = if logits.dims().len() == 2 {
|
220
|
-
let seq_len = logits.dim(0)?;
|
221
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
222
|
-
} else {
|
223
|
-
logits
|
224
|
-
};
|
225
|
-
|
226
|
-
let logits = logits.to_dtype(DType::F32)?;
|
227
|
-
|
228
|
-
let next_token = text_gen.sample_next_token(
|
229
|
-
&logits,
|
230
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
231
|
-
)?;
|
232
|
-
|
233
|
-
all_tokens.push(next_token);
|
234
|
-
|
235
|
-
// Stream callback with incremental decoding
|
236
|
-
if let Some(ref mut cb) = callback {
|
237
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
238
|
-
|
239
|
-
if current_decoded.len() > previously_decoded.len() {
|
240
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
241
|
-
cb(new_text);
|
242
|
-
previously_decoded = current_decoded;
|
177
|
+
if config.debug_tokens {
|
178
|
+
// In debug mode, only show debug tokens
|
179
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
180
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
181
|
+
} else {
|
182
|
+
// Normal mode: use incremental decoding for proper text
|
183
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
184
|
+
cb(&decoded_text);
|
243
185
|
}
|
244
186
|
}
|
245
187
|
|
@@ -249,12 +191,7 @@ impl Gemma {
|
|
249
191
|
}
|
250
192
|
|
251
193
|
// Check stop sequences
|
252
|
-
let generated_text =
|
253
|
-
previously_decoded.clone()
|
254
|
-
} else {
|
255
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
256
|
-
};
|
257
|
-
|
194
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
258
195
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
259
196
|
break;
|
260
197
|
}
|
@@ -312,7 +249,12 @@ impl TextGenerator for Gemma {
|
|
312
249
|
) -> CandleResult<String> {
|
313
250
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
314
251
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
315
|
-
|
252
|
+
|
253
|
+
if config.debug_tokens {
|
254
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
255
|
+
} else {
|
256
|
+
self.tokenizer.decode(&output_tokens, true)
|
257
|
+
}
|
316
258
|
}
|
317
259
|
|
318
260
|
fn generate_stream(
|
@@ -322,7 +264,7 @@ impl TextGenerator for Gemma {
|
|
322
264
|
mut callback: impl FnMut(&str),
|
323
265
|
) -> CandleResult<String> {
|
324
266
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
325
|
-
let output_tokens = self.
|
267
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
326
268
|
self.tokenizer.decode(&output_tokens, true)
|
327
269
|
}
|
328
270
|
|
@@ -21,6 +21,8 @@ pub struct GenerationConfig {
|
|
21
21
|
pub stop_sequences: Vec<String>,
|
22
22
|
/// Whether to return the prompt in the output
|
23
23
|
pub include_prompt: bool,
|
24
|
+
/// Whether to show raw tokens during generation (for debugging)
|
25
|
+
pub debug_tokens: bool,
|
24
26
|
}
|
25
27
|
|
26
28
|
/// Generate a random seed based on current time
|
@@ -43,6 +45,7 @@ impl Default for GenerationConfig {
|
|
43
45
|
seed: random_seed(),
|
44
46
|
stop_sequences: vec![],
|
45
47
|
include_prompt: false,
|
48
|
+
debug_tokens: false,
|
46
49
|
}
|
47
50
|
}
|
48
51
|
}
|
data/ext/candle/src/llm/llama.rs
CHANGED
@@ -28,6 +28,11 @@ impl Llama {
|
|
28
28
|
}
|
29
29
|
}
|
30
30
|
|
31
|
+
/// Get the tokenizer
|
32
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
33
|
+
&self.tokenizer
|
34
|
+
}
|
35
|
+
|
31
36
|
/// Load a Llama model from HuggingFace Hub
|
32
37
|
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
33
38
|
let api = Api::new()
|
@@ -205,77 +210,14 @@ impl Llama {
|
|
205
210
|
|
206
211
|
// Stream callback
|
207
212
|
if let Some(ref mut cb) = callback {
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
// Check stop sequences
|
218
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
219
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
220
|
-
break;
|
221
|
-
}
|
222
|
-
}
|
223
|
-
|
224
|
-
Ok(if config.include_prompt {
|
225
|
-
all_tokens
|
226
|
-
} else {
|
227
|
-
all_tokens[start_gen..].to_vec()
|
228
|
-
})
|
229
|
-
}
|
230
|
-
|
231
|
-
fn generate_tokens_decoded(
|
232
|
-
&mut self,
|
233
|
-
prompt_tokens: Vec<u32>,
|
234
|
-
config: &GenerationConfig,
|
235
|
-
mut callback: Option<impl FnMut(&str)>,
|
236
|
-
) -> CandleResult<Vec<u32>> {
|
237
|
-
let mut text_gen = TextGeneration::from_config(config);
|
238
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
239
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
240
|
-
|
241
|
-
let mut all_tokens = prompt_tokens.clone();
|
242
|
-
let start_gen = all_tokens.len();
|
243
|
-
let mut previously_decoded = String::new();
|
244
|
-
|
245
|
-
for index in 0..config.max_length {
|
246
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
247
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
248
|
-
let ctxt = &all_tokens[start_pos..];
|
249
|
-
|
250
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
251
|
-
let input = input.contiguous()?;
|
252
|
-
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
253
|
-
|
254
|
-
let logits = logits.squeeze(0)?;
|
255
|
-
let logits = if logits.dims().len() == 2 {
|
256
|
-
let seq_len = logits.dim(0)?;
|
257
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
258
|
-
} else {
|
259
|
-
logits
|
260
|
-
};
|
261
|
-
|
262
|
-
let logits = logits.to_dtype(DType::F32)?;
|
263
|
-
|
264
|
-
let next_token = text_gen.sample_next_token(
|
265
|
-
&logits,
|
266
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
267
|
-
)?;
|
268
|
-
|
269
|
-
all_tokens.push(next_token);
|
270
|
-
|
271
|
-
// Stream callback with incremental decoding
|
272
|
-
if let Some(ref mut cb) = callback {
|
273
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
274
|
-
|
275
|
-
if current_decoded.len() > previously_decoded.len() {
|
276
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
277
|
-
cb(new_text);
|
278
|
-
previously_decoded = current_decoded;
|
213
|
+
if config.debug_tokens {
|
214
|
+
// In debug mode, only show debug tokens
|
215
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
216
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
217
|
+
} else {
|
218
|
+
// Normal mode: use incremental decoding for proper text
|
219
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
220
|
+
cb(&decoded_text);
|
279
221
|
}
|
280
222
|
}
|
281
223
|
|
@@ -285,12 +227,7 @@ impl Llama {
|
|
285
227
|
}
|
286
228
|
|
287
229
|
// Check stop sequences
|
288
|
-
let generated_text =
|
289
|
-
previously_decoded.clone()
|
290
|
-
} else {
|
291
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
292
|
-
};
|
293
|
-
|
230
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
294
231
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
295
232
|
break;
|
296
233
|
}
|
@@ -374,7 +311,12 @@ impl TextGenerator for Llama {
|
|
374
311
|
) -> CandleResult<String> {
|
375
312
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
376
313
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
377
|
-
|
314
|
+
|
315
|
+
if config.debug_tokens {
|
316
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
317
|
+
} else {
|
318
|
+
self.tokenizer.decode(&output_tokens, true)
|
319
|
+
}
|
378
320
|
}
|
379
321
|
|
380
322
|
fn generate_stream(
|
@@ -384,7 +326,7 @@ impl TextGenerator for Llama {
|
|
384
326
|
mut callback: impl FnMut(&str),
|
385
327
|
) -> CandleResult<String> {
|
386
328
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
387
|
-
let output_tokens = self.
|
329
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
388
330
|
self.tokenizer.decode(&output_tokens, true)
|
389
331
|
}
|
390
332
|
|
@@ -21,6 +21,11 @@ impl Mistral {
|
|
21
21
|
self.model.clear_kv_cache();
|
22
22
|
}
|
23
23
|
|
24
|
+
/// Get the tokenizer
|
25
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
26
|
+
&self.tokenizer
|
27
|
+
}
|
28
|
+
|
24
29
|
/// Load a Mistral model from HuggingFace Hub
|
25
30
|
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
26
31
|
let api = Api::new()
|
@@ -180,87 +185,14 @@ impl Mistral {
|
|
180
185
|
|
181
186
|
// Stream callback
|
182
187
|
if let Some(ref mut cb) = callback {
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
// Check stop sequences
|
193
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
194
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
195
|
-
break;
|
196
|
-
}
|
197
|
-
}
|
198
|
-
|
199
|
-
Ok(if config.include_prompt {
|
200
|
-
all_tokens
|
201
|
-
} else {
|
202
|
-
all_tokens[start_gen..].to_vec()
|
203
|
-
})
|
204
|
-
}
|
205
|
-
|
206
|
-
fn generate_tokens_decoded(
|
207
|
-
&mut self,
|
208
|
-
prompt_tokens: Vec<u32>,
|
209
|
-
config: &GenerationConfig,
|
210
|
-
mut callback: Option<impl FnMut(&str)>,
|
211
|
-
) -> CandleResult<Vec<u32>> {
|
212
|
-
let mut text_gen = TextGeneration::from_config(config);
|
213
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
214
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
215
|
-
|
216
|
-
let mut all_tokens = prompt_tokens.clone();
|
217
|
-
let start_gen = all_tokens.len();
|
218
|
-
|
219
|
-
// For incremental decoding
|
220
|
-
let mut previously_decoded = String::new();
|
221
|
-
|
222
|
-
for index in 0..config.max_length {
|
223
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
224
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
225
|
-
let ctxt = &all_tokens[start_pos..];
|
226
|
-
|
227
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
228
|
-
// Ensure input tensor is contiguous for Metal backend
|
229
|
-
let input = input.contiguous()?;
|
230
|
-
let logits = self.model.forward(&input, start_pos)?;
|
231
|
-
|
232
|
-
// The model returns logits of shape [batch_size, seq_len, vocab_size]
|
233
|
-
// We need to get the logits for the last token only
|
234
|
-
let logits = logits.squeeze(0)?; // Remove batch dimension
|
235
|
-
let logits = if logits.dims().len() == 2 {
|
236
|
-
// If we still have [seq_len, vocab_size], take the last token
|
237
|
-
let seq_len = logits.dim(0)?;
|
238
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
239
|
-
} else {
|
240
|
-
// Already [vocab_size]
|
241
|
-
logits
|
242
|
-
};
|
243
|
-
|
244
|
-
// Convert to F32 for sampling if needed
|
245
|
-
let logits = logits.to_dtype(DType::F32)?;
|
246
|
-
|
247
|
-
let next_token = text_gen.sample_next_token(
|
248
|
-
&logits,
|
249
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
250
|
-
)?;
|
251
|
-
|
252
|
-
all_tokens.push(next_token);
|
253
|
-
|
254
|
-
// Stream callback with incremental decoding
|
255
|
-
if let Some(ref mut cb) = callback {
|
256
|
-
// Decode all generated tokens so far
|
257
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
258
|
-
|
259
|
-
// Only emit the new text since last callback
|
260
|
-
if current_decoded.len() > previously_decoded.len() {
|
261
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
262
|
-
cb(new_text);
|
263
|
-
previously_decoded = current_decoded;
|
188
|
+
if config.debug_tokens {
|
189
|
+
// In debug mode, only show debug tokens
|
190
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
191
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
192
|
+
} else {
|
193
|
+
// Normal mode: use incremental decoding for proper text
|
194
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
195
|
+
cb(&decoded_text);
|
264
196
|
}
|
265
197
|
}
|
266
198
|
|
@@ -270,12 +202,7 @@ impl Mistral {
|
|
270
202
|
}
|
271
203
|
|
272
204
|
// Check stop sequences
|
273
|
-
let generated_text =
|
274
|
-
previously_decoded.clone()
|
275
|
-
} else {
|
276
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
277
|
-
};
|
278
|
-
|
205
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
279
206
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
280
207
|
break;
|
281
208
|
}
|
@@ -297,7 +224,12 @@ impl TextGenerator for Mistral {
|
|
297
224
|
) -> CandleResult<String> {
|
298
225
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
299
226
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
300
|
-
|
227
|
+
|
228
|
+
if config.debug_tokens {
|
229
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
230
|
+
} else {
|
231
|
+
self.tokenizer.decode(&output_tokens, true)
|
232
|
+
}
|
301
233
|
}
|
302
234
|
|
303
235
|
fn generate_stream(
|
@@ -307,7 +239,7 @@ impl TextGenerator for Mistral {
|
|
307
239
|
mut callback: impl FnMut(&str),
|
308
240
|
) -> CandleResult<String> {
|
309
241
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
310
|
-
let output_tokens = self.
|
242
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
311
243
|
self.tokenizer.decode(&output_tokens, true)
|
312
244
|
}
|
313
245
|
|
data/ext/candle/src/llm/mod.rs
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
use candle_core::{Device, Result as CandleResult};
|
2
|
-
use tokenizers::Tokenizer;
|
3
2
|
|
4
3
|
pub mod mistral;
|
5
4
|
pub mod llama;
|
6
5
|
pub mod gemma;
|
7
6
|
pub mod generation_config;
|
8
7
|
pub mod text_generation;
|
8
|
+
pub mod quantized_gguf;
|
9
9
|
|
10
10
|
pub use generation_config::GenerationConfig;
|
11
11
|
pub use text_generation::TextGeneration;
|
12
|
+
pub use quantized_gguf::QuantizedGGUF;
|
13
|
+
pub use crate::tokenizer::TokenizerWrapper;
|
12
14
|
|
13
15
|
/// Trait for text generation models
|
14
16
|
pub trait TextGenerator: Send + Sync {
|
@@ -35,36 +37,4 @@ pub trait TextGenerator: Send + Sync {
|
|
35
37
|
|
36
38
|
/// Clear any cached state (like KV cache)
|
37
39
|
fn clear_cache(&mut self);
|
38
|
-
}
|
39
|
-
|
40
|
-
/// Common structure for managing tokenizer
|
41
|
-
#[derive(Debug)]
|
42
|
-
pub struct TokenizerWrapper {
|
43
|
-
tokenizer: Tokenizer,
|
44
|
-
}
|
45
|
-
|
46
|
-
impl TokenizerWrapper {
|
47
|
-
pub fn new(tokenizer: Tokenizer) -> Self {
|
48
|
-
Self { tokenizer }
|
49
|
-
}
|
50
|
-
|
51
|
-
pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
|
52
|
-
let encoding = self.tokenizer
|
53
|
-
.encode(text, add_special_tokens)
|
54
|
-
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
|
55
|
-
Ok(encoding.get_ids().to_vec())
|
56
|
-
}
|
57
|
-
|
58
|
-
pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
|
59
|
-
self.tokenizer
|
60
|
-
.decode(tokens, skip_special_tokens)
|
61
|
-
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
|
62
|
-
}
|
63
|
-
|
64
|
-
pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
|
65
|
-
self.tokenizer
|
66
|
-
.id_to_token(token)
|
67
|
-
.map(|s| s.to_string())
|
68
|
-
.ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
|
69
|
-
}
|
70
40
|
}
|