spellkit 0.1.0.pre.1 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +259 -33
- data/ext/spellkit/Cargo.lock +0 -57
- data/ext/spellkit/Cargo.toml +0 -2
- data/ext/spellkit/src/guards.rs +21 -3
- data/ext/spellkit/src/lib.rs +213 -75
- data/ext/spellkit/src/symspell.rs +115 -30
- data/ext/spellkit/target/debug/build/rb-sys-ead65721880de65e/out/bindings-0.9.117-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/ext/spellkit/target/debug/incremental/spellkit-07yduakb6espe/s-hbic3f250f-1cel1lt.lock +0 -0
- data/ext/spellkit/target/debug/incremental/spellkit-1d3zzknqc98bj/s-hbic3f250l-011iykk.lock +0 -0
- data/ext/spellkit/target/debug/incremental/spellkit-1pt6om2w642b5/s-hbihepi6zy-1r3p88g.lock +0 -0
- data/ext/spellkit/target/release/build/clang-sys-523e86284ef4dd76/out/common.rs +355 -0
- data/ext/spellkit/target/release/build/clang-sys-523e86284ef4dd76/out/dynamic.rs +276 -0
- data/ext/spellkit/target/release/build/clang-sys-523e86284ef4dd76/out/macros.rs +49 -0
- data/ext/spellkit/target/release/build/rb-sys-7d03ffe964952311/out/bindings-0.9.117-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/spellkit/version.rb +1 -1
- data/lib/spellkit.rb +176 -31
- metadata +97 -6
- data/LICENSE +0 -21
data/ext/spellkit/src/lib.rs
CHANGED
|
@@ -22,6 +22,10 @@ struct CheckerState {
|
|
|
22
22
|
loaded_at: Option<u64>,
|
|
23
23
|
dictionary_size: usize,
|
|
24
24
|
edit_distance: usize,
|
|
25
|
+
skipped_malformed: usize,
|
|
26
|
+
skipped_multiword: usize,
|
|
27
|
+
skipped_invalid_freq: usize,
|
|
28
|
+
skipped_duplicates: usize,
|
|
25
29
|
}
|
|
26
30
|
|
|
27
31
|
impl CheckerState {
|
|
@@ -34,10 +38,60 @@ impl CheckerState {
|
|
|
34
38
|
loaded_at: None,
|
|
35
39
|
dictionary_size: 0,
|
|
36
40
|
edit_distance: 1,
|
|
41
|
+
skipped_malformed: 0,
|
|
42
|
+
skipped_multiword: 0,
|
|
43
|
+
skipped_invalid_freq: 0,
|
|
44
|
+
skipped_duplicates: 0,
|
|
37
45
|
}
|
|
38
46
|
}
|
|
39
47
|
}
|
|
40
48
|
|
|
49
|
+
// Helper function to correct a single word
|
|
50
|
+
// Returns the corrected word or the original if no correction is appropriate
|
|
51
|
+
fn correct_word(
|
|
52
|
+
state: &CheckerState,
|
|
53
|
+
symspell: &SymSpell,
|
|
54
|
+
word: &str,
|
|
55
|
+
) -> String {
|
|
56
|
+
// Always check if word is protected
|
|
57
|
+
let normalized = SymSpell::normalize_word(word);
|
|
58
|
+
if state.guards.is_protected_normalized(word, &normalized) {
|
|
59
|
+
return word.to_string();
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
let suggestions = symspell.suggestions(word, 5);
|
|
63
|
+
|
|
64
|
+
// If exact match exists, return canonical form from dictionary
|
|
65
|
+
if !suggestions.is_empty() && suggestions[0].distance == 0 {
|
|
66
|
+
return suggestions[0].term.clone();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Get original word's frequency (if it exists in dictionary)
|
|
70
|
+
let original_freq = symspell.get_frequency(word);
|
|
71
|
+
|
|
72
|
+
// Find best correction with frequency threshold
|
|
73
|
+
for suggestion in &suggestions {
|
|
74
|
+
if suggestion.distance <= state.edit_distance {
|
|
75
|
+
// Apply frequency threshold
|
|
76
|
+
let passes_threshold = match original_freq {
|
|
77
|
+
// Word not in dictionary: require suggestion frequency >= absolute threshold
|
|
78
|
+
None => suggestion.frequency as f64 >= state.frequency_threshold,
|
|
79
|
+
// Word in dictionary: require suggestion frequency >= threshold * original frequency
|
|
80
|
+
Some(orig_freq) => {
|
|
81
|
+
suggestion.frequency as f64 >= state.frequency_threshold * orig_freq as f64
|
|
82
|
+
}
|
|
83
|
+
};
|
|
84
|
+
|
|
85
|
+
if passes_threshold {
|
|
86
|
+
return suggestion.term.clone();
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// No suggestions passed the threshold
|
|
92
|
+
word.to_string()
|
|
93
|
+
}
|
|
94
|
+
|
|
41
95
|
impl Checker {
|
|
42
96
|
fn new() -> Self {
|
|
43
97
|
Self {
|
|
@@ -54,56 +108,129 @@ impl Checker {
|
|
|
54
108
|
.map_err(|_| Error::new(ruby.exception_arg_error(), "dictionary_path is required"))?
|
|
55
109
|
)?;
|
|
56
110
|
|
|
57
|
-
|
|
58
|
-
|
|
111
|
+
// Optional: edit distance
|
|
112
|
+
let edit_dist: usize = config.get("edit_distance")
|
|
113
|
+
.and_then(|v: Value| TryConvert::try_convert(v).ok())
|
|
114
|
+
.unwrap_or(1);
|
|
59
115
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
.unwrap_or(1);
|
|
116
|
+
if edit_dist > 2 {
|
|
117
|
+
return Err(Error::new(ruby.exception_arg_error(), "edit_distance must be 1 or 2"));
|
|
118
|
+
}
|
|
64
119
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
120
|
+
// Stream dictionary loading: read line-by-line and add directly to SymSpell
|
|
121
|
+
// This avoids buffering the entire file and intermediate Vec allocation
|
|
122
|
+
let file = std::fs::File::open(&dictionary_path)
|
|
123
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to open dictionary file: {}", e)))?;
|
|
124
|
+
|
|
125
|
+
let reader = std::io::BufReader::new(file);
|
|
126
|
+
let mut symspell = SymSpell::new(edit_dist);
|
|
127
|
+
let mut dictionary_size = 0;
|
|
128
|
+
let mut skipped_malformed = 0;
|
|
129
|
+
let mut skipped_multiword = 0;
|
|
130
|
+
let mut skipped_invalid_freq = 0;
|
|
131
|
+
let mut skipped_duplicates = 0;
|
|
132
|
+
|
|
133
|
+
use std::io::BufRead;
|
|
134
|
+
for line in reader.lines() {
|
|
135
|
+
let line = line.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to read line: {}", e)))?;
|
|
136
|
+
|
|
137
|
+
// Try tab-separated first (allows multi-word terms), then space-separated (SymSpell format)
|
|
138
|
+
let parts: Vec<&str> = if line.contains('\t') {
|
|
139
|
+
line.split('\t').collect()
|
|
140
|
+
} else {
|
|
141
|
+
line.split_whitespace().collect()
|
|
142
|
+
};
|
|
143
|
+
|
|
144
|
+
// Validate we have exactly 2 columns (term and frequency)
|
|
145
|
+
if parts.len() != 2 {
|
|
146
|
+
skipped_malformed += 1;
|
|
147
|
+
continue;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
let term = parts[0].trim();
|
|
151
|
+
let freq_str = parts[1].trim();
|
|
68
152
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if let Ok(freq) = parts[1].parse::<u64>() {
|
|
74
|
-
words.push((parts[0].to_string(), freq));
|
|
153
|
+
// Skip empty terms or frequencies
|
|
154
|
+
if term.is_empty() || freq_str.is_empty() {
|
|
155
|
+
skipped_malformed += 1;
|
|
156
|
+
continue;
|
|
75
157
|
}
|
|
76
|
-
}
|
|
77
|
-
}
|
|
78
158
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
159
|
+
// Check for multi-word terms (SymSpell algorithm doesn't support phrases)
|
|
160
|
+
if term.contains(char::is_whitespace) {
|
|
161
|
+
skipped_multiword += 1;
|
|
162
|
+
continue;
|
|
163
|
+
}
|
|
82
164
|
|
|
83
|
-
|
|
165
|
+
// Parse frequency
|
|
166
|
+
match freq_str.parse::<u64>() {
|
|
167
|
+
Ok(freq) => {
|
|
168
|
+
let normalized = SymSpell::normalize_word(term);
|
|
169
|
+
let was_new = symspell.add_word(&normalized, term, freq);
|
|
170
|
+
if was_new {
|
|
171
|
+
dictionary_size += 1;
|
|
172
|
+
} else {
|
|
173
|
+
skipped_duplicates += 1;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
Err(_) => {
|
|
177
|
+
skipped_invalid_freq += 1;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
84
181
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if let
|
|
182
|
+
let mut guards = Guards::new();
|
|
183
|
+
|
|
184
|
+
// Load optional protected terms file
|
|
185
|
+
if let Some(protected_path) = config.get("protected_path") {
|
|
186
|
+
let path: String = TryConvert::try_convert(protected_path)?;
|
|
187
|
+
let content = std::fs::read_to_string(&path)
|
|
188
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(),
|
|
189
|
+
format!("Failed to read protected terms file '{}': {}", path, e)))?;
|
|
89
190
|
guards.load_protected(&content);
|
|
90
191
|
}
|
|
91
|
-
}
|
|
92
192
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
193
|
+
// Load optional protected patterns
|
|
194
|
+
if let Some(patterns_value) = config.get("protected_patterns") {
|
|
195
|
+
let patterns: RArray = TryConvert::try_convert(patterns_value)?;
|
|
196
|
+
for pattern_value in patterns.into_iter() {
|
|
197
|
+
let pattern_hash: RHash = TryConvert::try_convert(pattern_value)?;
|
|
198
|
+
|
|
199
|
+
let source: String = TryConvert::try_convert(
|
|
200
|
+
pattern_hash.fetch::<_, Value>("source")
|
|
201
|
+
.map_err(|_| Error::new(ruby.exception_arg_error(), "pattern hash missing 'source' key"))?
|
|
202
|
+
)?;
|
|
203
|
+
|
|
204
|
+
let case_insensitive: bool = pattern_hash.get("case_insensitive")
|
|
205
|
+
.and_then(|v: Value| TryConvert::try_convert(v).ok())
|
|
206
|
+
.unwrap_or(false);
|
|
207
|
+
|
|
208
|
+
let multiline: bool = pattern_hash.get("multiline")
|
|
209
|
+
.and_then(|v: Value| TryConvert::try_convert(v).ok())
|
|
210
|
+
.unwrap_or(false);
|
|
211
|
+
|
|
212
|
+
let extended: bool = pattern_hash.get("extended")
|
|
213
|
+
.and_then(|v: Value| TryConvert::try_convert(v).ok())
|
|
214
|
+
.unwrap_or(false);
|
|
215
|
+
|
|
216
|
+
guards.add_pattern_with_flags(&source, case_insensitive, multiline, extended)
|
|
217
|
+
.map_err(|e| Error::new(ruby.exception_arg_error(), e))?;
|
|
218
|
+
}
|
|
100
219
|
}
|
|
101
|
-
}
|
|
102
220
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
221
|
+
// Optional frequency threshold
|
|
222
|
+
let frequency_threshold: f64 = config.get("frequency_threshold")
|
|
223
|
+
.and_then(|v: Value| TryConvert::try_convert(v).ok())
|
|
224
|
+
.unwrap_or(10.0);
|
|
225
|
+
|
|
226
|
+
// Validate frequency threshold
|
|
227
|
+
if !frequency_threshold.is_finite() {
|
|
228
|
+
return Err(Error::new(ruby.exception_arg_error(), "frequency_threshold must be finite (not NaN or Infinity)"));
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
if frequency_threshold < 0.0 {
|
|
232
|
+
return Err(Error::new(ruby.exception_arg_error(), format!("frequency_threshold must be non-negative, got: {}", frequency_threshold)));
|
|
233
|
+
}
|
|
107
234
|
|
|
108
235
|
let loaded_at = SystemTime::now()
|
|
109
236
|
.duration_since(UNIX_EPOCH)
|
|
@@ -118,11 +245,15 @@ impl Checker {
|
|
|
118
245
|
state.loaded_at = loaded_at;
|
|
119
246
|
state.dictionary_size = dictionary_size;
|
|
120
247
|
state.edit_distance = edit_dist;
|
|
248
|
+
state.skipped_malformed = skipped_malformed;
|
|
249
|
+
state.skipped_multiword = skipped_multiword;
|
|
250
|
+
state.skipped_invalid_freq = skipped_invalid_freq;
|
|
251
|
+
state.skipped_duplicates = skipped_duplicates;
|
|
121
252
|
|
|
122
253
|
Ok(())
|
|
123
254
|
}
|
|
124
255
|
|
|
125
|
-
fn
|
|
256
|
+
fn suggestions(&self, word: String, max: Option<usize>) -> Result<RArray, Error> {
|
|
126
257
|
let ruby = Ruby::get().unwrap();
|
|
127
258
|
let max_suggestions = max.unwrap_or(5);
|
|
128
259
|
let state = self.state.read().unwrap();
|
|
@@ -132,7 +263,7 @@ impl Checker {
|
|
|
132
263
|
}
|
|
133
264
|
|
|
134
265
|
if let Some(ref symspell) = state.symspell {
|
|
135
|
-
let suggestions = symspell.
|
|
266
|
+
let suggestions = symspell.suggestions(&word, max_suggestions);
|
|
136
267
|
let result = RArray::new();
|
|
137
268
|
|
|
138
269
|
for suggestion in suggestions {
|
|
@@ -149,7 +280,7 @@ impl Checker {
|
|
|
149
280
|
}
|
|
150
281
|
}
|
|
151
282
|
|
|
152
|
-
fn
|
|
283
|
+
fn correct(&self, word: String) -> Result<bool, Error> {
|
|
153
284
|
let ruby = Ruby::get().unwrap();
|
|
154
285
|
let state = self.state.read().unwrap();
|
|
155
286
|
|
|
@@ -157,49 +288,51 @@ impl Checker {
|
|
|
157
288
|
return Err(Error::new(ruby.exception_runtime_error(), "Dictionary not loaded. Call load! first"));
|
|
158
289
|
}
|
|
159
290
|
|
|
160
|
-
// Check if word is protected
|
|
161
|
-
if use_guard.unwrap_or(false) {
|
|
162
|
-
let normalized = SymSpell::normalize_word(&word);
|
|
163
|
-
if state.guards.is_protected_normalized(&word, &normalized) {
|
|
164
|
-
return Ok(word);
|
|
165
|
-
}
|
|
166
|
-
}
|
|
167
|
-
|
|
168
291
|
if let Some(ref symspell) = state.symspell {
|
|
169
|
-
|
|
292
|
+
Ok(symspell.contains(&word))
|
|
293
|
+
} else {
|
|
294
|
+
Err(Error::new(ruby.exception_runtime_error(), "SymSpell not initialized"))
|
|
295
|
+
}
|
|
296
|
+
}
|
|
170
297
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
}
|
|
298
|
+
fn correct_if_unknown(&self, word: String) -> Result<String, Error> {
|
|
299
|
+
let ruby = Ruby::get().unwrap();
|
|
300
|
+
let state = self.state.read().unwrap();
|
|
175
301
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
// Check frequency threshold - correction should be significantly more common
|
|
180
|
-
// Since we don't have the original word's frequency, we'll just take any ED=1 match
|
|
181
|
-
// In a full implementation, we'd check if suggestion.frequency >= threshold * original_freq
|
|
182
|
-
return Ok(suggestion.term.clone());
|
|
183
|
-
}
|
|
184
|
-
}
|
|
302
|
+
if !state.loaded {
|
|
303
|
+
return Err(Error::new(ruby.exception_runtime_error(), "Dictionary not loaded. Call load! first"));
|
|
304
|
+
}
|
|
185
305
|
|
|
186
|
-
|
|
306
|
+
if let Some(ref symspell) = state.symspell {
|
|
307
|
+
Ok(correct_word(&state, symspell, &word))
|
|
187
308
|
} else {
|
|
188
309
|
Err(Error::new(ruby.exception_runtime_error(), "SymSpell not initialized"))
|
|
189
310
|
}
|
|
190
311
|
}
|
|
191
312
|
|
|
192
|
-
fn correct_tokens(&self, tokens: RArray
|
|
193
|
-
|
|
194
|
-
|
|
313
|
+
fn correct_tokens(&self, tokens: RArray) -> Result<RArray, Error> {
|
|
314
|
+
// Optimize batch correction by acquiring lock once for all tokens
|
|
315
|
+
// instead of calling correct_if_unknown per token (which re-locks each time)
|
|
316
|
+
let ruby = Ruby::get().unwrap();
|
|
317
|
+
let state = self.state.read().unwrap();
|
|
195
318
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
let corrected = self.correct_if_unknown(word, Some(guard))?;
|
|
199
|
-
result.push(corrected)?;
|
|
319
|
+
if !state.loaded {
|
|
320
|
+
return Err(Error::new(ruby.exception_runtime_error(), "Dictionary not loaded. Call load! first"));
|
|
200
321
|
}
|
|
201
322
|
|
|
202
|
-
|
|
323
|
+
let result = RArray::new();
|
|
324
|
+
|
|
325
|
+
if let Some(ref symspell) = state.symspell {
|
|
326
|
+
for token in tokens.into_iter() {
|
|
327
|
+
let word: String = TryConvert::try_convert(token)?;
|
|
328
|
+
let corrected = correct_word(&state, symspell, &word);
|
|
329
|
+
result.push(corrected)?;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
Ok(result)
|
|
333
|
+
} else {
|
|
334
|
+
Err(Error::new(ruby.exception_runtime_error(), "SymSpell not initialized"))
|
|
335
|
+
}
|
|
203
336
|
}
|
|
204
337
|
|
|
205
338
|
fn stats(&self) -> Result<RHash, Error> {
|
|
@@ -214,6 +347,10 @@ impl Checker {
|
|
|
214
347
|
stats.aset("loaded", true)?;
|
|
215
348
|
stats.aset("dictionary_size", state.dictionary_size)?;
|
|
216
349
|
stats.aset("edit_distance", state.edit_distance)?;
|
|
350
|
+
stats.aset("skipped_malformed", state.skipped_malformed)?;
|
|
351
|
+
stats.aset("skipped_multiword", state.skipped_multiword)?;
|
|
352
|
+
stats.aset("skipped_invalid_freq", state.skipped_invalid_freq)?;
|
|
353
|
+
stats.aset("skipped_duplicates", state.skipped_duplicates)?;
|
|
217
354
|
|
|
218
355
|
if let Some(loaded_at) = state.loaded_at {
|
|
219
356
|
stats.aset("loaded_at", loaded_at)?;
|
|
@@ -245,9 +382,10 @@ fn init(_ruby: &Ruby) -> Result<(), Error> {
|
|
|
245
382
|
|
|
246
383
|
checker_class.define_singleton_method("new", function!(Checker::new, 0))?;
|
|
247
384
|
checker_class.define_method("load!", method!(Checker::load_full, 1))?;
|
|
248
|
-
checker_class.define_method("
|
|
249
|
-
checker_class.define_method("
|
|
250
|
-
checker_class.define_method("
|
|
385
|
+
checker_class.define_method("suggestions", method!(Checker::suggestions, 2))?;
|
|
386
|
+
checker_class.define_method("correct?", method!(Checker::correct, 1))?;
|
|
387
|
+
checker_class.define_method("correct", method!(Checker::correct_if_unknown, 1))?;
|
|
388
|
+
checker_class.define_method("correct_tokens", method!(Checker::correct_tokens, 1))?;
|
|
251
389
|
checker_class.define_method("stats", method!(Checker::stats, 0))?;
|
|
252
390
|
checker_class.define_method("healthcheck", method!(Checker::healthcheck, 0))?;
|
|
253
391
|
|
|
@@ -2,6 +2,12 @@ use hashbrown::{HashMap, HashSet};
|
|
|
2
2
|
use std::cmp::Ordering;
|
|
3
3
|
use unicode_normalization::UnicodeNormalization;
|
|
4
4
|
|
|
5
|
+
#[derive(Debug, Clone)]
|
|
6
|
+
pub struct WordEntry {
|
|
7
|
+
pub canonical: String,
|
|
8
|
+
pub frequency: u64,
|
|
9
|
+
}
|
|
10
|
+
|
|
5
11
|
#[derive(Debug, Clone)]
|
|
6
12
|
pub struct Suggestion {
|
|
7
13
|
pub term: String,
|
|
@@ -44,7 +50,7 @@ impl Eq for Suggestion {}
|
|
|
44
50
|
|
|
45
51
|
pub struct SymSpell {
|
|
46
52
|
deletes: HashMap<String, HashSet<String>>,
|
|
47
|
-
words: HashMap<String,
|
|
53
|
+
words: HashMap<String, WordEntry>,
|
|
48
54
|
max_edit_distance: usize,
|
|
49
55
|
}
|
|
50
56
|
|
|
@@ -64,23 +70,44 @@ impl SymSpell {
|
|
|
64
70
|
.to_lowercase()
|
|
65
71
|
}
|
|
66
72
|
|
|
67
|
-
pub fn
|
|
68
|
-
|
|
69
|
-
let normalized = Self::normalize_word(&word);
|
|
70
|
-
self.add_word(&normalized, freq);
|
|
71
|
-
}
|
|
72
|
-
}
|
|
73
|
+
pub fn add_word(&mut self, normalized: &str, canonical: &str, frequency: u64) -> bool {
|
|
74
|
+
let normalized_key = normalized.to_string();
|
|
73
75
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
+
let was_new = if let Some(existing) = self.words.get_mut(&normalized_key) {
|
|
77
|
+
// Duplicate: sum frequencies and keep highest-frequency canonical form
|
|
78
|
+
let new_total_freq = existing.frequency + frequency;
|
|
79
|
+
|
|
80
|
+
// Keep the canonical form from the higher-frequency variant
|
|
81
|
+
if frequency > existing.frequency {
|
|
82
|
+
existing.canonical = canonical.to_string();
|
|
83
|
+
}
|
|
76
84
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
.
|
|
85
|
+
existing.frequency = new_total_freq;
|
|
86
|
+
false
|
|
87
|
+
} else {
|
|
88
|
+
// New entry
|
|
89
|
+
self.words.insert(
|
|
90
|
+
normalized_key.clone(),
|
|
91
|
+
WordEntry {
|
|
92
|
+
canonical: canonical.to_string(),
|
|
93
|
+
frequency,
|
|
94
|
+
},
|
|
95
|
+
);
|
|
96
|
+
true
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
// Only generate deletes for new entries (avoid redundant work)
|
|
100
|
+
if was_new {
|
|
101
|
+
let deletes = self.get_deletes(normalized, self.max_edit_distance);
|
|
102
|
+
for delete in deletes {
|
|
103
|
+
self.deletes
|
|
104
|
+
.entry(delete)
|
|
105
|
+
.or_insert_with(HashSet::new)
|
|
106
|
+
.insert(normalized_key.clone());
|
|
107
|
+
}
|
|
83
108
|
}
|
|
109
|
+
|
|
110
|
+
was_new
|
|
84
111
|
}
|
|
85
112
|
|
|
86
113
|
fn get_deletes(&self, word: &str, edit_distance: usize) -> HashSet<String> {
|
|
@@ -101,8 +128,10 @@ impl SymSpell {
|
|
|
101
128
|
processed.insert(item.clone());
|
|
102
129
|
|
|
103
130
|
for delete in self.generate_deletes(&item) {
|
|
104
|
-
|
|
105
|
-
|
|
131
|
+
deletes.insert(delete.clone());
|
|
132
|
+
|
|
133
|
+
// Only continue processing non-empty strings to avoid infinite loops
|
|
134
|
+
if !delete.is_empty() {
|
|
106
135
|
temp_queue.push(delete);
|
|
107
136
|
}
|
|
108
137
|
}
|
|
@@ -130,13 +159,23 @@ impl SymSpell {
|
|
|
130
159
|
deletes
|
|
131
160
|
}
|
|
132
161
|
|
|
133
|
-
pub fn
|
|
162
|
+
pub fn contains(&self, word: &str) -> bool {
|
|
163
|
+
let normalized = Self::normalize_word(word);
|
|
164
|
+
self.words.contains_key(&normalized)
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
pub fn get_frequency(&self, word: &str) -> Option<u64> {
|
|
168
|
+
let normalized = Self::normalize_word(word);
|
|
169
|
+
self.words.get(&normalized).map(|entry| entry.frequency)
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
pub fn suggestions(&self, word: &str, max_suggestions: usize) -> Vec<Suggestion> {
|
|
134
173
|
let normalized = Self::normalize_word(word);
|
|
135
174
|
let mut suggestions = Vec::new();
|
|
136
175
|
let mut seen = HashSet::new();
|
|
137
176
|
|
|
138
|
-
if let Some(
|
|
139
|
-
suggestions.push(Suggestion::new(
|
|
177
|
+
if let Some(entry) = self.words.get(&normalized) {
|
|
178
|
+
suggestions.push(Suggestion::new(entry.canonical.clone(), 0, entry.frequency));
|
|
140
179
|
seen.insert(normalized.clone());
|
|
141
180
|
}
|
|
142
181
|
|
|
@@ -145,10 +184,10 @@ impl SymSpell {
|
|
|
145
184
|
for delete in &input_deletes {
|
|
146
185
|
// Check if this delete is itself a dictionary word (important for finding words shorter than input)
|
|
147
186
|
if !seen.contains(delete) {
|
|
148
|
-
if let Some(
|
|
187
|
+
if let Some(entry) = self.words.get(delete) {
|
|
149
188
|
let distance = self.edit_distance(&normalized, delete);
|
|
150
189
|
if distance <= self.max_edit_distance {
|
|
151
|
-
suggestions.push(Suggestion::new(
|
|
190
|
+
suggestions.push(Suggestion::new(entry.canonical.clone(), distance, entry.frequency));
|
|
152
191
|
seen.insert(delete.clone());
|
|
153
192
|
}
|
|
154
193
|
}
|
|
@@ -163,8 +202,8 @@ impl SymSpell {
|
|
|
163
202
|
|
|
164
203
|
let distance = self.edit_distance(&normalized, candidate);
|
|
165
204
|
if distance <= self.max_edit_distance {
|
|
166
|
-
if let Some(
|
|
167
|
-
suggestions.push(Suggestion::new(
|
|
205
|
+
if let Some(entry) = self.words.get(candidate) {
|
|
206
|
+
suggestions.push(Suggestion::new(entry.canonical.clone(), distance, entry.frequency));
|
|
168
207
|
seen.insert(candidate.clone());
|
|
169
208
|
}
|
|
170
209
|
}
|
|
@@ -180,8 +219,8 @@ impl SymSpell {
|
|
|
180
219
|
|
|
181
220
|
let distance = self.edit_distance(&normalized, candidate);
|
|
182
221
|
if distance <= self.max_edit_distance {
|
|
183
|
-
if let Some(
|
|
184
|
-
suggestions.push(Suggestion::new(
|
|
222
|
+
if let Some(entry) = self.words.get(candidate) {
|
|
223
|
+
suggestions.push(Suggestion::new(entry.canonical.clone(), distance, entry.frequency));
|
|
185
224
|
seen.insert(candidate.clone());
|
|
186
225
|
}
|
|
187
226
|
}
|
|
@@ -252,13 +291,59 @@ mod tests {
|
|
|
252
291
|
#[test]
|
|
253
292
|
fn test_suggestions() {
|
|
254
293
|
let mut symspell = SymSpell::new(2);
|
|
255
|
-
symspell.add_word("hello", 1000);
|
|
256
|
-
symspell.add_word("hell", 500);
|
|
257
|
-
symspell.add_word("help", 750);
|
|
294
|
+
symspell.add_word("hello", "hello", 1000);
|
|
295
|
+
symspell.add_word("hell", "hell", 500);
|
|
296
|
+
symspell.add_word("help", "help", 750);
|
|
258
297
|
|
|
259
|
-
let suggestions = symspell.
|
|
298
|
+
let suggestions = symspell.suggestions("helo", 3);
|
|
260
299
|
assert!(!suggestions.is_empty());
|
|
261
300
|
assert_eq!(suggestions[0].term, "hello");
|
|
262
301
|
assert_eq!(suggestions[0].distance, 1);
|
|
263
302
|
}
|
|
303
|
+
|
|
304
|
+
#[test]
|
|
305
|
+
fn test_single_character_corrections() {
|
|
306
|
+
let mut symspell = SymSpell::new(1);
|
|
307
|
+
symspell.add_word("a", "a", 10000);
|
|
308
|
+
symspell.add_word("i", "I", 8000);
|
|
309
|
+
symspell.add_word("o", "o", 6000);
|
|
310
|
+
|
|
311
|
+
let suggestions = symspell.suggestions("x", 5);
|
|
312
|
+
assert!(!suggestions.is_empty(), "Single-character corrections should work");
|
|
313
|
+
assert!(suggestions.iter().any(|s| s.term == "a"), "Should suggest 'a' for 'x'");
|
|
314
|
+
|
|
315
|
+
let suggestions_for_j = symspell.suggestions("j", 5);
|
|
316
|
+
assert!(!suggestions_for_j.is_empty(), "Should find suggestions for 'j'");
|
|
317
|
+
assert!(suggestions_for_j.iter().any(|s| s.term == "I"), "Should suggest canonical 'I' (not 'i')");
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
#[test]
|
|
321
|
+
fn test_duplicate_entries_keep_highest_frequency_canonical() {
|
|
322
|
+
let mut symspell = SymSpell::new(1);
|
|
323
|
+
|
|
324
|
+
// Add high-frequency lowercase variant
|
|
325
|
+
symspell.add_word("hello", "hello", 10000);
|
|
326
|
+
|
|
327
|
+
// Add low-frequency uppercase variant (should not replace canonical)
|
|
328
|
+
symspell.add_word("hello", "HELLO", 100);
|
|
329
|
+
|
|
330
|
+
let suggestions = symspell.suggestions("hello", 1);
|
|
331
|
+
assert_eq!(suggestions.len(), 1);
|
|
332
|
+
assert_eq!(suggestions[0].term, "hello", "Should keep high-frequency 'hello' as canonical, not 'HELLO'");
|
|
333
|
+
assert_eq!(suggestions[0].frequency, 10100, "Should sum frequencies: 10000 + 100 = 10100");
|
|
334
|
+
|
|
335
|
+
// Verify reverse order also works
|
|
336
|
+
let mut symspell2 = SymSpell::new(1);
|
|
337
|
+
|
|
338
|
+
// Add low-frequency first
|
|
339
|
+
symspell2.add_word("world", "WORLD", 100);
|
|
340
|
+
|
|
341
|
+
// Add high-frequency second (should replace canonical)
|
|
342
|
+
symspell2.add_word("world", "world", 10000);
|
|
343
|
+
|
|
344
|
+
let suggestions2 = symspell2.suggestions("world", 1);
|
|
345
|
+
assert_eq!(suggestions2.len(), 1);
|
|
346
|
+
assert_eq!(suggestions2[0].term, "world", "Should update to high-frequency 'world' canonical");
|
|
347
|
+
assert_eq!(suggestions2[0].frequency, 10100, "Should sum frequencies");
|
|
348
|
+
}
|
|
264
349
|
}
|