red-candle 1.0.0.pre.6 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +481 -4
- data/Rakefile +1 -3
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +21 -79
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +21 -79
- data/ext/candle/src/llm/mistral.rs +21 -89
- data/ext/candle/src/llm/mod.rs +3 -33
- data/ext/candle/src/llm/quantized_gguf.rs +501 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -34
- data/ext/candle/src/ruby/llm.rs +110 -49
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +127 -3
- data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -0,0 +1,423 @@
|
|
1
|
+
use magnus::{class, function, method, prelude::*, Error, RModule, RArray, RHash};
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config};
|
3
|
+
use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
|
4
|
+
use candle_nn::{VarBuilder, Linear};
|
5
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
|
+
use std::collections::HashMap;
|
7
|
+
use serde::{Deserialize, Serialize};
|
8
|
+
use crate::ruby::{Device, Result};
|
9
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
10
|
+
|
11
|
+
#[derive(Debug, Clone, Serialize, Deserialize)]
|
12
|
+
pub struct NERConfig {
|
13
|
+
pub id2label: HashMap<i64, String>,
|
14
|
+
pub label2id: HashMap<String, i64>,
|
15
|
+
}
|
16
|
+
|
17
|
+
#[derive(Debug, Clone)]
|
18
|
+
pub struct EntitySpan {
|
19
|
+
pub text: String,
|
20
|
+
pub label: String,
|
21
|
+
pub start: usize,
|
22
|
+
pub end: usize,
|
23
|
+
pub token_start: usize,
|
24
|
+
pub token_end: usize,
|
25
|
+
pub confidence: f32,
|
26
|
+
}
|
27
|
+
|
28
|
+
#[magnus::wrap(class = "Candle::NER", free_immediately, size)]
|
29
|
+
pub struct NER {
|
30
|
+
model: BertModel,
|
31
|
+
tokenizer: TokenizerWrapper,
|
32
|
+
classifier: Linear,
|
33
|
+
config: NERConfig,
|
34
|
+
device: CoreDevice,
|
35
|
+
model_id: String,
|
36
|
+
}
|
37
|
+
|
38
|
+
impl NER {
|
39
|
+
pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
|
40
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
41
|
+
|
42
|
+
// Load model in a separate thread to avoid blocking
|
43
|
+
let device_clone = device.clone();
|
44
|
+
let model_id_clone = model_id.clone();
|
45
|
+
|
46
|
+
let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
47
|
+
let api = Api::new()?;
|
48
|
+
let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
|
49
|
+
|
50
|
+
// Download model files
|
51
|
+
let config_filename = repo.get("config.json")?;
|
52
|
+
|
53
|
+
// Handle tokenizer loading with optional tokenizer_id
|
54
|
+
let tokenizer = if let Some(tok_id) = tokenizer_id {
|
55
|
+
// Use the specified tokenizer
|
56
|
+
let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
|
57
|
+
let tokenizer_filename = tok_repo.get("tokenizer.json")?;
|
58
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
59
|
+
TokenizerLoader::with_padding(tokenizer, None)
|
60
|
+
} else {
|
61
|
+
// Try to load tokenizer from model repo
|
62
|
+
let tokenizer_filename = repo.get("tokenizer.json")?;
|
63
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
64
|
+
TokenizerLoader::with_padding(tokenizer, None)
|
65
|
+
};
|
66
|
+
let weights_filename = repo.get("pytorch_model.safetensors")
|
67
|
+
.or_else(|_| repo.get("model.safetensors"))?;
|
68
|
+
|
69
|
+
// Load BERT config
|
70
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
71
|
+
let config_json: serde_json::Value = serde_json::from_str(&config_str)?;
|
72
|
+
let bert_config: Config = serde_json::from_value(config_json.clone())?;
|
73
|
+
|
74
|
+
// Extract NER label configuration
|
75
|
+
let id2label = config_json["id2label"]
|
76
|
+
.as_object()
|
77
|
+
.ok_or("Missing id2label in config")?
|
78
|
+
.iter()
|
79
|
+
.map(|(k, v)| {
|
80
|
+
let id = k.parse::<i64>().unwrap_or(0);
|
81
|
+
let label = v.as_str().unwrap_or("O").to_string();
|
82
|
+
(id, label)
|
83
|
+
})
|
84
|
+
.collect::<HashMap<_, _>>();
|
85
|
+
|
86
|
+
let label2id = id2label.iter()
|
87
|
+
.map(|(id, label)| (label.clone(), *id))
|
88
|
+
.collect::<HashMap<_, _>>();
|
89
|
+
|
90
|
+
let num_labels = id2label.len();
|
91
|
+
let ner_config = NERConfig { id2label, label2id };
|
92
|
+
|
93
|
+
// Load model weights
|
94
|
+
let vb = unsafe {
|
95
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
|
96
|
+
};
|
97
|
+
|
98
|
+
// Load BERT model
|
99
|
+
let model = BertModel::load(vb.pp("bert"), &bert_config)?;
|
100
|
+
|
101
|
+
// Load classification head for token classification
|
102
|
+
let classifier = candle_nn::linear(
|
103
|
+
bert_config.hidden_size,
|
104
|
+
num_labels,
|
105
|
+
vb.pp("classifier")
|
106
|
+
)?;
|
107
|
+
|
108
|
+
Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
|
109
|
+
});
|
110
|
+
|
111
|
+
match handle.join() {
|
112
|
+
Ok(Ok((model, tokenizer, classifier, config))) => {
|
113
|
+
Ok(Self {
|
114
|
+
model,
|
115
|
+
tokenizer,
|
116
|
+
classifier,
|
117
|
+
config,
|
118
|
+
device,
|
119
|
+
model_id,
|
120
|
+
})
|
121
|
+
}
|
122
|
+
Ok(Err(e)) => Err(Error::new(
|
123
|
+
magnus::exception::runtime_error(),
|
124
|
+
format!("Failed to load NER model: {}", e)
|
125
|
+
)),
|
126
|
+
Err(_) => Err(Error::new(
|
127
|
+
magnus::exception::runtime_error(),
|
128
|
+
"Thread panicked while loading NER model"
|
129
|
+
)),
|
130
|
+
}
|
131
|
+
}
|
132
|
+
|
133
|
+
/// Extract entities from text with confidence scores
|
134
|
+
pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
|
135
|
+
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
136
|
+
|
137
|
+
// Tokenize the text
|
138
|
+
let encoding = self.tokenizer.inner().encode(text.as_str(), true)
|
139
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
140
|
+
|
141
|
+
let token_ids = encoding.get_ids();
|
142
|
+
let tokens = encoding.get_tokens();
|
143
|
+
let offsets = encoding.get_offsets();
|
144
|
+
|
145
|
+
// Convert to tensors
|
146
|
+
let input_ids = Tensor::new(token_ids, &self.device)
|
147
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
148
|
+
.unsqueeze(0)
|
149
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; // Add batch dimension
|
150
|
+
|
151
|
+
let attention_mask = Tensor::ones_like(&input_ids)
|
152
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
153
|
+
let token_type_ids = Tensor::zeros_like(&input_ids)
|
154
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
155
|
+
|
156
|
+
// Forward pass through BERT
|
157
|
+
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
158
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
159
|
+
|
160
|
+
// Apply classifier to get logits for each token
|
161
|
+
let logits = self.classifier.forward(&output)
|
162
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
163
|
+
|
164
|
+
// Apply softmax to get probabilities
|
165
|
+
let probs = candle_nn::ops::softmax(&logits, 2)
|
166
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
167
|
+
|
168
|
+
// Get predictions and confidence scores
|
169
|
+
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
170
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
171
|
+
.to_vec2()
|
172
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
173
|
+
|
174
|
+
// Extract entities with BIO decoding
|
175
|
+
let entities = self.decode_entities(
|
176
|
+
&text,
|
177
|
+
&tokens.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
|
178
|
+
offsets,
|
179
|
+
&probs_vec,
|
180
|
+
threshold
|
181
|
+
)?;
|
182
|
+
|
183
|
+
// Convert to Ruby array
|
184
|
+
let result = RArray::new();
|
185
|
+
for entity in entities {
|
186
|
+
let hash = RHash::new();
|
187
|
+
hash.aset("text", entity.text)?;
|
188
|
+
hash.aset("label", entity.label)?;
|
189
|
+
hash.aset("start", entity.start)?;
|
190
|
+
hash.aset("end", entity.end)?;
|
191
|
+
hash.aset("confidence", entity.confidence)?;
|
192
|
+
hash.aset("token_start", entity.token_start)?;
|
193
|
+
hash.aset("token_end", entity.token_end)?;
|
194
|
+
result.push(hash)?;
|
195
|
+
}
|
196
|
+
|
197
|
+
Ok(result)
|
198
|
+
}
|
199
|
+
|
200
|
+
/// Get token-level predictions with labels and confidence scores
|
201
|
+
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
202
|
+
// Tokenize the text
|
203
|
+
let encoding = self.tokenizer.inner().encode(text.as_str(), true)
|
204
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
205
|
+
|
206
|
+
let token_ids = encoding.get_ids();
|
207
|
+
let tokens = encoding.get_tokens();
|
208
|
+
|
209
|
+
// Convert to tensors
|
210
|
+
let input_ids = Tensor::new(token_ids, &self.device)
|
211
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
212
|
+
.unsqueeze(0)
|
213
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
214
|
+
|
215
|
+
let attention_mask = Tensor::ones_like(&input_ids)
|
216
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
217
|
+
let token_type_ids = Tensor::zeros_like(&input_ids)
|
218
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
219
|
+
|
220
|
+
// Forward pass
|
221
|
+
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
222
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
223
|
+
let logits = self.classifier.forward(&output)
|
224
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
225
|
+
let probs = candle_nn::ops::softmax(&logits, 2)
|
226
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
227
|
+
|
228
|
+
// Get predictions
|
229
|
+
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
230
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
231
|
+
.to_vec2()
|
232
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
233
|
+
|
234
|
+
// Build result array
|
235
|
+
let result = RArray::new();
|
236
|
+
for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
|
237
|
+
// Find best label
|
238
|
+
let (label_id, confidence) = probs.iter()
|
239
|
+
.enumerate()
|
240
|
+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
241
|
+
.map(|(idx, conf)| (idx as i64, *conf))
|
242
|
+
.unwrap_or((0, 0.0));
|
243
|
+
|
244
|
+
let label = self.config.id2label.get(&label_id)
|
245
|
+
.unwrap_or(&"O".to_string())
|
246
|
+
.clone();
|
247
|
+
|
248
|
+
let token_info = RHash::new();
|
249
|
+
token_info.aset("token", token.to_string())?;
|
250
|
+
token_info.aset("label", label)?;
|
251
|
+
token_info.aset("confidence", confidence)?;
|
252
|
+
token_info.aset("index", i)?;
|
253
|
+
|
254
|
+
// Add probability distribution if needed
|
255
|
+
let probs_hash = RHash::new();
|
256
|
+
for (id, label) in &self.config.id2label {
|
257
|
+
if let Some(prob) = probs.get(*id as usize) {
|
258
|
+
probs_hash.aset(label.as_str(), *prob)?;
|
259
|
+
}
|
260
|
+
}
|
261
|
+
token_info.aset("probabilities", probs_hash)?;
|
262
|
+
|
263
|
+
result.push(token_info)?;
|
264
|
+
}
|
265
|
+
|
266
|
+
Ok(result)
|
267
|
+
}
|
268
|
+
|
269
|
+
/// Decode BIO-tagged sequences into entity spans
|
270
|
+
fn decode_entities(
|
271
|
+
&self,
|
272
|
+
text: &str,
|
273
|
+
tokens: &[&str],
|
274
|
+
offsets: &[(usize, usize)],
|
275
|
+
probs: &[Vec<f32>],
|
276
|
+
threshold: f32,
|
277
|
+
) -> Result<Vec<EntitySpan>> {
|
278
|
+
let mut entities = Vec::new();
|
279
|
+
let mut current_entity: Option<(String, usize, usize, Vec<f32>)> = None;
|
280
|
+
|
281
|
+
for (i, (token, probs_vec)) in tokens.iter().zip(probs).enumerate() {
|
282
|
+
// Skip special tokens
|
283
|
+
if token.starts_with("[") && token.ends_with("]") {
|
284
|
+
continue;
|
285
|
+
}
|
286
|
+
|
287
|
+
// Get predicted label
|
288
|
+
let (label_id, confidence) = probs_vec.iter()
|
289
|
+
.enumerate()
|
290
|
+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
291
|
+
.map(|(idx, conf)| (idx as i64, *conf))
|
292
|
+
.unwrap_or((0, 0.0));
|
293
|
+
|
294
|
+
let label = self.config.id2label.get(&label_id)
|
295
|
+
.unwrap_or(&"O".to_string())
|
296
|
+
.clone();
|
297
|
+
|
298
|
+
// BIO decoding logic
|
299
|
+
if label == "O" || confidence < threshold {
|
300
|
+
// End current entity if exists
|
301
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
|
302
|
+
if let (Some(start_offset), Some(end_offset)) =
|
303
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
304
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
305
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
306
|
+
|
307
|
+
entities.push(EntitySpan {
|
308
|
+
text: entity_text,
|
309
|
+
label: entity_type,
|
310
|
+
start: start_offset.0,
|
311
|
+
end: end_offset.1,
|
312
|
+
token_start: start_idx,
|
313
|
+
token_end: end_idx,
|
314
|
+
confidence: avg_confidence,
|
315
|
+
});
|
316
|
+
}
|
317
|
+
}
|
318
|
+
} else if label.starts_with("B-") {
|
319
|
+
// Begin new entity
|
320
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
|
321
|
+
if let (Some(start_offset), Some(end_offset)) =
|
322
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
323
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
324
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
325
|
+
|
326
|
+
entities.push(EntitySpan {
|
327
|
+
text: entity_text,
|
328
|
+
label: entity_type,
|
329
|
+
start: start_offset.0,
|
330
|
+
end: end_offset.1,
|
331
|
+
token_start: start_idx,
|
332
|
+
token_end: end_idx,
|
333
|
+
confidence: avg_confidence,
|
334
|
+
});
|
335
|
+
}
|
336
|
+
}
|
337
|
+
|
338
|
+
let entity_type = label[2..].to_string();
|
339
|
+
current_entity = Some((entity_type, i, i + 1, vec![confidence]));
|
340
|
+
} else if label.starts_with("I-") {
|
341
|
+
// Continue entity
|
342
|
+
if let Some((ref mut entity_type, _, ref mut end_idx, ref mut confidences)) = current_entity {
|
343
|
+
let new_type = label[2..].to_string();
|
344
|
+
if *entity_type == new_type {
|
345
|
+
*end_idx = i + 1;
|
346
|
+
confidences.push(confidence);
|
347
|
+
} else {
|
348
|
+
// Type mismatch, start new entity
|
349
|
+
current_entity = Some((new_type, i, i + 1, vec![confidence]));
|
350
|
+
}
|
351
|
+
} else {
|
352
|
+
// I- tag without B- tag, treat as beginning
|
353
|
+
let entity_type = label[2..].to_string();
|
354
|
+
current_entity = Some((entity_type, i, i + 1, vec![confidence]));
|
355
|
+
}
|
356
|
+
}
|
357
|
+
}
|
358
|
+
|
359
|
+
// Handle final entity
|
360
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity {
|
361
|
+
if let (Some(start_offset), Some(end_offset)) =
|
362
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
363
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
364
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
365
|
+
|
366
|
+
entities.push(EntitySpan {
|
367
|
+
text: entity_text,
|
368
|
+
label: entity_type,
|
369
|
+
start: start_offset.0,
|
370
|
+
end: end_offset.1,
|
371
|
+
token_start: start_idx,
|
372
|
+
token_end: end_idx,
|
373
|
+
confidence: avg_confidence,
|
374
|
+
});
|
375
|
+
}
|
376
|
+
}
|
377
|
+
|
378
|
+
Ok(entities)
|
379
|
+
}
|
380
|
+
|
381
|
+
/// Get the label configuration
|
382
|
+
pub fn labels(&self) -> Result<RHash> {
|
383
|
+
let hash = RHash::new();
|
384
|
+
|
385
|
+
let id2label = RHash::new();
|
386
|
+
for (id, label) in &self.config.id2label {
|
387
|
+
id2label.aset(*id, label.as_str())?;
|
388
|
+
}
|
389
|
+
|
390
|
+
let label2id = RHash::new();
|
391
|
+
for (label, id) in &self.config.label2id {
|
392
|
+
label2id.aset(label.as_str(), *id)?;
|
393
|
+
}
|
394
|
+
|
395
|
+
hash.aset("id2label", id2label)?;
|
396
|
+
hash.aset("label2id", label2id)?;
|
397
|
+
hash.aset("num_labels", self.config.id2label.len())?;
|
398
|
+
|
399
|
+
Ok(hash)
|
400
|
+
}
|
401
|
+
|
402
|
+
/// Get the tokenizer
|
403
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
404
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
405
|
+
}
|
406
|
+
|
407
|
+
/// Get model info
|
408
|
+
pub fn model_info(&self) -> String {
|
409
|
+
format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
|
410
|
+
}
|
411
|
+
}
|
412
|
+
|
413
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
414
|
+
let ner_class = rb_candle.define_class("NER", class::object())?;
|
415
|
+
ner_class.define_singleton_method("new", function!(NER::new, 3))?;
|
416
|
+
ner_class.define_method("extract_entities", method!(NER::extract_entities, 2))?;
|
417
|
+
ner_class.define_method("predict_tokens", method!(NER::predict_tokens, 1))?;
|
418
|
+
ner_class.define_method("labels", method!(NER::labels, 0))?;
|
419
|
+
ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
|
420
|
+
ner_class.define_method("model_info", method!(NER::model_info, 0))?;
|
421
|
+
|
422
|
+
Ok(())
|
423
|
+
}
|
data/ext/candle/src/reranker.rs
CHANGED
@@ -3,28 +3,29 @@ use candle_transformers::models::bert::{BertModel, Config};
|
|
3
3
|
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
4
4
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
5
5
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
|
-
use tokenizers::{
|
6
|
+
use tokenizers::{EncodeInput, Tokenizer};
|
7
7
|
use std::thread;
|
8
|
-
use crate::ruby::{Device
|
8
|
+
use crate::ruby::{Device, Result};
|
9
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
9
10
|
|
10
11
|
#[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
|
11
12
|
pub struct Reranker {
|
12
13
|
model: BertModel,
|
13
|
-
tokenizer:
|
14
|
+
tokenizer: TokenizerWrapper,
|
14
15
|
pooler: Linear,
|
15
16
|
classifier: Linear,
|
16
17
|
device: CoreDevice,
|
17
18
|
}
|
18
19
|
|
19
20
|
impl Reranker {
|
20
|
-
pub fn new(model_id: String, device: Option<
|
21
|
-
let device = device.unwrap_or(
|
21
|
+
pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
|
22
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
22
23
|
Self::new_with_core_device(model_id, device)
|
23
24
|
}
|
24
25
|
|
25
|
-
fn new_with_core_device(model_id: String, device: CoreDevice) -> Result<Self, Error> {
|
26
|
+
fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
26
27
|
let device_clone = device.clone();
|
27
|
-
let handle = thread::spawn(move || -> Result<(BertModel,
|
28
|
+
let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
28
29
|
let api = Api::new()?;
|
29
30
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
30
31
|
|
@@ -38,12 +39,8 @@ impl Reranker {
|
|
38
39
|
let config: Config = serde_json::from_str(&config)?;
|
39
40
|
|
40
41
|
// Setup tokenizer with padding
|
41
|
-
let
|
42
|
-
let
|
43
|
-
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
44
|
-
..Default::default()
|
45
|
-
};
|
46
|
-
tokenizer.with_padding(Some(pp));
|
42
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
43
|
+
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
47
44
|
|
48
45
|
// Load model weights
|
49
46
|
let vb = unsafe {
|
@@ -59,7 +56,7 @@ impl Reranker {
|
|
59
56
|
// Load classifier layer for cross-encoder (single output score)
|
60
57
|
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
61
58
|
|
62
|
-
Ok((model, tokenizer, pooler, classifier))
|
59
|
+
Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
|
63
60
|
});
|
64
61
|
|
65
62
|
match handle.join() {
|
@@ -71,12 +68,12 @@ impl Reranker {
|
|
71
68
|
}
|
72
69
|
}
|
73
70
|
|
74
|
-
pub fn debug_tokenization(&self, query: String, document: String) -> Result<magnus::RHash, Error> {
|
71
|
+
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
|
75
72
|
// Create query-document pair for cross-encoder
|
76
73
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
77
74
|
|
78
|
-
// Tokenize
|
79
|
-
let encoding = self.tokenizer.encode(query_doc_pair, true)
|
75
|
+
// Tokenize using the inner tokenizer for detailed info
|
76
|
+
let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
|
80
77
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
81
78
|
|
82
79
|
// Get token information
|
@@ -95,7 +92,7 @@ impl Reranker {
|
|
95
92
|
Ok(result)
|
96
93
|
}
|
97
94
|
|
98
|
-
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> Result<RArray, Error> {
|
95
|
+
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
|
99
96
|
let documents: Vec<String> = documents.to_vec()?;
|
100
97
|
|
101
98
|
// Create query-document pairs for cross-encoder
|
@@ -104,8 +101,8 @@ impl Reranker {
|
|
104
101
|
.map(|d| (query.clone(), d.clone()).into())
|
105
102
|
.collect();
|
106
103
|
|
107
|
-
// Tokenize batch
|
108
|
-
let encodings = self.tokenizer.encode_batch(query_and_docs, true)
|
104
|
+
// Tokenize batch using inner tokenizer for access to token type IDs
|
105
|
+
let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
|
109
106
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
110
107
|
|
111
108
|
// Convert to tensors
|
@@ -256,12 +253,18 @@ impl Reranker {
|
|
256
253
|
}
|
257
254
|
Ok(result_array)
|
258
255
|
}
|
256
|
+
|
257
|
+
/// Get the tokenizer used by this model
|
258
|
+
pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
|
259
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
260
|
+
}
|
259
261
|
}
|
260
262
|
|
261
|
-
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
263
|
+
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
262
264
|
let c_reranker = rb_candle.define_class("Reranker", class::object())?;
|
263
265
|
c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
|
264
266
|
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
265
267
|
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
268
|
+
c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
|
266
269
|
Ok(())
|
267
270
|
}
|
@@ -2,7 +2,7 @@ use magnus::Error;
|
|
2
2
|
use magnus::{function, method, class, RModule, Module, Object};
|
3
3
|
|
4
4
|
use ::candle_core::Device as CoreDevice;
|
5
|
-
use crate::ruby::Result
|
5
|
+
use crate::ruby::Result;
|
6
6
|
|
7
7
|
#[cfg(any(feature = "cuda", feature = "metal"))]
|
8
8
|
use crate::ruby::errors::wrap_candle_err;
|
@@ -68,7 +68,7 @@ impl Device {
|
|
68
68
|
}
|
69
69
|
|
70
70
|
/// Create a CUDA device (GPU)
|
71
|
-
pub fn cuda() ->
|
71
|
+
pub fn cuda() -> Result<Self> {
|
72
72
|
#[cfg(not(feature = "cuda"))]
|
73
73
|
{
|
74
74
|
return Err(Error::new(
|
@@ -82,7 +82,7 @@ impl Device {
|
|
82
82
|
}
|
83
83
|
|
84
84
|
/// Create a Metal device (Apple GPU)
|
85
|
-
pub fn metal() ->
|
85
|
+
pub fn metal() -> Result<Self> {
|
86
86
|
#[cfg(not(feature = "metal"))]
|
87
87
|
{
|
88
88
|
return Err(Error::new(
|
@@ -103,7 +103,7 @@ impl Device {
|
|
103
103
|
}
|
104
104
|
}
|
105
105
|
|
106
|
-
pub fn as_device(&self) ->
|
106
|
+
pub fn as_device(&self) -> Result<CoreDevice> {
|
107
107
|
match self {
|
108
108
|
Self::Cpu => Ok(CoreDevice::Cpu),
|
109
109
|
Self::Cuda => {
|
@@ -165,7 +165,7 @@ impl Device {
|
|
165
165
|
}
|
166
166
|
|
167
167
|
impl magnus::TryConvert for Device {
|
168
|
-
fn try_convert(val: magnus::Value) ->
|
168
|
+
fn try_convert(val: magnus::Value) -> Result<Self> {
|
169
169
|
// First check if it's already a wrapped Device object
|
170
170
|
if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
|
171
171
|
return Ok(*device);
|
@@ -184,7 +184,7 @@ impl magnus::TryConvert for Device {
|
|
184
184
|
}
|
185
185
|
}
|
186
186
|
|
187
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
187
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
188
188
|
let rb_device = rb_candle.define_class("Device", class::object())?;
|
189
189
|
rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
|
190
190
|
rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
|
@@ -1,8 +1,8 @@
|
|
1
1
|
use magnus::value::ReprValue;
|
2
|
-
use magnus::{method, class, RModule,
|
2
|
+
use magnus::{method, class, RModule, Module};
|
3
3
|
|
4
4
|
use ::candle_core::DType as CoreDType;
|
5
|
-
use crate::ruby::Result
|
5
|
+
use crate::ruby::Result;
|
6
6
|
|
7
7
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
8
8
|
#[magnus::wrap(class = "Candle::DType", free_immediately, size)]
|
@@ -21,7 +21,7 @@ impl DType {
|
|
21
21
|
}
|
22
22
|
|
23
23
|
impl DType {
|
24
|
-
pub fn from_rbobject(dtype: magnus::Symbol) ->
|
24
|
+
pub fn from_rbobject(dtype: magnus::Symbol) -> Result<Self> {
|
25
25
|
let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
|
26
26
|
use std::str::FromStr;
|
27
27
|
let dtype = CoreDType::from_str(&dtype).unwrap();
|
@@ -29,7 +29,7 @@ impl DType {
|
|
29
29
|
}
|
30
30
|
}
|
31
31
|
|
32
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
32
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
33
33
|
let rb_dtype = rb_candle.define_class("DType", class::object())?;
|
34
34
|
rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
|
35
35
|
rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
|