red-candle 1.0.2 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +244 -6
- data/README.md +38 -3
- data/Rakefile +46 -1
- data/ext/candle/Cargo.toml +2 -0
- data/ext/candle/src/lib.rs +2 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +316 -0
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +11 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/mod.rs +5 -0
- data/ext/candle/src/llm/phi.rs +301 -0
- data/ext/candle/src/llm/quantized_gguf.rs +173 -9
- data/ext/candle/src/llm/qwen.rs +245 -0
- data/ext/candle/src/llm/text_generation.rs +183 -26
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/device.rs +5 -0
- data/ext/candle/src/ruby/llm.rs +119 -55
- data/ext/candle/src/ruby/mod.rs +1 -0
- data/ext/candle/src/ruby/structured.rs +47 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/lib/candle/llm.rb +203 -2
- data/lib/candle/version.rb +1 -1
- metadata +14 -4
@@ -0,0 +1,245 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_transformers::models::qwen2::{Config, Model as QwenModel};
|
3
|
+
use hf_hub::api::tokio::Api;
|
4
|
+
use tokenizers::Tokenizer;
|
5
|
+
|
6
|
+
use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
7
|
+
|
8
|
+
/// Qwen model wrapper for text generation
|
9
|
+
#[derive(Debug)]
|
10
|
+
pub struct Qwen {
|
11
|
+
model: QwenModel,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
}
|
17
|
+
|
18
|
+
impl Qwen {
|
19
|
+
pub fn eos_token_id(&self) -> u32 {
|
20
|
+
self.eos_token_id
|
21
|
+
}
|
22
|
+
|
23
|
+
/// Get the tokenizer
|
24
|
+
pub fn tokenizer(&self) -> &TokenizerWrapper {
|
25
|
+
&self.tokenizer
|
26
|
+
}
|
27
|
+
|
28
|
+
/// Clear the KV cache between generations
|
29
|
+
pub fn clear_kv_cache(&mut self) {
|
30
|
+
self.model.clear_kv_cache();
|
31
|
+
}
|
32
|
+
|
33
|
+
/// Load a Qwen model from HuggingFace
|
34
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
35
|
+
let api = Api::new()
|
36
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
37
|
+
|
38
|
+
let repo = api.model(model_id.to_string());
|
39
|
+
|
40
|
+
// Download configuration
|
41
|
+
let config_filename = repo.get("config.json").await
|
42
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
43
|
+
let config_str = std::fs::read_to_string(config_filename)?;
|
44
|
+
let config: Config = serde_json::from_str(&config_str)
|
45
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
46
|
+
|
47
|
+
// Download tokenizer
|
48
|
+
let tokenizer_filename = repo.get("tokenizer.json").await
|
49
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
50
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
51
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
52
|
+
|
53
|
+
// Determine EOS token
|
54
|
+
let vocab = tokenizer.get_vocab(true);
|
55
|
+
let eos_token_id = vocab.get("<|endoftext|>")
|
56
|
+
.or_else(|| vocab.get("<|im_end|>"))
|
57
|
+
.or_else(|| vocab.get("</s>"))
|
58
|
+
.copied()
|
59
|
+
.unwrap_or(151643); // Default Qwen3 EOS token
|
60
|
+
|
61
|
+
// Download model weights
|
62
|
+
// NOTE: Qwen uses hardcoded shard counts based on model size rather than
|
63
|
+
// reading model.safetensors.index.json. This works for official Qwen models
|
64
|
+
// but may fail for custom configurations with different shard counts.
|
65
|
+
let mut filenames = vec![];
|
66
|
+
let num_shards = if model_id.contains("72b") || model_id.contains("72B") { 8 }
|
67
|
+
else if model_id.contains("14b") || model_id.contains("14B") { 3 }
|
68
|
+
else { 1 };
|
69
|
+
|
70
|
+
if num_shards == 1 {
|
71
|
+
// Single file model
|
72
|
+
let filename = repo.get("model.safetensors").await
|
73
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
|
74
|
+
filenames.push(filename);
|
75
|
+
} else {
|
76
|
+
// Sharded model
|
77
|
+
for shard_idx in 1..=num_shards {
|
78
|
+
let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
|
79
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
|
80
|
+
filenames.push(filename);
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
// Load the model
|
85
|
+
let vb = unsafe {
|
86
|
+
candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
|
87
|
+
};
|
88
|
+
|
89
|
+
let model = QwenModel::new(&config, vb)?;
|
90
|
+
|
91
|
+
Ok(Self {
|
92
|
+
model,
|
93
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
94
|
+
device,
|
95
|
+
model_id: model_id.to_string(),
|
96
|
+
eos_token_id,
|
97
|
+
})
|
98
|
+
}
|
99
|
+
|
100
|
+
/// Apply Qwen chat template to messages
|
101
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
102
|
+
let mut prompt = String::new();
|
103
|
+
|
104
|
+
for message in messages {
|
105
|
+
let role = message["role"].as_str().unwrap_or("");
|
106
|
+
let content = message["content"].as_str().unwrap_or("");
|
107
|
+
|
108
|
+
match role {
|
109
|
+
"system" => {
|
110
|
+
prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", content));
|
111
|
+
}
|
112
|
+
"user" => {
|
113
|
+
prompt.push_str(&format!("<|im_start|>user\n{}<|im_end|>\n", content));
|
114
|
+
}
|
115
|
+
"assistant" => {
|
116
|
+
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
117
|
+
}
|
118
|
+
_ => {}
|
119
|
+
}
|
120
|
+
}
|
121
|
+
|
122
|
+
// Add generation prompt
|
123
|
+
prompt.push_str("<|im_start|>assistant\n");
|
124
|
+
|
125
|
+
Ok(prompt)
|
126
|
+
}
|
127
|
+
|
128
|
+
fn generate_tokens(
|
129
|
+
&mut self,
|
130
|
+
prompt_tokens: Vec<u32>,
|
131
|
+
config: &GenerationConfig,
|
132
|
+
mut callback: Option<impl FnMut(&str)>,
|
133
|
+
) -> CandleResult<Vec<u32>> {
|
134
|
+
let mut text_gen = TextGeneration::new(config);
|
135
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
136
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
137
|
+
|
138
|
+
let mut all_tokens = prompt_tokens.clone();
|
139
|
+
let start_gen = all_tokens.len();
|
140
|
+
|
141
|
+
for index in 0..config.max_length {
|
142
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
143
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
144
|
+
let ctxt = &all_tokens[start_pos..];
|
145
|
+
|
146
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
147
|
+
let logits = self.model.forward(&input, start_pos, None)?;
|
148
|
+
let logits = logits.squeeze(0)?;
|
149
|
+
|
150
|
+
// Handle different output shapes
|
151
|
+
let logits = if logits.dims().len() == 2 {
|
152
|
+
let seq_len = logits.dim(0)?;
|
153
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
154
|
+
} else {
|
155
|
+
logits
|
156
|
+
};
|
157
|
+
|
158
|
+
let logits = logits.to_dtype(DType::F32)?;
|
159
|
+
|
160
|
+
let next_token = text_gen.sample_next_token(&logits)?;
|
161
|
+
|
162
|
+
all_tokens.push(next_token);
|
163
|
+
|
164
|
+
// Stream callback
|
165
|
+
if let Some(ref mut cb) = callback {
|
166
|
+
if config.debug_tokens {
|
167
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
168
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
169
|
+
} else {
|
170
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
171
|
+
cb(&decoded_text);
|
172
|
+
}
|
173
|
+
}
|
174
|
+
|
175
|
+
// Check stop conditions
|
176
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
177
|
+
break;
|
178
|
+
}
|
179
|
+
|
180
|
+
// Check if constraint is satisfied (early stopping)
|
181
|
+
if config.stop_on_constraint_satisfaction {
|
182
|
+
let satisfied = if config.stop_on_match {
|
183
|
+
text_gen.is_constraint_satisfied_stop_on_match()
|
184
|
+
} else {
|
185
|
+
text_gen.is_constraint_satisfied()
|
186
|
+
};
|
187
|
+
if satisfied {
|
188
|
+
break;
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
// Check stop sequences
|
193
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
194
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
195
|
+
break;
|
196
|
+
}
|
197
|
+
}
|
198
|
+
|
199
|
+
Ok(if config.include_prompt {
|
200
|
+
all_tokens
|
201
|
+
} else {
|
202
|
+
all_tokens[start_gen..].to_vec()
|
203
|
+
})
|
204
|
+
}
|
205
|
+
}
|
206
|
+
|
207
|
+
impl TextGenerator for Qwen {
|
208
|
+
fn generate(
|
209
|
+
&mut self,
|
210
|
+
prompt: &str,
|
211
|
+
config: &GenerationConfig,
|
212
|
+
) -> CandleResult<String> {
|
213
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
214
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
215
|
+
|
216
|
+
if config.debug_tokens {
|
217
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
218
|
+
} else {
|
219
|
+
self.tokenizer.decode(&output_tokens, true)
|
220
|
+
}
|
221
|
+
}
|
222
|
+
|
223
|
+
fn generate_stream(
|
224
|
+
&mut self,
|
225
|
+
prompt: &str,
|
226
|
+
config: &GenerationConfig,
|
227
|
+
mut callback: impl FnMut(&str),
|
228
|
+
) -> CandleResult<String> {
|
229
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
230
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
231
|
+
self.tokenizer.decode(&output_tokens, true)
|
232
|
+
}
|
233
|
+
|
234
|
+
fn model_name(&self) -> &str {
|
235
|
+
&self.model_id
|
236
|
+
}
|
237
|
+
|
238
|
+
fn device(&self) -> &Device {
|
239
|
+
&self.device
|
240
|
+
}
|
241
|
+
|
242
|
+
fn clear_cache(&mut self) {
|
243
|
+
self.clear_kv_cache();
|
244
|
+
}
|
245
|
+
}
|
@@ -1,42 +1,45 @@
|
|
1
1
|
use candle_core::{Result as CandleResult, Tensor};
|
2
2
|
use candle_transformers::generation::LogitsProcessor;
|
3
|
+
use std::sync::Arc;
|
3
4
|
|
4
5
|
use super::GenerationConfig;
|
6
|
+
use crate::structured::Index;
|
5
7
|
|
6
8
|
/// Helper struct for text generation process
|
7
9
|
pub struct TextGeneration {
|
8
10
|
logits_processor: LogitsProcessor,
|
9
11
|
tokens: Vec<u32>,
|
10
12
|
eos_token_id: Option<u32>,
|
13
|
+
repetition_penalty: f32,
|
14
|
+
repetition_penalty_last_n: usize,
|
15
|
+
constraint: Option<Arc<Index>>,
|
16
|
+
constraint_state: Option<u32>,
|
17
|
+
constraint_completed: bool,
|
18
|
+
tokens_since_constraint_start: usize,
|
11
19
|
}
|
12
20
|
|
13
21
|
impl TextGeneration {
|
14
|
-
pub fn new(
|
15
|
-
seed
|
16
|
-
|
17
|
-
|
18
|
-
_top_k: Option<usize>,
|
19
|
-
_repetition_penalty: f32,
|
20
|
-
_repetition_penalty_last_n: usize,
|
21
|
-
) -> Self {
|
22
|
-
let logits_processor = LogitsProcessor::new(seed, temperature, top_p);
|
23
|
-
|
24
|
-
Self {
|
22
|
+
pub fn new(config: &GenerationConfig) -> Self {
|
23
|
+
let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
|
24
|
+
|
25
|
+
let mut text_gen = Self {
|
25
26
|
logits_processor,
|
26
27
|
tokens: Vec::new(),
|
27
28
|
eos_token_id: None,
|
29
|
+
repetition_penalty: config.repetition_penalty,
|
30
|
+
repetition_penalty_last_n: config.repetition_penalty_last_n,
|
31
|
+
constraint: None,
|
32
|
+
constraint_state: None,
|
33
|
+
constraint_completed: false,
|
34
|
+
tokens_since_constraint_start: 0,
|
35
|
+
};
|
36
|
+
|
37
|
+
// Set constraint if provided
|
38
|
+
if let Some(ref constraint) = config.constraint {
|
39
|
+
text_gen.set_constraint(Arc::clone(constraint));
|
28
40
|
}
|
29
|
-
|
30
|
-
|
31
|
-
pub fn from_config(config: &GenerationConfig) -> Self {
|
32
|
-
Self::new(
|
33
|
-
config.seed,
|
34
|
-
Some(config.temperature),
|
35
|
-
config.top_p,
|
36
|
-
config.top_k,
|
37
|
-
config.repetition_penalty,
|
38
|
-
config.repetition_penalty_last_n,
|
39
|
-
)
|
41
|
+
|
42
|
+
text_gen
|
40
43
|
}
|
41
44
|
|
42
45
|
pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
|
@@ -55,6 +58,38 @@ impl TextGeneration {
|
|
55
58
|
self.tokens.push(token);
|
56
59
|
}
|
57
60
|
|
61
|
+
pub fn set_constraint(&mut self, constraint: Arc<Index>) {
|
62
|
+
// Initialize with the first state
|
63
|
+
self.constraint_state = Some(constraint.initial_state());
|
64
|
+
self.constraint = Some(constraint);
|
65
|
+
self.constraint_completed = false;
|
66
|
+
self.tokens_since_constraint_start = self.tokens.len();
|
67
|
+
}
|
68
|
+
|
69
|
+
/// Apply constraints to logits by masking disallowed tokens
|
70
|
+
fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
|
71
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
72
|
+
let device = logits.device();
|
73
|
+
let vocab_size = logits.dims1()?;
|
74
|
+
|
75
|
+
// Get allowed tokens from the constraint index for current state
|
76
|
+
if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
|
77
|
+
// Create a mask where allowed tokens have value 0 and others have -inf
|
78
|
+
let mut mask = vec![f32::NEG_INFINITY; vocab_size];
|
79
|
+
for &token_id in &allowed_tokens {
|
80
|
+
if (token_id as usize) < vocab_size {
|
81
|
+
mask[token_id as usize] = 0.0;
|
82
|
+
}
|
83
|
+
}
|
84
|
+
|
85
|
+
// Apply mask to logits
|
86
|
+
let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
|
87
|
+
*logits = logits.add(&mask_tensor)?;
|
88
|
+
}
|
89
|
+
}
|
90
|
+
Ok(())
|
91
|
+
}
|
92
|
+
|
58
93
|
/// Apply repetition penalty to logits
|
59
94
|
pub fn apply_repetition_penalty(
|
60
95
|
&self,
|
@@ -94,22 +129,131 @@ impl TextGeneration {
|
|
94
129
|
pub fn sample_next_token(
|
95
130
|
&mut self,
|
96
131
|
logits: &Tensor,
|
97
|
-
repetition_penalty: Option<(f32, usize)>,
|
98
132
|
) -> CandleResult<u32> {
|
99
133
|
let mut logits = logits.clone();
|
100
134
|
|
101
|
-
// Apply repetition penalty
|
102
|
-
if
|
103
|
-
self.apply_repetition_penalty(&mut logits,
|
135
|
+
// Apply repetition penalty using stored parameters
|
136
|
+
if self.repetition_penalty != 1.0 {
|
137
|
+
self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
|
104
138
|
}
|
105
139
|
|
140
|
+
// Apply constraints if active
|
141
|
+
self.apply_constraints(&mut logits)?;
|
142
|
+
|
106
143
|
// Sample token
|
107
144
|
let next_token = self.logits_processor.sample(&logits)?;
|
108
145
|
self.tokens.push(next_token);
|
109
146
|
|
147
|
+
// Update constraint state if active
|
148
|
+
if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
|
149
|
+
// Get the next state
|
150
|
+
let next_state = constraint_index.next_state(¤t_state, &next_token);
|
151
|
+
|
152
|
+
// Check if we're transitioning to a state with no allowed tokens (completion)
|
153
|
+
if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
|
154
|
+
// Check if we've transitioned from a constrained state to an unconstrained state
|
155
|
+
// This happens when the pattern is complete and the FSM allows "anything"
|
156
|
+
|
157
|
+
let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(¤t_state) {
|
158
|
+
// Consider it constrained if we have a limited set of allowed tokens
|
159
|
+
allowed.len() < 1000 // Arbitrary threshold for "constrained"
|
160
|
+
} else {
|
161
|
+
true // No tokens allowed is definitely constrained
|
162
|
+
};
|
163
|
+
|
164
|
+
let next_constrained = if let Some(next_state_val) = next_state {
|
165
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
166
|
+
allowed.is_empty() || allowed.len() < 1000
|
167
|
+
} else {
|
168
|
+
true
|
169
|
+
}
|
170
|
+
} else {
|
171
|
+
true
|
172
|
+
};
|
173
|
+
|
174
|
+
// If we're transitioning from constrained to unconstrained, we've completed the pattern
|
175
|
+
if current_constrained && !next_constrained {
|
176
|
+
self.constraint_completed = true;
|
177
|
+
}
|
178
|
+
|
179
|
+
// Also check if next state has no allowed tokens at all
|
180
|
+
if let Some(next_state_val) = next_state {
|
181
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
182
|
+
if allowed.is_empty() {
|
183
|
+
self.constraint_completed = true;
|
184
|
+
}
|
185
|
+
} else {
|
186
|
+
// None means no tokens allowed - constraint is complete
|
187
|
+
self.constraint_completed = true;
|
188
|
+
}
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
self.constraint_state = next_state;
|
193
|
+
}
|
194
|
+
|
110
195
|
Ok(next_token)
|
111
196
|
}
|
112
197
|
|
198
|
+
/// Check if the constraint is satisfied (reached a valid completion state)
|
199
|
+
pub fn is_constraint_satisfied(&self) -> bool {
|
200
|
+
// If we've explicitly marked the constraint as completed, return true
|
201
|
+
if self.constraint_completed {
|
202
|
+
return true;
|
203
|
+
}
|
204
|
+
|
205
|
+
// Also check the current state
|
206
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
207
|
+
// Check if the constraint has reached a state where it could validly end
|
208
|
+
// This happens when:
|
209
|
+
// 1. We have no more allowed tokens (constraint fully satisfied)
|
210
|
+
// 2. The EOS token is in the allowed tokens (optional ending)
|
211
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
212
|
+
// If no tokens are allowed, the constraint is fully satisfied
|
213
|
+
if allowed.is_empty() {
|
214
|
+
return true;
|
215
|
+
}
|
216
|
+
|
217
|
+
// If EOS token is allowed, we've reached an optional completion point
|
218
|
+
if let Some(eos) = self.eos_token_id {
|
219
|
+
if allowed.contains(&eos) {
|
220
|
+
return true;
|
221
|
+
}
|
222
|
+
}
|
223
|
+
} else {
|
224
|
+
// None means no tokens allowed - constraint is satisfied
|
225
|
+
return true;
|
226
|
+
}
|
227
|
+
}
|
228
|
+
false
|
229
|
+
}
|
230
|
+
|
231
|
+
/// Check if the constraint is satisfied when stop_on_match is true
|
232
|
+
pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
|
233
|
+
// When stop_on_match is true, we stop as soon as the constraint is completed
|
234
|
+
if self.constraint_completed {
|
235
|
+
return true;
|
236
|
+
}
|
237
|
+
|
238
|
+
// Also check if we're currently in a state that could be a valid end
|
239
|
+
// This is important for patterns like phone numbers where after matching
|
240
|
+
// the pattern, the FSM might allow any token (including more numbers)
|
241
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
242
|
+
// Check if we've generated at least one token since constraint start
|
243
|
+
if self.tokens.len() > self.tokens_since_constraint_start {
|
244
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
245
|
+
// If the allowed tokens set is very large (unconstrained),
|
246
|
+
// it means the pattern has been satisfied
|
247
|
+
if allowed.len() > 1000 {
|
248
|
+
return true;
|
249
|
+
}
|
250
|
+
}
|
251
|
+
}
|
252
|
+
}
|
253
|
+
|
254
|
+
false
|
255
|
+
}
|
256
|
+
|
113
257
|
/// Check if we should stop generation
|
114
258
|
pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
|
115
259
|
if self.tokens.len() >= max_length {
|
@@ -122,6 +266,19 @@ impl TextGeneration {
|
|
122
266
|
}
|
123
267
|
}
|
124
268
|
|
269
|
+
// Check if we've reached a final state in constraint
|
270
|
+
// A state is considered final if it has no allowed tokens
|
271
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
272
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
273
|
+
if allowed.is_empty() {
|
274
|
+
return true;
|
275
|
+
}
|
276
|
+
} else {
|
277
|
+
// None means no tokens allowed - we're done
|
278
|
+
return true;
|
279
|
+
}
|
280
|
+
}
|
281
|
+
|
125
282
|
false
|
126
283
|
}
|
127
284
|
|
data/ext/candle/src/ner.rs
CHANGED
@@ -39,13 +39,9 @@ impl NER {
|
|
39
39
|
pub fn new(model_id: String, device: Option<Device>, tokenizer_id: Option<String>) -> Result<Self> {
|
40
40
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
41
41
|
|
42
|
-
|
43
|
-
let device_clone = device.clone();
|
44
|
-
let model_id_clone = model_id.clone();
|
45
|
-
|
46
|
-
let handle = std::thread::spawn(move || -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
42
|
+
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, NERConfig), Box<dyn std::error::Error + Send + Sync>> {
|
47
43
|
let api = Api::new()?;
|
48
|
-
let repo = api.repo(Repo::new(
|
44
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
49
45
|
|
50
46
|
// Download model files
|
51
47
|
let config_filename = repo.get("config.json")?;
|
@@ -92,7 +88,7 @@ impl NER {
|
|
92
88
|
|
93
89
|
// Load model weights
|
94
90
|
let vb = unsafe {
|
95
|
-
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &
|
91
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
96
92
|
};
|
97
93
|
|
98
94
|
// Load BERT model
|
@@ -106,10 +102,10 @@ impl NER {
|
|
106
102
|
)?;
|
107
103
|
|
108
104
|
Ok((model, TokenizerWrapper::new(tokenizer), classifier, ner_config))
|
109
|
-
});
|
105
|
+
})();
|
110
106
|
|
111
|
-
match
|
112
|
-
Ok(
|
107
|
+
match result {
|
108
|
+
Ok((model, tokenizer, classifier, config)) => {
|
113
109
|
Ok(Self {
|
114
110
|
model,
|
115
111
|
tokenizer,
|
@@ -119,28 +115,20 @@ impl NER {
|
|
119
115
|
model_id,
|
120
116
|
})
|
121
117
|
}
|
122
|
-
|
118
|
+
Err(e) => Err(Error::new(
|
123
119
|
magnus::exception::runtime_error(),
|
124
120
|
format!("Failed to load NER model: {}", e)
|
125
121
|
)),
|
126
|
-
Err(_) => Err(Error::new(
|
127
|
-
magnus::exception::runtime_error(),
|
128
|
-
"Thread panicked while loading NER model"
|
129
|
-
)),
|
130
122
|
}
|
131
123
|
}
|
132
124
|
|
133
|
-
///
|
134
|
-
|
135
|
-
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
136
|
-
|
125
|
+
/// Common tokenization and prediction logic
|
126
|
+
fn tokenize_and_predict(&self, text: &str) -> Result<(tokenizers::Encoding, Vec<Vec<f32>>)> {
|
137
127
|
// Tokenize the text
|
138
|
-
let encoding = self.tokenizer.inner().encode(text
|
128
|
+
let encoding = self.tokenizer.inner().encode(text, true)
|
139
129
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
140
130
|
|
141
131
|
let token_ids = encoding.get_ids();
|
142
|
-
let tokens = encoding.get_tokens();
|
143
|
-
let offsets = encoding.get_offsets();
|
144
132
|
|
145
133
|
// Convert to tensors
|
146
134
|
let input_ids = Tensor::new(token_ids, &self.device)
|
@@ -171,6 +159,19 @@ impl NER {
|
|
171
159
|
.to_vec2()
|
172
160
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
173
161
|
|
162
|
+
Ok((encoding, probs_vec))
|
163
|
+
}
|
164
|
+
|
165
|
+
/// Extract entities from text with confidence scores
|
166
|
+
pub fn extract_entities(&self, text: String, confidence_threshold: Option<f64>) -> Result<RArray> {
|
167
|
+
let threshold = confidence_threshold.unwrap_or(0.9) as f32;
|
168
|
+
|
169
|
+
// Use common tokenization and prediction logic
|
170
|
+
let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
|
171
|
+
|
172
|
+
let tokens = encoding.get_tokens();
|
173
|
+
let offsets = encoding.get_offsets();
|
174
|
+
|
174
175
|
// Extract entities with BIO decoding
|
175
176
|
let entities = self.decode_entities(
|
176
177
|
&text,
|
@@ -199,38 +200,11 @@ impl NER {
|
|
199
200
|
|
200
201
|
/// Get token-level predictions with labels and confidence scores
|
201
202
|
pub fn predict_tokens(&self, text: String) -> Result<RArray> {
|
202
|
-
//
|
203
|
-
let encoding = self.
|
204
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
203
|
+
// Use common tokenization and prediction logic
|
204
|
+
let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
|
205
205
|
|
206
|
-
let token_ids = encoding.get_ids();
|
207
206
|
let tokens = encoding.get_tokens();
|
208
207
|
|
209
|
-
// Convert to tensors
|
210
|
-
let input_ids = Tensor::new(token_ids, &self.device)
|
211
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
212
|
-
.unsqueeze(0)
|
213
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
214
|
-
|
215
|
-
let attention_mask = Tensor::ones_like(&input_ids)
|
216
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
217
|
-
let token_type_ids = Tensor::zeros_like(&input_ids)
|
218
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
219
|
-
|
220
|
-
// Forward pass
|
221
|
-
let output = self.model.forward(&input_ids, &token_type_ids, Some(&attention_mask))
|
222
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
223
|
-
let logits = self.classifier.forward(&output)
|
224
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
225
|
-
let probs = candle_nn::ops::softmax(&logits, 2)
|
226
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
227
|
-
|
228
|
-
// Get predictions
|
229
|
-
let probs_vec: Vec<Vec<f32>> = probs.squeeze(0)
|
230
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?
|
231
|
-
.to_vec2()
|
232
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?;
|
233
|
-
|
234
208
|
// Build result array
|
235
209
|
let result = RArray::new();
|
236
210
|
for (i, (token, probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
|