red-candle 1.0.0.pre.7 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +399 -18
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +5 -0
- data/ext/candle/src/llm/llama.rs +5 -0
- data/ext/candle/src/llm/mistral.rs +5 -0
- data/ext/candle/src/llm/mod.rs +1 -89
- data/ext/candle/src/llm/quantized_gguf.rs +5 -0
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -33
- data/ext/candle/src/ruby/llm.rs +31 -13
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +128 -5
- data/ext/candle/src/ruby/qtensor.rs +0 -69
data/ext/candle/src/llm/mod.rs
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
use candle_core::{Device, Result as CandleResult};
|
2
|
-
use tokenizers::Tokenizer;
|
3
2
|
|
4
3
|
pub mod mistral;
|
5
4
|
pub mod llama;
|
@@ -11,6 +10,7 @@ pub mod quantized_gguf;
|
|
11
10
|
pub use generation_config::GenerationConfig;
|
12
11
|
pub use text_generation::TextGeneration;
|
13
12
|
pub use quantized_gguf::QuantizedGGUF;
|
13
|
+
pub use crate::tokenizer::TokenizerWrapper;
|
14
14
|
|
15
15
|
/// Trait for text generation models
|
16
16
|
pub trait TextGenerator: Send + Sync {
|
@@ -37,92 +37,4 @@ pub trait TextGenerator: Send + Sync {
|
|
37
37
|
|
38
38
|
/// Clear any cached state (like KV cache)
|
39
39
|
fn clear_cache(&mut self);
|
40
|
-
}
|
41
|
-
|
42
|
-
/// Common structure for managing tokenizer
|
43
|
-
#[derive(Debug)]
|
44
|
-
pub struct TokenizerWrapper {
|
45
|
-
tokenizer: Tokenizer,
|
46
|
-
}
|
47
|
-
|
48
|
-
impl TokenizerWrapper {
|
49
|
-
pub fn new(tokenizer: Tokenizer) -> Self {
|
50
|
-
Self { tokenizer }
|
51
|
-
}
|
52
|
-
|
53
|
-
pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
|
54
|
-
let encoding = self.tokenizer
|
55
|
-
.encode(text, add_special_tokens)
|
56
|
-
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
|
57
|
-
Ok(encoding.get_ids().to_vec())
|
58
|
-
}
|
59
|
-
|
60
|
-
pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
|
61
|
-
self.tokenizer
|
62
|
-
.decode(tokens, skip_special_tokens)
|
63
|
-
.map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
|
64
|
-
}
|
65
|
-
|
66
|
-
pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
|
67
|
-
self.tokenizer
|
68
|
-
.id_to_token(token)
|
69
|
-
.map(|s| s.to_string())
|
70
|
-
.ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
|
71
|
-
}
|
72
|
-
|
73
|
-
/// Decode a single token for streaming output
|
74
|
-
pub fn decode_token(&self, token: u32) -> CandleResult<String> {
|
75
|
-
// Decode the single token properly
|
76
|
-
self.decode(&[token], true)
|
77
|
-
}
|
78
|
-
|
79
|
-
/// Decode tokens incrementally for streaming
|
80
|
-
/// This is more efficient than decoding single tokens
|
81
|
-
pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
|
82
|
-
if new_tokens_start >= all_tokens.len() {
|
83
|
-
return Ok(String::new());
|
84
|
-
}
|
85
|
-
|
86
|
-
// Decode all tokens up to this point
|
87
|
-
let full_text = self.decode(all_tokens, true)?;
|
88
|
-
|
89
|
-
// If we're at the start, return everything
|
90
|
-
if new_tokens_start == 0 {
|
91
|
-
return Ok(full_text);
|
92
|
-
}
|
93
|
-
|
94
|
-
// Otherwise, decode up to the previous token and return the difference
|
95
|
-
let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
|
96
|
-
|
97
|
-
// Find the common prefix between the two strings to handle cases where
|
98
|
-
// the tokenizer might produce slightly different text when decoding
|
99
|
-
// different token sequences
|
100
|
-
let common_prefix_len = full_text
|
101
|
-
.char_indices()
|
102
|
-
.zip(previous_text.chars())
|
103
|
-
.take_while(|((_, c1), c2)| c1 == c2)
|
104
|
-
.count();
|
105
|
-
|
106
|
-
// Find the byte position of the character boundary
|
107
|
-
let byte_pos = full_text
|
108
|
-
.char_indices()
|
109
|
-
.nth(common_prefix_len)
|
110
|
-
.map(|(pos, _)| pos)
|
111
|
-
.unwrap_or(full_text.len());
|
112
|
-
|
113
|
-
// Return only the new portion
|
114
|
-
Ok(full_text[byte_pos..].to_string())
|
115
|
-
}
|
116
|
-
|
117
|
-
/// Format tokens with debug information
|
118
|
-
pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
|
119
|
-
let mut result = String::new();
|
120
|
-
|
121
|
-
for &token in tokens {
|
122
|
-
let token_piece = self.token_to_piece(token)?;
|
123
|
-
result.push_str(&format!("[{}:{}]", token, token_piece));
|
124
|
-
}
|
125
|
-
|
126
|
-
Ok(result)
|
127
|
-
}
|
128
40
|
}
|
@@ -28,6 +28,11 @@ enum ModelType {
|
|
28
28
|
}
|
29
29
|
|
30
30
|
impl QuantizedGGUF {
|
31
|
+
/// Get the tokenizer
|
32
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
33
|
+
&self.tokenizer
|
34
|
+
}
|
35
|
+
|
31
36
|
/// Load a quantized model from a GGUF file
|
32
37
|
pub async fn from_pretrained(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
|
33
38
|
// Check if user specified an exact GGUF filename
|
@@ -0,0 +1,423 @@
|
|
1
|
+
use magnus::{class, function, method, prelude::*, Error, RModule, RArray, RHash};
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config};
|
3
|
+
use candle_core::{Device as CoreDevice, Tensor, DType, Module as CanModule};
|
4
|
+
use candle_nn::{VarBuilder, Linear};
|
5
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
|
+
use std::collections::HashMap;
|
7
|
+
use serde::{Deserialize, Serialize};
|
8
|
+
use crate::ruby::{Device, Result};
|
9
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
10
|
+
|
11
|
+
#[derive(Debug, Clone, Serialize, Deserialize)]
|
12
|
+
pub struct NERConfig {
|
13
|
+
pub id2label: HashMap<i64, String>,
|
14
|
+
pub label2id: HashMap<String, i64>,
|
15
|
+
}
|
16
|
+
|
17
|
+
#[derive(Debug, Clone)]
|
18
|
+
pub struct EntitySpan {
|
19
|
+
pub text: String,
|
20
|
+
pub label: String,
|
21
|
+
pub start: usize,
|
22
|
+
pub end: usize,
|
23
|
+
pub token_start: usize,
|
24
|
+
pub token_end: usize,
|
25
|
+
pub confidence: f32,
|
26
|
+
}
|
27
|
+
|
28
|
+
#[magnus::wrap(class = "Candle::NER", free_immediately, size)]
|
29
|
+
pub struct NER {
|
30
|
+
model: BertModel,
|
31
|
+
tokenizer: TokenizerWrapper,
|
32
|
+
classifier: Linear,
|
33
|
+
config: NERConfig,
|
34
|
+
device: CoreDevice,
|
35
|
+
model_id: String,
|
36
|
+
}
|
37
|
+
|
38
|
+
impl NER {
|
39
|
+
pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
|
40
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
41
|
+
|
42
|
+
// Load model in a separate thread to avoid blocking
|
43
|
+
let device_clone = device.clone();
|
44
|
+
let model_id_clone = model_id.clone();
|
45
|
+
|
46
|
+
let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
47
|
+
let api = Api::new()?;
|
48
|
+
let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
|
49
|
+
|
50
|
+
// Download model files
|
51
|
+
let config_filename = repo.get("config.json")?;
|
52
|
+
|
53
|
+
// Handle tokenizer loading with optional tokenizer_id
|
54
|
+
let tokenizer = if let Some(tok_id) = tokenizer_id {
|
55
|
+
// Use the specified tokenizer
|
56
|
+
let tok_repo = api.repo(Repo::new(tok_id, RepoType::Model));
|
57
|
+
let tokenizer_filename = tok_repo.get("tokenizer.json")?;
|
58
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
59
|
+
TokenizerLoader::with_padding(tokenizer, None)
|
60
|
+
} else {
|
61
|
+
// Try to load tokenizer from model repo
|
62
|
+
let tokenizer_filename = repo.get("tokenizer.json")?;
|
63
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename)?;
|
64
|
+
TokenizerLoader::with_padding(tokenizer, None)
|
65
|
+
};
|
66
|
+
let weights_filename = repo.get("pytorch_model.safetensors")
|
67
|
+
.or_else(|_| repo.get("model.safetensors"))?;
|
68
|
+
|
69
|
+
// Load BERT config
|
70
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
71
|
+
let config_json: serde_json::Value = serde_json::from_str(&config_str)?;
|
72
|
+
let bert_config: Config = serde_json::from_value(config_json.clone())?;
|
73
|
+
|
74
|
+
// Extract NER label configuration
|
75
|
+
let id2label = config_json["id2label"]
|
76
|
+
.as_object()
|
77
|
+
.ok_or("Missing id2label in config")?
|
78
|
+
.iter()
|
79
|
+
.map(|(k, v)| {
|
80
|
+
let id = k.parse::<i64>().unwrap_or(0);
|
81
|
+
let label = v.as_str().unwrap_or("O").to_string();
|
82
|
+
(id, label)
|
83
|
+
})
|
84
|
+
.collect::<HashMap<_, _>>();
|
85
|
+
|
86
|
+
let label2id = id2label.iter()
|
87
|
+
.map(|(id, label)| (label.clone(), *id))
|
88
|
+
.collect::<HashMap<_, _>>();
|
89
|
+
|
90
|
+
let num_labels = id2label.len();
|
91
|
+
let ner_config = NERConfig { id2label, label2id };
|
92
|
+
|
93
|
+
// Load model weights
|
94
|
+
let vb = unsafe {
|
95
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
|
96
|
+
};
|
97
|
+
|
98
|
+
// Load BERT model
|
99
|
+
let model = BertModel::load(vb.pp("bert"), &bert_config)?;
|
100
|
+
|
101
|
+
// Load classification head for token classification
|
102
|
+
let classifier = candle_nn::linear(
|
103
|
+
bert_config.hidden_size,
|
104
|
+
num_labels,
|
105
|
+
vb.pp("classifier")
|
106
|
+
)?;
|
107
|
+
|
108
|
+
Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
|
109
|
+
});
|
110
|
+
|
111
|
+
match handle.join() {
|
112
|
+
Ok(Ok((model, tokenizer, classifier, config))) => {
|
113
|
+
Ok(Self {
|
114
|
+
model,
|
115
|
+
tokenizer,
|
116
|
+
classifier,
|
117
|
+
config,
|
118
|
+
device,
|
119
|
+
model_id,
|
120
|
+
})
|
121
|
+
}
|
122
|
+
Ok(Err(e)) => Err(Error::new(
|
123
|
+
magnus::exception::runtime_error(),
|
124
|
+
format!("Failed to load NER model: {}", e)
|
125
|
+
)),
|
126
|
+
Err(_) => Err(Error::new(
|
127
|
+
magnus::exception::runtime_error(),
|
128
|
+
"Thread panicked while loading NER model"
|
129
|
+
)),
|
130
|
+
}
|
131
|
+
}
|
132
|
+
|
133
|
+
/// Extract entities from text with confidence scores
|
134
|
+
pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
|
135
|
+
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
136
|
+
|
137
|
+
// Tokenize the text
|
138
|
+
let encoding = self.tokenizer.inner().encode(text.as_str(), true)
|
139
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
140
|
+
|
141
|
+
let token_ids = encoding.get_ids();
|
142
|
+
let tokens = encoding.get_tokens();
|
143
|
+
let offsets = encoding.get_offsets();
|
144
|
+
|
145
|
+
// Convert to tensors
|
146
|
+
let input_ids = Tensor::new(token_ids, &self.device)
|
147
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
148
|
+
.unsqueeze(0)
|
149
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; // Add batch dimension
|
150
|
+
|
151
|
+
let attention_mask = Tensor::ones_like(&input_ids)
|
152
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
153
|
+
let token_type_ids = Tensor::zeros_like(&input_ids)
|
154
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
155
|
+
|
156
|
+
// Forward pass through BERT
|
157
|
+
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
158
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
159
|
+
|
160
|
+
// Apply classifier to get logits for each token
|
161
|
+
let logits = self.classifier.forward(&output)
|
162
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
163
|
+
|
164
|
+
// Apply softmax to get probabilities
|
165
|
+
let probs = candle_nn::ops::softmax(&logits, 2)
|
166
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
167
|
+
|
168
|
+
// Get predictions and confidence scores
|
169
|
+
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
170
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
171
|
+
.to_vec2()
|
172
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
173
|
+
|
174
|
+
// Extract entities with BIO decoding
|
175
|
+
let entities = self.decode_entities(
|
176
|
+
&text,
|
177
|
+
&tokens.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
|
178
|
+
offsets,
|
179
|
+
&probs_vec,
|
180
|
+
threshold
|
181
|
+
)?;
|
182
|
+
|
183
|
+
// Convert to Ruby array
|
184
|
+
let result = RArray::new();
|
185
|
+
for entity in entities {
|
186
|
+
let hash = RHash::new();
|
187
|
+
hash.aset("text", entity.text)?;
|
188
|
+
hash.aset("label", entity.label)?;
|
189
|
+
hash.aset("start", entity.start)?;
|
190
|
+
hash.aset("end", entity.end)?;
|
191
|
+
hash.aset("confidence", entity.confidence)?;
|
192
|
+
hash.aset("token_start", entity.token_start)?;
|
193
|
+
hash.aset("token_end", entity.token_end)?;
|
194
|
+
result.push(hash)?;
|
195
|
+
}
|
196
|
+
|
197
|
+
Ok(result)
|
198
|
+
}
|
199
|
+
|
200
|
+
/// Get token-level predictions with labels and confidence scores
|
201
|
+
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
202
|
+
// Tokenize the text
|
203
|
+
let encoding = self.tokenizer.inner().encode(text.as_str(), true)
|
204
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
205
|
+
|
206
|
+
let token_ids = encoding.get_ids();
|
207
|
+
let tokens = encoding.get_tokens();
|
208
|
+
|
209
|
+
// Convert to tensors
|
210
|
+
let input_ids = Tensor::new(token_ids, &self.device)
|
211
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
212
|
+
.unsqueeze(0)
|
213
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
214
|
+
|
215
|
+
let attention_mask = Tensor::ones_like(&input_ids)
|
216
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
217
|
+
let token_type_ids = Tensor::zeros_like(&input_ids)
|
218
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
219
|
+
|
220
|
+
// Forward pass
|
221
|
+
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
222
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
223
|
+
let logits = self.classifier.forward(&output)
|
224
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
225
|
+
let probs = candle_nn::ops::softmax(&logits, 2)
|
226
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
227
|
+
|
228
|
+
// Get predictions
|
229
|
+
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
230
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
231
|
+
.to_vec2()
|
232
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
233
|
+
|
234
|
+
// Build result array
|
235
|
+
let result = RArray::new();
|
236
|
+
for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
|
237
|
+
// Find best label
|
238
|
+
let (label_id, confidence) = probs.iter()
|
239
|
+
.enumerate()
|
240
|
+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
241
|
+
.map(|(idx, conf)| (idx as i64, *conf))
|
242
|
+
.unwrap_or((0, 0.0));
|
243
|
+
|
244
|
+
let label = self.config.id2label.get(&label_id)
|
245
|
+
.unwrap_or(&"O".to_string())
|
246
|
+
.clone();
|
247
|
+
|
248
|
+
let token_info = RHash::new();
|
249
|
+
token_info.aset("token", token.to_string())?;
|
250
|
+
token_info.aset("label", label)?;
|
251
|
+
token_info.aset("confidence", confidence)?;
|
252
|
+
token_info.aset("index", i)?;
|
253
|
+
|
254
|
+
// Add probability distribution if needed
|
255
|
+
let probs_hash = RHash::new();
|
256
|
+
for (id, label) in &self.config.id2label {
|
257
|
+
if let Some(prob) = probs.get(*id as usize) {
|
258
|
+
probs_hash.aset(label.as_str(), *prob)?;
|
259
|
+
}
|
260
|
+
}
|
261
|
+
token_info.aset("probabilities", probs_hash)?;
|
262
|
+
|
263
|
+
result.push(token_info)?;
|
264
|
+
}
|
265
|
+
|
266
|
+
Ok(result)
|
267
|
+
}
|
268
|
+
|
269
|
+
/// Decode BIO-tagged sequences into entity spans
|
270
|
+
fn decode_entities(
|
271
|
+
&self,
|
272
|
+
text: &str,
|
273
|
+
tokens: &[&str],
|
274
|
+
offsets: &[(usize, usize)],
|
275
|
+
probs: &[Vec<f32>],
|
276
|
+
threshold: f32,
|
277
|
+
) -> Result<Vec<EntitySpan>> {
|
278
|
+
let mut entities = Vec::new();
|
279
|
+
let mut current_entity: Option<(String, usize, usize, Vec<f32>)> = None;
|
280
|
+
|
281
|
+
for (i, (token, probs_vec)) in tokens.iter().zip(probs).enumerate() {
|
282
|
+
// Skip special tokens
|
283
|
+
if token.starts_with("[") && token.ends_with("]") {
|
284
|
+
continue;
|
285
|
+
}
|
286
|
+
|
287
|
+
// Get predicted label
|
288
|
+
let (label_id, confidence) = probs_vec.iter()
|
289
|
+
.enumerate()
|
290
|
+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
291
|
+
.map(|(idx, conf)| (idx as i64, *conf))
|
292
|
+
.unwrap_or((0, 0.0));
|
293
|
+
|
294
|
+
let label = self.config.id2label.get(&label_id)
|
295
|
+
.unwrap_or(&"O".to_string())
|
296
|
+
.clone();
|
297
|
+
|
298
|
+
// BIO decoding logic
|
299
|
+
if label == "O" || confidence < threshold {
|
300
|
+
// End current entity if exists
|
301
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
|
302
|
+
if let (Some(start_offset), Some(end_offset)) =
|
303
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
304
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
305
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
306
|
+
|
307
|
+
entities.push(EntitySpan {
|
308
|
+
text: entity_text,
|
309
|
+
label: entity_type,
|
310
|
+
start: start_offset.0,
|
311
|
+
end: end_offset.1,
|
312
|
+
token_start: start_idx,
|
313
|
+
token_end: end_idx,
|
314
|
+
confidence: avg_confidence,
|
315
|
+
});
|
316
|
+
}
|
317
|
+
}
|
318
|
+
} else if label.starts_with("B-") {
|
319
|
+
// Begin new entity
|
320
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity.take() {
|
321
|
+
if let (Some(start_offset), Some(end_offset)) =
|
322
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
323
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
324
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
325
|
+
|
326
|
+
entities.push(EntitySpan {
|
327
|
+
text: entity_text,
|
328
|
+
label: entity_type,
|
329
|
+
start: start_offset.0,
|
330
|
+
end: end_offset.1,
|
331
|
+
token_start: start_idx,
|
332
|
+
token_end: end_idx,
|
333
|
+
confidence: avg_confidence,
|
334
|
+
});
|
335
|
+
}
|
336
|
+
}
|
337
|
+
|
338
|
+
let entity_type = label[2..].to_string();
|
339
|
+
current_entity = Some((entity_type, i, i + 1, vec![confidence]));
|
340
|
+
} else if label.starts_with("I-") {
|
341
|
+
// Continue entity
|
342
|
+
if let Some((ref mut entity_type, _, ref mut end_idx, ref mut confidences)) = current_entity {
|
343
|
+
let new_type = label[2..].to_string();
|
344
|
+
if *entity_type == new_type {
|
345
|
+
*end_idx = i + 1;
|
346
|
+
confidences.push(confidence);
|
347
|
+
} else {
|
348
|
+
// Type mismatch, start new entity
|
349
|
+
current_entity = Some((new_type, i, i + 1, vec![confidence]));
|
350
|
+
}
|
351
|
+
} else {
|
352
|
+
// I- tag without B- tag, treat as beginning
|
353
|
+
let entity_type = label[2..].to_string();
|
354
|
+
current_entity = Some((entity_type, i, i + 1, vec![confidence]));
|
355
|
+
}
|
356
|
+
}
|
357
|
+
}
|
358
|
+
|
359
|
+
// Handle final entity
|
360
|
+
if let Some((entity_type, start_idx, end_idx, confidences)) = current_entity {
|
361
|
+
if let (Some(start_offset), Some(end_offset)) =
|
362
|
+
(offsets.get(start_idx), offsets.get(end_idx - 1)) {
|
363
|
+
let entity_text = text[start_offset.0..end_offset.1].to_string();
|
364
|
+
let avg_confidence = confidences.iter().sum::<f32>() / confidences.len() as f32;
|
365
|
+
|
366
|
+
entities.push(EntitySpan {
|
367
|
+
text: entity_text,
|
368
|
+
label: entity_type,
|
369
|
+
start: start_offset.0,
|
370
|
+
end: end_offset.1,
|
371
|
+
token_start: start_idx,
|
372
|
+
token_end: end_idx,
|
373
|
+
confidence: avg_confidence,
|
374
|
+
});
|
375
|
+
}
|
376
|
+
}
|
377
|
+
|
378
|
+
Ok(entities)
|
379
|
+
}
|
380
|
+
|
381
|
+
/// Get the label configuration
|
382
|
+
pub fn labels(&self) -> Result<RHash> {
|
383
|
+
let hash = RHash::new();
|
384
|
+
|
385
|
+
let id2label = RHash::new();
|
386
|
+
for (id, label) in &self.config.id2label {
|
387
|
+
id2label.aset(*id, label.as_str())?;
|
388
|
+
}
|
389
|
+
|
390
|
+
let label2id = RHash::new();
|
391
|
+
for (label, id) in &self.config.label2id {
|
392
|
+
label2id.aset(label.as_str(), *id)?;
|
393
|
+
}
|
394
|
+
|
395
|
+
hash.aset("id2label", id2label)?;
|
396
|
+
hash.aset("label2id", label2id)?;
|
397
|
+
hash.aset("num_labels", self.config.id2label.len())?;
|
398
|
+
|
399
|
+
Ok(hash)
|
400
|
+
}
|
401
|
+
|
402
|
+
/// Get the tokenizer
|
403
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
404
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
405
|
+
}
|
406
|
+
|
407
|
+
/// Get model info
|
408
|
+
pub fn model_info(&self) -> String {
|
409
|
+
format!("NER model: {}, labels: {}", self.model_id, self.config.id2label.len())
|
410
|
+
}
|
411
|
+
}
|
412
|
+
|
413
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
414
|
+
let ner_class = rb_candle.define_class("NER", class::object())?;
|
415
|
+
ner_class.define_singleton_method("new", function!(NER::new, 3))?;
|
416
|
+
ner_class.define_method("extract_entities", method!(NER::extract_entities, 2))?;
|
417
|
+
ner_class.define_method("predict_tokens", method!(NER::predict_tokens, 1))?;
|
418
|
+
ner_class.define_method("labels", method!(NER::labels, 0))?;
|
419
|
+
ner_class.define_method("tokenizer", method!(NER::tokenizer, 0))?;
|
420
|
+
ner_class.define_method("model_info", method!(NER::model_info, 0))?;
|
421
|
+
|
422
|
+
Ok(())
|
423
|
+
}
|
data/ext/candle/src/reranker.rs
CHANGED
@@ -3,28 +3,29 @@ use candle_transformers::models::bert::{BertModel, Config};
|
|
3
3
|
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
4
4
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
5
5
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
|
-
use tokenizers::{
|
6
|
+
use tokenizers::{EncodeInput, Tokenizer};
|
7
7
|
use std::thread;
|
8
|
-
use crate::ruby::{Device
|
8
|
+
use crate::ruby::{Device, Result};
|
9
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
9
10
|
|
10
11
|
#[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
|
11
12
|
pub struct Reranker {
|
12
13
|
model: BertModel,
|
13
|
-
tokenizer:
|
14
|
+
tokenizer: TokenizerWrapper,
|
14
15
|
pooler: Linear,
|
15
16
|
classifier: Linear,
|
16
17
|
device: CoreDevice,
|
17
18
|
}
|
18
19
|
|
19
20
|
impl Reranker {
|
20
|
-
pub fn new(model_id: String, device: Option<
|
21
|
-
let device = device.unwrap_or(
|
21
|
+
pub fn new(model_id: String, device: Option<Device>) -> Result<Self> {
|
22
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
22
23
|
Self::new_with_core_device(model_id, device)
|
23
24
|
}
|
24
25
|
|
25
|
-
fn new_with_core_device(model_id: String, device: CoreDevice) -> Result<Self, Error> {
|
26
|
+
fn new_with_core_device(model_id: String, device: CoreDevice) -> std::result::Result<Self, Error> {
|
26
27
|
let device_clone = device.clone();
|
27
|
-
let handle = thread::spawn(move || -> Result<(BertModel,
|
28
|
+
let handle = thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
28
29
|
let api = Api::new()?;
|
29
30
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
30
31
|
|
@@ -38,12 +39,8 @@ impl Reranker {
|
|
38
39
|
let config: Config = serde_json::from_str(&config)?;
|
39
40
|
|
40
41
|
// Setup tokenizer with padding
|
41
|
-
let
|
42
|
-
let
|
43
|
-
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
44
|
-
..Default::default()
|
45
|
-
};
|
46
|
-
tokenizer.with_padding(Some(pp));
|
42
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
43
|
+
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
47
44
|
|
48
45
|
// Load model weights
|
49
46
|
let vb = unsafe {
|
@@ -59,7 +56,7 @@ impl Reranker {
|
|
59
56
|
// Load classifier layer for cross-encoder (single output score)
|
60
57
|
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
61
58
|
|
62
|
-
Ok((model, tokenizer, pooler, classifier))
|
59
|
+
Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
|
63
60
|
});
|
64
61
|
|
65
62
|
match handle.join() {
|
@@ -71,12 +68,12 @@ impl Reranker {
|
|
71
68
|
}
|
72
69
|
}
|
73
70
|
|
74
|
-
pub fn debug_tokenization(&self, query: String, document: String) -> Result<magnus::RHash, Error> {
|
71
|
+
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<magnus::RHash, Error> {
|
75
72
|
// Create query-document pair for cross-encoder
|
76
73
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
77
74
|
|
78
|
-
// Tokenize
|
79
|
-
let encoding = self.tokenizer.encode(query_doc_pair, true)
|
75
|
+
// Tokenize using the inner tokenizer for detailed info
|
76
|
+
let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
|
80
77
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
81
78
|
|
82
79
|
// Get token information
|
@@ -95,7 +92,7 @@ impl Reranker {
|
|
95
92
|
Ok(result)
|
96
93
|
}
|
97
94
|
|
98
|
-
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> Result<RArray, Error> {
|
95
|
+
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
|
99
96
|
let documents: Vec<String> = documents.to_vec()?;
|
100
97
|
|
101
98
|
// Create query-document pairs for cross-encoder
|
@@ -104,8 +101,8 @@ impl Reranker {
|
|
104
101
|
.map(|d| (query.clone(), d.clone()).into())
|
105
102
|
.collect();
|
106
103
|
|
107
|
-
// Tokenize batch
|
108
|
-
let encodings = self.tokenizer.encode_batch(query_and_docs, true)
|
104
|
+
// Tokenize batch using inner tokenizer for access to token type IDs
|
105
|
+
let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
|
109
106
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
110
107
|
|
111
108
|
// Convert to tensors
|
@@ -256,12 +253,18 @@ impl Reranker {
|
|
256
253
|
}
|
257
254
|
Ok(result_array)
|
258
255
|
}
|
256
|
+
|
257
|
+
/// Get the tokenizer used by this model
|
258
|
+
pub fn tokenizer(&self) -> std::result::Result<crate::ruby::tokenizer::Tokenizer, Error> {
|
259
|
+
Ok(crate::ruby::tokenizer::Tokenizer(self.tokenizer.clone()))
|
260
|
+
}
|
259
261
|
}
|
260
262
|
|
261
|
-
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
263
|
+
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
262
264
|
let c_reranker = rb_candle.define_class("Reranker", class::object())?;
|
263
265
|
c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
|
264
266
|
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
265
267
|
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
268
|
+
c_reranker.define_method("tokenizer", method!(Reranker::tokenizer, 0))?;
|
266
269
|
Ok(())
|
267
270
|
}
|