red-candle 1.1.2 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +39 -45
- data/Rakefile +79 -88
- data/ext/candle/src/lib.rs +2 -4
- data/ext/candle/src/llm/quantized_gguf.rs +1 -1
- data/ext/candle/src/ruby/device.rs +30 -0
- data/ext/candle/src/ruby/embedding_model.rs +74 -28
- data/ext/candle/src/ruby/llm.rs +96 -1
- data/ext/candle/src/ruby/mod.rs +2 -0
- data/ext/candle/src/{ner.rs → ruby/ner.rs} +47 -15
- data/ext/candle/src/{reranker.rs → ruby/reranker.rs} +24 -2
- data/ext/candle/src/ruby/tensor.rs +101 -26
- data/ext/candle/src/ruby/tokenizer.rs +60 -3
- data/lib/candle/device_utils.rb +3 -15
- data/lib/candle/embedding_model.rb +44 -1
- data/lib/candle/llm.rb +63 -1
- data/lib/candle/ner.rb +34 -22
- data/lib/candle/reranker.rb +20 -1
- data/lib/candle/tensor.rb +15 -0
- data/lib/candle/version.rb +1 -1
- metadata +18 -4
@@ -13,7 +13,7 @@ use candle_transformers::models::{
|
|
13
13
|
jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
|
14
14
|
distilbert::{DistilBertModel, Config as DistilBertConfig}
|
15
15
|
};
|
16
|
-
use magnus::{class, function, method, prelude::*, Error, RModule};
|
16
|
+
use magnus::{class, function, method, prelude::*, Error, RModule, RHash};
|
17
17
|
use std::path::Path;
|
18
18
|
use serde_json;
|
19
19
|
|
@@ -53,7 +53,7 @@ pub enum EmbeddingModelVariant {
|
|
53
53
|
}
|
54
54
|
|
55
55
|
impl EmbeddingModelVariant {
|
56
|
-
pub fn
|
56
|
+
pub fn model_type(&self) -> EmbeddingModelType {
|
57
57
|
match self {
|
58
58
|
EmbeddingModelVariant::JinaBert(_) => EmbeddingModelType::JinaBert,
|
59
59
|
EmbeddingModelVariant::StandardBert(_) => EmbeddingModelType::StandardBert,
|
@@ -66,31 +66,31 @@ impl EmbeddingModelVariant {
|
|
66
66
|
|
67
67
|
pub struct EmbeddingModelInner {
|
68
68
|
device: CoreDevice,
|
69
|
-
|
70
|
-
|
71
|
-
|
69
|
+
tokenizer_id: Option<String>,
|
70
|
+
model_id: Option<String>,
|
71
|
+
model_type: Option<EmbeddingModelType>,
|
72
72
|
model: Option<EmbeddingModelVariant>,
|
73
73
|
tokenizer: Option<TokenizerWrapper>,
|
74
74
|
embedding_size: Option<usize>,
|
75
75
|
}
|
76
76
|
|
77
77
|
impl EmbeddingModel {
|
78
|
-
pub fn new(
|
79
|
-
let device = device.unwrap_or(Device::
|
80
|
-
let
|
78
|
+
pub fn new(model_id: Option<String>, tokenizer: Option<String>, device: Option<Device>, model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
|
79
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
80
|
+
let model_type = model_type
|
81
81
|
.and_then(|mt| EmbeddingModelType::from_string(&mt))
|
82
82
|
.unwrap_or(EmbeddingModelType::JinaBert);
|
83
83
|
Ok(EmbeddingModel(EmbeddingModelInner {
|
84
84
|
device: device.clone(),
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
model: match
|
89
|
-
Some(
|
85
|
+
model_id: model_id.clone(),
|
86
|
+
tokenizer_id: tokenizer.clone(),
|
87
|
+
model_type: Some(model_type),
|
88
|
+
model: match model_id.as_ref() {
|
89
|
+
Some(id) => Some(Self::build_embedding_model(id, device, model_type, embedding_size)?),
|
90
90
|
None => None
|
91
91
|
},
|
92
|
-
tokenizer: match
|
93
|
-
Some(
|
92
|
+
tokenizer: match tokenizer {
|
93
|
+
Some(tid) => Some(Self::build_tokenizer(tid)?),
|
94
94
|
None => None
|
95
95
|
},
|
96
96
|
embedding_size,
|
@@ -170,11 +170,11 @@ impl EmbeddingModel {
|
|
170
170
|
}
|
171
171
|
}
|
172
172
|
|
173
|
-
fn build_embedding_model(
|
173
|
+
fn build_embedding_model(model_id: &str, device: CoreDevice, model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
|
174
174
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
175
175
|
let api = Api::new().map_err(wrap_hf_err)?;
|
176
|
-
let repo = Repo::new(
|
177
|
-
match
|
176
|
+
let repo = Repo::new(model_id.to_string(), RepoType::Model);
|
177
|
+
match model_type {
|
178
178
|
EmbeddingModelType::JinaBert => {
|
179
179
|
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
180
180
|
if !std::path::Path::new(&model_path).exists() {
|
@@ -257,12 +257,12 @@ impl EmbeddingModel {
|
|
257
257
|
}
|
258
258
|
}
|
259
259
|
|
260
|
-
fn build_tokenizer(
|
260
|
+
fn build_tokenizer(tokenizer_id: String) -> Result<TokenizerWrapper> {
|
261
261
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
262
262
|
let tokenizer_path = Api::new()
|
263
263
|
.map_err(wrap_hf_err)?
|
264
264
|
.repo(Repo::new(
|
265
|
-
|
265
|
+
tokenizer_id,
|
266
266
|
RepoType::Model,
|
267
267
|
))
|
268
268
|
.get("tokenizer.json")
|
@@ -365,19 +365,19 @@ impl EmbeddingModel {
|
|
365
365
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
366
366
|
}
|
367
367
|
|
368
|
-
pub fn
|
369
|
-
match self.0.
|
370
|
-
Some(
|
368
|
+
pub fn model_type(&self) -> String {
|
369
|
+
match self.0.model_type {
|
370
|
+
Some(mt) => format!("{:?}", mt),
|
371
371
|
None => "nil".to_string(),
|
372
372
|
}
|
373
373
|
}
|
374
374
|
|
375
375
|
pub fn __repr__(&self) -> String {
|
376
376
|
format!(
|
377
|
-
"#<Candle::EmbeddingModel
|
378
|
-
self.
|
379
|
-
self.0.
|
380
|
-
self.0.
|
377
|
+
"#<Candle::EmbeddingModel model_type: {}, model_id: {}, tokenizer: {}, embedding_size: {}>",
|
378
|
+
self.model_type(),
|
379
|
+
self.0.model_id.as_deref().unwrap_or("nil"),
|
380
|
+
self.0.tokenizer_id.as_deref().unwrap_or("nil"),
|
381
381
|
self.0.embedding_size.map(|x| x.to_string()).unwrap_or("nil".to_string())
|
382
382
|
)
|
383
383
|
}
|
@@ -393,6 +393,49 @@ impl EmbeddingModel {
|
|
393
393
|
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "No tokenizer loaded for this model"))
|
394
394
|
}
|
395
395
|
}
|
396
|
+
|
397
|
+
/// Get the model_id
|
398
|
+
pub fn model_id(&self) -> Result<String> {
|
399
|
+
match &self.0.model_id {
|
400
|
+
Some(id) => Ok(id.clone()),
|
401
|
+
None => Ok("unknown".to_string())
|
402
|
+
}
|
403
|
+
}
|
404
|
+
|
405
|
+
/// Get the device
|
406
|
+
pub fn device(&self) -> Device {
|
407
|
+
Device::from_device(&self.0.device)
|
408
|
+
}
|
409
|
+
|
410
|
+
/// Get all options as a hash
|
411
|
+
pub fn options(&self) -> Result<RHash> {
|
412
|
+
let hash = RHash::new();
|
413
|
+
|
414
|
+
// Add model_id
|
415
|
+
if let Some(model_id) = &self.0.model_id {
|
416
|
+
hash.aset("model_id", model_id.clone())?;
|
417
|
+
}
|
418
|
+
|
419
|
+
// Add tokenizer
|
420
|
+
if let Some(tokenizer_id) = &self.0.tokenizer_id {
|
421
|
+
hash.aset("tokenizer", tokenizer_id.clone())?;
|
422
|
+
}
|
423
|
+
|
424
|
+
// Add device
|
425
|
+
hash.aset("device", self.device().__str__())?;
|
426
|
+
|
427
|
+
// Add model_type
|
428
|
+
if let Some(model_type) = &self.0.model_type {
|
429
|
+
hash.aset("model_type", format!("{:?}", model_type))?;
|
430
|
+
}
|
431
|
+
|
432
|
+
// Add embedding_size
|
433
|
+
if let Some(size) = self.0.embedding_size {
|
434
|
+
hash.aset("embedding_size", size)?;
|
435
|
+
}
|
436
|
+
|
437
|
+
Ok(hash)
|
438
|
+
}
|
396
439
|
}
|
397
440
|
|
398
441
|
pub fn init(rb_candle: RModule) -> Result<()> {
|
@@ -404,9 +447,12 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
404
447
|
rb_embedding_model.define_method("pool_embedding", method!(EmbeddingModel::pool_embedding, 1))?;
|
405
448
|
rb_embedding_model.define_method("pool_and_normalize_embedding", method!(EmbeddingModel::pool_and_normalize_embedding, 1))?;
|
406
449
|
rb_embedding_model.define_method("pool_cls_embedding", method!(EmbeddingModel::pool_cls_embedding, 1))?;
|
407
|
-
rb_embedding_model.define_method("
|
450
|
+
rb_embedding_model.define_method("model_type", method!(EmbeddingModel::model_type, 0))?;
|
408
451
|
rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
|
409
452
|
rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
|
410
453
|
rb_embedding_model.define_method("tokenizer", method!(EmbeddingModel::tokenizer, 0))?;
|
454
|
+
rb_embedding_model.define_method("model_id", method!(EmbeddingModel::model_id, 0))?;
|
455
|
+
rb_embedding_model.define_method("device", method!(EmbeddingModel::device, 0))?;
|
456
|
+
rb_embedding_model.define_method("options", method!(EmbeddingModel::options, 0))?;
|
411
457
|
Ok(())
|
412
458
|
}
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -201,6 +201,37 @@ impl GenerationConfig {
|
|
201
201
|
index: Arc::clone(c),
|
202
202
|
})
|
203
203
|
}
|
204
|
+
|
205
|
+
/// Get all options as a hash
|
206
|
+
pub fn options(&self) -> Result<RHash> {
|
207
|
+
let hash = RHash::new();
|
208
|
+
|
209
|
+
hash.aset("max_length", self.inner.max_length)?;
|
210
|
+
hash.aset("temperature", self.inner.temperature)?;
|
211
|
+
|
212
|
+
if let Some(top_p) = self.inner.top_p {
|
213
|
+
hash.aset("top_p", top_p)?;
|
214
|
+
}
|
215
|
+
|
216
|
+
if let Some(top_k) = self.inner.top_k {
|
217
|
+
hash.aset("top_k", top_k)?;
|
218
|
+
}
|
219
|
+
|
220
|
+
hash.aset("repetition_penalty", self.inner.repetition_penalty)?;
|
221
|
+
hash.aset("repetition_penalty_last_n", self.inner.repetition_penalty_last_n)?;
|
222
|
+
hash.aset("seed", self.inner.seed)?;
|
223
|
+
hash.aset("stop_sequences", self.inner.stop_sequences.clone())?;
|
224
|
+
hash.aset("include_prompt", self.inner.include_prompt)?;
|
225
|
+
hash.aset("debug_tokens", self.inner.debug_tokens)?;
|
226
|
+
hash.aset("stop_on_constraint_satisfaction", self.inner.stop_on_constraint_satisfaction)?;
|
227
|
+
hash.aset("stop_on_match", self.inner.stop_on_match)?;
|
228
|
+
|
229
|
+
if self.inner.constraint.is_some() {
|
230
|
+
hash.aset("has_constraint", true)?;
|
231
|
+
}
|
232
|
+
|
233
|
+
Ok(hash)
|
234
|
+
}
|
204
235
|
}
|
205
236
|
|
206
237
|
#[derive(Clone)]
|
@@ -214,7 +245,7 @@ pub struct LLM {
|
|
214
245
|
impl LLM {
|
215
246
|
/// Create a new LLM from a pretrained model
|
216
247
|
pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
|
217
|
-
let device = device.unwrap_or(Device::
|
248
|
+
let device = device.unwrap_or(Device::best());
|
218
249
|
let candle_device = device.as_device()?;
|
219
250
|
|
220
251
|
// For now, we'll use tokio runtime directly
|
@@ -448,6 +479,67 @@ impl LLM {
|
|
448
479
|
model_ref.apply_chat_template(&json_messages)
|
449
480
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to apply chat template: {}", e)))
|
450
481
|
}
|
482
|
+
|
483
|
+
/// Get the model ID
|
484
|
+
pub fn model_id(&self) -> String {
|
485
|
+
self.model_id.clone()
|
486
|
+
}
|
487
|
+
|
488
|
+
/// Get model options
|
489
|
+
pub fn options(&self) -> Result<RHash> {
|
490
|
+
let hash = RHash::new();
|
491
|
+
|
492
|
+
// Basic metadata
|
493
|
+
hash.aset("model_id", self.model_id.clone())?;
|
494
|
+
let device_str = match self.device {
|
495
|
+
Device::Cpu => "cpu",
|
496
|
+
Device::Cuda => "cuda",
|
497
|
+
Device::Metal => "metal",
|
498
|
+
};
|
499
|
+
hash.aset("device", device_str)?;
|
500
|
+
|
501
|
+
// Parse model_id to extract GGUF file if present
|
502
|
+
if let Some(at_pos) = self.model_id.find('@') {
|
503
|
+
let (base_model, gguf_part) = self.model_id.split_at(at_pos);
|
504
|
+
let gguf_part = &gguf_part[1..]; // Skip the @ character
|
505
|
+
|
506
|
+
// Check for tokenizer (@@)
|
507
|
+
if let Some(tokenizer_pos) = gguf_part.find("@@") {
|
508
|
+
let (gguf_file, tokenizer) = gguf_part.split_at(tokenizer_pos);
|
509
|
+
hash.aset("base_model", base_model)?;
|
510
|
+
hash.aset("gguf_file", gguf_file)?;
|
511
|
+
hash.aset("tokenizer_source", &tokenizer[2..])?;
|
512
|
+
} else {
|
513
|
+
hash.aset("base_model", base_model)?;
|
514
|
+
hash.aset("gguf_file", gguf_part)?;
|
515
|
+
}
|
516
|
+
}
|
517
|
+
|
518
|
+
// Add model type
|
519
|
+
let model = match self.model.lock() {
|
520
|
+
Ok(guard) => guard,
|
521
|
+
Err(poisoned) => poisoned.into_inner(),
|
522
|
+
};
|
523
|
+
let model_ref = model.borrow();
|
524
|
+
|
525
|
+
let model_type = match &*model_ref {
|
526
|
+
ModelType::Mistral(_) => "Mistral",
|
527
|
+
ModelType::Llama(_) => "Llama",
|
528
|
+
ModelType::Gemma(_) => "Gemma",
|
529
|
+
ModelType::Qwen(_) => "Qwen",
|
530
|
+
ModelType::Phi(_) => "Phi",
|
531
|
+
ModelType::QuantizedGGUF(_) => "QuantizedGGUF",
|
532
|
+
};
|
533
|
+
hash.aset("model_type", model_type)?;
|
534
|
+
|
535
|
+
// For GGUF models, add architecture info
|
536
|
+
if let ModelType::QuantizedGGUF(gguf) = &*model_ref {
|
537
|
+
hash.aset("architecture", gguf.architecture.clone())?;
|
538
|
+
hash.aset("eos_token_id", gguf.eos_token_id())?;
|
539
|
+
}
|
540
|
+
|
541
|
+
Ok(hash)
|
542
|
+
}
|
451
543
|
}
|
452
544
|
|
453
545
|
// Define a standalone function for from_pretrained that handles variable arguments
|
@@ -486,6 +578,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
486
578
|
rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
|
487
579
|
rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
|
488
580
|
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
581
|
+
rb_generation_config.define_method("options", method!(GenerationConfig::options, 0))?;
|
489
582
|
|
490
583
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
491
584
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
@@ -497,6 +590,8 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
497
590
|
rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
|
498
591
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
499
592
|
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
593
|
+
rb_llm.define_method("model_id", method!(LLM::model_id, 0))?;
|
594
|
+
rb_llm.define_method("options", method!(LLM::options, 0))?;
|
500
595
|
|
501
596
|
Ok(())
|
502
597
|
}
|
data/ext/candle/src/ruby/mod.rs
CHANGED
@@ -3,7 +3,7 @@ use candle_transformers::models::bert::{BertModel, Config};
|
|
3
3
|
use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
|
4
4
|
use candle_nn::{VarBuilder, Linear};
|
5
5
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
|
-
use std::collections::HashMap;
|
6
|
+
use std::collections::{HashMap, HashSet};
|
7
7
|
use serde::{Deserialize, Serialize};
|
8
8
|
use crate::ruby::{Device, Result};
|
9
9
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
@@ -36,8 +36,8 @@ pub struct NER {
|
|
36
36
|
}
|
37
37
|
|
38
38
|
impl NER {
|
39
|
-
pub fn new(model_id: String, device: Option<Device>,
|
40
|
-
let device = device.unwrap_or(Device::
|
39
|
+
pub fn new(model_id: String, device: Option<Device>, tokenizer: Option<String>) -> Result<Self> {
|
40
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
41
41
|
|
42
42
|
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
43
43
|
let api = Api::new()?;
|
@@ -46,18 +46,18 @@ impl NER {
|
|
46
46
|
// Download model files
|
47
47
|
let config_filename = repo.get("config.json")?;
|
48
48
|
|
49
|
-
// Handle tokenizer loading with optional
|
50
|
-
let
|
49
|
+
// Handle tokenizer loading with optional tokenizer
|
50
|
+
let tokenizer_wrapper = if let Some(tok_id) = tokenizer {
|
51
51
|
// Use the specified tokenizer
|
52
52
|
let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
|
53
53
|
let tokenizer_filename = tok_repo.get("tokenizer.json")?;
|
54
54
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
55
|
-
TokenizerLoader::with_padding(tokenizer, None)
|
55
|
+
TokenizerWrapper::new(TokenizerLoader::with_padding(tokenizer, None))
|
56
56
|
} else {
|
57
57
|
// Try to load tokenizer from model repo
|
58
58
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
59
59
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
60
|
-
TokenizerLoader::with_padding(tokenizer, None)
|
60
|
+
TokenizerWrapper::new(TokenizerLoader::with_padding(tokenizer, None))
|
61
61
|
};
|
62
62
|
let weights_filename = repo.get("pytorch_model.safetensors")
|
63
63
|
.or_else(|_| repo.get("model.safetensors"))?;
|
@@ -101,7 +101,7 @@ impl NER {
|
|
101
101
|
vb.pp("classifier")
|
102
102
|
)?;
|
103
103
|
|
104
|
-
Ok((model,
|
104
|
+
Ok((model, tokenizer_wrapper, classifier, ner_config))
|
105
105
|
})();
|
106
106
|
|
107
107
|
match result {
|
@@ -185,13 +185,13 @@ impl NER {
|
|
185
185
|
let result = RArray::new();
|
186
186
|
for entity in entities {
|
187
187
|
let hash = RHash::new();
|
188
|
-
hash.aset("text", entity.text)?;
|
189
|
-
hash.aset("label", entity.label)?;
|
190
|
-
hash.aset("start", entity.start)?;
|
191
|
-
hash.aset("end", entity.end)?;
|
192
|
-
hash.aset("confidence", entity.confidence)?;
|
193
|
-
hash.aset("token_start", entity.token_start)?;
|
194
|
-
hash.aset("token_end", entity.token_end)?;
|
188
|
+
hash.aset(magnus::Symbol::new("text"), entity.text)?;
|
189
|
+
hash.aset(magnus::Symbol::new("label"), entity.label)?;
|
190
|
+
hash.aset(magnus::Symbol::new("start"), entity.start)?;
|
191
|
+
hash.aset(magnus::Symbol::new("end"), entity.end)?;
|
192
|
+
hash.aset(magnus::Symbol::new("confidence"), entity.confidence)?;
|
193
|
+
hash.aset(magnus::Symbol::new("token_start"), entity.token_start)?;
|
194
|
+
hash.aset(magnus::Symbol::new("token_end"), entity.token_end)?;
|
195
195
|
result.push(hash)?;
|
196
196
|
}
|
197
197
|
|
@@ -382,6 +382,35 @@ impl NER {
|
|
382
382
|
pub fn model_info(&self) -> String {
|
383
383
|
format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
|
384
384
|
}
|
385
|
+
|
386
|
+
/// Get the model_id
|
387
|
+
pub fn model_id(&self) -> String {
|
388
|
+
self.model_id.clone()
|
389
|
+
}
|
390
|
+
|
391
|
+
/// Get the device
|
392
|
+
pub fn device(&self) -> Device {
|
393
|
+
Device::from_device(&self.device)
|
394
|
+
}
|
395
|
+
|
396
|
+
/// Get all options as a hash
|
397
|
+
pub fn options(&self) -> Result<RHash> {
|
398
|
+
let hash = RHash::new();
|
399
|
+
hash.aset("model_id", self.model_id.clone())?;
|
400
|
+
hash.aset("device", self.device().__str__())?;
|
401
|
+
hash.aset("num_labels", self.config.id2label.len())?;
|
402
|
+
|
403
|
+
// Add entity types as a list
|
404
|
+
let entity_types: Vec<String> = self.config.label2id.keys()
|
405
|
+
.filter(|l| *l != "O")
|
406
|
+
.map(|l| l.trim_start_matches("B-").trim_start_matches("I-").to_string())
|
407
|
+
.collect::<HashSet<_>>()
|
408
|
+
.into_iter()
|
409
|
+
.collect();
|
410
|
+
hash.aset("entity_types", entity_types)?;
|
411
|
+
|
412
|
+
Ok(hash)
|
413
|
+
}
|
385
414
|
}
|
386
415
|
|
387
416
|
pub fn init(rb_candle: RModule) -> Result<()> {
|
@@ -392,6 +421,9 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
392
421
|
ner_class.define_method("labels", method!(NER::labels, 0))?;
|
393
422
|
ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
|
394
423
|
ner_class.define_method("model_info", method!(NER::model_info, 0))?;
|
424
|
+
ner_class.define_method("model_id", method!(NER::model_id, 0))?;
|
425
|
+
ner_class.define_method("device", method!(NER::device, 0))?;
|
426
|
+
ner_class.define_method("options", method!(NER::options, 0))?;
|
395
427
|
|
396
428
|
Ok(())
|
397
429
|
}
|
@@ -14,11 +14,12 @@ pub struct Reranker {
|
|
14
14
|
pooler: Linear,
|
15
15
|
classifier: Linear,
|
16
16
|
device: CoreDevice,
|
17
|
+
model_id: String,
|
17
18
|
}
|
18
19
|
|
19
20
|
impl Reranker {
|
20
21
|
pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
|
21
|
-
let device = device.unwrap_or(Device::
|
22
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
22
23
|
Self::new_with_core_device(model_id, device)
|
23
24
|
}
|
24
25
|
|
@@ -59,7 +60,7 @@ impl Reranker {
|
|
59
60
|
|
60
61
|
match result {
|
61
62
|
Ok((model, tokenizer, pooler, classifier)) => {
|
62
|
-
Ok(Self { model, tokenizer, pooler, classifier, device })
|
63
|
+
Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
|
63
64
|
}
|
64
65
|
Err(e) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
|
65
66
|
}
|
@@ -231,6 +232,24 @@ impl Reranker {
|
|
231
232
|
pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
|
232
233
|
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
233
234
|
}
|
235
|
+
|
236
|
+
/// Get the model_id
|
237
|
+
pub fn model_id(&self) -> String {
|
238
|
+
self.model_id.clone()
|
239
|
+
}
|
240
|
+
|
241
|
+
/// Get the device
|
242
|
+
pub fn device(&self) -> Device {
|
243
|
+
Device::from_device(&self.device)
|
244
|
+
}
|
245
|
+
|
246
|
+
/// Get all options as a hash
|
247
|
+
pub fn options(&self) -> std::result::Result<magnus::RHash, Error> {
|
248
|
+
let hash = magnus::RHash::new();
|
249
|
+
hash.aset("model_id", self.model_id.clone())?;
|
250
|
+
hash.aset("device", self.device().__str__())?;
|
251
|
+
Ok(hash)
|
252
|
+
}
|
234
253
|
}
|
235
254
|
|
236
255
|
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
@@ -239,5 +258,8 @@ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
|
239
258
|
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
240
259
|
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
241
260
|
c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
|
261
|
+
c_reranker.define_method("model_id", method!(Reranker::model_id, 0))?;
|
262
|
+
c_reranker.define_method("device", method!(Reranker::device, 0))?;
|
263
|
+
c_reranker.define_method("options", method!(Reranker::options, 0))?;
|
242
264
|
Ok(())
|
243
265
|
}
|