spellkit 0.1.0.pre.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,255 @@
1
+ mod symspell;
2
+ mod guards;
3
+
4
+ use magnus::{class, define_module, function, method, prelude::*, Error, RArray, RHash, Ruby, Value, TryConvert};
5
+ use std::sync::{Arc, RwLock};
6
+ use symspell::SymSpell;
7
+ use guards::Guards;
8
+
9
+ use std::time::{SystemTime, UNIX_EPOCH};
10
+
11
+ #[derive(Clone)]
12
+ #[magnus::wrap(class = "SpellKit::Checker", free_immediately, size)]
13
+ struct Checker {
14
+ state: Arc<RwLock<CheckerState>>,
15
+ }
16
+
17
+ struct CheckerState {
18
+ symspell: Option<SymSpell>,
19
+ guards: Guards,
20
+ loaded: bool,
21
+ frequency_threshold: f64,
22
+ loaded_at: Option<u64>,
23
+ dictionary_size: usize,
24
+ edit_distance: usize,
25
+ }
26
+
27
+ impl CheckerState {
28
+ fn new() -> Self {
29
+ Self {
30
+ symspell: None,
31
+ guards: Guards::new(),
32
+ loaded: false,
33
+ frequency_threshold: 10.0,
34
+ loaded_at: None,
35
+ dictionary_size: 0,
36
+ edit_distance: 1,
37
+ }
38
+ }
39
+ }
40
+
41
+ impl Checker {
42
+ fn new() -> Self {
43
+ Self {
44
+ state: Arc::new(RwLock::new(CheckerState::new())),
45
+ }
46
+ }
47
+
48
+ fn load_full(&self, config: RHash) -> Result<(), Error> {
49
+ let ruby = Ruby::get().unwrap();
50
+
51
+ // Required: dictionary path
52
+ let dictionary_path: String = TryConvert::try_convert(
53
+ config.fetch::<_, Value>("dictionary_path")
54
+ .map_err(|_| Error::new(ruby.exception_arg_error(), "dictionary_path is required"))?
55
+ )?;
56
+
57
+ let content = std::fs::read_to_string(&dictionary_path)
58
+ .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to read dictionary file: {}", e)))?;
59
+
60
+ // Optional: edit distance
61
+ let edit_dist: usize = config.get("edit_distance")
62
+ .and_then(|v: Value| TryConvert::try_convert(v).ok())
63
+ .unwrap_or(1);
64
+
65
+ if edit_dist > 2 {
66
+ return Err(Error::new(ruby.exception_arg_error(), "edit_distance must be 1 or 2"));
67
+ }
68
+
69
+ let mut words = Vec::new();
70
+ for line in content.lines() {
71
+ let parts: Vec<&str> = line.split_whitespace().collect();
72
+ if parts.len() == 2 {
73
+ if let Ok(freq) = parts[1].parse::<u64>() {
74
+ words.push((parts[0].to_string(), freq));
75
+ }
76
+ }
77
+ }
78
+
79
+ let dictionary_size = words.len();
80
+ let mut symspell = SymSpell::new(edit_dist);
81
+ symspell.load_dictionary(words);
82
+
83
+ let mut guards = Guards::new();
84
+
85
+ // Load optional protected terms file
86
+ if let Some(protected_path) = config.get("protected_path") {
87
+ let path: String = TryConvert::try_convert(protected_path)?;
88
+ if let Ok(content) = std::fs::read_to_string(path) {
89
+ guards.load_protected(&content);
90
+ }
91
+ }
92
+
93
+ // Load optional protected patterns
94
+ if let Some(patterns_value) = config.get("protected_patterns") {
95
+ let patterns: RArray = TryConvert::try_convert(patterns_value)?;
96
+ for pattern_value in patterns.into_iter() {
97
+ let pattern: String = TryConvert::try_convert(pattern_value)?;
98
+ guards.add_pattern(&pattern)
99
+ .map_err(|e| Error::new(ruby.exception_arg_error(), e))?;
100
+ }
101
+ }
102
+
103
+ // Optional frequency threshold
104
+ let frequency_threshold: f64 = config.get("frequency_threshold")
105
+ .and_then(|v: Value| TryConvert::try_convert(v).ok())
106
+ .unwrap_or(10.0);
107
+
108
+ let loaded_at = SystemTime::now()
109
+ .duration_since(UNIX_EPOCH)
110
+ .ok()
111
+ .map(|d| d.as_secs());
112
+
113
+ let mut state = self.state.write().unwrap();
114
+ state.symspell = Some(symspell);
115
+ state.guards = guards;
116
+ state.frequency_threshold = frequency_threshold;
117
+ state.loaded = true;
118
+ state.loaded_at = loaded_at;
119
+ state.dictionary_size = dictionary_size;
120
+ state.edit_distance = edit_dist;
121
+
122
+ Ok(())
123
+ }
124
+
125
+ fn suggest(&self, word: String, max: Option<usize>) -> Result<RArray, Error> {
126
+ let ruby = Ruby::get().unwrap();
127
+ let max_suggestions = max.unwrap_or(5);
128
+ let state = self.state.read().unwrap();
129
+
130
+ if !state.loaded {
131
+ return Err(Error::new(ruby.exception_runtime_error(), "Dictionary not loaded. Call load! first"));
132
+ }
133
+
134
+ if let Some(ref symspell) = state.symspell {
135
+ let suggestions = symspell.suggest(&word, max_suggestions);
136
+ let result = RArray::new();
137
+
138
+ for suggestion in suggestions {
139
+ let hash = RHash::new();
140
+ hash.aset("term", suggestion.term)?;
141
+ hash.aset("distance", suggestion.distance)?;
142
+ hash.aset("freq", suggestion.frequency)?;
143
+ result.push(hash)?;
144
+ }
145
+
146
+ Ok(result)
147
+ } else {
148
+ Err(Error::new(ruby.exception_runtime_error(), "SymSpell not initialized"))
149
+ }
150
+ }
151
+
152
+ fn correct_if_unknown(&self, word: String, use_guard: Option<bool>) -> Result<String, Error> {
153
+ let ruby = Ruby::get().unwrap();
154
+ let state = self.state.read().unwrap();
155
+
156
+ if !state.loaded {
157
+ return Err(Error::new(ruby.exception_runtime_error(), "Dictionary not loaded. Call load! first"));
158
+ }
159
+
160
+ // Check if word is protected
161
+ if use_guard.unwrap_or(false) {
162
+ let normalized = SymSpell::normalize_word(&word);
163
+ if state.guards.is_protected_normalized(&word, &normalized) {
164
+ return Ok(word);
165
+ }
166
+ }
167
+
168
+ if let Some(ref symspell) = state.symspell {
169
+ let suggestions = symspell.suggest(&word, 5);
170
+
171
+ // If exact match exists, return original
172
+ if !suggestions.is_empty() && suggestions[0].distance == 0 {
173
+ return Ok(word);
174
+ }
175
+
176
+ // Find best correction with frequency gating
177
+ for suggestion in &suggestions {
178
+ if suggestion.distance <= 1 {
179
+ // Check frequency threshold - correction should be significantly more common
180
+ // Since we don't have the original word's frequency, we'll just take any ED=1 match
181
+ // In a full implementation, we'd check if suggestion.frequency >= threshold * original_freq
182
+ return Ok(suggestion.term.clone());
183
+ }
184
+ }
185
+
186
+ Ok(word)
187
+ } else {
188
+ Err(Error::new(ruby.exception_runtime_error(), "SymSpell not initialized"))
189
+ }
190
+ }
191
+
192
+ fn correct_tokens(&self, tokens: RArray, use_guard: Option<bool>) -> Result<RArray, Error> {
193
+ let result = RArray::new();
194
+ let guard = use_guard.unwrap_or(false);
195
+
196
+ for token in tokens.into_iter() {
197
+ let word: String = TryConvert::try_convert(token)?;
198
+ let corrected = self.correct_if_unknown(word, Some(guard))?;
199
+ result.push(corrected)?;
200
+ }
201
+
202
+ Ok(result)
203
+ }
204
+
205
+ fn stats(&self) -> Result<RHash, Error> {
206
+ let state = self.state.read().unwrap();
207
+ let stats = RHash::new();
208
+
209
+ if !state.loaded {
210
+ stats.aset("loaded", false)?;
211
+ return Ok(stats);
212
+ }
213
+
214
+ stats.aset("loaded", true)?;
215
+ stats.aset("dictionary_size", state.dictionary_size)?;
216
+ stats.aset("edit_distance", state.edit_distance)?;
217
+
218
+ if let Some(loaded_at) = state.loaded_at {
219
+ stats.aset("loaded_at", loaded_at)?;
220
+ }
221
+
222
+ Ok(stats)
223
+ }
224
+
225
+ fn healthcheck(&self) -> Result<(), Error> {
226
+ let ruby = Ruby::get().unwrap();
227
+ let state = self.state.read().unwrap();
228
+
229
+ if !state.loaded {
230
+ return Err(Error::new(ruby.exception_runtime_error(), "Dictionary not loaded"));
231
+ }
232
+
233
+ if state.symspell.is_none() {
234
+ return Err(Error::new(ruby.exception_runtime_error(), "SymSpell not initialized"));
235
+ }
236
+
237
+ Ok(())
238
+ }
239
+ }
240
+
241
+ #[magnus::init]
242
+ fn init(_ruby: &Ruby) -> Result<(), Error> {
243
+ let module = define_module("SpellKit")?;
244
+ let checker_class = module.define_class("Checker", class::object())?;
245
+
246
+ checker_class.define_singleton_method("new", function!(Checker::new, 0))?;
247
+ checker_class.define_method("load!", method!(Checker::load_full, 1))?;
248
+ checker_class.define_method("suggest", method!(Checker::suggest, 2))?;
249
+ checker_class.define_method("correct_if_unknown", method!(Checker::correct_if_unknown, 2))?;
250
+ checker_class.define_method("correct_tokens", method!(Checker::correct_tokens, 2))?;
251
+ checker_class.define_method("stats", method!(Checker::stats, 0))?;
252
+ checker_class.define_method("healthcheck", method!(Checker::healthcheck, 0))?;
253
+
254
+ Ok(())
255
+ }
@@ -0,0 +1,264 @@
1
+ use hashbrown::{HashMap, HashSet};
2
+ use std::cmp::Ordering;
3
+ use unicode_normalization::UnicodeNormalization;
4
+
5
+ #[derive(Debug, Clone)]
6
+ pub struct Suggestion {
7
+ pub term: String,
8
+ pub distance: usize,
9
+ pub frequency: u64,
10
+ }
11
+
12
+ impl Suggestion {
13
+ pub fn new(term: String, distance: usize, frequency: u64) -> Self {
14
+ Self {
15
+ term,
16
+ distance,
17
+ frequency,
18
+ }
19
+ }
20
+ }
21
+
22
+ impl Ord for Suggestion {
23
+ fn cmp(&self, other: &Self) -> Ordering {
24
+ self.distance
25
+ .cmp(&other.distance)
26
+ .then_with(|| other.frequency.cmp(&self.frequency))
27
+ .then_with(|| self.term.cmp(&other.term))
28
+ }
29
+ }
30
+
31
+ impl PartialOrd for Suggestion {
32
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33
+ Some(self.cmp(other))
34
+ }
35
+ }
36
+
37
+ impl PartialEq for Suggestion {
38
+ fn eq(&self, other: &Self) -> bool {
39
+ self.term == other.term && self.distance == other.distance && self.frequency == other.frequency
40
+ }
41
+ }
42
+
43
+ impl Eq for Suggestion {}
44
+
45
+ pub struct SymSpell {
46
+ deletes: HashMap<String, HashSet<String>>,
47
+ words: HashMap<String, u64>,
48
+ max_edit_distance: usize,
49
+ }
50
+
51
+ impl SymSpell {
52
+ pub fn new(max_edit_distance: usize) -> Self {
53
+ Self {
54
+ deletes: HashMap::new(),
55
+ words: HashMap::new(),
56
+ max_edit_distance,
57
+ }
58
+ }
59
+
60
+ pub fn normalize_word(word: &str) -> String {
61
+ word.nfkd()
62
+ .filter(|c| !c.is_control() && !c.is_whitespace())
63
+ .collect::<String>()
64
+ .to_lowercase()
65
+ }
66
+
67
+ pub fn load_dictionary(&mut self, words: Vec<(String, u64)>) {
68
+ for (word, freq) in words {
69
+ let normalized = Self::normalize_word(&word);
70
+ self.add_word(&normalized, freq);
71
+ }
72
+ }
73
+
74
+ pub fn add_word(&mut self, word: &str, frequency: u64) {
75
+ self.words.insert(word.to_string(), frequency);
76
+
77
+ let deletes = self.get_deletes(word, self.max_edit_distance);
78
+ for delete in deletes {
79
+ self.deletes
80
+ .entry(delete)
81
+ .or_insert_with(HashSet::new)
82
+ .insert(word.to_string());
83
+ }
84
+ }
85
+
86
+ fn get_deletes(&self, word: &str, edit_distance: usize) -> HashSet<String> {
87
+ let mut deletes = HashSet::new();
88
+ if edit_distance == 0 {
89
+ return deletes;
90
+ }
91
+
92
+ let mut queue = vec![word.to_string()];
93
+ let mut processed = HashSet::new();
94
+
95
+ for _ in 0..edit_distance {
96
+ let mut temp_queue = Vec::new();
97
+ for item in queue {
98
+ if processed.contains(&item) {
99
+ continue;
100
+ }
101
+ processed.insert(item.clone());
102
+
103
+ for delete in self.generate_deletes(&item) {
104
+ if delete.len() >= 1 {
105
+ deletes.insert(delete.clone());
106
+ temp_queue.push(delete);
107
+ }
108
+ }
109
+ }
110
+ queue = temp_queue;
111
+ }
112
+
113
+ deletes
114
+ }
115
+
116
+ fn generate_deletes(&self, word: &str) -> Vec<String> {
117
+ let chars: Vec<char> = word.chars().collect();
118
+ let mut deletes = Vec::new();
119
+
120
+ for i in 0..chars.len() {
121
+ let mut new_word = String::new();
122
+ for (j, &ch) in chars.iter().enumerate() {
123
+ if j != i {
124
+ new_word.push(ch);
125
+ }
126
+ }
127
+ deletes.push(new_word);
128
+ }
129
+
130
+ deletes
131
+ }
132
+
133
+ pub fn suggest(&self, word: &str, max_suggestions: usize) -> Vec<Suggestion> {
134
+ let normalized = Self::normalize_word(word);
135
+ let mut suggestions = Vec::new();
136
+ let mut seen = HashSet::new();
137
+
138
+ if let Some(&freq) = self.words.get(&normalized) {
139
+ suggestions.push(Suggestion::new(normalized.clone(), 0, freq));
140
+ seen.insert(normalized.clone());
141
+ }
142
+
143
+ let input_deletes = self.get_deletes(&normalized, self.max_edit_distance);
144
+
145
+ for delete in &input_deletes {
146
+ // Check if this delete is itself a dictionary word (important for finding words shorter than input)
147
+ if !seen.contains(delete) {
148
+ if let Some(&freq) = self.words.get(delete) {
149
+ let distance = self.edit_distance(&normalized, delete);
150
+ if distance <= self.max_edit_distance {
151
+ suggestions.push(Suggestion::new(delete.clone(), distance, freq));
152
+ seen.insert(delete.clone());
153
+ }
154
+ }
155
+ }
156
+
157
+ // Check the deletes map for candidates
158
+ if let Some(candidates) = self.deletes.get(delete) {
159
+ for candidate in candidates {
160
+ if seen.contains(candidate) {
161
+ continue;
162
+ }
163
+
164
+ let distance = self.edit_distance(&normalized, candidate);
165
+ if distance <= self.max_edit_distance {
166
+ if let Some(&freq) = self.words.get(candidate) {
167
+ suggestions.push(Suggestion::new(candidate.clone(), distance, freq));
168
+ seen.insert(candidate.clone());
169
+ }
170
+ }
171
+ }
172
+ }
173
+ }
174
+
175
+ if let Some(candidates) = self.deletes.get(&normalized) {
176
+ for candidate in candidates {
177
+ if seen.contains(candidate) {
178
+ continue;
179
+ }
180
+
181
+ let distance = self.edit_distance(&normalized, candidate);
182
+ if distance <= self.max_edit_distance {
183
+ if let Some(&freq) = self.words.get(candidate) {
184
+ suggestions.push(Suggestion::new(candidate.clone(), distance, freq));
185
+ seen.insert(candidate.clone());
186
+ }
187
+ }
188
+ }
189
+ }
190
+
191
+ suggestions.sort();
192
+ suggestions.truncate(max_suggestions);
193
+ suggestions
194
+ }
195
+
196
+ fn edit_distance(&self, s1: &str, s2: &str) -> usize {
197
+ let len1 = s1.chars().count();
198
+ let len2 = s2.chars().count();
199
+
200
+ if len1 == 0 {
201
+ return len2;
202
+ }
203
+ if len2 == 0 {
204
+ return len1;
205
+ }
206
+
207
+ let s1_chars: Vec<char> = s1.chars().collect();
208
+ let s2_chars: Vec<char> = s2.chars().collect();
209
+
210
+ let mut prev_row: Vec<usize> = (0..=len2).collect();
211
+ let mut curr_row = vec![0; len2 + 1];
212
+
213
+ for i in 1..=len1 {
214
+ curr_row[0] = i;
215
+
216
+ for j in 1..=len2 {
217
+ let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
218
+ 0
219
+ } else {
220
+ 1
221
+ };
222
+
223
+ curr_row[j] = std::cmp::min(
224
+ std::cmp::min(
225
+ prev_row[j] + 1, // deletion
226
+ curr_row[j - 1] + 1 // insertion
227
+ ),
228
+ prev_row[j - 1] + cost // substitution
229
+ );
230
+ }
231
+
232
+ std::mem::swap(&mut prev_row, &mut curr_row);
233
+ }
234
+
235
+ prev_row[len2]
236
+ }
237
+ }
238
+
239
+ #[cfg(test)]
240
+ mod tests {
241
+ use super::*;
242
+
243
+ #[test]
244
+ fn test_edit_distance() {
245
+ let symspell = SymSpell::new(2);
246
+ assert_eq!(symspell.edit_distance("test", "test"), 0);
247
+ assert_eq!(symspell.edit_distance("test", "tests"), 1);
248
+ assert_eq!(symspell.edit_distance("test", "tast"), 1);
249
+ assert_eq!(symspell.edit_distance("test", "toast"), 2);
250
+ }
251
+
252
+ #[test]
253
+ fn test_suggestions() {
254
+ let mut symspell = SymSpell::new(2);
255
+ symspell.add_word("hello", 1000);
256
+ symspell.add_word("hell", 500);
257
+ symspell.add_word("help", 750);
258
+
259
+ let suggestions = symspell.suggest("helo", 3);
260
+ assert!(!suggestions.is_empty());
261
+ assert_eq!(suggestions[0].term, "hello");
262
+ assert_eq!(suggestions[0].distance, 1);
263
+ }
264
+ }