red-candle 1.1.2 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,7 @@ use candle_transformers::models::{
13
13
  jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
14
14
  distilbert::{DistilBertModel, Config as DistilBertConfig}
15
15
  };
16
- use magnus::{class, function, method, prelude::*, Error, RModule};
16
+ use magnus::{class, function, method, prelude::*, Error, RModule, RHash};
17
17
  use std::path::Path;
18
18
  use serde_json;
19
19
 
@@ -53,7 +53,7 @@ pub enum EmbeddingModelVariant {
53
53
  }
54
54
 
55
55
  impl EmbeddingModelVariant {
56
- pub fn embedding_model_type(&self) -> EmbeddingModelType {
56
+ pub fn model_type(&self) -> EmbeddingModelType {
57
57
  match self {
58
58
  EmbeddingModelVariant::JinaBert(_) => EmbeddingModelType::JinaBert,
59
59
  EmbeddingModelVariant::StandardBert(_) => EmbeddingModelType::StandardBert,
@@ -66,31 +66,31 @@ impl EmbeddingModelVariant {
66
66
 
67
67
  pub struct EmbeddingModelInner {
68
68
  device: CoreDevice,
69
- tokenizer_path: Option<String>,
70
- model_path: Option<String>,
71
- embedding_model_type: Option<EmbeddingModelType>,
69
+ tokenizer_id: Option<String>,
70
+ model_id: Option<String>,
71
+ model_type: Option<EmbeddingModelType>,
72
72
  model: Option<EmbeddingModelVariant>,
73
73
  tokenizer: Option<TokenizerWrapper>,
74
74
  embedding_size: Option<usize>,
75
75
  }
76
76
 
77
77
  impl EmbeddingModel {
78
- pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
79
- let device = device.unwrap_or(Device::Cpu).as_device()?;
80
- let embedding_model_type = embedding_model_type
78
+ pub fn new(model_id: Option<String>, tokenizer: Option<String>, device: Option<Device>, model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
79
+ let device = device.unwrap_or(Device::best()).as_device()?;
80
+ let model_type = model_type
81
81
  .and_then(|mt| EmbeddingModelType::from_string(&mt))
82
82
  .unwrap_or(EmbeddingModelType::JinaBert);
83
83
  Ok(EmbeddingModel(EmbeddingModelInner {
84
84
  device: device.clone(),
85
- model_path: model_path.clone(),
86
- tokenizer_path: tokenizer_path.clone(),
87
- embedding_model_type: Some(embedding_model_type),
88
- model: match model_path {
89
- Some(mp) => Some(Self::build_embedding_model(Path::new(&mp), device, embedding_model_type, embedding_size)?),
85
+ model_id: model_id.clone(),
86
+ tokenizer_id: tokenizer.clone(),
87
+ model_type: Some(model_type),
88
+ model: match model_id.as_ref() {
89
+ Some(id) => Some(Self::build_embedding_model(id, device, model_type, embedding_size)?),
90
90
  None => None
91
91
  },
92
- tokenizer: match tokenizer_path {
93
- Some(tp) => Some(Self::build_tokenizer(tp)?),
92
+ tokenizer: match tokenizer {
93
+ Some(tid) => Some(Self::build_tokenizer(tid)?),
94
94
  None => None
95
95
  },
96
96
  embedding_size,
@@ -170,11 +170,11 @@ impl EmbeddingModel {
170
170
  }
171
171
  }
172
172
 
173
- fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
173
+ fn build_embedding_model(model_id: &str, device: CoreDevice, model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
174
174
  use hf_hub::{api::sync::Api, Repo, RepoType};
175
175
  let api = Api::new().map_err(wrap_hf_err)?;
176
- let repo = Repo::new(model_path.to_str().unwrap().to_string(), RepoType::Model);
177
- match embedding_model_type {
176
+ let repo = Repo::new(model_id.to_string(), RepoType::Model);
177
+ match model_type {
178
178
  EmbeddingModelType::JinaBert => {
179
179
  let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
180
180
  if !std::path::Path::new(&model_path).exists() {
@@ -257,12 +257,12 @@ impl EmbeddingModel {
257
257
  }
258
258
  }
259
259
 
260
- fn build_tokenizer(tokenizer_path: String) -> Result<TokenizerWrapper> {
260
+ fn build_tokenizer(tokenizer_id: String) -> Result<TokenizerWrapper> {
261
261
  use hf_hub::{api::sync::Api, Repo, RepoType};
262
262
  let tokenizer_path = Api::new()
263
263
  .map_err(wrap_hf_err)?
264
264
  .repo(Repo::new(
265
- tokenizer_path,
265
+ tokenizer_id,
266
266
  RepoType::Model,
267
267
  ))
268
268
  .get("tokenizer.json")
@@ -365,19 +365,19 @@ impl EmbeddingModel {
365
365
  v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
366
366
  }
367
367
 
368
- pub fn embedding_model_type(&self) -> String {
369
- match self.0.embedding_model_type {
370
- Some(model_type) => format!("{:?}", model_type),
368
+ pub fn model_type(&self) -> String {
369
+ match self.0.model_type {
370
+ Some(mt) => format!("{:?}", mt),
371
371
  None => "nil".to_string(),
372
372
  }
373
373
  }
374
374
 
375
375
  pub fn __repr__(&self) -> String {
376
376
  format!(
377
- "#<Candle::EmbeddingModel embedding_model_type: {}, model_path: {}, tokenizer_path: {}, embedding_size: {}>",
378
- self.embedding_model_type(),
379
- self.0.model_path.as_deref().unwrap_or("nil"),
380
- self.0.tokenizer_path.as_deref().unwrap_or("nil"),
377
+ "#<Candle::EmbeddingModel model_type: {}, model_id: {}, tokenizer: {}, embedding_size: {}>",
378
+ self.model_type(),
379
+ self.0.model_id.as_deref().unwrap_or("nil"),
380
+ self.0.tokenizer_id.as_deref().unwrap_or("nil"),
381
381
  self.0.embedding_size.map(|x| x.to_string()).unwrap_or("nil".to_string())
382
382
  )
383
383
  }
@@ -393,6 +393,49 @@ impl EmbeddingModel {
393
393
  None => Err(magnus::Error::new(magnus::exception::runtime_error(), "No tokenizer loaded for this model"))
394
394
  }
395
395
  }
396
+
397
+ /// Get the model_id
398
+ pub fn model_id(&self) -> Result<String> {
399
+ match &self.0.model_id {
400
+ Some(id) => Ok(id.clone()),
401
+ None => Ok("unknown".to_string())
402
+ }
403
+ }
404
+
405
+ /// Get the device
406
+ pub fn device(&self) -> Device {
407
+ Device::from_device(&self.0.device)
408
+ }
409
+
410
+ /// Get all options as a hash
411
+ pub fn options(&self) -> Result<RHash> {
412
+ let hash = RHash::new();
413
+
414
+ // Add model_id
415
+ if let Some(model_id) = &self.0.model_id {
416
+ hash.aset("model_id", model_id.clone())?;
417
+ }
418
+
419
+ // Add tokenizer
420
+ if let Some(tokenizer_id) = &self.0.tokenizer_id {
421
+ hash.aset("tokenizer", tokenizer_id.clone())?;
422
+ }
423
+
424
+ // Add device
425
+ hash.aset("device", self.device().__str__())?;
426
+
427
+ // Add model_type
428
+ if let Some(model_type) = &self.0.model_type {
429
+ hash.aset("model_type", format!("{:?}", model_type))?;
430
+ }
431
+
432
+ // Add embedding_size
433
+ if let Some(size) = self.0.embedding_size {
434
+ hash.aset("embedding_size", size)?;
435
+ }
436
+
437
+ Ok(hash)
438
+ }
396
439
  }
397
440
 
398
441
  pub fn init(rb_candle: RModule) -> Result<()> {
@@ -404,9 +447,12 @@ pub fn init(rb_candle: RModule) -> Result<()> {
404
447
  rb_embedding_model.define_method("pool_embedding", method!(EmbeddingModel::pool_embedding, 1))?;
405
448
  rb_embedding_model.define_method("pool_and_normalize_embedding", method!(EmbeddingModel::pool_and_normalize_embedding, 1))?;
406
449
  rb_embedding_model.define_method("pool_cls_embedding", method!(EmbeddingModel::pool_cls_embedding, 1))?;
407
- rb_embedding_model.define_method("embedding_model_type", method!(EmbeddingModel::embedding_model_type, 0))?;
450
+ rb_embedding_model.define_method("model_type", method!(EmbeddingModel::model_type, 0))?;
408
451
  rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
409
452
  rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
410
453
  rb_embedding_model.define_method("tokenizer", method!(EmbeddingModel::tokenizer, 0))?;
454
+ rb_embedding_model.define_method("model_id", method!(EmbeddingModel::model_id, 0))?;
455
+ rb_embedding_model.define_method("device", method!(EmbeddingModel::device, 0))?;
456
+ rb_embedding_model.define_method("options", method!(EmbeddingModel::options, 0))?;
411
457
  Ok(())
412
458
  }
@@ -201,6 +201,37 @@ impl GenerationConfig {
201
201
  index: Arc::clone(c),
202
202
  })
203
203
  }
204
+
205
+ /// Get all options as a hash
206
+ pub fn options(&self) -> Result<RHash> {
207
+ let hash = RHash::new();
208
+
209
+ hash.aset("max_length", self.inner.max_length)?;
210
+ hash.aset("temperature", self.inner.temperature)?;
211
+
212
+ if let Some(top_p) = self.inner.top_p {
213
+ hash.aset("top_p", top_p)?;
214
+ }
215
+
216
+ if let Some(top_k) = self.inner.top_k {
217
+ hash.aset("top_k", top_k)?;
218
+ }
219
+
220
+ hash.aset("repetition_penalty", self.inner.repetition_penalty)?;
221
+ hash.aset("repetition_penalty_last_n", self.inner.repetition_penalty_last_n)?;
222
+ hash.aset("seed", self.inner.seed)?;
223
+ hash.aset("stop_sequences", self.inner.stop_sequences.clone())?;
224
+ hash.aset("include_prompt", self.inner.include_prompt)?;
225
+ hash.aset("debug_tokens", self.inner.debug_tokens)?;
226
+ hash.aset("stop_on_constraint_satisfaction", self.inner.stop_on_constraint_satisfaction)?;
227
+ hash.aset("stop_on_match", self.inner.stop_on_match)?;
228
+
229
+ if self.inner.constraint.is_some() {
230
+ hash.aset("has_constraint", true)?;
231
+ }
232
+
233
+ Ok(hash)
234
+ }
204
235
  }
205
236
 
206
237
  #[derive(Clone)]
@@ -214,7 +245,7 @@ pub struct LLM {
214
245
  impl LLM {
215
246
  /// Create a new LLM from a pretrained model
216
247
  pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
217
- let device = device.unwrap_or(Device::Cpu);
248
+ let device = device.unwrap_or(Device::best());
218
249
  let candle_device = device.as_device()?;
219
250
 
220
251
  // For now, we'll use tokio runtime directly
@@ -448,6 +479,67 @@ impl LLM {
448
479
  model_ref.apply_chat_template(&json_messages)
449
480
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to apply chat template: {}", e)))
450
481
  }
482
+
483
+ /// Get the model ID
484
+ pub fn model_id(&self) -> String {
485
+ self.model_id.clone()
486
+ }
487
+
488
+ /// Get model options
489
+ pub fn options(&self) -> Result<RHash> {
490
+ let hash = RHash::new();
491
+
492
+ // Basic metadata
493
+ hash.aset("model_id", self.model_id.clone())?;
494
+ let device_str = match self.device {
495
+ Device::Cpu => "cpu",
496
+ Device::Cuda => "cuda",
497
+ Device::Metal => "metal",
498
+ };
499
+ hash.aset("device", device_str)?;
500
+
501
+ // Parse model_id to extract GGUF file if present
502
+ if let Some(at_pos) = self.model_id.find('@') {
503
+ let (base_model, gguf_part) = self.model_id.split_at(at_pos);
504
+ let gguf_part = &gguf_part[1..]; // Skip the @ character
505
+
506
+ // Check for tokenizer (@@)
507
+ if let Some(tokenizer_pos) = gguf_part.find("@@") {
508
+ let (gguf_file, tokenizer) = gguf_part.split_at(tokenizer_pos);
509
+ hash.aset("base_model", base_model)?;
510
+ hash.aset("gguf_file", gguf_file)?;
511
+ hash.aset("tokenizer_source", &tokenizer[2..])?;
512
+ } else {
513
+ hash.aset("base_model", base_model)?;
514
+ hash.aset("gguf_file", gguf_part)?;
515
+ }
516
+ }
517
+
518
+ // Add model type
519
+ let model = match self.model.lock() {
520
+ Ok(guard) => guard,
521
+ Err(poisoned) => poisoned.into_inner(),
522
+ };
523
+ let model_ref = model.borrow();
524
+
525
+ let model_type = match &*model_ref {
526
+ ModelType::Mistral(_) => "Mistral",
527
+ ModelType::Llama(_) => "Llama",
528
+ ModelType::Gemma(_) => "Gemma",
529
+ ModelType::Qwen(_) => "Qwen",
530
+ ModelType::Phi(_) => "Phi",
531
+ ModelType::QuantizedGGUF(_) => "QuantizedGGUF",
532
+ };
533
+ hash.aset("model_type", model_type)?;
534
+
535
+ // For GGUF models, add architecture info
536
+ if let ModelType::QuantizedGGUF(gguf) = &*model_ref {
537
+ hash.aset("architecture", gguf.architecture.clone())?;
538
+ hash.aset("eos_token_id", gguf.eos_token_id())?;
539
+ }
540
+
541
+ Ok(hash)
542
+ }
451
543
  }
452
544
 
453
545
  // Define a standalone function for from_pretrained that handles variable arguments
@@ -486,6 +578,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
486
578
  rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
487
579
  rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
488
580
  rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
581
+ rb_generation_config.define_method("options", method!(GenerationConfig::options, 0))?;
489
582
 
490
583
  let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
491
584
  rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
@@ -497,6 +590,8 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
497
590
  rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
498
591
  rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
499
592
  rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
593
+ rb_llm.define_method("model_id", method!(LLM::model_id, 0))?;
594
+ rb_llm.define_method("options", method!(LLM::options, 0))?;
500
595
 
501
596
  Ok(())
502
597
  }
@@ -8,6 +8,8 @@ pub mod utils;
8
8
  pub mod llm;
9
9
  pub mod tokenizer;
10
10
  pub mod structured;
11
+ pub mod reranker;
12
+ pub mod ner;
11
13
 
12
14
  pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
13
15
  pub use tensor::Tensor;
@@ -3,7 +3,7 @@ use candle_transformers::models::bert::{BertModel, Config};
3
3
  use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
4
4
  use candle_nn::{VarBuilder, Linear};
5
5
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
- use std::collections::HashMap;
6
+ use std::collections::{HashMap, HashSet};
7
7
  use serde::{Deserialize, Serialize};
8
8
  use crate::ruby::{Device, Result};
9
9
  use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
@@ -36,8 +36,8 @@ pub struct NER {
36
36
  }
37
37
 
38
38
  impl NER {
39
- pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
40
- let device = device.unwrap_or(Device::Cpu).as_device()?;
39
+ pub fn new(model_id: String, device: Option<Device>, tokenizer: Option<String>) -> Result<Self> {
40
+ let device = device.unwrap_or(Device::best()).as_device()?;
41
41
 
42
42
  let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
43
43
  let api = Api::new()?;
@@ -46,18 +46,18 @@ impl NER {
46
46
  // Download model files
47
47
  let config_filename = repo.get("config.json")?;
48
48
 
49
- // Handle tokenizer loading with optional tokenizer_id
50
- let tokenizer = if let Some(tok_id) = tokenizer_id {
49
+ // Handle tokenizer loading with optional tokenizer
50
+ let tokenizer_wrapper = if let Some(tok_id) = tokenizer {
51
51
  // Use the specified tokenizer
52
52
  let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
53
53
  let tokenizer_filename = tok_repo.get("tokenizer.json")?;
54
54
  let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
55
- TokenizerLoader::with_padding(tokenizer, None)
55
+ TokenizerWrapper::new(TokenizerLoader::with_padding(tokenizer, None))
56
56
  } else {
57
57
  // Try to load tokenizer from model repo
58
58
  let tokenizer_filename = repo.get("tokenizer.json")?;
59
59
  let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
60
- TokenizerLoader::with_padding(tokenizer, None)
60
+ TokenizerWrapper::new(TokenizerLoader::with_padding(tokenizer, None))
61
61
  };
62
62
  let weights_filename = repo.get("pytorch_model.safetensors")
63
63
  .or_else(|_| repo.get("model.safetensors"))?;
@@ -101,7 +101,7 @@ impl NER {
101
101
  vb.pp("classifier")
102
102
  )?;
103
103
 
104
- Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
104
+ Ok((model, tokenizer_wrapper, classifier, ner_config))
105
105
  })();
106
106
 
107
107
  match result {
@@ -185,13 +185,13 @@ impl NER {
185
185
  let result = RArray::new();
186
186
  for entity in entities {
187
187
  let hash = RHash::new();
188
- hash.aset("text", entity.text)?;
189
- hash.aset("label", entity.label)?;
190
- hash.aset("start", entity.start)?;
191
- hash.aset("end", entity.end)?;
192
- hash.aset("confidence", entity.confidence)?;
193
- hash.aset("token_start", entity.token_start)?;
194
- hash.aset("token_end", entity.token_end)?;
188
+ hash.aset(magnus::Symbol::new("text"), entity.text)?;
189
+ hash.aset(magnus::Symbol::new("label"), entity.label)?;
190
+ hash.aset(magnus::Symbol::new("start"), entity.start)?;
191
+ hash.aset(magnus::Symbol::new("end"), entity.end)?;
192
+ hash.aset(magnus::Symbol::new("confidence"), entity.confidence)?;
193
+ hash.aset(magnus::Symbol::new("token_start"), entity.token_start)?;
194
+ hash.aset(magnus::Symbol::new("token_end"), entity.token_end)?;
195
195
  result.push(hash)?;
196
196
  }
197
197
 
@@ -382,6 +382,35 @@ impl NER {
382
382
  pub fn model_info(&self) -> String {
383
383
  format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
384
384
  }
385
+
386
+ /// Get the model_id
387
+ pub fn model_id(&self) -> String {
388
+ self.model_id.clone()
389
+ }
390
+
391
+ /// Get the device
392
+ pub fn device(&self) -> Device {
393
+ Device::from_device(&self.device)
394
+ }
395
+
396
+ /// Get all options as a hash
397
+ pub fn options(&self) -> Result<RHash> {
398
+ let hash = RHash::new();
399
+ hash.aset("model_id", self.model_id.clone())?;
400
+ hash.aset("device", self.device().__str__())?;
401
+ hash.aset("num_labels", self.config.id2label.len())?;
402
+
403
+ // Add entity types as a list
404
+ let entity_types: Vec<String> = self.config.label2id.keys()
405
+ .filter(|l| *l != "O")
406
+ .map(|l| l.trim_start_matches("B-").trim_start_matches("I-").to_string())
407
+ .collect::<HashSet<_>>()
408
+ .into_iter()
409
+ .collect();
410
+ hash.aset("entity_types", entity_types)?;
411
+
412
+ Ok(hash)
413
+ }
385
414
  }
386
415
 
387
416
  pub fn init(rb_candle: RModule) -> Result<()> {
@@ -392,6 +421,9 @@ pub fn init(rb_candle: RModule) -> Result<()> {
392
421
  ner_class.define_method("labels", method!(NER::labels, 0))?;
393
422
  ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
394
423
  ner_class.define_method("model_info", method!(NER::model_info, 0))?;
424
+ ner_class.define_method("model_id", method!(NER::model_id, 0))?;
425
+ ner_class.define_method("device", method!(NER::device, 0))?;
426
+ ner_class.define_method("options", method!(NER::options, 0))?;
395
427
 
396
428
  Ok(())
397
429
  }
@@ -14,11 +14,12 @@ pub struct Reranker {
14
14
  pooler: Linear,
15
15
  classifier: Linear,
16
16
  device: CoreDevice,
17
+ model_id: String,
17
18
  }
18
19
 
19
20
  impl Reranker {
20
21
  pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
21
- let device = device.unwrap_or(Device::Cpu).as_device()?;
22
+ let device = device.unwrap_or(Device::best()).as_device()?;
22
23
  Self::new_with_core_device(model_id, device)
23
24
  }
24
25
 
@@ -59,7 +60,7 @@ impl Reranker {
59
60
 
60
61
  match result {
61
62
  Ok((model, tokenizer, pooler, classifier)) => {
62
- Ok(Self { model, tokenizer, pooler, classifier, device })
63
+ Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
63
64
  }
64
65
  Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
65
66
  }
@@ -231,6 +232,24 @@ impl Reranker {
231
232
  pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
232
233
  Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
233
234
  }
235
+
236
+ /// Get the model_id
237
+ pub fn model_id(&self) -> String {
238
+ self.model_id.clone()
239
+ }
240
+
241
+ /// Get the device
242
+ pub fn device(&self) -> Device {
243
+ Device::from_device(&self.device)
244
+ }
245
+
246
+ /// Get all options as a hash
247
+ pub fn options(&self) -> std::result::Result<magnus::RHash, Error> {
248
+ let hash = magnus::RHash::new();
249
+ hash.aset("model_id", self.model_id.clone())?;
250
+ hash.aset("device", self.device().__str__())?;
251
+ Ok(hash)
252
+ }
234
253
  }
235
254
 
236
255
  pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
@@ -239,5 +258,8 @@ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
239
258
  c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
240
259
  c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
241
260
  c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
261
+ c_reranker.define_method("model_id", method!(Reranker::model_id, 0))?;
262
+ c_reranker.define_method("device", method!(Reranker::device, 0))?;
263
+ c_reranker.define_method("options", method!(Reranker::options, 0))?;
242
264
  Ok(())
243
265
  }