red-candle 1.0.0.pre.7 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +399 -18
  4. data/ext/candle/src/lib.rs +6 -3
  5. data/ext/candle/src/llm/gemma.rs +5 -0
  6. data/ext/candle/src/llm/llama.rs +5 -0
  7. data/ext/candle/src/llm/mistral.rs +5 -0
  8. data/ext/candle/src/llm/mod.rs +1 -89
  9. data/ext/candle/src/llm/quantized_gguf.rs +5 -0
  10. data/ext/candle/src/ner.rs +423 -0
  11. data/ext/candle/src/reranker.rs +24 -21
  12. data/ext/candle/src/ruby/device.rs +6 -6
  13. data/ext/candle/src/ruby/dtype.rs +4 -4
  14. data/ext/candle/src/ruby/embedding_model.rs +36 -33
  15. data/ext/candle/src/ruby/llm.rs +31 -13
  16. data/ext/candle/src/ruby/mod.rs +1 -2
  17. data/ext/candle/src/ruby/tensor.rs +66 -66
  18. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  19. data/ext/candle/src/ruby/utils.rs +6 -24
  20. data/ext/candle/src/tokenizer/loader.rs +108 -0
  21. data/ext/candle/src/tokenizer/mod.rs +103 -0
  22. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  23. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  24. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  25. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  26. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  27. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  28. data/lib/candle/build_info.rb +2 -0
  29. data/lib/candle/device_utils.rb +2 -0
  30. data/lib/candle/ner.rb +345 -0
  31. data/lib/candle/reranker.rb +1 -1
  32. data/lib/candle/tensor.rb +2 -0
  33. data/lib/candle/tokenizer.rb +139 -0
  34. data/lib/candle/version.rb +4 -2
  35. data/lib/candle.rb +2 -0
  36. metadata +128 -5
  37. data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -1,5 +1,4 @@
1
1
  use candle_core::{Device, Result as CandleResult};
2
- use tokenizers::Tokenizer;
3
2
 
4
3
  pub mod mistral;
5
4
  pub mod llama;
@@ -11,6 +10,7 @@ pub mod quantized_gguf;
11
10
  pub use generation_config::GenerationConfig;
12
11
  pub use text_generation::TextGeneration;
13
12
  pub use quantized_gguf::QuantizedGGUF;
13
+ pub use crate::tokenizer::TokenizerWrapper;
14
14
 
15
15
  /// Trait for text generation models
16
16
  pub trait TextGenerator: Send + Sync {
@@ -37,92 +37,4 @@ pub trait TextGenerator: Send + Sync {
37
37
 
38
38
  /// Clear any cached state (like KV cache)
39
39
  fn clear_cache(&mut self);
40
- }
41
-
42
- /// Common structure for managing tokenizer
43
- #[derive(Debug)]
44
- pub struct TokenizerWrapper {
45
- tokenizer: Tokenizer,
46
- }
47
-
48
- impl TokenizerWrapper {
49
- pub fn new(tokenizer: Tokenizer) -> Self {
50
- Self { tokenizer }
51
- }
52
-
53
- pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
54
- let encoding = self.tokenizer
55
- .encode(text, add_special_tokens)
56
- .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
57
- Ok(encoding.get_ids().to_vec())
58
- }
59
-
60
- pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
61
- self.tokenizer
62
- .decode(tokens, skip_special_tokens)
63
- .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
64
- }
65
-
66
- pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
67
- self.tokenizer
68
- .id_to_token(token)
69
- .map(|s| s.to_string())
70
- .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
71
- }
72
-
73
- /// Decode a single token for streaming output
74
- pub fn decode_token(&self, token: u32) -> CandleResult<String> {
75
- // Decode the single token properly
76
- self.decode(&[token], true)
77
- }
78
-
79
- /// Decode tokens incrementally for streaming
80
- /// This is more efficient than decoding single tokens
81
- pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
82
- if new_tokens_start >= all_tokens.len() {
83
- return Ok(String::new());
84
- }
85
-
86
- // Decode all tokens up to this point
87
- let full_text = self.decode(all_tokens, true)?;
88
-
89
- // If we're at the start, return everything
90
- if new_tokens_start == 0 {
91
- return Ok(full_text);
92
- }
93
-
94
- // Otherwise, decode up to the previous token and return the difference
95
- let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
96
-
97
- // Find the common prefix between the two strings to handle cases where
98
- // the tokenizer might produce slightly different text when decoding
99
- // different token sequences
100
- let common_prefix_len = full_text
101
- .char_indices()
102
- .zip(previous_text.chars())
103
- .take_while(|((_, c1), c2)| c1 == c2)
104
- .count();
105
-
106
- // Find the byte position of the character boundary
107
- let byte_pos = full_text
108
- .char_indices()
109
- .nth(common_prefix_len)
110
- .map(|(pos, _)| pos)
111
- .unwrap_or(full_text.len());
112
-
113
- // Return only the new portion
114
- Ok(full_text[byte_pos..].to_string())
115
- }
116
-
117
- /// Format tokens with debug information
118
- pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
119
- let mut result = String::new();
120
-
121
- for &token in tokens {
122
- let token_piece = self.token_to_piece(token)?;
123
- result.push_str(&format!("[{}:{}]", token, token_piece));
124
- }
125
-
126
- Ok(result)
127
- }
128
40
  }
@@ -28,6 +28,11 @@ enum ModelType {
28
28
  }
29
29
 
30
30
  impl QuantizedGGUF {
31
+ /// Get the tokenizer
32
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
33
+ &self.tokenizer
34
+ }
35
+
31
36
  /// Load a quantized model from a GGUF file
32
37
  pub async fn from_pretrained(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
33
38
  // Check if user specified an exact GGUF filename
@@ -0,0 +1,423 @@
1
+ use magnus::{class, function, method, prelude::*, Error, RModule, RArray, RHash};
2
+ use candle_transformers::models::bert::{BertModel, Config};
3
+ use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
4
+ use candle_nn::{VarBuilder, Linear};
5
+ use hf_hub::{api::sync::Api, Repo, RepoType};
6
+ use std::collections::HashMap;
7
+ use serde::{Deserialize, Serialize};
8
+ use crate::ruby::{Device, Result};
9
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
10
+
11
+ #[derive(Debug, Clone, Serialize, Deserialize)]
12
+ pub struct NERConfig {
13
+ pub id2label: HashMap<i64, String>,
14
+ pub label2id: HashMap<String, i64>,
15
+ }
16
+
17
+ #[derive(Debug, Clone)]
18
+ pub struct EntitySpan {
19
+ pub text: String,
20
+ pub label: String,
21
+ pub start: usize,
22
+ pub end: usize,
23
+ pub token_start: usize,
24
+ pub token_end: usize,
25
+ pub confidence: f32,
26
+ }
27
+
28
+ #[magnus::wrap(class = "Candle::NER", free_immediately, size)]
29
+ pub struct NER {
30
+ model: BertModel,
31
+ tokenizer: TokenizerWrapper,
32
+ classifier: Linear,
33
+ config: NERConfig,
34
+ device: CoreDevice,
35
+ model_id: String,
36
+ }
37
+
38
+ impl NER {
39
+ pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
40
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
41
+
42
+ // Load model in a separate thread to avoid blocking
43
+ let device_clone = device.clone();
44
+ let model_id_clone = model_id.clone();
45
+
46
+ let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
47
+ let api = Api::new()?;
48
+ let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
49
+
50
+ // Download model files
51
+ let config_filename = repo.get("config.json")?;
52
+
53
+ // Handle tokenizer loading with optional tokenizer_id
54
+ let tokenizer = if let Some(tok_id) = tokenizer_id {
55
+ // Use the specified tokenizer
56
+ let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
57
+ let tokenizer_filename = tok_repo.get("tokenizer.json")?;
58
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
59
+ TokenizerLoader::with_padding(tokenizer, None)
60
+ } else {
61
+ // Try to load tokenizer from model repo
62
+ let tokenizer_filename = repo.get("tokenizer.json")?;
63
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
64
+ TokenizerLoader::with_padding(tokenizer, None)
65
+ };
66
+ let weights_filename = repo.get("pytorch_model.safetensors")
67
+ .or_else(|_| repo.get("model.safetensors"))?;
68
+
69
+ // Load BERT config
70
+ let config_str = std::fs::read_to_string(&config_filename)?;
71
+ let config_json: serde_json::Value = serde_json::from_str(&config_str)?;
72
+ let bert_config: Config = serde_json::from_value(config_json.clone())?;
73
+
74
+ // Extract NER label configuration
75
+ let id2label = config_json["id2label"]
76
+ .as_object()
77
+ .ok_or("Missing id2label in config")?
78
+ .iter()
79
+ .map(|(k, v)| {
80
+ let id = k.parse::<i64>().unwrap_or(0);
81
+ let label = v.as_str().unwrap_or("O").to_string();
82
+ (id, label)
83
+ })
84
+ .collect::<HashMap<_, _>>();
85
+
86
+ let label2id = id2label.iter()
87
+ .map(|(id, label)| (label.clone(), *id))
88
+ .collect::<HashMap<_, _>>();
89
+
90
+ let num_labels = id2label.len();
91
+ let ner_config = NERConfig { id2label, label2id };
92
+
93
+ // Load model weights
94
+ let vb = unsafe {
95
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
96
+ };
97
+
98
+ // Load BERT model
99
+ let model = BertModel::load(vb.pp("bert"), &bert_config)?;
100
+
101
+ // Load classification head for token classification
102
+ let classifier = candle_nn::linear(
103
+ bert_config.hidden_size,
104
+ num_labels,
105
+ vb.pp("classifier")
106
+ )?;
107
+
108
+ Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
109
+ });
110
+
111
+ match handle.join() {
112
+ Ok(Ok((model, tokenizer, classifier, config))) => {
113
+ Ok(Self {
114
+ model,
115
+ tokenizer,
116
+ classifier,
117
+ config,
118
+ device,
119
+ model_id,
120
+ })
121
+ }
122
+ Ok(Err(e)) => Err(Error::new(
123
+ magnus::exception::runtime_error(),
124
+ format!("Failed to load NER model: {}", e)
125
+ )),
126
+ Err(_) => Err(Error::new(
127
+ magnus::exception::runtime_error(),
128
+ "Thread panicked while loading NER model"
129
+ )),
130
+ }
131
+ }
132
+
133
+ /// Extract entities from text with confidence scores
134
+ pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
135
+ let threshold = confidence_threshold.unwrap_or(0.9) as f32;
136
+
137
+ // Tokenize the text
138
+ let encoding = self.tokenizer.inner().encode(text.as_str(), true)
139
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
140
+
141
+ let token_ids = encoding.get_ids();
142
+ let tokens = encoding.get_tokens();
143
+ let offsets = encoding.get_offsets();
144
+
145
+ // Convert to tensors
146
+ let input_ids = Tensor::new(token_ids, &self.device)
147
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
148
+ .unsqueeze(0)
149
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; // Add batch dimension
150
+
151
+ let attention_mask = Tensor::ones_like(&input_ids)
152
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
153
+ let token_type_ids = Tensor::zeros_like(&input_ids)
154
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
155
+
156
+ // Forward pass through BERT
157
+ let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
158
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
159
+
160
+ // Apply classifier to get logits for each token
161
+ let logits = self.classifier.forward(&output)
162
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
163
+
164
+ // Apply softmax to get probabilities
165
+ let probs = candle_nn::ops::softmax(&logits, 2)
166
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
167
+
168
+ // Get predictions and confidence scores
169
+ let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
170
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
171
+ .to_vec2()
172
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
173
+
174
+ // Extract entities with BIO decoding
175
+ let entities = self.decode_entities(
176
+ &text,
177
+ &tokens.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
178
+ offsets,
179
+ &probs_vec,
180
+ threshold
181
+ )?;
182
+
183
+ // Convert to Ruby array
184
+ let result = RArray::new();
185
+ for entity in entities {
186
+ let hash = RHash::new();
187
+ hash.aset("text", entity.text)?;
188
+ hash.aset("label", entity.label)?;
189
+ hash.aset("start", entity.start)?;
190
+ hash.aset("end", entity.end)?;
191
+ hash.aset("confidence", entity.confidence)?;
192
+ hash.aset("token_start", entity.token_start)?;
193
+ hash.aset("token_end", entity.token_end)?;
194
+ result.push(hash)?;
195
+ }
196
+
197
+ Ok(result)
198
+ }
199
+
200
+ /// Get token-level predictions with labels and confidence scores
201
+ pub fn predict_tokens(&self, text: String) -> Result<RArray> {
202
+ // Tokenize the text
203
+ let encoding = self.tokenizer.inner().encode(text.as_str(), true)
204
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
205
+
206
+ let token_ids = encoding.get_ids();
207
+ let tokens = encoding.get_tokens();
208
+
209
+ // Convert to tensors
210
+ let input_ids = Tensor::new(token_ids, &self.device)
211
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
212
+ .unsqueeze(0)
213
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
214
+
215
+ let attention_mask = Tensor::ones_like(&input_ids)
216
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
217
+ let token_type_ids = Tensor::zeros_like(&input_ids)
218
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
219
+
220
+ // Forward pass
221
+ let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
222
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
223
+ let logits = self.classifier.forward(&output)
224
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
225
+ let probs = candle_nn::ops::softmax(&logits, 2)
226
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
227
+
228
+ // Get predictions
229
+ let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
230
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
231
+ .to_vec2()
232
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
233
+
234
+ // Build result array
235
+ let result = RArray::new();
236
+ for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
237
+ // Find best label
238
+ let (label_id, confidence) = probs.iter()
239
+ .enumerate()
240
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
241
+ .map(|(idx, conf)| (idx as i64, *conf))
242
+ .unwrap_or((0, 0.0));
243
+
244
+ let label = self.config.id2label.get(&label_id)
245
+ .unwrap_or(&"O".to_string())
246
+ .clone();
247
+
248
+ let token_info = RHash::new();
249
+ token_info.aset("token", token.to_string())?;
250
+ token_info.aset("label", label)?;
251
+ token_info.aset("confidence", confidence)?;
252
+ token_info.aset("index", i)?;
253
+
254
+ // Add probability distribution if needed
255
+ let probs_hash = RHash::new();
256
+ for (id, label) in &self.config.id2label {
257
+ if let Some(prob) = probs.get(*id as usize) {
258
+ probs_hash.aset(label.as_str(), *prob)?;
259
+ }
260
+ }
261
+ token_info.aset("probabilities", probs_hash)?;
262
+
263
+ result.push(token_info)?;
264
+ }
265
+
266
+ Ok(result)
267
+ }
268
+
269
+ /// Decode BIO-tagged sequences into entity spans
270
+ fn decode_entities(
271
+ &self,
272
+ text: &str,
273
+ tokens: &[&str],
274
+ offsets: &[(usize, usize)],
275
+ probs: &[Vec<f32>],
276
+ threshold: f32,
277
+ ) -> Result<Vec<EntitySpan>> {
278
+ let mut entities = Vec::new();
279
+ let mut current_entity: Option<(String, usize, usize, Vec<f32>)> = None;
280
+
281
+ for (i, (token, probs_vec)) in tokens.iter().zip(probs).enumerate() {
282
+ // Skip special tokens
283
+ if token.starts_with("[") && token.ends_with("]") {
284
+ continue;
285
+ }
286
+
287
+ // Get predicted label
288
+ let (label_id, confidence) = probs_vec.iter()
289
+ .enumerate()
290
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
291
+ .map(|(idx, conf)| (idx as i64, *conf))
292
+ .unwrap_or((0, 0.0));
293
+
294
+ let label = self.config.id2label.get(&label_id)
295
+ .unwrap_or(&"O".to_string())
296
+ .clone();
297
+
298
+ // BIO decoding logic
299
+ if label == "O" || confidence < threshold {
300
+ // End current entity if exists
301
+ if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
302
+ if let (Some(start_offset), Some(end_offset)) =
303
+ (offsets.get(start_idx), offsets.get(end_idx - 1)) {
304
+ let entity_text = text[start_offset.0..end_offset.1].to_string();
305
+ let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
306
+
307
+ entities.push(EntitySpan {
308
+ text: entity_text,
309
+ label: entity_type,
310
+ start: start_offset.0,
311
+ end: end_offset.1,
312
+ token_start: start_idx,
313
+ token_end: end_idx,
314
+ confidence: avg_confidence,
315
+ });
316
+ }
317
+ }
318
+ } else if label.starts_with("B-") {
319
+ // Begin new entity
320
+ if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
321
+ if let (Some(start_offset), Some(end_offset)) =
322
+ (offsets.get(start_idx), offsets.get(end_idx - 1)) {
323
+ let entity_text = text[start_offset.0..end_offset.1].to_string();
324
+ let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
325
+
326
+ entities.push(EntitySpan {
327
+ text: entity_text,
328
+ label: entity_type,
329
+ start: start_offset.0,
330
+ end: end_offset.1,
331
+ token_start: start_idx,
332
+ token_end: end_idx,
333
+ confidence: avg_confidence,
334
+ });
335
+ }
336
+ }
337
+
338
+ let entity_type = label[2..].to_string();
339
+ current_entity = Some((entity_type, i, i + 1, vec![confidence]));
340
+ } else if label.starts_with("I-") {
341
+ // Continue entity
342
+ if let Some((ref mut entity_type, _, ref mut end_idx, ref mut confidences)) = current_entity {
343
+ let new_type = label[2..].to_string();
344
+ if *entity_type == new_type {
345
+ *end_idx = i + 1;
346
+ confidences.push(confidence);
347
+ } else {
348
+ // Type mismatch, start new entity
349
+ current_entity = Some((new_type, i, i + 1, vec![confidence]));
350
+ }
351
+ } else {
352
+ // I- tag without B- tag, treat as beginning
353
+ let entity_type = label[2..].to_string();
354
+ current_entity = Some((entity_type, i, i + 1, vec![confidence]));
355
+ }
356
+ }
357
+ }
358
+
359
+ // Handle final entity
360
+ if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity {
361
+ if let (Some(start_offset), Some(end_offset)) =
362
+ (offsets.get(start_idx), offsets.get(end_idx - 1)) {
363
+ let entity_text = text[start_offset.0..end_offset.1].to_string();
364
+ let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
365
+
366
+ entities.push(EntitySpan {
367
+ text: entity_text,
368
+ label: entity_type,
369
+ start: start_offset.0,
370
+ end: end_offset.1,
371
+ token_start: start_idx,
372
+ token_end: end_idx,
373
+ confidence: avg_confidence,
374
+ });
375
+ }
376
+ }
377
+
378
+ Ok(entities)
379
+ }
380
+
381
+ /// Get the label configuration
382
+ pub fn labels(&self) -> Result<RHash> {
383
+ let hash = RHash::new();
384
+
385
+ let id2label = RHash::new();
386
+ for (id, label) in &self.config.id2label {
387
+ id2label.aset(*id, label.as_str())?;
388
+ }
389
+
390
+ let label2id = RHash::new();
391
+ for (label, id) in &self.config.label2id {
392
+ label2id.aset(label.as_str(), *id)?;
393
+ }
394
+
395
+ hash.aset("id2label", id2label)?;
396
+ hash.aset("label2id", label2id)?;
397
+ hash.aset("num_labels", self.config.id2label.len())?;
398
+
399
+ Ok(hash)
400
+ }
401
+
402
+ /// Get the tokenizer
403
+ pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
404
+ Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
405
+ }
406
+
407
+ /// Get model info
408
+ pub fn model_info(&self) -> String {
409
+ format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
410
+ }
411
+ }
412
+
413
+ pub fn init(rb_candle: RModule) -> Result<()> {
414
+ let ner_class = rb_candle.define_class("NER", class::object())?;
415
+ ner_class.define_singleton_method("new", function!(NER::new, 3))?;
416
+ ner_class.define_method("extract_entities", method!(NER::extract_entities, 2))?;
417
+ ner_class.define_method("predict_tokens", method!(NER::predict_tokens, 1))?;
418
+ ner_class.define_method("labels", method!(NER::labels, 0))?;
419
+ ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
420
+ ner_class.define_method("model_info", method!(NER::model_info, 0))?;
421
+
422
+ Ok(())
423
+ }
@@ -3,28 +3,29 @@ use candle_transformers::models::bert::{BertModel, Config};
3
3
  use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
4
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
5
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
- use tokenizers::{PaddingParams, Tokenizer, EncodeInput};
6
+ use tokenizers::{EncodeInput, Tokenizer};
7
7
  use std::thread;
8
- use crate::ruby::{Device as RbDevice, Result as RbResult};
8
+ use crate::ruby::{Device, Result};
9
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
9
10
 
10
11
  #[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
11
12
  pub struct Reranker {
12
13
  model: BertModel,
13
- tokenizer: Tokenizer,
14
+ tokenizer: TokenizerWrapper,
14
15
  pooler: Linear,
15
16
  classifier: Linear,
16
17
  device: CoreDevice,
17
18
  }
18
19
 
19
20
  impl Reranker {
20
- pub fn new(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
21
- let device = device.unwrap_or(RbDevice::Cpu).as_device()?;
21
+ pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
22
+ let device = device.unwrap_or(Device::Cpu).as_device()?;
22
23
  Self::new_with_core_device(model_id, device)
23
24
  }
24
25
 
25
- fn new_with_core_device(model_id: String, device: CoreDevice) -> Result<Self, Error> {
26
+ fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
26
27
  let device_clone = device.clone();
27
- let handle = thread::spawn(move || -> Result<(BertModel, Tokenizer, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
28
+ let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
28
29
  let api = Api::new()?;
29
30
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
30
31
 
@@ -38,12 +39,8 @@ impl Reranker {
38
39
  let config: Config = serde_json::from_str(&config)?;
39
40
 
40
41
  // Setup tokenizer with padding
41
- let mut tokenizer = Tokenizer::from_file(tokenizer_filename)?;
42
- let pp = PaddingParams {
43
- strategy: tokenizers::PaddingStrategy::BatchLongest,
44
- ..Default::default()
45
- };
46
- tokenizer.with_padding(Some(pp));
42
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
43
+ let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
47
44
 
48
45
  // Load model weights
49
46
  let vb = unsafe {
@@ -59,7 +56,7 @@ impl Reranker {
59
56
  // Load classifier layer for cross-encoder (single output score)
60
57
  let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
61
58
 
62
- Ok((model, tokenizer, pooler, classifier))
59
+ Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
63
60
  });
64
61
 
65
62
  match handle.join() {
@@ -71,12 +68,12 @@ impl Reranker {
71
68
  }
72
69
  }
73
70
 
74
- pub fn debug_tokenization(&self, query: String, document: String) -> Result<magnus::RHash, Error> {
71
+ pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
75
72
  // Create query-document pair for cross-encoder
76
73
  let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
77
74
 
78
- // Tokenize
79
- let encoding = self.tokenizer.encode(query_doc_pair, true)
75
+ // Tokenize using the inner tokenizer for detailed info
76
+ let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
80
77
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
81
78
 
82
79
  // Get token information
@@ -95,7 +92,7 @@ impl Reranker {
95
92
  Ok(result)
96
93
  }
97
94
 
98
- pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> Result<RArray, Error> {
95
+ pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
99
96
  let documents: Vec<String> = documents.to_vec()?;
100
97
 
101
98
  // Create query-document pairs for cross-encoder
@@ -104,8 +101,8 @@ impl Reranker {
104
101
  .map(|d| (query.clone(), d.clone()).into())
105
102
  .collect();
106
103
 
107
- // Tokenize batch
108
- let encodings = self.tokenizer.encode_batch(query_and_docs, true)
104
+ // Tokenize batch using inner tokenizer for access to token type IDs
105
+ let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
109
106
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
110
107
 
111
108
  // Convert to tensors
@@ -256,12 +253,18 @@ impl Reranker {
256
253
  }
257
254
  Ok(result_array)
258
255
  }
256
+
257
+ /// Get the tokenizer used by this model
258
+ pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
259
+ Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
260
+ }
259
261
  }
260
262
 
261
- pub fn init(rb_candle: RModule) -> Result<(), Error> {
263
+ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
262
264
  let c_reranker = rb_candle.define_class("Reranker", class::object())?;
263
265
  c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
264
266
  c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
265
267
  c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
268
+ c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
266
269
  Ok(())
267
270
  }