red-candle 1.0.0.pre.7 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +322 -4
  4. data/ext/candle/src/lib.rs +6 -3
  5. data/ext/candle/src/llm/gemma.rs +5 -0
  6. data/ext/candle/src/llm/llama.rs +5 -0
  7. data/ext/candle/src/llm/mistral.rs +5 -0
  8. data/ext/candle/src/llm/mod.rs +1 -89
  9. data/ext/candle/src/llm/quantized_gguf.rs +5 -0
  10. data/ext/candle/src/ner.rs +423 -0
  11. data/ext/candle/src/reranker.rs +24 -21
  12. data/ext/candle/src/ruby/device.rs +6 -6
  13. data/ext/candle/src/ruby/dtype.rs +4 -4
  14. data/ext/candle/src/ruby/embedding_model.rs +36 -33
  15. data/ext/candle/src/ruby/llm.rs +31 -13
  16. data/ext/candle/src/ruby/mod.rs +1 -2
  17. data/ext/candle/src/ruby/tensor.rs +66 -66
  18. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  19. data/ext/candle/src/ruby/utils.rs +6 -24
  20. data/ext/candle/src/tokenizer/loader.rs +108 -0
  21. data/ext/candle/src/tokenizer/mod.rs +103 -0
  22. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  23. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  24. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  25. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  26. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  27. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  28. data/lib/candle/build_info.rb +2 -0
  29. data/lib/candle/device_utils.rb +2 -0
  30. data/lib/candle/ner.rb +345 -0
  31. data/lib/candle/reranker.rb +1 -1
  32. data/lib/candle/tensor.rb +2 -0
  33. data/lib/candle/tokenizer.rb +139 -0
  34. data/lib/candle/version.rb +4 -2
  35. data/lib/candle.rb +2 -0
  36. metadata +126 -3
  37. data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -0,0 +1,269 @@
1
+ use magnus::{class, function, method, prelude::*, Error, Module, RArray, RHash, RModule, TryConvert};
2
+ use crate::tokenizer::{TokenizerWrapper as InnerTokenizer, loader::TokenizerLoader};
3
+ use crate::ruby::Result;
4
+
5
+ #[derive(Clone, Debug)]
6
+ #[magnus::wrap(class = "Candle::Tokenizer", free_immediately, size)]
7
+ pub struct Tokenizer(pub InnerTokenizer);
8
+
9
+ impl Tokenizer {
10
+ /// Create a new tokenizer from a file path
11
+ pub fn from_file(path: String) -> Result<Self> {
12
+ let tokenizer = TokenizerLoader::from_file(&path)
13
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
14
+ Ok(Self(InnerTokenizer::new(tokenizer)))
15
+ }
16
+
17
+ /// Create a new tokenizer from HuggingFace model ID
18
+ pub fn from_pretrained(model_id: String) -> Result<Self> {
19
+ // Use tokio runtime for async operations
20
+ let rt = tokio::runtime::Runtime::new()
21
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
22
+
23
+ let tokenizer = rt.block_on(async {
24
+ TokenizerLoader::from_hf_hub(&model_id, None).await
25
+ })
26
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
27
+
28
+ Ok(Self(InnerTokenizer::new(tokenizer)))
29
+ }
30
+
31
+ /// Encode text into token IDs
32
+ pub fn encode(&self, text: String, add_special_tokens: Option<bool>) -> Result<RArray> {
33
+ let add_special = add_special_tokens.unwrap_or(true);
34
+ let token_ids = self.0.encode(&text, add_special)
35
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
36
+
37
+ Ok(RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))
38
+ }
39
+
40
+ /// Encode text into token strings (words/subwords)
41
+ pub fn encode_to_tokens(&self, text: String, add_special_tokens: Option<bool>) -> Result<RArray> {
42
+ let add_special = add_special_tokens.unwrap_or(true);
43
+ let token_ids = self.0.encode(&text, add_special)
44
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
45
+
46
+ let mut tokens = Vec::new();
47
+ for id in token_ids {
48
+ let token = self.0.token_to_piece(id)
49
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
50
+ tokens.push(token);
51
+ }
52
+
53
+ Ok(RArray::from_vec(tokens))
54
+ }
55
+
56
+ /// Encode multiple texts in batch
57
+ pub fn encode_batch(&self, texts: RArray, add_special_tokens: Option<bool>) -> Result<RArray> {
58
+ let texts: Vec<String> = texts.to_vec()?;
59
+ let add_special = add_special_tokens.unwrap_or(true);
60
+
61
+ let token_ids_batch = self.0.encode_batch(texts, add_special)
62
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
63
+
64
+ let result = RArray::new();
65
+ for token_ids in token_ids_batch {
66
+ result.push(RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
67
+ }
68
+
69
+ Ok(result)
70
+ }
71
+
72
+ /// Encode multiple texts in batch, returning token strings
73
+ pub fn encode_batch_to_tokens(&self, texts: RArray, add_special_tokens: Option<bool>) -> Result<RArray> {
74
+ let texts: Vec<String> = texts.to_vec()?;
75
+ let add_special = add_special_tokens.unwrap_or(true);
76
+
77
+ let token_ids_batch = self.0.encode_batch(texts, add_special)
78
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
79
+
80
+ let result = RArray::new();
81
+ for token_ids in token_ids_batch {
82
+ let mut tokens = Vec::new();
83
+ for id in token_ids {
84
+ let token = self.0.token_to_piece(id)
85
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
86
+ tokens.push(token);
87
+ }
88
+ result.push(RArray::from_vec(tokens))?;
89
+ }
90
+
91
+ Ok(result)
92
+ }
93
+
94
+ /// Encode text and return both token IDs and token strings
95
+ pub fn encode_with_tokens(&self, text: String, add_special_tokens: Option<bool>) -> Result<RHash> {
96
+ let add_special = add_special_tokens.unwrap_or(true);
97
+ let token_ids = self.0.encode(&text, add_special)
98
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
99
+
100
+ let mut tokens = Vec::new();
101
+ for &id in &token_ids {
102
+ let token = self.0.token_to_piece(id)
103
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
104
+ tokens.push(token);
105
+ }
106
+
107
+ let hash = RHash::new();
108
+ hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
+ hash.aset("tokens", RArray::from_vec(tokens))?;
110
+
111
+ Ok(hash)
112
+ }
113
+
114
+ /// Decode token IDs back to text
115
+ pub fn decode(&self, token_ids: RArray, skip_special_tokens: Option<bool>) -> Result<String> {
116
+ let token_ids: Vec<i64> = token_ids.to_vec()?;
117
+ let token_ids: Vec<u32> = token_ids.into_iter()
118
+ .map(|id| id as u32)
119
+ .collect();
120
+ let skip_special = skip_special_tokens.unwrap_or(true);
121
+
122
+ self.0.decode(&token_ids, skip_special)
123
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))
124
+ }
125
+
126
+ /// Get the string representation of a single token ID
127
+ pub fn id_to_token(&self, token_id: i64) -> Result<String> {
128
+ self.0.token_to_piece(token_id as u32)
129
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))
130
+ }
131
+
132
+ /// Get the vocabulary as a hash of token string to ID
133
+ pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> Result<RHash> {
134
+ let with_added = with_added_tokens.unwrap_or(true);
135
+ let vocab = self.0.inner().get_vocab(with_added);
136
+
137
+ let hash = RHash::new();
138
+ for (token, id) in vocab {
139
+ hash.aset(token, id as i64)?;
140
+ }
141
+
142
+ Ok(hash)
143
+ }
144
+
145
+ /// Get vocabulary size
146
+ pub fn vocab_size(&self, with_added_tokens: Option<bool>) -> usize {
147
+ let with_added = with_added_tokens.unwrap_or(true);
148
+ self.0.inner().get_vocab_size(with_added)
149
+ }
150
+
151
+ /// Enable padding - returns a new tokenizer with padding enabled
152
+ pub fn with_padding(&self, kwargs: RHash) -> Result<Self> {
153
+ use tokenizers::{PaddingParams, PaddingStrategy, PaddingDirection};
154
+
155
+ let mut params = PaddingParams::default();
156
+
157
+ // Extract parameters from kwargs
158
+ if let Some(length) = kwargs.get(magnus::Symbol::new("length")) {
159
+ if let Ok(len) = usize::try_convert(length) {
160
+ params.strategy = PaddingStrategy::Fixed(len);
161
+ }
162
+ }
163
+
164
+ if let Some(max_length) = kwargs.get(magnus::Symbol::new("max_length")) {
165
+ if let Ok(_) = usize::try_convert(max_length) {
166
+ params.strategy = PaddingStrategy::BatchLongest;
167
+ }
168
+ }
169
+
170
+ if let Some(direction) = kwargs.get(magnus::Symbol::new("direction")) {
171
+ if let Ok(dir) = String::try_convert(direction) {
172
+ params.direction = match dir.as_str() {
173
+ "right" => PaddingDirection::Right,
174
+ "left" => PaddingDirection::Left,
175
+ _ => PaddingDirection::Right,
176
+ };
177
+ }
178
+ }
179
+
180
+ if let Some(pad_id) = kwargs.get(magnus::Symbol::new("pad_id")) {
181
+ if let Ok(id) = u32::try_convert(pad_id) {
182
+ params.pad_id = id;
183
+ }
184
+ }
185
+
186
+ if let Some(pad_token) = kwargs.get(magnus::Symbol::new("pad_token")) {
187
+ if let Ok(token) = String::try_convert(pad_token) {
188
+ params.pad_token = token;
189
+ }
190
+ }
191
+
192
+ let mut new_tokenizer = self.0.clone();
193
+ let _ = new_tokenizer.inner_mut().with_padding(Some(params));
194
+ Ok(Self(new_tokenizer))
195
+ }
196
+
197
+ /// Enable truncation - returns a new tokenizer with truncation enabled
198
+ pub fn with_truncation(&self, max_length: usize) -> Result<Self> {
199
+ use tokenizers::{TruncationParams, TruncationStrategy, TruncationDirection};
200
+
201
+ let params = TruncationParams {
202
+ max_length,
203
+ strategy: TruncationStrategy::LongestFirst,
204
+ stride: 0,
205
+ direction: TruncationDirection::Right,
206
+ };
207
+
208
+ let mut new_tokenizer = self.0.clone();
209
+ let _ = new_tokenizer.inner_mut().with_truncation(Some(params));
210
+ Ok(Self(new_tokenizer))
211
+ }
212
+
213
+ /// Get special tokens information
214
+ pub fn get_special_tokens(&self) -> Result<RHash> {
215
+ let hash = RHash::new();
216
+
217
+ // Common special tokens
218
+ let special_tokens = vec![
219
+ ("[CLS]", "cls_token"),
220
+ ("[SEP]", "sep_token"),
221
+ ("[PAD]", "pad_token"),
222
+ ("[UNK]", "unk_token"),
223
+ ("[MASK]", "mask_token"),
224
+ ("<s>", "bos_token"),
225
+ ("</s>", "eos_token"),
226
+ ];
227
+
228
+ let vocab = self.0.inner().get_vocab(true);
229
+
230
+ for (token, name) in special_tokens {
231
+ if let Some(id) = vocab.get(token) {
232
+ hash.aset(name, *id as i64)?;
233
+ }
234
+ }
235
+
236
+ Ok(hash)
237
+ }
238
+
239
+ /// String representation
240
+ pub fn inspect(&self) -> String {
241
+ format!("#<Candle::Tokenizer vocab_size={}>", self.vocab_size(Some(true)))
242
+ }
243
+ }
244
+
245
+ pub fn init(rb_candle: RModule) -> Result<()> {
246
+ let tokenizer_class = rb_candle.define_class("Tokenizer", class::object())?;
247
+
248
+ // Class methods
249
+ tokenizer_class.define_singleton_method("from_file", function!(Tokenizer::from_file, 1))?;
250
+ tokenizer_class.define_singleton_method("from_pretrained", function!(Tokenizer::from_pretrained, 1))?;
251
+
252
+ // Instance methods
253
+ tokenizer_class.define_method("encode", method!(Tokenizer::encode, 2))?;
254
+ tokenizer_class.define_method("encode_to_tokens", method!(Tokenizer::encode_to_tokens, 2))?;
255
+ tokenizer_class.define_method("encode_with_tokens", method!(Tokenizer::encode_with_tokens, 2))?;
256
+ tokenizer_class.define_method("encode_batch", method!(Tokenizer::encode_batch, 2))?;
257
+ tokenizer_class.define_method("encode_batch_to_tokens", method!(Tokenizer::encode_batch_to_tokens, 2))?;
258
+ tokenizer_class.define_method("decode", method!(Tokenizer::decode, 2))?;
259
+ tokenizer_class.define_method("id_to_token", method!(Tokenizer::id_to_token, 1))?;
260
+ tokenizer_class.define_method("get_vocab", method!(Tokenizer::get_vocab, 1))?;
261
+ tokenizer_class.define_method("vocab_size", method!(Tokenizer::vocab_size, 1))?;
262
+ tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
263
+ tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
264
+ tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
265
+ tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
266
+ tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
267
+
268
+ Ok(())
269
+ }
@@ -1,11 +1,10 @@
1
- use magnus::{function, Error, Module, Object};
1
+ use magnus::{function, Module, Object};
2
2
 
3
- use ::candle_core::Tensor;
3
+ use ::candle_core::Tensor as CoreTensor;
4
4
 
5
- use crate::ruby::errors::wrap_candle_err;
6
- use crate::ruby::{Result as RbResult, Tensor as RbTensor};
5
+ use crate::ruby::Result;
7
6
 
8
- pub fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<usize> {
7
+ pub fn actual_index(t: &CoreTensor, dim: usize, index: i64) -> candle_core::Result<usize> {
9
8
  let dim = t.dim(dim)?;
10
9
  if 0 <= index {
11
10
  let index = index as usize;
@@ -21,7 +20,7 @@ pub fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<u
21
20
  }
22
21
  }
23
22
 
24
- pub fn actual_dim(t: &Tensor, dim: i64) -> candle_core::Result<usize> {
23
+ pub fn actual_dim(t: &CoreTensor, dim: i64) -> candle_core::Result<usize> {
25
24
  let rank = t.rank();
26
25
  if 0 <= dim {
27
26
  let dim = dim as usize;
@@ -61,7 +60,7 @@ fn get_num_threads() -> usize {
61
60
  candle_core::utils::get_num_threads()
62
61
  }
63
62
 
64
- pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
63
+ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<()> {
65
64
  let rb_utils = rb_candle.define_module("Utils")?;
66
65
  rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
67
66
  rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
@@ -69,20 +68,3 @@ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
69
68
  rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
70
69
  Ok(())
71
70
  }
72
-
73
- /// Applies the Softmax function to a given tensor.#
74
- /// &RETURNS&: Tensor
75
- #[allow(dead_code)]
76
- fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {
77
- let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?;
78
- let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?;
79
- Ok(RbTensor(sm))
80
- }
81
-
82
- /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
83
- /// &RETURNS&: Tensor
84
- #[allow(dead_code)]
85
- fn silu(tensor: RbTensor) -> RbResult<RbTensor> {
86
- let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?;
87
- Ok(RbTensor(s))
88
- }
@@ -0,0 +1,108 @@
1
+ use candle_core::Result as CandleResult;
2
+ use hf_hub::api::tokio::{Api, ApiRepo};
3
+ use tokenizers::Tokenizer;
4
+ use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
5
+ use std::path::PathBuf;
6
+
7
+ /// Standard padding configuration for models
8
+ pub fn standard_padding_params() -> PaddingParams {
9
+ PaddingParams {
10
+ strategy: PaddingStrategy::BatchLongest,
11
+ direction: tokenizers::PaddingDirection::Left,
12
+ pad_to_multiple_of: None,
13
+ pad_id: 0,
14
+ pad_type_id: 0,
15
+ pad_token: "[PAD]".to_string(),
16
+ }
17
+ }
18
+
19
+ /// Unified tokenizer loader with common download logic
20
+ pub struct TokenizerLoader;
21
+
22
+ impl TokenizerLoader {
23
+ /// Load tokenizer from a local file path
24
+ pub fn from_file(path: &str) -> CandleResult<Tokenizer> {
25
+ Tokenizer::from_file(path)
26
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer from file: {}", e)))
27
+ }
28
+
29
+ /// Download and load tokenizer from HuggingFace
30
+ pub async fn from_hf_hub(repo_id: &str, filename: Option<&str>) -> CandleResult<Tokenizer> {
31
+ let api = Api::new()
32
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
33
+
34
+ let repo = api.model(repo_id.to_string());
35
+ let tokenizer_path = Self::download_tokenizer_file(&repo, filename).await?;
36
+
37
+ Self::from_file(tokenizer_path.to_str()
38
+ .ok_or_else(|| candle_core::Error::Msg("Invalid tokenizer path".to_string()))?)
39
+ }
40
+
41
+ /// Download tokenizer file from repository
42
+ async fn download_tokenizer_file(repo: &ApiRepo, filename: Option<&str>) -> CandleResult<PathBuf> {
43
+ if let Some(file) = filename {
44
+ // Try specific filename
45
+ repo.get(file).await
46
+ .map_err(|e| candle_core::Error::Msg(
47
+ format!("Failed to download tokenizer file '{}': {}", file, e)
48
+ ))
49
+ } else {
50
+ // Try common tokenizer filenames in order
51
+ let filenames = ["tokenizer.json", "tokenizer.model"];
52
+
53
+ for file in filenames {
54
+ if let Ok(path) = repo.get(file).await {
55
+ return Ok(path);
56
+ }
57
+ }
58
+
59
+ Err(candle_core::Error::Msg(
60
+ "No tokenizer file found. Tried: tokenizer.json, tokenizer.model".to_string()
61
+ ))
62
+ }
63
+ }
64
+
65
+ /// Configure tokenizer with standard padding for batch processing
66
+ pub fn with_padding(mut tokenizer: Tokenizer, padding_params: Option<PaddingParams>) -> Tokenizer {
67
+ let params = padding_params.unwrap_or_else(standard_padding_params);
68
+ let _ = tokenizer.with_padding(Some(params));
69
+ tokenizer
70
+ }
71
+
72
+ /// Configure tokenizer with truncation
73
+ pub fn with_truncation(mut tokenizer: Tokenizer, max_length: usize) -> Tokenizer {
74
+ let _ = tokenizer.with_truncation(Some(TruncationParams {
75
+ max_length,
76
+ strategy: tokenizers::TruncationStrategy::LongestFirst,
77
+ stride: 0,
78
+ direction: tokenizers::TruncationDirection::Right,
79
+ }));
80
+ tokenizer
81
+ }
82
+
83
+ /// Download tokenizer from a specific source (for GGUF models)
84
+ pub async fn from_source(api: &Api, source: &str) -> CandleResult<PathBuf> {
85
+ // Check if it's a local file path
86
+ if source.ends_with(".json") && std::path::Path::new(source).exists() {
87
+ return Ok(PathBuf::from(source));
88
+ }
89
+
90
+ // Otherwise treat it as a HuggingFace repo
91
+ let repo = api.model(source.to_string());
92
+
93
+ // Try tokenizer.json first
94
+ if let Ok(path) = repo.get("tokenizer.json").await {
95
+ return Ok(path);
96
+ }
97
+
98
+ // Try tokenizer.model (for models that use sentencepiece)
99
+ if let Ok(path) = repo.get("tokenizer.model").await {
100
+ return Ok(path);
101
+ }
102
+
103
+ Err(candle_core::Error::Msg(format!(
104
+ "Failed to find tokenizer in specified source: {}. Please check network connectivity and that the model repository exists.",
105
+ source
106
+ )))
107
+ }
108
+ }
@@ -0,0 +1,103 @@
1
+ use candle_core::Result as CandleResult;
2
+ use tokenizers::Tokenizer;
3
+
4
+ pub mod loader;
5
+
6
+ /// Common structure for managing tokenizer
7
+ #[derive(Debug, Clone)]
8
+ pub struct TokenizerWrapper {
9
+ tokenizer: Tokenizer,
10
+ }
11
+
12
+ impl TokenizerWrapper {
13
+ pub fn new(tokenizer: Tokenizer) -> Self {
14
+ Self { tokenizer }
15
+ }
16
+
17
+ pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
18
+ let encoding = self.tokenizer
19
+ .encode(text, add_special_tokens)
20
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
21
+ Ok(encoding.get_ids().to_vec())
22
+ }
23
+
24
+ pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
25
+ self.tokenizer
26
+ .decode(tokens, skip_special_tokens)
27
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
28
+ }
29
+
30
+ pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
31
+ self.tokenizer
32
+ .id_to_token(token)
33
+ .map(|s| s.to_string())
34
+ .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
35
+ }
36
+
37
+ /// Decode a single token for streaming output
38
+ pub fn decode_token(&self, token: u32) -> CandleResult<String> {
39
+ // Decode the single token properly
40
+ self.decode(&[token], true)
41
+ }
42
+
43
+ /// Decode tokens incrementally for streaming
44
+ /// This is more efficient than decoding single tokens
45
+ pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
46
+ if new_tokens_start >= all_tokens.len() {
47
+ return Ok(String::new());
48
+ }
49
+
50
+ // Decode all tokens up to this point
51
+ let full_text = self.decode(all_tokens, true)?;
52
+
53
+ // If we're at the start, return everything
54
+ if new_tokens_start == 0 {
55
+ return Ok(full_text);
56
+ }
57
+
58
+ // Otherwise, decode up to the previous token and return the difference
59
+ let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
60
+
61
+ // Find the common prefix between the two strings to handle cases where
62
+ // the tokenizer might produce slightly different text when decoding
63
+ // different token sequences
64
+ let common_len = full_text
65
+ .chars()
66
+ .zip(previous_text.chars())
67
+ .take_while(|(a, b)| a == b)
68
+ .count();
69
+
70
+ Ok(full_text.chars().skip(common_len).collect())
71
+ }
72
+
73
+ /// Format tokens with debug information
74
+ pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
75
+ let mut result = String::new();
76
+ for &token in tokens {
77
+ let piece = self.token_to_piece(token)?;
78
+ result.push_str(&format!("[{}:{}]", token, piece));
79
+ }
80
+ Ok(result)
81
+ }
82
+
83
+ /// Encode a batch of texts (needed for reranker)
84
+ pub fn encode_batch(&self, texts: Vec<String>, add_special_tokens: bool) -> CandleResult<Vec<Vec<u32>>> {
85
+ let encodings = self.tokenizer
86
+ .encode_batch(texts, add_special_tokens)
87
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer batch error: {}", e)))?;
88
+
89
+ Ok(encodings.into_iter()
90
+ .map(|encoding| encoding.get_ids().to_vec())
91
+ .collect())
92
+ }
93
+
94
+ /// Get the underlying tokenizer (for advanced use cases)
95
+ pub fn inner(&self) -> &Tokenizer {
96
+ &self.tokenizer
97
+ }
98
+
99
+ /// Get a mutable reference to the underlying tokenizer (for configuration)
100
+ pub fn inner_mut(&mut self) -> &mut Tokenizer {
101
+ &mut self.tokenizer
102
+ }
103
+ }