phrasekit 0.2.0-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,320 @@
1
+ use serde::{Deserialize, Serialize};
2
+ use std::collections::HashMap;
3
+ use std::env;
4
+ use std::fs::File;
5
+ use std::io::{BufRead, BufReader, BufWriter, Write};
6
+ use std::path::Path;
7
+ use std::process;
8
+
9
+ #[path = "../payload.rs"]
10
+ mod payload;
11
+
12
+ use payload::Payload;
13
+
14
+ #[derive(Debug, Deserialize)]
15
+ struct TagConfig {
16
+ automaton_path: String,
17
+ payloads_path: String,
18
+ manifest_path: String,
19
+ vocab_path: String,
20
+ #[serde(default = "default_policy")]
21
+ policy: String,
22
+ #[serde(default = "default_max_spans")]
23
+ max_spans: usize,
24
+ #[serde(default = "default_label")]
25
+ label: String,
26
+ }
27
+
28
+ fn default_policy() -> String {
29
+ "leftmost_longest".to_string()
30
+ }
31
+
32
+ fn default_max_spans() -> usize {
33
+ 100
34
+ }
35
+
36
+ fn default_label() -> String {
37
+ "PHRASE".to_string()
38
+ }
39
+
40
+ #[derive(Debug, Deserialize)]
41
+ struct InputDocument {
42
+ doc_id: String,
43
+ tokens: Vec<String>,
44
+ }
45
+
46
+ #[derive(Debug, Serialize)]
47
+ struct OutputDocument {
48
+ doc_id: String,
49
+ tokens: Vec<String>,
50
+ spans: Vec<Span>,
51
+ }
52
+
53
+ #[derive(Debug, Serialize)]
54
+ struct Span {
55
+ start: usize,
56
+ end: usize,
57
+ phrase_id: u32,
58
+ label: String,
59
+ }
60
+
61
+ #[derive(Debug, Deserialize)]
62
+ struct Vocabulary {
63
+ tokens: HashMap<String, u32>,
64
+ special_tokens: HashMap<String, u32>,
65
+ }
66
+
67
+ #[derive(Debug)]
68
+ struct TaggingStats {
69
+ documents: usize,
70
+ total_spans: usize,
71
+ docs_with_spans: usize,
72
+ }
73
+
74
+ fn encode_tokens(tokens: &[String], vocab: &Vocabulary) -> Vec<u32> {
75
+ let unk_id = vocab.special_tokens.get("<UNK>").copied().unwrap_or(0);
76
+
77
+ tokens
78
+ .iter()
79
+ .map(|token| {
80
+ let normalized = token.to_lowercase();
81
+ vocab.tokens.get(&normalized).copied().unwrap_or(unk_id)
82
+ })
83
+ .collect()
84
+ }
85
+
86
+ fn tag_corpus(
87
+ corpus_path: &str,
88
+ config: &TagConfig,
89
+ output_path: &str,
90
+ ) -> Result<TaggingStats, Box<dyn std::error::Error>> {
91
+ println!("🏷️ PhraseKit Corpus Tagging");
92
+ println!("════════════════════════════════════════");
93
+ println!("Corpus: {}", corpus_path);
94
+ println!("Config: <config>");
95
+ println!("Output: {}", output_path);
96
+ println!();
97
+
98
+ println!("📚 Loading matcher artifacts...");
99
+
100
+ let vocab_data = std::fs::read_to_string(&config.vocab_path)?;
101
+ let vocab: Vocabulary = serde_json::from_str(&vocab_data)?;
102
+ println!(" ✓ Loaded vocabulary ({} tokens)", vocab.tokens.len());
103
+
104
+ use daachorse::DoubleArrayAhoCorasick;
105
+ let automaton_bytes = std::fs::read(&config.automaton_path)?;
106
+ let (automaton, _): (DoubleArrayAhoCorasick<u32>, _) = unsafe {
107
+ DoubleArrayAhoCorasick::deserialize_unchecked(&automaton_bytes)
108
+ };
109
+ println!(" ✓ Loaded automaton");
110
+
111
+ let payloads_file = File::open(&config.payloads_path)?;
112
+ let payloads_reader = BufReader::new(payloads_file);
113
+ let payloads = payload::load_payloads(payloads_reader)?;
114
+ println!(" ✓ Loaded {} phrase payloads", payloads.len());
115
+
116
+ #[derive(Debug, Deserialize)]
117
+ struct Manifest {
118
+ separator_id: u32,
119
+ }
120
+
121
+ let manifest_data = std::fs::read_to_string(&config.manifest_path)?;
122
+ let manifest: Manifest = serde_json::from_str(&manifest_data)?;
123
+ println!(" ✓ Loaded manifest");
124
+ println!();
125
+
126
+ println!("🔍 Tagging documents...");
127
+
128
+ let corpus_file = File::open(corpus_path)?;
129
+ let corpus_reader = BufReader::new(corpus_file);
130
+
131
+ let output_file = File::create(output_path)?;
132
+ let mut output_writer = BufWriter::new(output_file);
133
+
134
+ let mut stats = TaggingStats {
135
+ documents: 0,
136
+ total_spans: 0,
137
+ docs_with_spans: 0,
138
+ };
139
+
140
+ for line in corpus_reader.lines() {
141
+ let line = line?;
142
+ if line.trim().is_empty() {
143
+ continue;
144
+ }
145
+
146
+ let doc: InputDocument = serde_json::from_str(&line)?;
147
+
148
+ let token_ids = encode_tokens(&doc.tokens, &vocab);
149
+
150
+ let separator = manifest.separator_id;
151
+ let mut bytes = Vec::with_capacity(token_ids.len() * 5);
152
+ for &token_id in &token_ids {
153
+ bytes.extend_from_slice(&token_id.to_le_bytes());
154
+ bytes.extend_from_slice(&separator.to_le_bytes());
155
+ }
156
+
157
+ #[derive(Debug, Clone, Copy)]
158
+ struct Match {
159
+ start: usize,
160
+ end: usize,
161
+ phrase_id: u32,
162
+ }
163
+
164
+ let mut matches: Vec<Match> = automaton
165
+ .find_overlapping_iter(&bytes)
166
+ .filter_map(|m| {
167
+ let pattern_id = m.value() as usize;
168
+ let start_token = m.start() / 8;
169
+ let end_token = (m.end() + 7) / 8;
170
+
171
+ payloads.get(pattern_id).map(|payload| Match {
172
+ start: start_token,
173
+ end: end_token,
174
+ phrase_id: payload.phrase_id,
175
+ })
176
+ })
177
+ .collect();
178
+
179
+ if config.policy == "leftmost_longest" {
180
+ matches.sort_by_key(|m| (m.start, std::cmp::Reverse(m.end)));
181
+
182
+ let mut resolved = Vec::new();
183
+ let mut covered_end = 0;
184
+
185
+ for m in matches {
186
+ if m.start >= covered_end {
187
+ resolved.push(m);
188
+ covered_end = m.end;
189
+ }
190
+ }
191
+
192
+ matches = resolved;
193
+ } else if config.policy == "leftmost_first" {
194
+ matches.sort_by_key(|m| m.start);
195
+
196
+ let mut resolved = Vec::new();
197
+ let mut covered_end = 0;
198
+
199
+ for m in matches {
200
+ if m.start >= covered_end {
201
+ resolved.push(m);
202
+ covered_end = m.end;
203
+ }
204
+ }
205
+
206
+ matches = resolved;
207
+ }
208
+
209
+ if matches.len() > config.max_spans {
210
+ matches.truncate(config.max_spans);
211
+ }
212
+
213
+ let spans: Vec<Span> = matches
214
+ .into_iter()
215
+ .map(|m| Span {
216
+ start: m.start,
217
+ end: m.end,
218
+ phrase_id: m.phrase_id,
219
+ label: config.label.clone(),
220
+ })
221
+ .collect();
222
+
223
+ stats.total_spans += spans.len();
224
+ if !spans.is_empty() {
225
+ stats.docs_with_spans += 1;
226
+ }
227
+
228
+ let output_doc = OutputDocument {
229
+ doc_id: doc.doc_id,
230
+ tokens: doc.tokens,
231
+ spans,
232
+ };
233
+
234
+ serde_json::to_writer(&mut output_writer, &output_doc)?;
235
+ writeln!(&mut output_writer)?;
236
+
237
+ stats.documents += 1;
238
+
239
+ if stats.documents % 1000 == 0 {
240
+ print!("\r Processed {} documents...", stats.documents);
241
+ std::io::stdout().flush()?;
242
+ }
243
+ }
244
+
245
+ if stats.documents % 1000 != 0 {
246
+ println!("\r ✓ Processed {} documents", stats.documents);
247
+ } else {
248
+ println!();
249
+ println!(" ✓ Processed {} documents", stats.documents);
250
+ }
251
+
252
+ output_writer.flush()?;
253
+
254
+ println!();
255
+ println!("✅ Tagging complete!");
256
+ println!();
257
+ println!("📈 Statistics:");
258
+ println!(" Documents: {}", stats.documents);
259
+ println!(" Total spans: {}", stats.total_spans);
260
+ println!(" Documents with spans: {}", stats.docs_with_spans);
261
+ println!(
262
+ " Avg spans per document: {:.2}",
263
+ if stats.documents > 0 {
264
+ stats.total_spans as f64 / stats.documents as f64
265
+ } else {
266
+ 0.0
267
+ }
268
+ );
269
+
270
+ Ok(stats)
271
+ }
272
+
273
+ fn main() {
274
+ let args: Vec<String> = env::args().collect();
275
+
276
+ if args.len() != 4 {
277
+ eprintln!("Usage: {} <corpus.jsonl> <config.json> <output.jsonl>", args[0]);
278
+ eprintln!();
279
+ eprintln!("Arguments:");
280
+ eprintln!(" corpus.jsonl - Input corpus with pre-tokenized documents");
281
+ eprintln!(" config.json - Tagging configuration");
282
+ eprintln!(" output.jsonl - Output path for tagged corpus");
283
+ process::exit(1);
284
+ }
285
+
286
+ let corpus_path = &args[1];
287
+ let config_path = &args[2];
288
+ let output_path = &args[3];
289
+
290
+ if !Path::new(corpus_path).exists() {
291
+ eprintln!("Error: Corpus file not found: {}", corpus_path);
292
+ process::exit(1);
293
+ }
294
+
295
+ if !Path::new(config_path).exists() {
296
+ eprintln!("Error: Config file not found: {}", config_path);
297
+ process::exit(1);
298
+ }
299
+
300
+ let config_data = match std::fs::read_to_string(config_path) {
301
+ Ok(data) => data,
302
+ Err(e) => {
303
+ eprintln!("Error: Failed to read config file: {}", e);
304
+ process::exit(1);
305
+ }
306
+ };
307
+
308
+ let config: TagConfig = match serde_json::from_str(&config_data) {
309
+ Ok(cfg) => cfg,
310
+ Err(e) => {
311
+ eprintln!("Error: Failed to parse config: {}", e);
312
+ process::exit(1);
313
+ }
314
+ };
315
+
316
+ if let Err(e) = tag_corpus(corpus_path, &config, output_path) {
317
+ eprintln!("Error: Tagging failed: {}", e);
318
+ process::exit(1);
319
+ }
320
+ }
@@ -0,0 +1,104 @@
1
+ mod manifest;
2
+ mod matcher;
3
+ mod payload;
4
+ mod policy;
5
+
6
+ use magnus::{define_module, function, method, prelude::*, Error, RArray, RHash, Ruby};
7
+ use matcher::{Matcher as RustMatcher, Stats};
8
+ use parking_lot::RwLock;
9
+ use policy::MatchPolicy;
10
+ use std::sync::Arc;
11
+
12
+ type SharedMatcher = Arc<RwLock<Option<Arc<RustMatcher>>>>;
13
+
14
+ #[magnus::wrap(class = "PhraseKit::NativeMatcher", free_immediately, size)]
15
+ struct MatcherWrapper {
16
+ matcher: SharedMatcher,
17
+ }
18
+
19
+ impl MatcherWrapper {
20
+ fn new() -> Self {
21
+ Self {
22
+ matcher: Arc::new(RwLock::new(None)),
23
+ }
24
+ }
25
+
26
+ fn load(&self, automaton_path: String, payloads_path: String, manifest_path: String) -> Result<(), Error> {
27
+ let matcher = RustMatcher::load(&automaton_path, &payloads_path, &manifest_path)
28
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load matcher: {}", e)))?;
29
+
30
+ let mut guard = self.matcher.write();
31
+ *guard = Some(Arc::new(matcher));
32
+
33
+ Ok(())
34
+ }
35
+
36
+ fn match_tokens(&self, token_ids: Vec<u32>, policy: String, max: usize) -> Result<RArray, Error> {
37
+ let guard = self.matcher.read();
38
+ let matcher = guard
39
+ .as_ref()
40
+ .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "Matcher not loaded"))?;
41
+
42
+ let match_policy = MatchPolicy::from_str(&policy)
43
+ .ok_or_else(|| Error::new(magnus::exception::arg_error(), format!("Invalid policy: {}", policy)))?;
44
+
45
+ let matches = matcher.match_tokens(&token_ids, match_policy, max);
46
+
47
+ let result = RArray::new();
48
+ for m in matches {
49
+ let hash = RHash::new();
50
+ hash.aset("start", m.start)?;
51
+ hash.aset("end", m.end)?;
52
+ hash.aset("phrase_id", m.payload.phrase_id)?;
53
+ hash.aset("salience", m.payload.salience)?;
54
+ hash.aset("count", m.payload.count)?;
55
+ hash.aset("n", m.payload.n)?;
56
+ result.push(hash)?;
57
+ }
58
+
59
+ Ok(result)
60
+ }
61
+
62
+ fn stats(&self) -> Result<RHash, Error> {
63
+ let guard = self.matcher.read();
64
+ let matcher = guard
65
+ .as_ref()
66
+ .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "Matcher not loaded"))?;
67
+
68
+ let stats = Stats::from_matcher(matcher);
69
+ let hash = RHash::new();
70
+
71
+ hash.aset("version", stats.version)?;
72
+ hash.aset("loaded_at", stats.loaded_at.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64)?;
73
+ hash.aset("num_patterns", stats.num_patterns)?;
74
+ hash.aset("heap_mb", stats.heap_mb)?;
75
+ hash.aset("hits_total", stats.hits_total)?;
76
+ hash.aset("p50_us", stats.p50_us)?;
77
+ hash.aset("p95_us", stats.p95_us)?;
78
+ hash.aset("p99_us", stats.p99_us)?;
79
+
80
+ Ok(hash)
81
+ }
82
+
83
+ fn healthcheck(&self) -> Result<bool, Error> {
84
+ let guard = self.matcher.read();
85
+ guard
86
+ .as_ref()
87
+ .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "Matcher not loaded"))?;
88
+ Ok(true)
89
+ }
90
+ }
91
+
92
+ #[magnus::init]
93
+ fn init(ruby: &Ruby) -> Result<(), Error> {
94
+ let module = define_module("PhraseKit")?;
95
+ let class = module.define_class("NativeMatcher", ruby.class_object())?;
96
+
97
+ class.define_singleton_method("new", function!(MatcherWrapper::new, 0))?;
98
+ class.define_method("load", method!(MatcherWrapper::load, 3))?;
99
+ class.define_method("match_tokens", method!(MatcherWrapper::match_tokens, 3))?;
100
+ class.define_method("stats", method!(MatcherWrapper::stats, 0))?;
101
+ class.define_method("healthcheck", method!(MatcherWrapper::healthcheck, 0))?;
102
+
103
+ Ok(())
104
+ }
@@ -0,0 +1,88 @@
1
+ use serde::{Deserialize, Serialize};
2
+ use std::fs::File;
3
+ use std::io::BufReader;
4
+ use std::path::Path;
5
+ use thiserror::Error;
6
+
7
+ #[derive(Debug, Clone, Serialize, Deserialize)]
8
+ pub struct Manifest {
9
+ pub version: String,
10
+ pub tokenizer: String,
11
+ pub num_patterns: usize,
12
+ pub min_count: Option<u32>,
13
+ pub salience_threshold: Option<f32>,
14
+ pub built_at: String,
15
+ pub separator_id: u32,
16
+ }
17
+
18
+ #[derive(Error, Debug)]
19
+ pub enum ManifestError {
20
+ #[error("IO error: {0}")]
21
+ Io(#[from] std::io::Error),
22
+
23
+ #[error("JSON parse error: {0}")]
24
+ Json(#[from] serde_json::Error),
25
+
26
+ #[error("Invalid manifest: {0}")]
27
+ #[allow(dead_code)]
28
+ Invalid(String),
29
+ }
30
+
31
+ impl Manifest {
32
+ pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ManifestError> {
33
+ let file = File::open(path)?;
34
+ let reader = BufReader::new(file);
35
+ let manifest: Manifest = serde_json::from_reader(reader)?;
36
+
37
+ if manifest.separator_id == 0 {
38
+ return Err(ManifestError::Invalid(
39
+ "separator_id must be non-zero".to_string(),
40
+ ));
41
+ }
42
+
43
+ Ok(manifest)
44
+ }
45
+
46
+ #[allow(dead_code)]
47
+ pub fn validate_compatible(&self, other: &Manifest) -> Result<(), ManifestError> {
48
+ if self.tokenizer != other.tokenizer {
49
+ return Err(ManifestError::Invalid(format!(
50
+ "Tokenizer mismatch: expected {}, got {}",
51
+ self.tokenizer, other.tokenizer
52
+ )));
53
+ }
54
+
55
+ if self.separator_id != other.separator_id {
56
+ return Err(ManifestError::Invalid(format!(
57
+ "Separator ID mismatch: expected {}, got {}",
58
+ self.separator_id, other.separator_id
59
+ )));
60
+ }
61
+
62
+ Ok(())
63
+ }
64
+ }
65
+
66
+ #[cfg(test)]
67
+ mod tests {
68
+ use super::*;
69
+
70
+ #[test]
71
+ fn test_manifest_deserialize() {
72
+ let json = r#"{
73
+ "version": "pk-2025-09-25-01",
74
+ "tokenizer": "scientist-v1",
75
+ "num_patterns": 1287345,
76
+ "min_count": 20,
77
+ "salience_threshold": 1.0,
78
+ "built_at": "2025-09-25T18:44:00Z",
79
+ "separator_id": 4294967294
80
+ }"#;
81
+
82
+ let manifest: Manifest = serde_json::from_str(json).unwrap();
83
+ assert_eq!(manifest.version, "pk-2025-09-25-01");
84
+ assert_eq!(manifest.tokenizer, "scientist-v1");
85
+ assert_eq!(manifest.num_patterns, 1287345);
86
+ assert_eq!(manifest.separator_id, 4294967294);
87
+ }
88
+ }