red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert};
|
|
2
|
+
use crate::tokenizer::{TokenizerWrapper as InnerTokenizer, loader::TokenizerLoader};
|
|
3
|
+
use crate::ruby::Result;
|
|
4
|
+
|
|
5
|
+
#[derive(Clone, Debug)]
|
|
6
|
+
#[magnus::wrap(class = "Candle::Tokenizer", free_immediately, size)]
|
|
7
|
+
pub struct Tokenizer(pub InnerTokenizer);
|
|
8
|
+
|
|
9
|
+
impl Tokenizer {
|
|
10
|
+
/// Create a new tokenizer from a file path
|
|
11
|
+
pub fn from_file(path: String) -> Result<Self> {
|
|
12
|
+
let ruby = Ruby::get().unwrap();
|
|
13
|
+
let tokenizer = TokenizerLoader::from_file(&path)
|
|
14
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?;
|
|
15
|
+
Ok(Self(InnerTokenizer::new(tokenizer)))
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
/// Create a new tokenizer from HuggingFace model ID
|
|
19
|
+
pub fn from_pretrained(model_id: String) -> Result<Self> {
|
|
20
|
+
let ruby = Ruby::get().unwrap();
|
|
21
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
22
|
+
// Use tokio runtime for async operations
|
|
23
|
+
let rt = tokio::runtime::Runtime::new()
|
|
24
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create runtime: {}", e)))?;
|
|
25
|
+
|
|
26
|
+
let tokenizer = rt.block_on(async {
|
|
27
|
+
TokenizerLoader::from_hf_hub(&model_id, None).await
|
|
28
|
+
})
|
|
29
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
30
|
+
|
|
31
|
+
Ok(Self(InnerTokenizer::new(tokenizer)))
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
/// Encode text into token IDs
|
|
35
|
+
pub fn encode(&self, text: String, add_special_tokens: Option<bool>) -> Result<RArray> {
|
|
36
|
+
let ruby = Ruby::get().unwrap();
|
|
37
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
|
38
|
+
let token_ids = self.0.encode(&text, add_special)
|
|
39
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?;
|
|
40
|
+
|
|
41
|
+
Ok(ruby.ary_from_vec(token_ids.into_iter().map(|id| id as i64).collect()))
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/// Encode text into token strings (words/subwords)
|
|
45
|
+
pub fn encode_to_tokens(&self, text: String, add_special_tokens: Option<bool>) -> Result<RArray> {
|
|
46
|
+
let ruby = Ruby::get().unwrap();
|
|
47
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
48
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
|
49
|
+
let token_ids = self.0.encode(&text, add_special)
|
|
50
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
51
|
+
|
|
52
|
+
let mut tokens = Vec::new();
|
|
53
|
+
for id in token_ids {
|
|
54
|
+
let token = self.0.token_to_piece(id)
|
|
55
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
56
|
+
tokens.push(token);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
Ok(ruby.ary_from_vec(tokens))
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
/// Encode multiple texts in batch
|
|
63
|
+
pub fn encode_batch(&self, texts: RArray, add_special_tokens: Option<bool>) -> Result<RArray> {
|
|
64
|
+
let ruby = Ruby::get().unwrap();
|
|
65
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
66
|
+
let texts: Vec<String> = texts.to_vec()?;
|
|
67
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
|
68
|
+
|
|
69
|
+
let token_ids_batch = self.0.encode_batch(texts, add_special)
|
|
70
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
71
|
+
|
|
72
|
+
let result = ruby.ary_new();
|
|
73
|
+
for token_ids in token_ids_batch {
|
|
74
|
+
result.push(ruby.ary_from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
Ok(result)
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/// Encode multiple texts in batch, returning token strings
|
|
81
|
+
pub fn encode_batch_to_tokens(&self, texts: RArray, add_special_tokens: Option<bool>) -> Result<RArray> {
|
|
82
|
+
let ruby = Ruby::get().unwrap();
|
|
83
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
84
|
+
let texts: Vec<String> = texts.to_vec()?;
|
|
85
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
|
86
|
+
|
|
87
|
+
let token_ids_batch = self.0.encode_batch(texts, add_special)
|
|
88
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
89
|
+
|
|
90
|
+
let result = ruby.ary_new();
|
|
91
|
+
for token_ids in token_ids_batch {
|
|
92
|
+
let mut tokens = Vec::new();
|
|
93
|
+
for id in token_ids {
|
|
94
|
+
let token = self.0.token_to_piece(id)
|
|
95
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
96
|
+
tokens.push(token);
|
|
97
|
+
}
|
|
98
|
+
result.push(ruby.ary_from_vec(tokens))?;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
Ok(result)
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/// Encode text and return both token IDs and token strings
|
|
105
|
+
pub fn encode_with_tokens(&self, text: String, add_special_tokens: Option<bool>) -> Result<RHash> {
|
|
106
|
+
let ruby = Ruby::get().unwrap();
|
|
107
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
108
|
+
let add_special = add_special_tokens.unwrap_or(true);
|
|
109
|
+
let token_ids = self.0.encode(&text, add_special)
|
|
110
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
111
|
+
|
|
112
|
+
let mut tokens = Vec::new();
|
|
113
|
+
for &id in &token_ids {
|
|
114
|
+
let token = self.0.token_to_piece(id)
|
|
115
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
116
|
+
tokens.push(token);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
let hash = ruby.hash_new();
|
|
120
|
+
hash.aset(ruby.to_symbol("ids"), ruby.ary_from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
|
121
|
+
hash.aset(ruby.to_symbol("tokens"), ruby.ary_from_vec(tokens))?;
|
|
122
|
+
|
|
123
|
+
Ok(hash)
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
/// Decode token IDs back to text
|
|
127
|
+
pub fn decode(&self, token_ids: RArray, skip_special_tokens: Option<bool>) -> Result<String> {
|
|
128
|
+
let ruby = Ruby::get().unwrap();
|
|
129
|
+
let token_ids: Vec<i64> = token_ids.to_vec()?;
|
|
130
|
+
let token_ids: Vec<u32> = token_ids.into_iter()
|
|
131
|
+
.map(|id| id as u32)
|
|
132
|
+
.collect();
|
|
133
|
+
let skip_special = skip_special_tokens.unwrap_or(true);
|
|
134
|
+
|
|
135
|
+
self.0.decode(&token_ids, skip_special)
|
|
136
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/// Get the string representation of a single token ID
|
|
140
|
+
pub fn id_to_token(&self, token_id: i64) -> Result<String> {
|
|
141
|
+
let ruby = Ruby::get().unwrap();
|
|
142
|
+
self.0.token_to_piece(token_id as u32)
|
|
143
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
/// Get the vocabulary as a hash of token string to ID
|
|
147
|
+
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> Result<RHash> {
|
|
148
|
+
let ruby = Ruby::get().unwrap();
|
|
149
|
+
let with_added = with_added_tokens.unwrap_or(true);
|
|
150
|
+
let vocab = self.0.inner().get_vocab(with_added);
|
|
151
|
+
|
|
152
|
+
let hash = ruby.hash_new();
|
|
153
|
+
for (token, id) in vocab {
|
|
154
|
+
hash.aset(token, id as i64)?;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
Ok(hash)
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/// Get vocabulary size
|
|
161
|
+
pub fn vocab_size(&self, with_added_tokens: Option<bool>) -> usize {
|
|
162
|
+
let with_added = with_added_tokens.unwrap_or(true);
|
|
163
|
+
self.0.inner().get_vocab_size(with_added)
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
/// Enable padding - returns a new tokenizer with padding enabled
|
|
167
|
+
pub fn with_padding(&self, kwargs: RHash) -> Result<Self> {
|
|
168
|
+
use tokenizers::{PaddingParams, PaddingStrategy, PaddingDirection};
|
|
169
|
+
let ruby = Ruby::get().unwrap();
|
|
170
|
+
|
|
171
|
+
let mut params = PaddingParams::default();
|
|
172
|
+
|
|
173
|
+
// Extract parameters from kwargs
|
|
174
|
+
if let Some(length) = kwargs.get(ruby.to_symbol("length")) {
|
|
175
|
+
if let Ok(len) = usize::try_convert(length) {
|
|
176
|
+
params.strategy = PaddingStrategy::Fixed(len);
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
if let Some(max_length) = kwargs.get(ruby.to_symbol("max_length")) {
|
|
181
|
+
if let Ok(_) = usize::try_convert(max_length) {
|
|
182
|
+
params.strategy = PaddingStrategy::BatchLongest;
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
if let Some(direction) = kwargs.get(ruby.to_symbol("direction")) {
|
|
187
|
+
if let Ok(dir) = String::try_convert(direction) {
|
|
188
|
+
params.direction = match dir.as_str() {
|
|
189
|
+
"right" => PaddingDirection::Right,
|
|
190
|
+
"left" => PaddingDirection::Left,
|
|
191
|
+
_ => PaddingDirection::Right,
|
|
192
|
+
};
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if let Some(pad_id) = kwargs.get(ruby.to_symbol("pad_id")) {
|
|
197
|
+
if let Ok(id) = u32::try_convert(pad_id) {
|
|
198
|
+
params.pad_id = id;
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if let Some(pad_token) = kwargs.get(ruby.to_symbol("pad_token")) {
|
|
203
|
+
if let Ok(token) = String::try_convert(pad_token) {
|
|
204
|
+
params.pad_token = token;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
let mut new_tokenizer = self.0.clone();
|
|
209
|
+
let _ = new_tokenizer.inner_mut().with_padding(Some(params));
|
|
210
|
+
Ok(Self(new_tokenizer))
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
/// Enable truncation - returns a new tokenizer with truncation enabled
|
|
214
|
+
pub fn with_truncation(&self, max_length: usize) -> Result<Self> {
|
|
215
|
+
use tokenizers::{TruncationParams, TruncationStrategy, TruncationDirection};
|
|
216
|
+
|
|
217
|
+
let params = TruncationParams {
|
|
218
|
+
max_length,
|
|
219
|
+
strategy: TruncationStrategy::LongestFirst,
|
|
220
|
+
stride: 0,
|
|
221
|
+
direction: TruncationDirection::Right,
|
|
222
|
+
};
|
|
223
|
+
|
|
224
|
+
let mut new_tokenizer = self.0.clone();
|
|
225
|
+
let _ = new_tokenizer.inner_mut().with_truncation(Some(params));
|
|
226
|
+
Ok(Self(new_tokenizer))
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
/// Get special tokens information
|
|
230
|
+
pub fn get_special_tokens(&self) -> Result<RHash> {
|
|
231
|
+
let ruby = Ruby::get().unwrap();
|
|
232
|
+
let hash = ruby.hash_new();
|
|
233
|
+
|
|
234
|
+
// Common special tokens
|
|
235
|
+
let special_tokens = vec![
|
|
236
|
+
("[CLS]", "cls_token"),
|
|
237
|
+
("[SEP]", "sep_token"),
|
|
238
|
+
("[PAD]", "pad_token"),
|
|
239
|
+
("[UNK]", "unk_token"),
|
|
240
|
+
("[MASK]", "mask_token"),
|
|
241
|
+
("<s>", "bos_token"),
|
|
242
|
+
("</s>", "eos_token"),
|
|
243
|
+
];
|
|
244
|
+
|
|
245
|
+
let vocab = self.0.inner().get_vocab(true);
|
|
246
|
+
|
|
247
|
+
for (token, name) in special_tokens {
|
|
248
|
+
if let Some(id) = vocab.get(token) {
|
|
249
|
+
hash.aset(name, *id as i64)?;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
Ok(hash)
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
/// Get tokenizer options as a hash
|
|
257
|
+
pub fn options(&self) -> Result<RHash> {
|
|
258
|
+
let ruby = Ruby::get().unwrap();
|
|
259
|
+
let hash = ruby.hash_new();
|
|
260
|
+
|
|
261
|
+
// Get vocab size
|
|
262
|
+
hash.aset("vocab_size", self.vocab_size(Some(true)))?;
|
|
263
|
+
hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
|
|
264
|
+
|
|
265
|
+
// Get special tokens info
|
|
266
|
+
let special_tokens = self.get_special_tokens()?;
|
|
267
|
+
hash.aset("special_tokens", special_tokens)?;
|
|
268
|
+
|
|
269
|
+
// Get padding/truncation info if available
|
|
270
|
+
let inner_tokenizer = self.0.inner();
|
|
271
|
+
|
|
272
|
+
// Check if padding is enabled
|
|
273
|
+
if let Some(_padding) = inner_tokenizer.get_padding() {
|
|
274
|
+
let padding_info = ruby.hash_new();
|
|
275
|
+
padding_info.aset("enabled", true)?;
|
|
276
|
+
hash.aset("padding", padding_info)?;
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
// Check if truncation is enabled
|
|
280
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
|
281
|
+
let truncation_info = ruby.hash_new();
|
|
282
|
+
truncation_info.aset("enabled", true)?;
|
|
283
|
+
truncation_info.aset("max_length", truncation.max_length)?;
|
|
284
|
+
hash.aset("truncation", truncation_info)?;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
Ok(hash)
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
/// String representation
|
|
291
|
+
pub fn inspect(&self) -> String {
|
|
292
|
+
let vocab_size = self.vocab_size(Some(true));
|
|
293
|
+
let special_tokens = self.get_special_tokens()
|
|
294
|
+
.ok()
|
|
295
|
+
.map(|h| h.len())
|
|
296
|
+
.unwrap_or(0);
|
|
297
|
+
|
|
298
|
+
let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
|
|
299
|
+
|
|
300
|
+
if special_tokens > 0 {
|
|
301
|
+
parts.push(format!("special_tokens={}", special_tokens));
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// Check for padding/truncation
|
|
305
|
+
let inner_tokenizer = self.0.inner();
|
|
306
|
+
if inner_tokenizer.get_padding().is_some() {
|
|
307
|
+
parts.push("padding=enabled".to_string());
|
|
308
|
+
}
|
|
309
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
|
310
|
+
parts.push(format!("truncation={}", truncation.max_length));
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
parts.join(" ") + ">"
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
318
|
+
let ruby = Ruby::get().unwrap();
|
|
319
|
+
let tokenizer_class = rb_candle.define_class("Tokenizer", ruby.class_object())?;
|
|
320
|
+
|
|
321
|
+
// Class methods
|
|
322
|
+
tokenizer_class.define_singleton_method("from_file", function!(Tokenizer::from_file, 1))?;
|
|
323
|
+
tokenizer_class.define_singleton_method("from_pretrained", function!(Tokenizer::from_pretrained, 1))?;
|
|
324
|
+
|
|
325
|
+
// Instance methods
|
|
326
|
+
tokenizer_class.define_method("encode", method!(Tokenizer::encode, 2))?;
|
|
327
|
+
tokenizer_class.define_method("encode_to_tokens", method!(Tokenizer::encode_to_tokens, 2))?;
|
|
328
|
+
tokenizer_class.define_method("encode_with_tokens", method!(Tokenizer::encode_with_tokens, 2))?;
|
|
329
|
+
tokenizer_class.define_method("encode_batch", method!(Tokenizer::encode_batch, 2))?;
|
|
330
|
+
tokenizer_class.define_method("encode_batch_to_tokens", method!(Tokenizer::encode_batch_to_tokens, 2))?;
|
|
331
|
+
tokenizer_class.define_method("decode", method!(Tokenizer::decode, 2))?;
|
|
332
|
+
tokenizer_class.define_method("id_to_token", method!(Tokenizer::id_to_token, 1))?;
|
|
333
|
+
tokenizer_class.define_method("get_vocab", method!(Tokenizer::get_vocab, 1))?;
|
|
334
|
+
tokenizer_class.define_method("vocab_size", method!(Tokenizer::vocab_size, 1))?;
|
|
335
|
+
tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
|
|
336
|
+
tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
|
|
337
|
+
tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
|
|
338
|
+
tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
|
|
339
|
+
tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
|
|
340
|
+
tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
|
|
341
|
+
|
|
342
|
+
Ok(())
|
|
343
|
+
}
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
use magnus::{function, Module, Object};
|
|
2
|
+
|
|
3
|
+
use ::candle_core::Tensor as CoreTensor;
|
|
4
|
+
|
|
5
|
+
use crate::ruby::Result;
|
|
6
|
+
|
|
7
|
+
/// Ensures the HuggingFace cache directory exists before Api::new() is called.
|
|
8
|
+
///
|
|
9
|
+
/// The hf_hub crate stores downloaded models in a "hub" subdirectory under the
|
|
10
|
+
/// cache root. When the parent directory doesn't exist, hf_hub may fail to
|
|
11
|
+
/// create the full path or silently produce an empty cache. This function
|
|
12
|
+
/// pre-creates the directory tree to avoid the race condition described in
|
|
13
|
+
/// issue #72.
|
|
14
|
+
///
|
|
15
|
+
/// Resolution order for the cache root:
|
|
16
|
+
/// 1. $HF_HOME (if set)
|
|
17
|
+
/// 2. $XDG_CACHE_HOME/huggingface (if XDG_CACHE_HOME is set)
|
|
18
|
+
/// 3. ~/.cache/huggingface
|
|
19
|
+
pub fn ensure_hf_cache_dir() {
|
|
20
|
+
let cache_root = if let Ok(hf_home) = std::env::var("HF_HOME") {
|
|
21
|
+
std::path::PathBuf::from(hf_home)
|
|
22
|
+
} else if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
|
|
23
|
+
std::path::PathBuf::from(xdg).join("huggingface")
|
|
24
|
+
} else if let Ok(home) = std::env::var("HOME") {
|
|
25
|
+
std::path::PathBuf::from(home).join(".cache").join("huggingface")
|
|
26
|
+
} else {
|
|
27
|
+
return;
|
|
28
|
+
};
|
|
29
|
+
let hub_dir = cache_root.join("hub");
|
|
30
|
+
let _ = std::fs::create_dir_all(hub_dir);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
pub fn actual_index(t: &CoreTensor, dim: usize, index: i64) -> candle_core::Result<usize> {
|
|
34
|
+
let dim = t.dim(dim)?;
|
|
35
|
+
if 0 <= index {
|
|
36
|
+
let index = index as usize;
|
|
37
|
+
if dim <= index {
|
|
38
|
+
candle_core::bail!("index {index} is too large for tensor dimension {dim}")
|
|
39
|
+
}
|
|
40
|
+
Ok(index)
|
|
41
|
+
} else {
|
|
42
|
+
if (dim as i64) < -index {
|
|
43
|
+
candle_core::bail!("index {index} is too low for tensor dimension {dim}")
|
|
44
|
+
}
|
|
45
|
+
Ok((dim as i64 + index) as usize)
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
pub fn actual_dim(t: &CoreTensor, dim: i64) -> candle_core::Result<usize> {
|
|
50
|
+
let rank = t.rank();
|
|
51
|
+
if 0 <= dim {
|
|
52
|
+
let dim = dim as usize;
|
|
53
|
+
if rank <= dim {
|
|
54
|
+
candle_core::bail!("dimension index {dim} is too large for tensor rank {rank}")
|
|
55
|
+
}
|
|
56
|
+
Ok(dim)
|
|
57
|
+
} else {
|
|
58
|
+
if (rank as i64) < -dim {
|
|
59
|
+
candle_core::bail!("dimension index {dim} is too low for tensor rank {rank}")
|
|
60
|
+
}
|
|
61
|
+
Ok((rank as i64 + dim) as usize)
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/// Returns true if the 'cuda' backend is available.
|
|
66
|
+
/// &RETURNS&: bool
|
|
67
|
+
fn cuda_is_available() -> bool {
|
|
68
|
+
candle_core::utils::cuda_is_available()
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/// Returns true if candle was compiled with 'accelerate' support.
|
|
72
|
+
/// &RETURNS&: bool
|
|
73
|
+
fn has_accelerate() -> bool {
|
|
74
|
+
candle_core::utils::has_accelerate()
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/// Returns true if candle was compiled with MKL support.
|
|
78
|
+
/// &RETURNS&: bool
|
|
79
|
+
fn has_mkl() -> bool {
|
|
80
|
+
candle_core::utils::has_mkl()
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/// Returns the number of threads used by the candle.
|
|
84
|
+
/// &RETURNS&: int
|
|
85
|
+
fn get_num_threads() -> usize {
|
|
86
|
+
candle_core::utils::get_num_threads()
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
pub fn candle_utils(rb_candle: magnus::RModule) -> Result<()> {
|
|
90
|
+
let rb_utils = rb_candle.define_module("Utils")?;
|
|
91
|
+
rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
|
|
92
|
+
rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
|
|
93
|
+
rb_utils.define_singleton_method("has_accelerate", function!(has_accelerate, 0))?;
|
|
94
|
+
rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
|
|
95
|
+
Ok(())
|
|
96
|
+
}
|