red-candle 1.8.0.pre2-x86_64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5193 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +33 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
pub mod embedding_model;
|
|
2
|
+
pub mod tensor;
|
|
3
|
+
pub mod device;
|
|
4
|
+
pub mod dtype;
|
|
5
|
+
pub mod result;
|
|
6
|
+
pub mod errors;
|
|
7
|
+
pub mod utils;
|
|
8
|
+
pub mod llm;
|
|
9
|
+
pub mod tokenizer;
|
|
10
|
+
pub mod structured;
|
|
11
|
+
pub mod reranker;
|
|
12
|
+
pub mod ner;
|
|
13
|
+
pub mod vlm;
|
|
14
|
+
|
|
15
|
+
pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
|
|
16
|
+
pub use tensor::Tensor;
|
|
17
|
+
pub use device::Device;
|
|
18
|
+
pub use dtype::DType;
|
|
19
|
+
pub use result::Result;
|
|
20
|
+
|
|
21
|
+
// Re-export for convenience
|
|
22
|
+
pub use embedding_model::init as init_embedding_model;
|
|
23
|
+
pub use utils::candle_utils;
|
|
24
|
+
pub use llm::init_llm;
|
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
|
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config};
|
|
3
|
+
use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
|
|
4
|
+
use candle_nn::{VarBuilder, Linear};
|
|
5
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
6
|
+
use std::collections::{HashMap, HashSet};
|
|
7
|
+
use serde::{Deserialize, Serialize};
|
|
8
|
+
use crate::ruby::{Device, Result};
|
|
9
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
10
|
+
|
|
11
|
+
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
12
|
+
pub struct NERConfig {
|
|
13
|
+
pub id2label: HashMap<i64, String>,
|
|
14
|
+
pub label2id: HashMap<String, i64>,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
#[derive(Debug, Clone)]
|
|
18
|
+
pub struct EntitySpan {
|
|
19
|
+
pub text: String,
|
|
20
|
+
pub label: String,
|
|
21
|
+
pub start: usize,
|
|
22
|
+
pub end: usize,
|
|
23
|
+
pub token_start: usize,
|
|
24
|
+
pub token_end: usize,
|
|
25
|
+
pub confidence: f32,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
#[magnus::wrap(class = "Candle::NER", free_immediately, size)]
|
|
29
|
+
pub struct NER {
|
|
30
|
+
model: BertModel,
|
|
31
|
+
tokenizer: TokenizerWrapper,
|
|
32
|
+
classifier: Linear,
|
|
33
|
+
config: NERConfig,
|
|
34
|
+
device: CoreDevice,
|
|
35
|
+
model_id: String,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
impl NER {
|
|
39
|
+
pub fn new(model_id: String, device: Option<Device>, tokenizer: Option<String>) -> Result<Self> {
|
|
40
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
|
41
|
+
|
|
42
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
|
43
|
+
let api = Api::new()?;
|
|
44
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
|
45
|
+
|
|
46
|
+
// Download model files
|
|
47
|
+
let config_filename = repo.get("config.json")?;
|
|
48
|
+
|
|
49
|
+
// Handle tokenizer loading with optional tokenizer
|
|
50
|
+
let tokenizer_wrapper = if let Some(tok_id) = tokenizer {
|
|
51
|
+
// Use the specified tokenizer
|
|
52
|
+
let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
|
|
53
|
+
let tokenizer_filename = tok_repo.get("tokenizer.json")?;
|
|
54
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
|
55
|
+
TokenizerWrapper::new(TokenizerLoader::with_padding(tokenizer, None))
|
|
56
|
+
} else {
|
|
57
|
+
// Try to load tokenizer from model repo
|
|
58
|
+
let tokenizer_filename = repo.get("tokenizer.json")?;
|
|
59
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
|
60
|
+
TokenizerWrapper::new(TokenizerLoader::with_padding(tokenizer, None))
|
|
61
|
+
};
|
|
62
|
+
let weights_filename = repo.get("pytorch_model.safetensors")
|
|
63
|
+
.or_else(|_| repo.get("model.safetensors"))?;
|
|
64
|
+
|
|
65
|
+
// Load BERT config
|
|
66
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
|
67
|
+
let config_json: serde_json::Value = serde_json::from_str(&config_str)?;
|
|
68
|
+
let bert_config: Config = serde_json::from_value(config_json.clone())?;
|
|
69
|
+
|
|
70
|
+
// Extract NER label configuration
|
|
71
|
+
let id2label = config_json["id2label"]
|
|
72
|
+
.as_object()
|
|
73
|
+
.ok_or("Missing id2label in config")?
|
|
74
|
+
.iter()
|
|
75
|
+
.map(|(k, v)| {
|
|
76
|
+
let id = k.parse::<i64>().unwrap_or(0);
|
|
77
|
+
let label = v.as_str().unwrap_or("O").to_string();
|
|
78
|
+
(id, label)
|
|
79
|
+
})
|
|
80
|
+
.collect::<HashMap<_, _>>();
|
|
81
|
+
|
|
82
|
+
let label2id = id2label.iter()
|
|
83
|
+
.map(|(id, label)| (label.clone(), *id))
|
|
84
|
+
.collect::<HashMap<_, _>>();
|
|
85
|
+
|
|
86
|
+
let num_labels = id2label.len();
|
|
87
|
+
let ner_config = NERConfig { id2label, label2id };
|
|
88
|
+
|
|
89
|
+
// Load model weights
|
|
90
|
+
let vb = unsafe {
|
|
91
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
// Load BERT model
|
|
95
|
+
let model = BertModel::load(vb.pp("bert"), &bert_config)?;
|
|
96
|
+
|
|
97
|
+
// Load classification head for token classification
|
|
98
|
+
let classifier = candle_nn::linear(
|
|
99
|
+
bert_config.hidden_size,
|
|
100
|
+
num_labels,
|
|
101
|
+
vb.pp("classifier")
|
|
102
|
+
)?;
|
|
103
|
+
|
|
104
|
+
Ok((model, tokenizer_wrapper, classifier, ner_config))
|
|
105
|
+
})();
|
|
106
|
+
|
|
107
|
+
match result {
|
|
108
|
+
Ok((model, tokenizer, classifier, config)) => {
|
|
109
|
+
Ok(Self {
|
|
110
|
+
model,
|
|
111
|
+
tokenizer,
|
|
112
|
+
classifier,
|
|
113
|
+
config,
|
|
114
|
+
device,
|
|
115
|
+
model_id,
|
|
116
|
+
})
|
|
117
|
+
}
|
|
118
|
+
Err(e) => {
|
|
119
|
+
let ruby = Ruby::get().unwrap();
|
|
120
|
+
Err(Error::new(
|
|
121
|
+
ruby.exception_runtime_error(),
|
|
122
|
+
format!("Failed to load NER model: {}", e)
|
|
123
|
+
))
|
|
124
|
+
},
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
/// Common tokenization and prediction logic
|
|
129
|
+
fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
|
|
130
|
+
let ruby = Ruby::get().unwrap();
|
|
131
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
132
|
+
|
|
133
|
+
// Tokenize the text
|
|
134
|
+
let encoding = self.tokenizer.inner().encode(text, true)
|
|
135
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
136
|
+
|
|
137
|
+
let token_ids = encoding.get_ids();
|
|
138
|
+
|
|
139
|
+
// Convert to tensors
|
|
140
|
+
let input_ids = Tensor::new(token_ids, &self.device)
|
|
141
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?
|
|
142
|
+
.unsqueeze(0)
|
|
143
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?; // Add batch dimension
|
|
144
|
+
|
|
145
|
+
let attention_mask = Tensor::ones_like(&input_ids)
|
|
146
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
147
|
+
let token_type_ids = Tensor::zeros_like(&input_ids)
|
|
148
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
149
|
+
|
|
150
|
+
// Forward pass through BERT
|
|
151
|
+
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
|
152
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
153
|
+
|
|
154
|
+
// Apply classifier to get logits for each token
|
|
155
|
+
let logits = self.classifier.forward(&output)
|
|
156
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
157
|
+
|
|
158
|
+
// Apply softmax to get probabilities
|
|
159
|
+
let probs = candle_nn::ops::softmax(&logits, 2)
|
|
160
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
161
|
+
|
|
162
|
+
// Get predictions and confidence scores
|
|
163
|
+
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
|
164
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?
|
|
165
|
+
.to_vec2()
|
|
166
|
+
.map_err(|e| Error::new(runtime_error, e.to_string()))?;
|
|
167
|
+
|
|
168
|
+
Ok((encoding, probs_vec))
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
/// Extract entities from text with confidence scores
|
|
172
|
+
pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
|
|
173
|
+
let ruby = Ruby::get().unwrap();
|
|
174
|
+
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
|
175
|
+
|
|
176
|
+
// Release GVL during tokenization + model forward pass
|
|
177
|
+
let (encoding, probs_vec) = crate::gvl::without_gvl(|| {
|
|
178
|
+
self.tokenize_and_predict(&text)
|
|
179
|
+
})?;
|
|
180
|
+
|
|
181
|
+
let tokens = encoding.get_tokens();
|
|
182
|
+
let offsets = encoding.get_offsets();
|
|
183
|
+
|
|
184
|
+
// Extract entities with BIO decoding
|
|
185
|
+
let entities = self.decode_entities(
|
|
186
|
+
&text,
|
|
187
|
+
&tokens.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
|
|
188
|
+
offsets,
|
|
189
|
+
&probs_vec,
|
|
190
|
+
threshold
|
|
191
|
+
)?;
|
|
192
|
+
|
|
193
|
+
// Convert to Ruby array
|
|
194
|
+
let result = ruby.ary_new();
|
|
195
|
+
for entity in entities {
|
|
196
|
+
let hash = ruby.hash_new();
|
|
197
|
+
hash.aset(ruby.to_symbol("text"), entity.text)?;
|
|
198
|
+
hash.aset(ruby.to_symbol("label"), entity.label)?;
|
|
199
|
+
hash.aset(ruby.to_symbol("start"), entity.start)?;
|
|
200
|
+
hash.aset(ruby.to_symbol("end"), entity.end)?;
|
|
201
|
+
hash.aset(ruby.to_symbol("confidence"), entity.confidence)?;
|
|
202
|
+
hash.aset(ruby.to_symbol("token_start"), entity.token_start)?;
|
|
203
|
+
hash.aset(ruby.to_symbol("token_end"), entity.token_end)?;
|
|
204
|
+
result.push(hash)?;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
Ok(result)
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
/// Get token-level predictions with labels and confidence scores
|
|
211
|
+
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
|
212
|
+
let ruby = Ruby::get().unwrap();
|
|
213
|
+
// Release GVL during tokenization + model forward pass
|
|
214
|
+
let (encoding, probs_vec) = crate::gvl::without_gvl(|| {
|
|
215
|
+
self.tokenize_and_predict(&text)
|
|
216
|
+
})?;
|
|
217
|
+
|
|
218
|
+
let tokens = encoding.get_tokens();
|
|
219
|
+
|
|
220
|
+
// Build result array
|
|
221
|
+
let result = ruby.ary_new();
|
|
222
|
+
for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
|
|
223
|
+
// Find best label
|
|
224
|
+
let (label_id, confidence) = probs.iter()
|
|
225
|
+
.enumerate()
|
|
226
|
+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
|
227
|
+
.map(|(idx, conf)| (idx as i64, *conf))
|
|
228
|
+
.unwrap_or((0, 0.0));
|
|
229
|
+
|
|
230
|
+
let label = self.config.id2label.get(&label_id)
|
|
231
|
+
.unwrap_or(&"O".to_string())
|
|
232
|
+
.clone();
|
|
233
|
+
|
|
234
|
+
let token_info = ruby.hash_new();
|
|
235
|
+
token_info.aset("token", token.to_string())?;
|
|
236
|
+
token_info.aset("label", label)?;
|
|
237
|
+
token_info.aset("confidence", confidence)?;
|
|
238
|
+
token_info.aset("index", i)?;
|
|
239
|
+
|
|
240
|
+
// Add probability distribution if needed
|
|
241
|
+
let probs_hash = ruby.hash_new();
|
|
242
|
+
for (id, label) in &self.config.id2label {
|
|
243
|
+
if let Some(prob) = probs.get(*id as usize) {
|
|
244
|
+
probs_hash.aset(label.as_str(), *prob)?;
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
token_info.aset("probabilities", probs_hash)?;
|
|
248
|
+
|
|
249
|
+
result.push(token_info)?;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
Ok(result)
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
/// Decode BIO-tagged sequences into entity spans
|
|
256
|
+
fn decode_entities(
|
|
257
|
+
&self,
|
|
258
|
+
text: &str,
|
|
259
|
+
tokens: &[&str],
|
|
260
|
+
offsets: &[(usize, usize)],
|
|
261
|
+
probs: &[Vec<f32>],
|
|
262
|
+
threshold: f32,
|
|
263
|
+
) -> Result<Vec<EntitySpan>> {
|
|
264
|
+
let mut entities = Vec::new();
|
|
265
|
+
let mut current_entity: Option<(String, usize, usize, Vec<f32>)> = None;
|
|
266
|
+
|
|
267
|
+
for (i, (token, probs_vec)) in tokens.iter().zip(probs).enumerate() {
|
|
268
|
+
// Skip special tokens
|
|
269
|
+
if token.starts_with("[") && token.ends_with("]") {
|
|
270
|
+
continue;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
// Get predicted label
|
|
274
|
+
let (label_id, confidence) = probs_vec.iter()
|
|
275
|
+
.enumerate()
|
|
276
|
+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
|
277
|
+
.map(|(idx, conf)| (idx as i64, *conf))
|
|
278
|
+
.unwrap_or((0, 0.0));
|
|
279
|
+
|
|
280
|
+
let label = self.config.id2label.get(&label_id)
|
|
281
|
+
.unwrap_or(&"O".to_string())
|
|
282
|
+
.clone();
|
|
283
|
+
|
|
284
|
+
// BIO decoding logic
|
|
285
|
+
if label == "O" || confidence < threshold {
|
|
286
|
+
// End current entity if exists
|
|
287
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
|
|
288
|
+
if let (Some(start_offset), Some(end_offset)) =
|
|
289
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
|
290
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
|
291
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
|
292
|
+
|
|
293
|
+
entities.push(EntitySpan {
|
|
294
|
+
text: entity_text,
|
|
295
|
+
label: entity_type,
|
|
296
|
+
start: start_offset.0,
|
|
297
|
+
end: end_offset.1,
|
|
298
|
+
token_start: start_idx,
|
|
299
|
+
token_end: end_idx,
|
|
300
|
+
confidence: avg_confidence,
|
|
301
|
+
});
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
} else if label.starts_with("B-") {
|
|
305
|
+
// Begin new entity
|
|
306
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
|
|
307
|
+
if let (Some(start_offset), Some(end_offset)) =
|
|
308
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
|
309
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
|
310
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
|
311
|
+
|
|
312
|
+
entities.push(EntitySpan {
|
|
313
|
+
text: entity_text,
|
|
314
|
+
label: entity_type,
|
|
315
|
+
start: start_offset.0,
|
|
316
|
+
end: end_offset.1,
|
|
317
|
+
token_start: start_idx,
|
|
318
|
+
token_end: end_idx,
|
|
319
|
+
confidence: avg_confidence,
|
|
320
|
+
});
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
let entity_type = label[2..].to_string();
|
|
325
|
+
current_entity = Some((entity_type, i, i + 1, vec![confidence]));
|
|
326
|
+
} else if label.starts_with("I-") {
|
|
327
|
+
// Continue entity
|
|
328
|
+
if let Some((ref mut entity_type, _, ref mut end_idx, ref mut confidences)) = current_entity {
|
|
329
|
+
let new_type = label[2..].to_string();
|
|
330
|
+
if *entity_type == new_type {
|
|
331
|
+
*end_idx = i + 1;
|
|
332
|
+
confidences.push(confidence);
|
|
333
|
+
} else {
|
|
334
|
+
// Type mismatch, start new entity
|
|
335
|
+
current_entity = Some((new_type, i, i + 1, vec![confidence]));
|
|
336
|
+
}
|
|
337
|
+
} else {
|
|
338
|
+
// I- tag without B- tag, treat as beginning
|
|
339
|
+
let entity_type = label[2..].to_string();
|
|
340
|
+
current_entity = Some((entity_type, i, i + 1, vec![confidence]));
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
// Handle final entity
|
|
346
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity {
|
|
347
|
+
if let (Some(start_offset), Some(end_offset)) =
|
|
348
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
|
349
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
|
350
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
|
351
|
+
|
|
352
|
+
entities.push(EntitySpan {
|
|
353
|
+
text: entity_text,
|
|
354
|
+
label: entity_type,
|
|
355
|
+
start: start_offset.0,
|
|
356
|
+
end: end_offset.1,
|
|
357
|
+
token_start: start_idx,
|
|
358
|
+
token_end: end_idx,
|
|
359
|
+
confidence: avg_confidence,
|
|
360
|
+
});
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
Ok(entities)
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
/// Get the label configuration
|
|
368
|
+
pub fn labels(&self) -> Result<RHash> {
|
|
369
|
+
let ruby = Ruby::get().unwrap();
|
|
370
|
+
let hash = ruby.hash_new();
|
|
371
|
+
|
|
372
|
+
let id2label = ruby.hash_new();
|
|
373
|
+
for (id, label) in &self.config.id2label {
|
|
374
|
+
id2label.aset(*id, label.as_str())?;
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
let label2id = ruby.hash_new();
|
|
378
|
+
for (label, id) in &self.config.label2id {
|
|
379
|
+
label2id.aset(label.as_str(), *id)?;
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
hash.aset("id2label", id2label)?;
|
|
383
|
+
hash.aset("label2id", label2id)?;
|
|
384
|
+
hash.aset("num_labels", self.config.id2label.len())?;
|
|
385
|
+
|
|
386
|
+
Ok(hash)
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
/// Get the tokenizer
|
|
390
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
|
391
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
/// Get model info
|
|
395
|
+
pub fn model_info(&self) -> String {
|
|
396
|
+
format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
/// Get the model_id
|
|
400
|
+
pub fn model_id(&self) -> String {
|
|
401
|
+
self.model_id.clone()
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
/// Get the device
|
|
405
|
+
pub fn device(&self) -> Device {
|
|
406
|
+
Device::from_device(&self.device)
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
/// Get all options as a hash
|
|
410
|
+
pub fn options(&self) -> Result<RHash> {
|
|
411
|
+
let ruby = Ruby::get().unwrap();
|
|
412
|
+
let hash = ruby.hash_new();
|
|
413
|
+
hash.aset("model_id", self.model_id.clone())?;
|
|
414
|
+
hash.aset("device", self.device().__str__())?;
|
|
415
|
+
hash.aset("num_labels", self.config.id2label.len())?;
|
|
416
|
+
|
|
417
|
+
// Add entity types as a list
|
|
418
|
+
let entity_types: Vec<String> = self.config.label2id.keys()
|
|
419
|
+
.filter(|l| *l != "O")
|
|
420
|
+
.map(|l| l.trim_start_matches("B-").trim_start_matches("I-").to_string())
|
|
421
|
+
.collect::<HashSet<_>>()
|
|
422
|
+
.into_iter()
|
|
423
|
+
.collect();
|
|
424
|
+
hash.aset("entity_types", entity_types)?;
|
|
425
|
+
|
|
426
|
+
Ok(hash)
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
431
|
+
let ruby = Ruby::get().unwrap();
|
|
432
|
+
let ner_class = rb_candle.define_class("NER", ruby.class_object())?;
|
|
433
|
+
ner_class.define_singleton_method("new", function!(NER::new, 3))?;
|
|
434
|
+
ner_class.define_method("extract_entities", method!(NER::extract_entities, 2))?;
|
|
435
|
+
ner_class.define_method("predict_tokens", method!(NER::predict_tokens, 1))?;
|
|
436
|
+
ner_class.define_method("labels", method!(NER::labels, 0))?;
|
|
437
|
+
ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
|
|
438
|
+
ner_class.define_method("model_info", method!(NER::model_info, 0))?;
|
|
439
|
+
ner_class.define_method("model_id", method!(NER::model_id, 0))?;
|
|
440
|
+
ner_class.define_method("device", method!(NER::device, 0))?;
|
|
441
|
+
ner_class.define_method("options", method!(NER::options, 0))?;
|
|
442
|
+
|
|
443
|
+
Ok(())
|
|
444
|
+
}
|