red-candle 1.0.0.pre.7 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +399 -18
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +5 -0
- data/ext/candle/src/llm/llama.rs +5 -0
- data/ext/candle/src/llm/mistral.rs +5 -0
- data/ext/candle/src/llm/mod.rs +1 -89
- data/ext/candle/src/llm/quantized_gguf.rs +5 -0
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -33
- data/ext/candle/src/ruby/llm.rs +31 -13
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +128 -5
- data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -0,0 +1,269 @@
|
|
1
|
+
use magnus::{class, function, method, prelude::*, Error, Module, RArray, RHash, RModule, TryConvert};
|
2
|
+
use crate::tokenizer::{TokenizerWrapper as InnerTokenizer, loader::TokenizerLoader};
|
3
|
+
use crate::ruby::Result;
|
4
|
+
|
5
|
+
#[derive(Clone, Debug)]
|
6
|
+
#[magnus::wrap(class = "Candle::Tokenizer", free_immediately, size)]
|
7
|
+
pub struct Tokenizer(pub InnerTokenizer);
|
8
|
+
|
9
|
+
impl Tokenizer {
|
10
|
+
/// Create a new tokenizer from a file path
|
11
|
+
pub fn from_file(path: String) -> Result<Self> {
|
12
|
+
let tokenizer = TokenizerLoader::from_file(&path)
|
13
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
14
|
+
Ok(Self(InnerTokenizer::new(tokenizer)))
|
15
|
+
}
|
16
|
+
|
17
|
+
/// Create a new tokenizer from HuggingFace model ID
|
18
|
+
pub fn from_pretrained(model_id: String) -> Result<Self> {
|
19
|
+
// Use tokio runtime for async operations
|
20
|
+
let rt = tokio::runtime::Runtime::new()
|
21
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
22
|
+
|
23
|
+
let tokenizer = rt.block_on(async {
|
24
|
+
TokenizerLoader::from_hf_hub(&model_id, None).await
|
25
|
+
})
|
26
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
27
|
+
|
28
|
+
Ok(Self(InnerTokenizer::new(tokenizer)))
|
29
|
+
}
|
30
|
+
|
31
|
+
/// Encode text into token IDs
|
32
|
+
pub fn encode(&self, text: String, add_special_tokens: Option<bool>) -> Result<RArray> {
|
33
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
34
|
+
let token_ids = self.0.encode(&text, add_special)
|
35
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
36
|
+
|
37
|
+
Ok(RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))
|
38
|
+
}
|
39
|
+
|
40
|
+
/// Encode text into token strings (words/subwords)
|
41
|
+
pub fn encode_to_tokens(&self, text: String, add_special_tokens: Option<bool>) -> Result<RArray> {
|
42
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
43
|
+
let token_ids = self.0.encode(&text, add_special)
|
44
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
45
|
+
|
46
|
+
let mut tokens = Vec::new();
|
47
|
+
for id in token_ids {
|
48
|
+
let token = self.0.token_to_piece(id)
|
49
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
50
|
+
tokens.push(token);
|
51
|
+
}
|
52
|
+
|
53
|
+
Ok(RArray::from_vec(tokens))
|
54
|
+
}
|
55
|
+
|
56
|
+
/// Encode multiple texts in batch
|
57
|
+
pub fn encode_batch(&self, texts: RArray, add_special_tokens: Option<bool>) -> Result<RArray> {
|
58
|
+
let texts: Vec<String> = texts.to_vec()?;
|
59
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
60
|
+
|
61
|
+
let token_ids_batch = self.0.encode_batch(texts, add_special)
|
62
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
63
|
+
|
64
|
+
let result = RArray::new();
|
65
|
+
for token_ids in token_ids_batch {
|
66
|
+
result.push(RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
67
|
+
}
|
68
|
+
|
69
|
+
Ok(result)
|
70
|
+
}
|
71
|
+
|
72
|
+
/// Encode multiple texts in batch, returning token strings
|
73
|
+
pub fn encode_batch_to_tokens(&self, texts: RArray, add_special_tokens: Option<bool>) -> Result<RArray> {
|
74
|
+
let texts: Vec<String> = texts.to_vec()?;
|
75
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
76
|
+
|
77
|
+
let token_ids_batch = self.0.encode_batch(texts, add_special)
|
78
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
79
|
+
|
80
|
+
let result = RArray::new();
|
81
|
+
for token_ids in token_ids_batch {
|
82
|
+
let mut tokens = Vec::new();
|
83
|
+
for id in token_ids {
|
84
|
+
let token = self.0.token_to_piece(id)
|
85
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
86
|
+
tokens.push(token);
|
87
|
+
}
|
88
|
+
result.push(RArray::from_vec(tokens))?;
|
89
|
+
}
|
90
|
+
|
91
|
+
Ok(result)
|
92
|
+
}
|
93
|
+
|
94
|
+
/// Encode text and return both token IDs and token strings
|
95
|
+
pub fn encode_with_tokens(&self, text: String, add_special_tokens: Option<bool>) -> Result<RHash> {
|
96
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
97
|
+
let token_ids = self.0.encode(&text, add_special)
|
98
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
99
|
+
|
100
|
+
let mut tokens = Vec::new();
|
101
|
+
for &id in &token_ids {
|
102
|
+
let token = self.0.token_to_piece(id)
|
103
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
104
|
+
tokens.push(token);
|
105
|
+
}
|
106
|
+
|
107
|
+
let hash = RHash::new();
|
108
|
+
hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
+
hash.aset("tokens", RArray::from_vec(tokens))?;
|
110
|
+
|
111
|
+
Ok(hash)
|
112
|
+
}
|
113
|
+
|
114
|
+
/// Decode token IDs back to text
|
115
|
+
pub fn decode(&self, token_ids: RArray, skip_special_tokens: Option<bool>) -> Result<String> {
|
116
|
+
let token_ids: Vec<i64> = token_ids.to_vec()?;
|
117
|
+
let token_ids: Vec<u32> = token_ids.into_iter()
|
118
|
+
.map(|id| id as u32)
|
119
|
+
.collect();
|
120
|
+
let skip_special = skip_special_tokens.unwrap_or(true);
|
121
|
+
|
122
|
+
self.0.decode(&token_ids, skip_special)
|
123
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))
|
124
|
+
}
|
125
|
+
|
126
|
+
/// Get the string representation of a single token ID
|
127
|
+
pub fn id_to_token(&self, token_id: i64) -> Result<String> {
|
128
|
+
self.0.token_to_piece(token_id as u32)
|
129
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))
|
130
|
+
}
|
131
|
+
|
132
|
+
/// Get the vocabulary as a hash of token string to ID
|
133
|
+
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> Result<RHash> {
|
134
|
+
let with_added = with_added_tokens.unwrap_or(true);
|
135
|
+
let vocab = self.0.inner().get_vocab(with_added);
|
136
|
+
|
137
|
+
let hash = RHash::new();
|
138
|
+
for (token, id) in vocab {
|
139
|
+
hash.aset(token, id as i64)?;
|
140
|
+
}
|
141
|
+
|
142
|
+
Ok(hash)
|
143
|
+
}
|
144
|
+
|
145
|
+
/// Get vocabulary size
|
146
|
+
pub fn vocab_size(&self, with_added_tokens: Option<bool>) -> usize {
|
147
|
+
let with_added = with_added_tokens.unwrap_or(true);
|
148
|
+
self.0.inner().get_vocab_size(with_added)
|
149
|
+
}
|
150
|
+
|
151
|
+
/// Enable padding - returns a new tokenizer with padding enabled
|
152
|
+
pub fn with_padding(&self, kwargs: RHash) -> Result<Self> {
|
153
|
+
use tokenizers::{PaddingParams, PaddingStrategy, PaddingDirection};
|
154
|
+
|
155
|
+
let mut params = PaddingParams::default();
|
156
|
+
|
157
|
+
// Extract parameters from kwargs
|
158
|
+
if let Some(length) = kwargs.get(magnus::Symbol::new("length")) {
|
159
|
+
if let Ok(len) = usize::try_convert(length) {
|
160
|
+
params.strategy = PaddingStrategy::Fixed(len);
|
161
|
+
}
|
162
|
+
}
|
163
|
+
|
164
|
+
if let Some(max_length) = kwargs.get(magnus::Symbol::new("max_length")) {
|
165
|
+
if let Ok(_) = usize::try_convert(max_length) {
|
166
|
+
params.strategy = PaddingStrategy::BatchLongest;
|
167
|
+
}
|
168
|
+
}
|
169
|
+
|
170
|
+
if let Some(direction) = kwargs.get(magnus::Symbol::new("direction")) {
|
171
|
+
if let Ok(dir) = String::try_convert(direction) {
|
172
|
+
params.direction = match dir.as_str() {
|
173
|
+
"right" => PaddingDirection::Right,
|
174
|
+
"left" => PaddingDirection::Left,
|
175
|
+
_ => PaddingDirection::Right,
|
176
|
+
};
|
177
|
+
}
|
178
|
+
}
|
179
|
+
|
180
|
+
if let Some(pad_id) = kwargs.get(magnus::Symbol::new("pad_id")) {
|
181
|
+
if let Ok(id) = u32::try_convert(pad_id) {
|
182
|
+
params.pad_id = id;
|
183
|
+
}
|
184
|
+
}
|
185
|
+
|
186
|
+
if let Some(pad_token) = kwargs.get(magnus::Symbol::new("pad_token")) {
|
187
|
+
if let Ok(token) = String::try_convert(pad_token) {
|
188
|
+
params.pad_token = token;
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
let mut new_tokenizer = self.0.clone();
|
193
|
+
let _ = new_tokenizer.inner_mut().with_padding(Some(params));
|
194
|
+
Ok(Self(new_tokenizer))
|
195
|
+
}
|
196
|
+
|
197
|
+
/// Enable truncation - returns a new tokenizer with truncation enabled
|
198
|
+
pub fn with_truncation(&self, max_length: usize) -> Result<Self> {
|
199
|
+
use tokenizers::{TruncationParams, TruncationStrategy, TruncationDirection};
|
200
|
+
|
201
|
+
let params = TruncationParams {
|
202
|
+
max_length,
|
203
|
+
strategy: TruncationStrategy::LongestFirst,
|
204
|
+
stride: 0,
|
205
|
+
direction: TruncationDirection::Right,
|
206
|
+
};
|
207
|
+
|
208
|
+
let mut new_tokenizer = self.0.clone();
|
209
|
+
let _ = new_tokenizer.inner_mut().with_truncation(Some(params));
|
210
|
+
Ok(Self(new_tokenizer))
|
211
|
+
}
|
212
|
+
|
213
|
+
/// Get special tokens information
|
214
|
+
pub fn get_special_tokens(&self) -> Result<RHash> {
|
215
|
+
let hash = RHash::new();
|
216
|
+
|
217
|
+
// Common special tokens
|
218
|
+
let special_tokens = vec![
|
219
|
+
("[CLS]", "cls_token"),
|
220
|
+
("[SEP]", "sep_token"),
|
221
|
+
("[PAD]", "pad_token"),
|
222
|
+
("[UNK]", "unk_token"),
|
223
|
+
("[MASK]", "mask_token"),
|
224
|
+
("<s>", "bos_token"),
|
225
|
+
("</s>", "eos_token"),
|
226
|
+
];
|
227
|
+
|
228
|
+
let vocab = self.0.inner().get_vocab(true);
|
229
|
+
|
230
|
+
for (token, name) in special_tokens {
|
231
|
+
if let Some(id) = vocab.get(token) {
|
232
|
+
hash.aset(name, *id as i64)?;
|
233
|
+
}
|
234
|
+
}
|
235
|
+
|
236
|
+
Ok(hash)
|
237
|
+
}
|
238
|
+
|
239
|
+
/// String representation
|
240
|
+
pub fn inspect(&self) -> String {
|
241
|
+
format!("#<Candle::Tokenizer vocab_size={}>", self.vocab_size(Some(true)))
|
242
|
+
}
|
243
|
+
}
|
244
|
+
|
245
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
246
|
+
let tokenizer_class = rb_candle.define_class("Tokenizer", class::object())?;
|
247
|
+
|
248
|
+
// Class methods
|
249
|
+
tokenizer_class.define_singleton_method("from_file", function!(Tokenizer::from_file, 1))?;
|
250
|
+
tokenizer_class.define_singleton_method("from_pretrained", function!(Tokenizer::from_pretrained, 1))?;
|
251
|
+
|
252
|
+
// Instance methods
|
253
|
+
tokenizer_class.define_method("encode", method!(Tokenizer::encode, 2))?;
|
254
|
+
tokenizer_class.define_method("encode_to_tokens", method!(Tokenizer::encode_to_tokens, 2))?;
|
255
|
+
tokenizer_class.define_method("encode_with_tokens", method!(Tokenizer::encode_with_tokens, 2))?;
|
256
|
+
tokenizer_class.define_method("encode_batch", method!(Tokenizer::encode_batch, 2))?;
|
257
|
+
tokenizer_class.define_method("encode_batch_to_tokens", method!(Tokenizer::encode_batch_to_tokens, 2))?;
|
258
|
+
tokenizer_class.define_method("decode", method!(Tokenizer::decode, 2))?;
|
259
|
+
tokenizer_class.define_method("id_to_token", method!(Tokenizer::id_to_token, 1))?;
|
260
|
+
tokenizer_class.define_method("get_vocab", method!(Tokenizer::get_vocab, 1))?;
|
261
|
+
tokenizer_class.define_method("vocab_size", method!(Tokenizer::vocab_size, 1))?;
|
262
|
+
tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
|
263
|
+
tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
|
264
|
+
tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
|
265
|
+
tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
|
266
|
+
tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
|
267
|
+
|
268
|
+
Ok(())
|
269
|
+
}
|
@@ -1,11 +1,10 @@
|
|
1
|
-
use magnus::{function,
|
1
|
+
use magnus::{function, Module, Object};
|
2
2
|
|
3
|
-
use ::candle_core::Tensor;
|
3
|
+
use ::candle_core::Tensor as CoreTensor;
|
4
4
|
|
5
|
-
use crate::ruby::
|
6
|
-
use crate::ruby::{Result as RbResult, Tensor as RbTensor};
|
5
|
+
use crate::ruby::Result;
|
7
6
|
|
8
|
-
pub fn actual_index(t: &
|
7
|
+
pub fn actual_index(t: &CoreTensor, dim: usize, index: i64) -> candle_core::Result<usize> {
|
9
8
|
let dim = t.dim(dim)?;
|
10
9
|
if 0 <= index {
|
11
10
|
let index = index as usize;
|
@@ -21,7 +20,7 @@ pub fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<u
|
|
21
20
|
}
|
22
21
|
}
|
23
22
|
|
24
|
-
pub fn actual_dim(t: &
|
23
|
+
pub fn actual_dim(t: &CoreTensor, dim: i64) -> candle_core::Result<usize> {
|
25
24
|
let rank = t.rank();
|
26
25
|
if 0 <= dim {
|
27
26
|
let dim = dim as usize;
|
@@ -61,7 +60,7 @@ fn get_num_threads() -> usize {
|
|
61
60
|
candle_core::utils::get_num_threads()
|
62
61
|
}
|
63
62
|
|
64
|
-
pub fn candle_utils(rb_candle: magnus::RModule) -> Result<()
|
63
|
+
pub fn candle_utils(rb_candle: magnus::RModule) -> Result<()> {
|
65
64
|
let rb_utils = rb_candle.define_module("Utils")?;
|
66
65
|
rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
|
67
66
|
rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
|
@@ -69,20 +68,3 @@ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
|
|
69
68
|
rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
|
70
69
|
Ok(())
|
71
70
|
}
|
72
|
-
|
73
|
-
/// Applies the Softmax function to a given tensor.#
|
74
|
-
/// &RETURNS&: Tensor
|
75
|
-
#[allow(dead_code)]
|
76
|
-
fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {
|
77
|
-
let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?;
|
78
|
-
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?;
|
79
|
-
Ok(RbTensor(sm))
|
80
|
-
}
|
81
|
-
|
82
|
-
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
83
|
-
/// &RETURNS&: Tensor
|
84
|
-
#[allow(dead_code)]
|
85
|
-
fn silu(tensor: RbTensor) -> RbResult<RbTensor> {
|
86
|
-
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?;
|
87
|
-
Ok(RbTensor(s))
|
88
|
-
}
|
@@ -0,0 +1,108 @@
|
|
1
|
+
use candle_core::Result as CandleResult;
|
2
|
+
use hf_hub::api::tokio::{Api, ApiRepo};
|
3
|
+
use tokenizers::Tokenizer;
|
4
|
+
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
|
5
|
+
use std::path::PathBuf;
|
6
|
+
|
7
|
+
/// Standard padding configuration for models
|
8
|
+
pub fn standard_padding_params() -> PaddingParams {
|
9
|
+
PaddingParams {
|
10
|
+
strategy: PaddingStrategy::BatchLongest,
|
11
|
+
direction: tokenizers::PaddingDirection::Left,
|
12
|
+
pad_to_multiple_of: None,
|
13
|
+
pad_id: 0,
|
14
|
+
pad_type_id: 0,
|
15
|
+
pad_token: "[PAD]".to_string(),
|
16
|
+
}
|
17
|
+
}
|
18
|
+
|
19
|
+
/// Unified tokenizer loader with common download logic
|
20
|
+
pub struct TokenizerLoader;
|
21
|
+
|
22
|
+
impl TokenizerLoader {
|
23
|
+
/// Load tokenizer from a local file path
|
24
|
+
pub fn from_file(path: &str) -> CandleResult<Tokenizer> {
|
25
|
+
Tokenizer::from_file(path)
|
26
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer from file: {}", e)))
|
27
|
+
}
|
28
|
+
|
29
|
+
/// Download and load tokenizer from HuggingFace
|
30
|
+
pub async fn from_hf_hub(repo_id: &str, filename: Option<&str>) -> CandleResult<Tokenizer> {
|
31
|
+
let api = Api::new()
|
32
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
33
|
+
|
34
|
+
let repo = api.model(repo_id.to_string());
|
35
|
+
let tokenizer_path = Self::download_tokenizer_file(&repo, filename).await?;
|
36
|
+
|
37
|
+
Self::from_file(tokenizer_path.to_str()
|
38
|
+
.ok_or_else(|| candle_core::Error::Msg("Invalid tokenizer path".to_string()))?)
|
39
|
+
}
|
40
|
+
|
41
|
+
/// Download tokenizer file from repository
|
42
|
+
async fn download_tokenizer_file(repo: &ApiRepo, filename: Option<&str>) -> CandleResult<PathBuf> {
|
43
|
+
if let Some(file) = filename {
|
44
|
+
// Try specific filename
|
45
|
+
repo.get(file).await
|
46
|
+
.map_err(|e| candle_core::Error::Msg(
|
47
|
+
format!("Failed to download tokenizer file '{}': {}", file, e)
|
48
|
+
))
|
49
|
+
} else {
|
50
|
+
// Try common tokenizer filenames in order
|
51
|
+
let filenames = ["tokenizer.json", "tokenizer.model"];
|
52
|
+
|
53
|
+
for file in filenames {
|
54
|
+
if let Ok(path) = repo.get(file).await {
|
55
|
+
return Ok(path);
|
56
|
+
}
|
57
|
+
}
|
58
|
+
|
59
|
+
Err(candle_core::Error::Msg(
|
60
|
+
"No tokenizer file found. Tried: tokenizer.json, tokenizer.model".to_string()
|
61
|
+
))
|
62
|
+
}
|
63
|
+
}
|
64
|
+
|
65
|
+
/// Configure tokenizer with standard padding for batch processing
|
66
|
+
pub fn with_padding(mut tokenizer: Tokenizer, padding_params: Option<PaddingParams>) -> Tokenizer {
|
67
|
+
let params = padding_params.unwrap_or_else(standard_padding_params);
|
68
|
+
let _ = tokenizer.with_padding(Some(params));
|
69
|
+
tokenizer
|
70
|
+
}
|
71
|
+
|
72
|
+
/// Configure tokenizer with truncation
|
73
|
+
pub fn with_truncation(mut tokenizer: Tokenizer, max_length: usize) -> Tokenizer {
|
74
|
+
let _ = tokenizer.with_truncation(Some(TruncationParams {
|
75
|
+
max_length,
|
76
|
+
strategy: tokenizers::TruncationStrategy::LongestFirst,
|
77
|
+
stride: 0,
|
78
|
+
direction: tokenizers::TruncationDirection::Right,
|
79
|
+
}));
|
80
|
+
tokenizer
|
81
|
+
}
|
82
|
+
|
83
|
+
/// Download tokenizer from a specific source (for GGUF models)
|
84
|
+
pub async fn from_source(api: &Api, source: &str) -> CandleResult<PathBuf> {
|
85
|
+
// Check if it's a local file path
|
86
|
+
if source.ends_with(".json") && std::path::Path::new(source).exists() {
|
87
|
+
return Ok(PathBuf::from(source));
|
88
|
+
}
|
89
|
+
|
90
|
+
// Otherwise treat it as a HuggingFace repo
|
91
|
+
let repo = api.model(source.to_string());
|
92
|
+
|
93
|
+
// Try tokenizer.json first
|
94
|
+
if let Ok(path) = repo.get("tokenizer.json").await {
|
95
|
+
return Ok(path);
|
96
|
+
}
|
97
|
+
|
98
|
+
// Try tokenizer.model (for models that use sentencepiece)
|
99
|
+
if let Ok(path) = repo.get("tokenizer.model").await {
|
100
|
+
return Ok(path);
|
101
|
+
}
|
102
|
+
|
103
|
+
Err(candle_core::Error::Msg(format!(
|
104
|
+
"Failed to find tokenizer in specified source: {}. Please check network connectivity and that the model repository exists.",
|
105
|
+
source
|
106
|
+
)))
|
107
|
+
}
|
108
|
+
}
|
@@ -0,0 +1,103 @@
|
|
1
|
+
use candle_core::Result as CandleResult;
|
2
|
+
use tokenizers::Tokenizer;
|
3
|
+
|
4
|
+
pub mod loader;
|
5
|
+
|
6
|
+
/// Common structure for managing tokenizer
|
7
|
+
#[derive(Debug, Clone)]
|
8
|
+
pub struct TokenizerWrapper {
|
9
|
+
tokenizer: Tokenizer,
|
10
|
+
}
|
11
|
+
|
12
|
+
impl TokenizerWrapper {
|
13
|
+
pub fn new(tokenizer: Tokenizer) -> Self {
|
14
|
+
Self { tokenizer }
|
15
|
+
}
|
16
|
+
|
17
|
+
pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
|
18
|
+
let encoding = self.tokenizer
|
19
|
+
.encode(text, add_special_tokens)
|
20
|
+
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
|
21
|
+
Ok(encoding.get_ids().to_vec())
|
22
|
+
}
|
23
|
+
|
24
|
+
pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
|
25
|
+
self.tokenizer
|
26
|
+
.decode(tokens, skip_special_tokens)
|
27
|
+
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
|
28
|
+
}
|
29
|
+
|
30
|
+
pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
|
31
|
+
self.tokenizer
|
32
|
+
.id_to_token(token)
|
33
|
+
.map(|s| s.to_string())
|
34
|
+
.ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
|
35
|
+
}
|
36
|
+
|
37
|
+
/// Decode a single token for streaming output
|
38
|
+
pub fn decode_token(&self, token: u32) -> CandleResult<String> {
|
39
|
+
// Decode the single token properly
|
40
|
+
self.decode(&[token], true)
|
41
|
+
}
|
42
|
+
|
43
|
+
/// Decode tokens incrementally for streaming
|
44
|
+
/// This is more efficient than decoding single tokens
|
45
|
+
pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
|
46
|
+
if new_tokens_start >= all_tokens.len() {
|
47
|
+
return Ok(String::new());
|
48
|
+
}
|
49
|
+
|
50
|
+
// Decode all tokens up to this point
|
51
|
+
let full_text = self.decode(all_tokens, true)?;
|
52
|
+
|
53
|
+
// If we're at the start, return everything
|
54
|
+
if new_tokens_start == 0 {
|
55
|
+
return Ok(full_text);
|
56
|
+
}
|
57
|
+
|
58
|
+
// Otherwise, decode up to the previous token and return the difference
|
59
|
+
let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
|
60
|
+
|
61
|
+
// Find the common prefix between the two strings to handle cases where
|
62
|
+
// the tokenizer might produce slightly different text when decoding
|
63
|
+
// different token sequences
|
64
|
+
let common_len = full_text
|
65
|
+
.chars()
|
66
|
+
.zip(previous_text.chars())
|
67
|
+
.take_while(|(a, b)| a == b)
|
68
|
+
.count();
|
69
|
+
|
70
|
+
Ok(full_text.chars().skip(common_len).collect())
|
71
|
+
}
|
72
|
+
|
73
|
+
/// Format tokens with debug information
|
74
|
+
pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
|
75
|
+
let mut result = String::new();
|
76
|
+
for &token in tokens {
|
77
|
+
let piece = self.token_to_piece(token)?;
|
78
|
+
result.push_str(&format!("[{}:{}]", token, piece));
|
79
|
+
}
|
80
|
+
Ok(result)
|
81
|
+
}
|
82
|
+
|
83
|
+
/// Encode a batch of texts (needed for reranker)
|
84
|
+
pub fn encode_batch(&self, texts: Vec<String>, add_special_tokens: bool) -> CandleResult<Vec<Vec<u32>>> {
|
85
|
+
let encodings = self.tokenizer
|
86
|
+
.encode_batch(texts, add_special_tokens)
|
87
|
+
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer batch error: {}", e)))?;
|
88
|
+
|
89
|
+
Ok(encodings.into_iter()
|
90
|
+
.map(|encoding| encoding.get_ids().to_vec())
|
91
|
+
.collect())
|
92
|
+
}
|
93
|
+
|
94
|
+
/// Get the underlying tokenizer (for advanced use cases)
|
95
|
+
pub fn inner(&self) -> &Tokenizer {
|
96
|
+
&self.tokenizer
|
97
|
+
}
|
98
|
+
|
99
|
+
/// Get a mutable reference to the underlying tokenizer (for configuration)
|
100
|
+
pub fn inner_mut(&mut self) -> &mut Tokenizer {
|
101
|
+
&mut self.tokenizer
|
102
|
+
}
|
103
|
+
}
|
@@ -0,0 +1 @@
|
|
1
|
+
aarch64-apple-darwin
|