red-candle 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,245 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_transformers::models::qwen2::{Config, Model as QwenModel};
3
+ use hf_hub::api::tokio::Api;
4
+ use tokenizers::Tokenizer;
5
+
6
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
7
+
8
+ /// Qwen model wrapper for text generation
9
+ #[derive(Debug)]
10
+ pub struct Qwen {
11
+ model: QwenModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Qwen {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
23
+ /// Get the tokenizer
24
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
25
+ &self.tokenizer
26
+ }
27
+
28
+ /// Clear the KV cache between generations
29
+ pub fn clear_kv_cache(&mut self) {
30
+ self.model.clear_kv_cache();
31
+ }
32
+
33
+ /// Load a Qwen model from HuggingFace
34
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
35
+ let api = Api::new()
36
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
+
38
+ let repo = api.model(model_id.to_string());
39
+
40
+ // Download configuration
41
+ let config_filename = repo.get("config.json").await
42
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
43
+ let config_str = std::fs::read_to_string(config_filename)?;
44
+ let config: Config = serde_json::from_str(&config_str)
45
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
46
+
47
+ // Download tokenizer
48
+ let tokenizer_filename = repo.get("tokenizer.json").await
49
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
50
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
51
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
52
+
53
+ // Determine EOS token
54
+ let vocab = tokenizer.get_vocab(true);
55
+ let eos_token_id = vocab.get("<|endoftext|>")
56
+ .or_else(|| vocab.get("<|im_end|>"))
57
+ .or_else(|| vocab.get("</s>"))
58
+ .copied()
59
+ .unwrap_or(151643); // Default Qwen3 EOS token
60
+
61
+ // Download model weights
62
+ // NOTE: Qwen uses hardcoded shard counts based on model size rather than
63
+ // reading model.safetensors.index.json. This works for official Qwen models
64
+ // but may fail for custom configurations with different shard counts.
65
+ let mut filenames = vec![];
66
+ let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
67
+ else if model_id.contains("14b") || model_id.contains("14B") { 3 }
68
+ else { 1 };
69
+
70
+ if num_shards == 1 {
71
+ // Single file model
72
+ let filename = repo.get("model.safetensors").await
73
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
74
+ filenames.push(filename);
75
+ } else {
76
+ // Sharded model
77
+ for shard_idx in 1..=num_shards {
78
+ let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
79
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
80
+ filenames.push(filename);
81
+ }
82
+ }
83
+
84
+ // Load the model
85
+ let vb = unsafe {
86
+ candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
87
+ };
88
+
89
+ let model = QwenModel::new(&config, vb)?;
90
+
91
+ Ok(Self {
92
+ model,
93
+ tokenizer: TokenizerWrapper::new(tokenizer),
94
+ device,
95
+ model_id: model_id.to_string(),
96
+ eos_token_id,
97
+ })
98
+ }
99
+
100
+ /// Apply Qwen chat template to messages
101
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
102
+ let mut prompt = String::new();
103
+
104
+ for message in messages {
105
+ let role = message["role"].as_str().unwrap_or("");
106
+ let content = message["content"].as_str().unwrap_or("");
107
+
108
+ match role {
109
+ "system" => {
110
+ prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
111
+ }
112
+ "user" => {
113
+ prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
114
+ }
115
+ "assistant" => {
116
+ prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
117
+ }
118
+ _ => {}
119
+ }
120
+ }
121
+
122
+ // Add generation prompt
123
+ prompt.push_str("<|im_start|>assistant\n");
124
+
125
+ Ok(prompt)
126
+ }
127
+
128
+ fn generate_tokens(
129
+ &mut self,
130
+ prompt_tokens: Vec<u32>,
131
+ config: &GenerationConfig,
132
+ mut callback: Option<impl FnMut(&str)>,
133
+ ) -> CandleResult<Vec<u32>> {
134
+ let mut text_gen = TextGeneration::new(config);
135
+ text_gen.set_eos_token_id(self.eos_token_id);
136
+ text_gen.set_tokens(prompt_tokens.clone());
137
+
138
+ let mut all_tokens = prompt_tokens.clone();
139
+ let start_gen = all_tokens.len();
140
+
141
+ for index in 0..config.max_length {
142
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
143
+ let start_pos = all_tokens.len().saturating_sub(context_size);
144
+ let ctxt = &all_tokens[start_pos..];
145
+
146
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
147
+ let logits = self.model.forward(&input, start_pos, None)?;
148
+ let logits = logits.squeeze(0)?;
149
+
150
+ // Handle different output shapes
151
+ let logits = if logits.dims().len() == 2 {
152
+ let seq_len = logits.dim(0)?;
153
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
154
+ } else {
155
+ logits
156
+ };
157
+
158
+ let logits = logits.to_dtype(DType::F32)?;
159
+
160
+ let next_token = text_gen.sample_next_token(&logits)?;
161
+
162
+ all_tokens.push(next_token);
163
+
164
+ // Stream callback
165
+ if let Some(ref mut cb) = callback {
166
+ if config.debug_tokens {
167
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
168
+ cb(&format!("[{}:{}]", next_token, token_piece));
169
+ } else {
170
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
171
+ cb(&decoded_text);
172
+ }
173
+ }
174
+
175
+ // Check stop conditions
176
+ if text_gen.should_stop(next_token, config.max_length) {
177
+ break;
178
+ }
179
+
180
+ // Check if constraint is satisfied (early stopping)
181
+ if config.stop_on_constraint_satisfaction {
182
+ let satisfied = if config.stop_on_match {
183
+ text_gen.is_constraint_satisfied_stop_on_match()
184
+ } else {
185
+ text_gen.is_constraint_satisfied()
186
+ };
187
+ if satisfied {
188
+ break;
189
+ }
190
+ }
191
+
192
+ // Check stop sequences
193
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
194
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
195
+ break;
196
+ }
197
+ }
198
+
199
+ Ok(if config.include_prompt {
200
+ all_tokens
201
+ } else {
202
+ all_tokens[start_gen..].to_vec()
203
+ })
204
+ }
205
+ }
206
+
207
+ impl TextGenerator for Qwen {
208
+ fn generate(
209
+ &mut self,
210
+ prompt: &str,
211
+ config: &GenerationConfig,
212
+ ) -> CandleResult<String> {
213
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
214
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
215
+
216
+ if config.debug_tokens {
217
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
218
+ } else {
219
+ self.tokenizer.decode(&output_tokens, true)
220
+ }
221
+ }
222
+
223
+ fn generate_stream(
224
+ &mut self,
225
+ prompt: &str,
226
+ config: &GenerationConfig,
227
+ mut callback: impl FnMut(&str),
228
+ ) -> CandleResult<String> {
229
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
230
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
231
+ self.tokenizer.decode(&output_tokens, true)
232
+ }
233
+
234
+ fn model_name(&self) -> &str {
235
+ &self.model_id
236
+ }
237
+
238
+ fn device(&self) -> &Device {
239
+ &self.device
240
+ }
241
+
242
+ fn clear_cache(&mut self) {
243
+ self.clear_kv_cache();
244
+ }
245
+ }
@@ -1,42 +1,45 @@
1
1
  use candle_core::{Result as CandleResult, Tensor};
2
2
  use candle_transformers::generation::LogitsProcessor;
3
+ use std::sync::Arc;
3
4
 
4
5
  use super::GenerationConfig;
6
+ use crate::structured::Index;
5
7
 
6
8
  /// Helper struct for text generation process
7
9
  pub struct TextGeneration {
8
10
  logits_processor: LogitsProcessor,
9
11
  tokens: Vec<u32>,
10
12
  eos_token_id: Option<u32>,
13
+ repetition_penalty: f32,
14
+ repetition_penalty_last_n: usize,
15
+ constraint: Option<Arc<Index>>,
16
+ constraint_state: Option<u32>,
17
+ constraint_completed: bool,
18
+ tokens_since_constraint_start: usize,
11
19
  }
12
20
 
13
21
  impl TextGeneration {
14
- pub fn new(
15
- seed: u64,
16
- temperature: Option<f64>,
17
- top_p: Option<f64>,
18
- _top_k: Option<usize>,
19
- _repetition_penalty: f32,
20
- _repetition_penalty_last_n: usize,
21
- ) -> Self {
22
- let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
23
-
24
- Self {
22
+ pub fn new(config: &GenerationConfig) -> Self {
23
+ let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
24
+
25
+ let mut text_gen = Self {
25
26
  logits_processor,
26
27
  tokens: Vec::new(),
27
28
  eos_token_id: None,
29
+ repetition_penalty: config.repetition_penalty,
30
+ repetition_penalty_last_n: config.repetition_penalty_last_n,
31
+ constraint: None,
32
+ constraint_state: None,
33
+ constraint_completed: false,
34
+ tokens_since_constraint_start: 0,
35
+ };
36
+
37
+ // Set constraint if provided
38
+ if let Some(ref constraint) = config.constraint {
39
+ text_gen.set_constraint(Arc::clone(constraint));
28
40
  }
29
- }
30
-
31
- pub fn from_config(config: &GenerationConfig) -> Self {
32
- Self::new(
33
- config.seed,
34
- Some(config.temperature),
35
- config.top_p,
36
- config.top_k,
37
- config.repetition_penalty,
38
- config.repetition_penalty_last_n,
39
- )
41
+
42
+ text_gen
40
43
  }
41
44
 
42
45
  pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
@@ -55,6 +58,38 @@ impl TextGeneration {
55
58
  self.tokens.push(token);
56
59
  }
57
60
 
61
+ pub fn set_constraint(&mut self, constraint: Arc<Index>) {
62
+ // Initialize with the first state
63
+ self.constraint_state = Some(constraint.initial_state());
64
+ self.constraint = Some(constraint);
65
+ self.constraint_completed = false;
66
+ self.tokens_since_constraint_start = self.tokens.len();
67
+ }
68
+
69
+ /// Apply constraints to logits by masking disallowed tokens
70
+ fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
71
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
72
+ let device = logits.device();
73
+ let vocab_size = logits.dims1()?;
74
+
75
+ // Get allowed tokens from the constraint index for current state
76
+ if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
77
+ // Create a mask where allowed tokens have value 0 and others have -inf
78
+ let mut mask = vec![f32::NEG_INFINITY; vocab_size];
79
+ for &token_id in &allowed_tokens {
80
+ if (token_id as usize) < vocab_size {
81
+ mask[token_id as usize] = 0.0;
82
+ }
83
+ }
84
+
85
+ // Apply mask to logits
86
+ let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
87
+ *logits = logits.add(&mask_tensor)?;
88
+ }
89
+ }
90
+ Ok(())
91
+ }
92
+
58
93
  /// Apply repetition penalty to logits
59
94
  pub fn apply_repetition_penalty(
60
95
  &self,
@@ -94,22 +129,131 @@ impl TextGeneration {
94
129
  pub fn sample_next_token(
95
130
  &mut self,
96
131
  logits: &Tensor,
97
- repetition_penalty: Option<(f32, usize)>,
98
132
  ) -> CandleResult<u32> {
99
133
  let mut logits = logits.clone();
100
134
 
101
- // Apply repetition penalty if specified
102
- if let Some((penalty, last_n)) = repetition_penalty {
103
- self.apply_repetition_penalty(&mut logits, penalty, last_n)?;
135
+ // Apply repetition penalty using stored parameters
136
+ if self.repetition_penalty != 1.0 {
137
+ self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
104
138
  }
105
139
 
140
+ // Apply constraints if active
141
+ self.apply_constraints(&mut logits)?;
142
+
106
143
  // Sample token
107
144
  let next_token = self.logits_processor.sample(&logits)?;
108
145
  self.tokens.push(next_token);
109
146
 
147
+ // Update constraint state if active
148
+ if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
149
+ // Get the next state
150
+ let next_state = constraint_index.next_state(&current_state, &next_token);
151
+
152
+ // Check if we're transitioning to a state with no allowed tokens (completion)
153
+ if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
154
+ // Check if we've transitioned from a constrained state to an unconstrained state
155
+ // This happens when the pattern is complete and the FSM allows "anything"
156
+
157
+ let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(&current_state) {
158
+ // Consider it constrained if we have a limited set of allowed tokens
159
+ allowed.len() < 1000 // Arbitrary threshold for "constrained"
160
+ } else {
161
+ true // No tokens allowed is definitely constrained
162
+ };
163
+
164
+ let next_constrained = if let Some(next_state_val) = next_state {
165
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
166
+ allowed.is_empty() || allowed.len() < 1000
167
+ } else {
168
+ true
169
+ }
170
+ } else {
171
+ true
172
+ };
173
+
174
+ // If we're transitioning from constrained to unconstrained, we've completed the pattern
175
+ if current_constrained && !next_constrained {
176
+ self.constraint_completed = true;
177
+ }
178
+
179
+ // Also check if next state has no allowed tokens at all
180
+ if let Some(next_state_val) = next_state {
181
+ if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
182
+ if allowed.is_empty() {
183
+ self.constraint_completed = true;
184
+ }
185
+ } else {
186
+ // None means no tokens allowed - constraint is complete
187
+ self.constraint_completed = true;
188
+ }
189
+ }
190
+ }
191
+
192
+ self.constraint_state = next_state;
193
+ }
194
+
110
195
  Ok(next_token)
111
196
  }
112
197
 
198
+ /// Check if the constraint is satisfied (reached a valid completion state)
199
+ pub fn is_constraint_satisfied(&self) -> bool {
200
+ // If we've explicitly marked the constraint as completed, return true
201
+ if self.constraint_completed {
202
+ return true;
203
+ }
204
+
205
+ // Also check the current state
206
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
207
+ // Check if the constraint has reached a state where it could validly end
208
+ // This happens when:
209
+ // 1. We have no more allowed tokens (constraint fully satisfied)
210
+ // 2. The EOS token is in the allowed tokens (optional ending)
211
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
212
+ // If no tokens are allowed, the constraint is fully satisfied
213
+ if allowed.is_empty() {
214
+ return true;
215
+ }
216
+
217
+ // If EOS token is allowed, we've reached an optional completion point
218
+ if let Some(eos) = self.eos_token_id {
219
+ if allowed.contains(&eos) {
220
+ return true;
221
+ }
222
+ }
223
+ } else {
224
+ // None means no tokens allowed - constraint is satisfied
225
+ return true;
226
+ }
227
+ }
228
+ false
229
+ }
230
+
231
+ /// Check if the constraint is satisfied when stop_on_match is true
232
+ pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
233
+ // When stop_on_match is true, we stop as soon as the constraint is completed
234
+ if self.constraint_completed {
235
+ return true;
236
+ }
237
+
238
+ // Also check if we're currently in a state that could be a valid end
239
+ // This is important for patterns like phone numbers where after matching
240
+ // the pattern, the FSM might allow any token (including more numbers)
241
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
242
+ // Check if we've generated at least one token since constraint start
243
+ if self.tokens.len() > self.tokens_since_constraint_start {
244
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
245
+ // If the allowed tokens set is very large (unconstrained),
246
+ // it means the pattern has been satisfied
247
+ if allowed.len() > 1000 {
248
+ return true;
249
+ }
250
+ }
251
+ }
252
+ }
253
+
254
+ false
255
+ }
256
+
113
257
  /// Check if we should stop generation
114
258
  pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
115
259
  if self.tokens.len() >= max_length {
@@ -122,6 +266,19 @@ impl TextGeneration {
122
266
  }
123
267
  }
124
268
 
269
+ // Check if we've reached a final state in constraint
270
+ // A state is considered final if it has no allowed tokens
271
+ if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
272
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
273
+ if allowed.is_empty() {
274
+ return true;
275
+ }
276
+ } else {
277
+ // None means no tokens allowed - we're done
278
+ return true;
279
+ }
280
+ }
281
+
125
282
  false
126
283
  }
127
284
 
@@ -39,13 +39,9 @@ impl NER {
39
39
  pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
40
40
  let device = device.unwrap_or(Device::Cpu).as_device()?;
41
41
 
42
- // Load model in a separate thread to avoid blocking
43
- let device_clone = device.clone();
44
- let model_id_clone = model_id.clone();
45
-
46
- let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
42
+ let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
47
43
  let api = Api::new()?;
48
- let repo = api.repo(Repo::new(model_id_clone.clone(), RepoType::Model));
44
+ let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
49
45
 
50
46
  // Download model files
51
47
  let config_filename = repo.get("config.json")?;
@@ -92,7 +88,7 @@ impl NER {
92
88
 
93
89
  // Load model weights
94
90
  let vb = unsafe {
95
- VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
91
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
96
92
  };
97
93
 
98
94
  // Load BERT model
@@ -106,10 +102,10 @@ impl NER {
106
102
  )?;
107
103
 
108
104
  Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
109
- });
105
+ })();
110
106
 
111
- match handle.join() {
112
- Ok(Ok((model, tokenizer, classifier, config))) => {
107
+ match result {
108
+ Ok((model, tokenizer, classifier, config)) => {
113
109
  Ok(Self {
114
110
  model,
115
111
  tokenizer,
@@ -119,28 +115,20 @@ impl NER {
119
115
  model_id,
120
116
  })
121
117
  }
122
- Ok(Err(e)) => Err(Error::new(
118
+ Err(e) => Err(Error::new(
123
119
  magnus::exception::runtime_error(),
124
120
  format!("Failed to load NER model: {}", e)
125
121
  )),
126
- Err(_) => Err(Error::new(
127
- magnus::exception::runtime_error(),
128
- "Thread panicked while loading NER model"
129
- )),
130
122
  }
131
123
  }
132
124
 
133
- /// Extract entities from text with confidence scores
134
- pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
135
- let threshold = confidence_threshold.unwrap_or(0.9) as f32;
136
-
125
+ /// Common tokenization and prediction logic
126
+ fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
137
127
  // Tokenize the text
138
- let encoding = self.tokenizer.inner().encode(text.as_str(), true)
128
+ let encoding = self.tokenizer.inner().encode(text, true)
139
129
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
140
130
 
141
131
  let token_ids = encoding.get_ids();
142
- let tokens = encoding.get_tokens();
143
- let offsets = encoding.get_offsets();
144
132
 
145
133
  // Convert to tensors
146
134
  let input_ids = Tensor::new(token_ids, &self.device)
@@ -171,6 +159,19 @@ impl NER {
171
159
  .to_vec2()
172
160
  .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
173
161
 
162
+ Ok((encoding, probs_vec))
163
+ }
164
+
165
+ /// Extract entities from text with confidence scores
166
+ pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
167
+ let threshold = confidence_threshold.unwrap_or(0.9) as f32;
168
+
169
+ // Use common tokenization and prediction logic
170
+ let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
171
+
172
+ let tokens = encoding.get_tokens();
173
+ let offsets = encoding.get_offsets();
174
+
174
175
  // Extract entities with BIO decoding
175
176
  let entities = self.decode_entities(
176
177
  &text,
@@ -199,38 +200,11 @@ impl NER {
199
200
 
200
201
  /// Get token-level predictions with labels and confidence scores
201
202
  pub fn predict_tokens(&self, text: String) -> Result<RArray> {
202
- // Tokenize the text
203
- let encoding = self.tokenizer.inner().encode(text.as_str(), true)
204
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
203
+ // Use common tokenization and prediction logic
204
+ let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
205
205
 
206
- let token_ids = encoding.get_ids();
207
206
  let tokens = encoding.get_tokens();
208
207
 
209
- // Convert to tensors
210
- let input_ids = Tensor::new(token_ids, &self.device)
211
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
212
- .unsqueeze(0)
213
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
214
-
215
- let attention_mask = Tensor::ones_like(&input_ids)
216
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
217
- let token_type_ids = Tensor::zeros_like(&input_ids)
218
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
219
-
220
- // Forward pass
221
- let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
222
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
223
- let logits = self.classifier.forward(&output)
224
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
225
- let probs = candle_nn::ops::softmax(&logits, 2)
226
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
227
-
228
- // Get predictions
229
- let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
230
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
231
- .to_vec2()
232
- .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
233
-
234
208
  // Build result array
235
209
  let result = RArray::new();
236
210
  for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {