red-candle 1.0.0.pre.6 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +481 -4
  4. data/Rakefile +1 -3
  5. data/ext/candle/src/lib.rs +6 -3
  6. data/ext/candle/src/llm/gemma.rs +21 -79
  7. data/ext/candle/src/llm/generation_config.rs +3 -0
  8. data/ext/candle/src/llm/llama.rs +21 -79
  9. data/ext/candle/src/llm/mistral.rs +21 -89
  10. data/ext/candle/src/llm/mod.rs +3 -33
  11. data/ext/candle/src/llm/quantized_gguf.rs +501 -0
  12. data/ext/candle/src/llm/text_generation.rs +0 -4
  13. data/ext/candle/src/ner.rs +423 -0
  14. data/ext/candle/src/reranker.rs +24 -21
  15. data/ext/candle/src/ruby/device.rs +6 -6
  16. data/ext/candle/src/ruby/dtype.rs +4 -4
  17. data/ext/candle/src/ruby/embedding_model.rs +36 -34
  18. data/ext/candle/src/ruby/llm.rs +110 -49
  19. data/ext/candle/src/ruby/mod.rs +1 -2
  20. data/ext/candle/src/ruby/tensor.rs +66 -66
  21. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  22. data/ext/candle/src/ruby/utils.rs +6 -24
  23. data/ext/candle/src/tokenizer/loader.rs +108 -0
  24. data/ext/candle/src/tokenizer/mod.rs +103 -0
  25. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  26. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  27. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  28. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  29. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  30. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  31. data/lib/candle/build_info.rb +2 -0
  32. data/lib/candle/device_utils.rb +2 -0
  33. data/lib/candle/llm.rb +91 -2
  34. data/lib/candle/ner.rb +345 -0
  35. data/lib/candle/reranker.rb +1 -1
  36. data/lib/candle/tensor.rb +2 -0
  37. data/lib/candle/tokenizer.rb +139 -0
  38. data/lib/candle/version.rb +4 -2
  39. data/lib/candle.rb +2 -0
  40. metadata +127 -3
  41. data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -0,0 +1,423 @@
1
+ use magnus::{class, function, method, prelude::*, Error, RModule, RArray, RHash};
2
+ use candle_transformers::models::bert::{BertModel, Config};
3
+ use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
4
+ use candle_nn::{VarBuilder, Linear};
5
+ use hf_hub::{api::sync::Api, Repo, RepoType};
6
+ use std::collections::HashMap;
7
+ use serde::{Deserialize, Serialize};
8
+ use crate::ruby::{Device, Result};
9
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
10
+
11
+ #[derive(Debug, Clone, Serialize, Deserialize)]
12
+ pub struct NERConfig {
13
+ pub id2label: HashMap<i64, String>,
14
+ pub label2id: HashMap<String, i64>,
15
+ }
16
+
17
+ #[derive(Debug, Clone)]
18
+ pub struct EntitySpan {
19
+ pub text: String,
20
+ pub label: String,
21
+ pub start: usize,
22
+ pub end: usize,
23
+ pub token_start: usize,
24
+ pub token_end: usize,
25
+ pub confidence: f32,
26
+ }
27
+
28
+ #[magnus::wrap(class = "Candle::NER", free_immediately, size)]
29
+ pub struct NER {
30
+ model: BertModel,
31
+ tokenizer: TokenizerWrapper,
32
+ classifier: Linear,
33
+ config: NERConfig,
34
+ device: CoreDevice,
35
+ model_id: String,
36
+ }
37
+
38
+ impl NER {
39
+ pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
40
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
41
+
42
+ // Load model in a separate thread to avoid blocking
43
+ let device_clone = device.clone();
44
+ let model_id_clone = model_id.clone();
45
+
46
+ let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
47
+ let api = Api::new()?;
48
+ let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
49
+
50
+ // Download model files
51
+ let config_filename = repo.get("config.json")?;
52
+
53
+ // Handle tokenizer loading with optional tokenizer_id
54
+ let tokenizer = if let Some(tok_id) = tokenizer_id {
55
+ // Use the specified tokenizer
56
+ let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
57
+ let tokenizer_filename = tok_repo.get("tokenizer.json")?;
58
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
59
+ TokenizerLoader::with_padding(tokenizer, None)
60
+ } else {
61
+ // Try to load tokenizer from model repo
62
+ let tokenizer_filename = repo.get("tokenizer.json")?;
63
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
64
+ TokenizerLoader::with_padding(tokenizer, None)
65
+ };
66
+ let weights_filename = repo.get("pytorch_model.safetensors")
67
+ .or_else(|_| repo.get("model.safetensors"))?;
68
+
69
+ // Load BERT config
70
+ let config_str = std::fs::read_to_string(&config_filename)?;
71
+ let config_json: serde_json::Value = serde_json::from_str(&config_str)?;
72
+ let bert_config: Config = serde_json::from_value(config_json.clone())?;
73
+
74
+ // Extract NER label configuration
75
+ let id2label = config_json["id2label"]
76
+ .as_object()
77
+ .ok_or("Missing id2label in config")?
78
+ .iter()
79
+ .map(|(k, v)| {
80
+ let id = k.parse::<i64>().unwrap_or(0);
81
+ let label = v.as_str().unwrap_or("O").to_string();
82
+ (id, label)
83
+ })
84
+ .collect::<HashMap<_, _>>();
85
+
86
+ let label2id = id2label.iter()
87
+ .map(|(id, label)| (label.clone(), *id))
88
+ .collect::<HashMap<_, _>>();
89
+
90
+ let num_labels = id2label.len();
91
+ let ner_config = NERConfig { id2label, label2id };
92
+
93
+ // Load model weights
94
+ let vb = unsafe {
95
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
96
+ };
97
+
98
+ // Load BERT model
99
+ let model = BertModel::load(vb.pp("bert"), &bert_config)?;
100
+
101
+ // Load classification head for token classification
102
+ let classifier = candle_nn::linear(
103
+ bert_config.hidden_size,
104
+ num_labels,
105
+ vb.pp("classifier")
106
+ )?;
107
+
108
+ Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
109
+ });
110
+
111
+ match handle.join() {
112
+ Ok(Ok((model, tokenizer, classifier, config))) => {
113
+ Ok(Self {
114
+ model,
115
+ tokenizer,
116
+ classifier,
117
+ config,
118
+ device,
119
+ model_id,
120
+ })
121
+ }
122
+ Ok(Err(e)) => Err(Error::new(
123
+ magnus::exception::runtime_error(),
124
+ format!("Failed to load NER model: {}", e)
125
+ )),
126
+ Err(_) => Err(Error::new(
127
+ magnus::exception::runtime_error(),
128
+ "Thread panicked while loading NER model"
129
+ )),
130
+ }
131
+ }
132
+
133
+ /// Extract entities from text with confidence scores
134
+ pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
135
+ let threshold = confidence_threshold.unwrap_or(0.9) as f32;
136
+
137
+ // Tokenize the text
138
+ let encoding = self.tokenizer.inner().encode(text.as_str(), true)
139
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
140
+
141
+ let token_ids = encoding.get_ids();
142
+ let tokens = encoding.get_tokens();
143
+ let offsets = encoding.get_offsets();
144
+
145
+ // Convert to tensors
146
+ let input_ids = Tensor::new(token_ids, &self.device)
147
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
148
+ .unsqueeze(0)
149
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; // Add batch dimension
150
+
151
+ let attention_mask = Tensor::ones_like(&input_ids)
152
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
153
+ let token_type_ids = Tensor::zeros_like(&input_ids)
154
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
155
+
156
+ // Forward pass through BERT
157
+ let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
158
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
159
+
160
+ // Apply classifier to get logits for each token
161
+ let logits = self.classifier.forward(&output)
162
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
163
+
164
+ // Apply softmax to get probabilities
165
+ let probs = candle_nn::ops::softmax(&logits, 2)
166
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
167
+
168
+ // Get predictions and confidence scores
169
+ let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
170
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
171
+ .to_vec2()
172
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
173
+
174
+ // Extract entities with BIO decoding
175
+ let entities = self.decode_entities(
176
+ &text,
177
+ &tokens.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
178
+ offsets,
179
+ &probs_vec,
180
+ threshold
181
+ )?;
182
+
183
+ // Convert to Ruby array
184
+ let result = RArray::new();
185
+ for entity in entities {
186
+ let hash = RHash::new();
187
+ hash.aset("text", entity.text)?;
188
+ hash.aset("label", entity.label)?;
189
+ hash.aset("start", entity.start)?;
190
+ hash.aset("end", entity.end)?;
191
+ hash.aset("confidence", entity.confidence)?;
192
+ hash.aset("token_start", entity.token_start)?;
193
+ hash.aset("token_end", entity.token_end)?;
194
+ result.push(hash)?;
195
+ }
196
+
197
+ Ok(result)
198
+ }
199
+
200
+ /// Get token-level predictions with labels and confidence scores
201
+ pub fn predict_tokens(&self, text: String) -> Result<RArray> {
202
+ // Tokenize the text
203
+ let encoding = self.tokenizer.inner().encode(text.as_str(), true)
204
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
205
+
206
+ let token_ids = encoding.get_ids();
207
+ let tokens = encoding.get_tokens();
208
+
209
+ // Convert to tensors
210
+ let input_ids = Tensor::new(token_ids, &self.device)
211
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
212
+ .unsqueeze(0)
213
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
214
+
215
+ let attention_mask = Tensor::ones_like(&input_ids)
216
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
217
+ let token_type_ids = Tensor::zeros_like(&input_ids)
218
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
219
+
220
+ // Forward pass
221
+ let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
222
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
223
+ let logits = self.classifier.forward(&output)
224
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
225
+ let probs = candle_nn::ops::softmax(&logits, 2)
226
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
227
+
228
+ // Get predictions
229
+ let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
230
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
231
+ .to_vec2()
232
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
233
+
234
+ // Build result array
235
+ let result = RArray::new();
236
+ for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
237
+ // Find best label
238
+ let (label_id, confidence) = probs.iter()
239
+ .enumerate()
240
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
241
+ .map(|(idx, conf)| (idx as i64, *conf))
242
+ .unwrap_or((0, 0.0));
243
+
244
+ let label = self.config.id2label.get(&label_id)
245
+ .unwrap_or(&"O".to_string())
246
+ .clone();
247
+
248
+ let token_info = RHash::new();
249
+ token_info.aset("token", token.to_string())?;
250
+ token_info.aset("label", label)?;
251
+ token_info.aset("confidence", confidence)?;
252
+ token_info.aset("index", i)?;
253
+
254
+ // Add probability distribution if needed
255
+ let probs_hash = RHash::new();
256
+ for (id, label) in &self.config.id2label {
257
+ if let Some(prob) = probs.get(*id as usize) {
258
+ probs_hash.aset(label.as_str(), *prob)?;
259
+ }
260
+ }
261
+ token_info.aset("probabilities", probs_hash)?;
262
+
263
+ result.push(token_info)?;
264
+ }
265
+
266
+ Ok(result)
267
+ }
268
+
269
+ /// Decode BIO-tagged sequences into entity spans
270
+ fn decode_entities(
271
+ &self,
272
+ text: &str,
273
+ tokens: &[&str],
274
+ offsets: &[(usize, usize)],
275
+ probs: &[Vec<f32>],
276
+ threshold: f32,
277
+ ) -> Result<Vec<EntitySpan>> {
278
+ let mut entities = Vec::new();
279
+ let mut current_entity: Option<(String, usize, usize, Vec<f32>)> = None;
280
+
281
+ for (i, (token, probs_vec)) in tokens.iter().zip(probs).enumerate() {
282
+ // Skip special tokens
283
+ if token.starts_with("[") && token.ends_with("]") {
284
+ continue;
285
+ }
286
+
287
+ // Get predicted label
288
+ let (label_id, confidence) = probs_vec.iter()
289
+ .enumerate()
290
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
291
+ .map(|(idx, conf)| (idx as i64, *conf))
292
+ .unwrap_or((0, 0.0));
293
+
294
+ let label = self.config.id2label.get(&label_id)
295
+ .unwrap_or(&"O".to_string())
296
+ .clone();
297
+
298
+ // BIO decoding logic
299
+ if label == "O" || confidence < threshold {
300
+ // End current entity if exists
301
+ if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
302
+ if let (Some(start_offset), Some(end_offset)) =
303
+ (offsets.get(start_idx), offsets.get(end_idx - 1)) {
304
+ let entity_text = text[start_offset.0..end_offset.1].to_string();
305
+ let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
306
+
307
+ entities.push(EntitySpan {
308
+ text: entity_text,
309
+ label: entity_type,
310
+ start: start_offset.0,
311
+ end: end_offset.1,
312
+ token_start: start_idx,
313
+ token_end: end_idx,
314
+ confidence: avg_confidence,
315
+ });
316
+ }
317
+ }
318
+ } else if label.starts_with("B-") {
319
+ // Begin new entity
320
+ if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
321
+ if let (Some(start_offset), Some(end_offset)) =
322
+ (offsets.get(start_idx), offsets.get(end_idx - 1)) {
323
+ let entity_text = text[start_offset.0..end_offset.1].to_string();
324
+ let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
325
+
326
+ entities.push(EntitySpan {
327
+ text: entity_text,
328
+ label: entity_type,
329
+ start: start_offset.0,
330
+ end: end_offset.1,
331
+ token_start: start_idx,
332
+ token_end: end_idx,
333
+ confidence: avg_confidence,
334
+ });
335
+ }
336
+ }
337
+
338
+ let entity_type = label[2..].to_string();
339
+ current_entity = Some((entity_type, i, i + 1, vec![confidence]));
340
+ } else if label.starts_with("I-") {
341
+ // Continue entity
342
+ if let Some((ref mut entity_type, _, ref mut end_idx, ref mut confidences)) = current_entity {
343
+ let new_type = label[2..].to_string();
344
+ if *entity_type == new_type {
345
+ *end_idx = i + 1;
346
+ confidences.push(confidence);
347
+ } else {
348
+ // Type mismatch, start new entity
349
+ current_entity = Some((new_type, i, i + 1, vec![confidence]));
350
+ }
351
+ } else {
352
+ // I- tag without B- tag, treat as beginning
353
+ let entity_type = label[2..].to_string();
354
+ current_entity = Some((entity_type, i, i + 1, vec![confidence]));
355
+ }
356
+ }
357
+ }
358
+
359
+ // Handle final entity
360
+ if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity {
361
+ if let (Some(start_offset), Some(end_offset)) =
362
+ (offsets.get(start_idx), offsets.get(end_idx - 1)) {
363
+ let entity_text = text[start_offset.0..end_offset.1].to_string();
364
+ let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
365
+
366
+ entities.push(EntitySpan {
367
+ text: entity_text,
368
+ label: entity_type,
369
+ start: start_offset.0,
370
+ end: end_offset.1,
371
+ token_start: start_idx,
372
+ token_end: end_idx,
373
+ confidence: avg_confidence,
374
+ });
375
+ }
376
+ }
377
+
378
+ Ok(entities)
379
+ }
380
+
381
+ /// Get the label configuration
382
+ pub fn labels(&self) -> Result<RHash> {
383
+ let hash = RHash::new();
384
+
385
+ let id2label = RHash::new();
386
+ for (id, label) in &self.config.id2label {
387
+ id2label.aset(*id, label.as_str())?;
388
+ }
389
+
390
+ let label2id = RHash::new();
391
+ for (label, id) in &self.config.label2id {
392
+ label2id.aset(label.as_str(), *id)?;
393
+ }
394
+
395
+ hash.aset("id2label", id2label)?;
396
+ hash.aset("label2id", label2id)?;
397
+ hash.aset("num_labels", self.config.id2label.len())?;
398
+
399
+ Ok(hash)
400
+ }
401
+
402
+ /// Get the tokenizer
403
+ pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
404
+ Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
405
+ }
406
+
407
+ /// Get model info
408
+ pub fn model_info(&self) -> String {
409
+ format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
410
+ }
411
+ }
412
+
413
+ pub fn init(rb_candle: RModule) -> Result<()> {
414
+ let ner_class = rb_candle.define_class("NER", class::object())?;
415
+ ner_class.define_singleton_method("new", function!(NER::new, 3))?;
416
+ ner_class.define_method("extract_entities", method!(NER::extract_entities, 2))?;
417
+ ner_class.define_method("predict_tokens", method!(NER::predict_tokens, 1))?;
418
+ ner_class.define_method("labels", method!(NER::labels, 0))?;
419
+ ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
420
+ ner_class.define_method("model_info", method!(NER::model_info, 0))?;
421
+
422
+ Ok(())
423
+ }
@@ -3,28 +3,29 @@ use candle_transformers::models::bert::{BertModel, Config};
3
3
  use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
4
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
5
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
- use tokenizers::{PaddingParams, Tokenizer, EncodeInput};
6
+ use tokenizers::{EncodeInput, Tokenizer};
7
7
  use std::thread;
8
- use crate::ruby::{Device as RbDevice, Result as RbResult};
8
+ use crate::ruby::{Device, Result};
9
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
9
10
 
10
11
  #[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
11
12
  pub struct Reranker {
12
13
  model: BertModel,
13
- tokenizer: Tokenizer,
14
+ tokenizer: TokenizerWrapper,
14
15
  pooler: Linear,
15
16
  classifier: Linear,
16
17
  device: CoreDevice,
17
18
  }
18
19
 
19
20
  impl Reranker {
20
- pub fn new(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
21
- let device = device.unwrap_or(RbDevice::Cpu).as_device()?;
21
+ pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
22
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
22
23
  Self::new_with_core_device(model_id, device)
23
24
  }
24
25
 
25
- fn new_with_core_device(model_id: String, device: CoreDevice) -> Result<Self, Error> {
26
+ fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
26
27
  let device_clone = device.clone();
27
- let handle = thread::spawn(move || -> Result<(BertModel, Tokenizer, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
28
+ let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
28
29
  let api = Api::new()?;
29
30
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
30
31
 
@@ -38,12 +39,8 @@ impl Reranker {
38
39
  let config: Config = serde_json::from_str(&config)?;
39
40
 
40
41
  // Setup tokenizer with padding
41
- let mut tokenizer = Tokenizer::from_file(tokenizer_filename)?;
42
- let pp = PaddingParams {
43
- strategy: tokenizers::PaddingStrategy::BatchLongest,
44
- ..Default::default()
45
- };
46
- tokenizer.with_padding(Some(pp));
42
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
43
+ let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
47
44
 
48
45
  // Load model weights
49
46
  let vb = unsafe {
@@ -59,7 +56,7 @@ impl Reranker {
59
56
  // Load classifier layer for cross-encoder (single output score)
60
57
  let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
61
58
 
62
- Ok((model, tokenizer, pooler, classifier))
59
+ Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
63
60
  });
64
61
 
65
62
  match handle.join() {
@@ -71,12 +68,12 @@ impl Reranker {
71
68
  }
72
69
  }
73
70
 
74
- pub fn debug_tokenization(&self, query: String, document: String) -> Result<magnus::RHash, Error> {
71
+ pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
75
72
  // Create query-document pair for cross-encoder
76
73
  let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
77
74
 
78
- // Tokenize
79
- let encoding = self.tokenizer.encode(query_doc_pair, true)
75
+ // Tokenize using the inner tokenizer for detailed info
76
+ let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
80
77
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
81
78
 
82
79
  // Get token information
@@ -95,7 +92,7 @@ impl Reranker {
95
92
  Ok(result)
96
93
  }
97
94
 
98
- pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> Result<RArray, Error> {
95
+ pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
99
96
  let documents: Vec<String> = documents.to_vec()?;
100
97
 
101
98
  // Create query-document pairs for cross-encoder
@@ -104,8 +101,8 @@ impl Reranker {
104
101
  .map(|d| (query.clone(), d.clone()).into())
105
102
  .collect();
106
103
 
107
- // Tokenize batch
108
- let encodings = self.tokenizer.encode_batch(query_and_docs, true)
104
+ // Tokenize batch using inner tokenizer for access to token type IDs
105
+ let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
109
106
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
110
107
 
111
108
  // Convert to tensors
@@ -256,12 +253,18 @@ impl Reranker {
256
253
  }
257
254
  Ok(result_array)
258
255
  }
256
+
257
+ /// Get the tokenizer used by this model
258
+ pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
259
+ Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
260
+ }
259
261
  }
260
262
 
261
- pub fn init(rb_candle: RModule) -> Result<(), Error> {
263
+ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
262
264
  let c_reranker = rb_candle.define_class("Reranker", class::object())?;
263
265
  c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
264
266
  c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
265
267
  c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
268
+ c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
266
269
  Ok(())
267
270
  }
@@ -2,7 +2,7 @@ use magnus::Error;
2
2
  use magnus::{function, method, class, RModule, Module, Object};
3
3
 
4
4
  use ::candle_core::Device as CoreDevice;
5
- use crate::ruby::Result as RbResult;
5
+ use crate::ruby::Result;
6
6
 
7
7
  #[cfg(any(feature = "cuda", feature = "metal"))]
8
8
  use crate::ruby::errors::wrap_candle_err;
@@ -68,7 +68,7 @@ impl Device {
68
68
  }
69
69
 
70
70
  /// Create a CUDA device (GPU)
71
- pub fn cuda() -> RbResult<Self> {
71
+ pub fn cuda() -> Result<Self> {
72
72
  #[cfg(not(feature = "cuda"))]
73
73
  {
74
74
  return Err(Error::new(
@@ -82,7 +82,7 @@ impl Device {
82
82
  }
83
83
 
84
84
  /// Create a Metal device (Apple GPU)
85
- pub fn metal() -> RbResult<Self> {
85
+ pub fn metal() -> Result<Self> {
86
86
  #[cfg(not(feature = "metal"))]
87
87
  {
88
88
  return Err(Error::new(
@@ -103,7 +103,7 @@ impl Device {
103
103
  }
104
104
  }
105
105
 
106
- pub fn as_device(&self) -> RbResult<CoreDevice> {
106
+ pub fn as_device(&self) -> Result<CoreDevice> {
107
107
  match self {
108
108
  Self::Cpu => Ok(CoreDevice::Cpu),
109
109
  Self::Cuda => {
@@ -165,7 +165,7 @@ impl Device {
165
165
  }
166
166
 
167
167
  impl magnus::TryConvert for Device {
168
- fn try_convert(val: magnus::Value) -> RbResult<Self> {
168
+ fn try_convert(val: magnus::Value) -> Result<Self> {
169
169
  // First check if it's already a wrapped Device object
170
170
  if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
171
171
  return Ok(*device);
@@ -184,7 +184,7 @@ impl magnus::TryConvert for Device {
184
184
  }
185
185
  }
186
186
 
187
- pub fn init(rb_candle: RModule) -> Result<(), Error> {
187
+ pub fn init(rb_candle: RModule) -> Result<()> {
188
188
  let rb_device = rb_candle.define_class("Device", class::object())?;
189
189
  rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
190
190
  rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
@@ -1,8 +1,8 @@
1
1
  use magnus::value::ReprValue;
2
- use magnus::{method, class, RModule, Error, Module};
2
+ use magnus::{method, class, RModule, Module};
3
3
 
4
4
  use ::candle_core::DType as CoreDType;
5
- use crate::ruby::Result as RbResult;
5
+ use crate::ruby::Result;
6
6
 
7
7
  #[derive(Clone, Copy, Debug, PartialEq, Eq)]
8
8
  #[magnus::wrap(class = "Candle::DType", free_immediately, size)]
@@ -21,7 +21,7 @@ impl DType {
21
21
  }
22
22
 
23
23
  impl DType {
24
- pub fn from_rbobject(dtype: magnus::Symbol) -> RbResult<Self> {
24
+ pub fn from_rbobject(dtype: magnus::Symbol) -> Result<Self> {
25
25
  let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
26
26
  use std::str::FromStr;
27
27
  let dtype = CoreDType::from_str(&dtype).unwrap();
@@ -29,7 +29,7 @@ impl DType {
29
29
  }
30
30
  }
31
31
 
32
- pub fn init(rb_candle: RModule) -> Result<(), Error> {
32
+ pub fn init(rb_candle: RModule) -> Result<()> {
33
33
  let rb_dtype = rb_candle.define_class("DType", class::object())?;
34
34
  rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
35
35
  rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;