red-candle 1.3.1 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- use magnus::{class, function, method, prelude::*, Error, RModule, RArray, RHash};
1
+ use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
2
2
  use candle_transformers::models::bert::{BertModel, Config};
3
3
  use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
4
4
  use candle_nn::{VarBuilder, Linear};
@@ -38,14 +38,14 @@ pub struct NER {
38
38
  impl NER {
39
39
  pub fn new(model_id: String, device: Option<Device>, tokenizer: Option<String>) -> Result<Self> {
40
40
  let device = device.unwrap_or(Device::best()).as_device()?;
41
-
41
+
42
42
  let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
43
43
  let api = Api::new()?;
44
44
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
45
-
45
+
46
46
  // Download model files
47
47
  let config_filename = repo.get("config.json")?;
48
-
48
+
49
49
  // Handle tokenizer loading with optional tokenizer
50
50
  let tokenizer_wrapper = if let Some(tok_id) = tokenizer {
51
51
  // Use the specified tokenizer
@@ -61,12 +61,12 @@ impl NER {
61
61
  };
62
62
  let weights_filename = repo.get("pytorch_model.safetensors")
63
63
  .or_else(|_| repo.get("model.safetensors"))?;
64
-
64
+
65
65
  // Load BERT config
66
66
  let config_str = std::fs::read_to_string(&config_filename)?;
67
67
  let config_json: serde_json::Value = serde_json::from_str(&config_str)?;
68
68
  let bert_config: Config = serde_json::from_value(config_json.clone())?;
69
-
69
+
70
70
  // Extract NER label configuration
71
71
  let id2label = config_json["id2label"]
72
72
  .as_object()
@@ -78,32 +78,32 @@ impl NER {
78
78
  (id, label)
79
79
  })
80
80
  .collect::<HashMap<_, _>>();
81
-
81
+
82
82
  let label2id = id2label.iter()
83
83
  .map(|(id, label)| (label.clone(), *id))
84
84
  .collect::<HashMap<_, _>>();
85
-
85
+
86
86
  let num_labels = id2label.len();
87
87
  let ner_config = NERConfig { id2label, label2id };
88
-
88
+
89
89
  // Load model weights
90
90
  let vb = unsafe {
91
91
  VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
92
92
  };
93
-
93
+
94
94
  // Load BERT model
95
95
  let model = BertModel::load(vb.pp("bert"), &bert_config)?;
96
-
96
+
97
97
  // Load classification head for token classification
98
98
  let classifier = candle_nn::linear(
99
99
  bert_config.hidden_size,
100
100
  num_labels,
101
101
  vb.pp("classifier")
102
102
  )?;
103
-
103
+
104
104
  Ok((model, tokenizer_wrapper, classifier, ner_config))
105
105
  })();
106
-
106
+
107
107
  match result {
108
108
  Ok((model, tokenizer, classifier, config)) => {
109
109
  Ok(Self {
@@ -115,63 +115,70 @@ impl NER {
115
115
  model_id,
116
116
  })
117
117
  }
118
- Err(e) => Err(Error::new(
119
- magnus::exception::runtime_error(),
120
- format!("Failed to load NER model: {}", e)
121
- )),
118
+ Err(e) => {
119
+ let ruby = Ruby::get().unwrap();
120
+ Err(Error::new(
121
+ ruby.exception_runtime_error(),
122
+ format!("Failed to load NER model: {}", e)
123
+ ))
124
+ },
122
125
  }
123
126
  }
124
-
127
+
125
128
  /// Common tokenization and prediction logic
126
129
  fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
130
+ let ruby = Ruby::get().unwrap();
131
+ let runtime_error = ruby.exception_runtime_error();
132
+
127
133
  // Tokenize the text
128
134
  let encoding = self.tokenizer.inner().encode(text, true)
129
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
130
-
135
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
136
+
131
137
  let token_ids = encoding.get_ids();
132
-
138
+
133
139
  // Convert to tensors
134
140
  let input_ids = Tensor::new(token_ids, &self.device)
135
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
141
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?
136
142
  .unsqueeze(0)
137
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; // Add batch dimension
138
-
143
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?; // Add batch dimension
144
+
139
145
  let attention_mask = Tensor::ones_like(&input_ids)
140
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
146
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?;
141
147
  let token_type_ids = Tensor::zeros_like(&input_ids)
142
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
143
-
148
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?;
149
+
144
150
  // Forward pass through BERT
145
151
  let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
146
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
147
-
152
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?;
153
+
148
154
  // Apply classifier to get logits for each token
149
155
  let logits = self.classifier.forward(&output)
150
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
151
-
156
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?;
157
+
152
158
  // Apply softmax to get probabilities
153
159
  let probs = candle_nn::ops::softmax(&logits, 2)
154
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
155
-
160
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?;
161
+
156
162
  // Get predictions and confidence scores
157
163
  let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
158
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
164
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?
159
165
  .to_vec2()
160
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
161
-
166
+ .map_err(|e| Error::new(runtime_error, e.to_string()))?;
167
+
162
168
  Ok((encoding, probs_vec))
163
169
  }
164
-
170
+
165
171
  /// Extract entities from text with confidence scores
166
172
  pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
173
+ let ruby = Ruby::get().unwrap();
167
174
  let threshold = confidence_threshold.unwrap_or(0.9) as f32;
168
-
175
+
169
176
  // Use common tokenization and prediction logic
170
177
  let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
171
-
178
+
172
179
  let tokens = encoding.get_tokens();
173
180
  let offsets = encoding.get_offsets();
174
-
181
+
175
182
  // Extract entities with BIO decoding
176
183
  let entities = self.decode_entities(
177
184
  &text,
@@ -180,33 +187,34 @@ impl NER {
180
187
  &probs_vec,
181
188
  threshold
182
189
  )?;
183
-
190
+
184
191
  // Convert to Ruby array
185
- let result = RArray::new();
192
+ let result = ruby.ary_new();
186
193
  for entity in entities {
187
- let hash = RHash::new();
188
- hash.aset(magnus::Symbol::new("text"), entity.text)?;
189
- hash.aset(magnus::Symbol::new("label"), entity.label)?;
190
- hash.aset(magnus::Symbol::new("start"), entity.start)?;
191
- hash.aset(magnus::Symbol::new("end"), entity.end)?;
192
- hash.aset(magnus::Symbol::new("confidence"), entity.confidence)?;
193
- hash.aset(magnus::Symbol::new("token_start"), entity.token_start)?;
194
- hash.aset(magnus::Symbol::new("token_end"), entity.token_end)?;
194
+ let hash = ruby.hash_new();
195
+ hash.aset(ruby.to_symbol("text"), entity.text)?;
196
+ hash.aset(ruby.to_symbol("label"), entity.label)?;
197
+ hash.aset(ruby.to_symbol("start"), entity.start)?;
198
+ hash.aset(ruby.to_symbol("end"), entity.end)?;
199
+ hash.aset(ruby.to_symbol("confidence"), entity.confidence)?;
200
+ hash.aset(ruby.to_symbol("token_start"), entity.token_start)?;
201
+ hash.aset(ruby.to_symbol("token_end"), entity.token_end)?;
195
202
  result.push(hash)?;
196
203
  }
197
-
204
+
198
205
  Ok(result)
199
206
  }
200
-
207
+
201
208
  /// Get token-level predictions with labels and confidence scores
202
209
  pub fn predict_tokens(&self, text: String) -> Result<RArray> {
210
+ let ruby = Ruby::get().unwrap();
203
211
  // Use common tokenization and prediction logic
204
212
  let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
205
-
213
+
206
214
  let tokens = encoding.get_tokens();
207
-
215
+
208
216
  // Build result array
209
- let result = RArray::new();
217
+ let result = ruby.ary_new();
210
218
  for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
211
219
  // Find best label
212
220
  let (label_id, confidence) = probs.iter()
@@ -214,32 +222,32 @@ impl NER {
214
222
  .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
215
223
  .map(|(idx, conf)| (idx as i64, *conf))
216
224
  .unwrap_or((0, 0.0));
217
-
225
+
218
226
  let label = self.config.id2label.get(&label_id)
219
227
  .unwrap_or(&"O".to_string())
220
228
  .clone();
221
-
222
- let token_info = RHash::new();
229
+
230
+ let token_info = ruby.hash_new();
223
231
  token_info.aset("token", token.to_string())?;
224
232
  token_info.aset("label", label)?;
225
233
  token_info.aset("confidence", confidence)?;
226
234
  token_info.aset("index", i)?;
227
-
235
+
228
236
  // Add probability distribution if needed
229
- let probs_hash = RHash::new();
237
+ let probs_hash = ruby.hash_new();
230
238
  for (id, label) in &self.config.id2label {
231
239
  if let Some(prob) = probs.get(*id as usize) {
232
240
  probs_hash.aset(label.as_str(), *prob)?;
233
241
  }
234
242
  }
235
243
  token_info.aset("probabilities", probs_hash)?;
236
-
244
+
237
245
  result.push(token_info)?;
238
246
  }
239
-
247
+
240
248
  Ok(result)
241
249
  }
242
-
250
+
243
251
  /// Decode BIO-tagged sequences into entity spans
244
252
  fn decode_entities(
245
253
  &self,
@@ -251,33 +259,33 @@ impl NER {
251
259
  ) -> Result<Vec<EntitySpan>> {
252
260
  let mut entities = Vec::new();
253
261
  let mut current_entity: Option<(String, usize, usize, Vec<f32>)> = None;
254
-
262
+
255
263
  for (i, (token, probs_vec)) in tokens.iter().zip(probs).enumerate() {
256
264
  // Skip special tokens
257
265
  if token.starts_with("[") && token.ends_with("]") {
258
266
  continue;
259
267
  }
260
-
268
+
261
269
  // Get predicted label
262
270
  let (label_id, confidence) = probs_vec.iter()
263
271
  .enumerate()
264
272
  .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
265
273
  .map(|(idx, conf)| (idx as i64, *conf))
266
274
  .unwrap_or((0, 0.0));
267
-
275
+
268
276
  let label = self.config.id2label.get(&label_id)
269
277
  .unwrap_or(&"O".to_string())
270
278
  .clone();
271
-
279
+
272
280
  // BIO decoding logic
273
281
  if label == "O" || confidence < threshold {
274
282
  // End current entity if exists
275
283
  if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
276
- if let (Some(start_offset), Some(end_offset)) =
284
+ if let (Some(start_offset), Some(end_offset)) =
277
285
  (offsets.get(start_idx), offsets.get(end_idx - 1)) {
278
286
  let entity_text = text[start_offset.0..end_offset.1].to_string();
279
287
  let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
280
-
288
+
281
289
  entities.push(EntitySpan {
282
290
  text: entity_text,
283
291
  label: entity_type,
@@ -292,11 +300,11 @@ impl NER {
292
300
  } else if label.starts_with("B-") {
293
301
  // Begin new entity
294
302
  if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
295
- if let (Some(start_offset), Some(end_offset)) =
303
+ if let (Some(start_offset), Some(end_offset)) =
296
304
  (offsets.get(start_idx), offsets.get(end_idx - 1)) {
297
305
  let entity_text = text[start_offset.0..end_offset.1].to_string();
298
306
  let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
299
-
307
+
300
308
  entities.push(EntitySpan {
301
309
  text: entity_text,
302
310
  label: entity_type,
@@ -308,7 +316,7 @@ impl NER {
308
316
  });
309
317
  }
310
318
  }
311
-
319
+
312
320
  let entity_type = label[2..].to_string();
313
321
  current_entity = Some((entity_type, i, i + 1, vec![confidence]));
314
322
  } else if label.starts_with("I-") {
@@ -329,14 +337,14 @@ impl NER {
329
337
  }
330
338
  }
331
339
  }
332
-
340
+
333
341
  // Handle final entity
334
342
  if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity {
335
- if let (Some(start_offset), Some(end_offset)) =
343
+ if let (Some(start_offset), Some(end_offset)) =
336
344
  (offsets.get(start_idx), offsets.get(end_idx - 1)) {
337
345
  let entity_text = text[start_offset.0..end_offset.1].to_string();
338
346
  let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
339
-
347
+
340
348
  entities.push(EntitySpan {
341
349
  text: entity_text,
342
350
  label: entity_type,
@@ -348,58 +356,60 @@ impl NER {
348
356
  });
349
357
  }
350
358
  }
351
-
359
+
352
360
  Ok(entities)
353
361
  }
354
-
362
+
355
363
  /// Get the label configuration
356
364
  pub fn labels(&self) -> Result<RHash> {
357
- let hash = RHash::new();
358
-
359
- let id2label = RHash::new();
365
+ let ruby = Ruby::get().unwrap();
366
+ let hash = ruby.hash_new();
367
+
368
+ let id2label = ruby.hash_new();
360
369
  for (id, label) in &self.config.id2label {
361
370
  id2label.aset(*id, label.as_str())?;
362
371
  }
363
-
364
- let label2id = RHash::new();
372
+
373
+ let label2id = ruby.hash_new();
365
374
  for (label, id) in &self.config.label2id {
366
375
  label2id.aset(label.as_str(), *id)?;
367
376
  }
368
-
377
+
369
378
  hash.aset("id2label", id2label)?;
370
379
  hash.aset("label2id", label2id)?;
371
380
  hash.aset("num_labels", self.config.id2label.len())?;
372
-
381
+
373
382
  Ok(hash)
374
383
  }
375
-
384
+
376
385
  /// Get the tokenizer
377
386
  pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
378
387
  Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
379
388
  }
380
-
389
+
381
390
  /// Get model info
382
391
  pub fn model_info(&self) -> String {
383
392
  format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
384
393
  }
385
-
394
+
386
395
  /// Get the model_id
387
396
  pub fn model_id(&self) -> String {
388
397
  self.model_id.clone()
389
398
  }
390
-
399
+
391
400
  /// Get the device
392
401
  pub fn device(&self) -> Device {
393
402
  Device::from_device(&self.device)
394
403
  }
395
-
404
+
396
405
  /// Get all options as a hash
397
406
  pub fn options(&self) -> Result<RHash> {
398
- let hash = RHash::new();
407
+ let ruby = Ruby::get().unwrap();
408
+ let hash = ruby.hash_new();
399
409
  hash.aset("model_id", self.model_id.clone())?;
400
410
  hash.aset("device", self.device().__str__())?;
401
411
  hash.aset("num_labels", self.config.id2label.len())?;
402
-
412
+
403
413
  // Add entity types as a list
404
414
  let entity_types: Vec<String> = self.config.label2id.keys()
405
415
  .filter(|l| *l != "O")
@@ -408,13 +418,14 @@ impl NER {
408
418
  .into_iter()
409
419
  .collect();
410
420
  hash.aset("entity_types", entity_types)?;
411
-
421
+
412
422
  Ok(hash)
413
423
  }
414
424
  }
415
425
 
416
426
  pub fn init(rb_candle: RModule) -> Result<()> {
417
- let ner_class = rb_candle.define_class("NER", class::object())?;
427
+ let ruby = Ruby::get().unwrap();
428
+ let ner_class = rb_candle.define_class("NER", ruby.class_object())?;
418
429
  ner_class.define_singleton_method("new", function!(NER::new, 3))?;
419
430
  ner_class.define_method("extract_entities", method!(NER::extract_entities, 2))?;
420
431
  ner_class.define_method("predict_tokens", method!(NER::predict_tokens, 1))?;
@@ -424,6 +435,6 @@ pub fn init(rb_candle: RModule) -> Result<()> {
424
435
  ner_class.define_method("model_id", method!(NER::model_id, 0))?;
425
436
  ner_class.define_method("device", method!(NER::device, 0))?;
426
437
  ner_class.define_method("options", method!(NER::options, 0))?;
427
-
438
+
428
439
  Ok(())
429
- }
440
+ }