ocr-stringdist 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/Cargo.lock +53 -1
  2. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/Cargo.toml +2 -1
  3. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/Justfile +6 -0
  4. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/PKG-INFO +2 -3
  5. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/README.md +1 -2
  6. ocr_stringdist-0.0.7/docs/source/api/index.rst +23 -0
  7. ocr_stringdist-0.0.7/examples/batch_processing.py +58 -0
  8. ocr_stringdist-0.0.5/example.py → ocr_stringdist-0.0.7/examples/weighted_levenshtein.py +1 -0
  9. ocr_stringdist-0.0.7/python/ocr_stringdist/__init__.py +10 -0
  10. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/python/ocr_stringdist/default_ocr_distances.py +4 -0
  11. ocr_stringdist-0.0.7/python/ocr_stringdist/levenshtein.py +71 -0
  12. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/src/lib.rs +2 -0
  13. ocr_stringdist-0.0.7/src/longest_tokens.rs +48 -0
  14. ocr_stringdist-0.0.7/src/rust_stringdist.rs +56 -0
  15. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/src/weighted_levenshtein.rs +46 -16
  16. ocr_stringdist-0.0.7/tests/test_batch_functions.py +131 -0
  17. ocr_stringdist-0.0.5/docs/source/api/index.rst +0 -18
  18. ocr_stringdist-0.0.5/python/ocr_stringdist/__init__.py +0 -42
  19. ocr_stringdist-0.0.5/src/rust_stringdist.rs +0 -39
  20. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/.github/workflows/CI.yml +0 -0
  21. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/.github/workflows/docs.yml +0 -0
  22. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/.gitignore +0 -0
  23. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/LICENSE +0 -0
  24. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/docs/Makefile +0 -0
  25. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/docs/make.bat +0 -0
  26. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/docs/source/conf.py +0 -0
  27. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/docs/source/index.rst +0 -0
  28. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/mypy.ini +0 -0
  29. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/pyproject.toml +0 -0
  30. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/python/ocr_stringdist/matching.py +0 -0
  31. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/python/ocr_stringdist/py.typed +0 -0
  32. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/ruff.toml +0 -0
  33. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/tests/test_matching.py +0 -0
  34. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/tests/test_ocr_stringdist.py +0 -0
  35. {ocr_stringdist-0.0.5 → ocr_stringdist-0.0.7}/uv.lock +0 -0
@@ -14,6 +14,37 @@ version = "1.0.0"
14
14
  source = "registry+https://github.com/rust-lang/crates.io-index"
15
15
  checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
16
16
 
17
+ [[package]]
18
+ name = "crossbeam-deque"
19
+ version = "0.8.6"
20
+ source = "registry+https://github.com/rust-lang/crates.io-index"
21
+ checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
22
+ dependencies = [
23
+ "crossbeam-epoch",
24
+ "crossbeam-utils",
25
+ ]
26
+
27
+ [[package]]
28
+ name = "crossbeam-epoch"
29
+ version = "0.9.18"
30
+ source = "registry+https://github.com/rust-lang/crates.io-index"
31
+ checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
32
+ dependencies = [
33
+ "crossbeam-utils",
34
+ ]
35
+
36
+ [[package]]
37
+ name = "crossbeam-utils"
38
+ version = "0.8.21"
39
+ source = "registry+https://github.com/rust-lang/crates.io-index"
40
+ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
41
+
42
+ [[package]]
43
+ name = "either"
44
+ version = "1.15.0"
45
+ source = "registry+https://github.com/rust-lang/crates.io-index"
46
+ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
47
+
17
48
  [[package]]
18
49
  name = "heck"
19
50
  version = "0.5.0"
@@ -43,9 +74,10 @@ dependencies = [
43
74
 
44
75
  [[package]]
45
76
  name = "ocr_stringdist"
46
- version = "0.0.5"
77
+ version = "0.0.7"
47
78
  dependencies = [
48
79
  "pyo3",
80
+ "rayon",
49
81
  ]
50
82
 
51
83
  [[package]]
@@ -141,6 +173,26 @@ dependencies = [
141
173
  "proc-macro2",
142
174
  ]
143
175
 
176
+ [[package]]
177
+ name = "rayon"
178
+ version = "1.10.0"
179
+ source = "registry+https://github.com/rust-lang/crates.io-index"
180
+ checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
181
+ dependencies = [
182
+ "either",
183
+ "rayon-core",
184
+ ]
185
+
186
+ [[package]]
187
+ name = "rayon-core"
188
+ version = "1.12.1"
189
+ source = "registry+https://github.com/rust-lang/crates.io-index"
190
+ checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
191
+ dependencies = [
192
+ "crossbeam-deque",
193
+ "crossbeam-utils",
194
+ ]
195
+
144
196
  [[package]]
145
197
  name = "syn"
146
198
  version = "2.0.100"
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "ocr_stringdist"
3
- version = "0.0.5"
3
+ version = "0.0.7"
4
4
  edition = "2021"
5
5
  description = "String distances considering OCR errors."
6
6
  authors = ["Niklas von Moers <niklasvmoers@protonmail.com>"]
@@ -15,6 +15,7 @@ crate-type = ["cdylib"]
15
15
 
16
16
  [dependencies]
17
17
  pyo3 = { version = "0.24.0", features = [] }
18
+ rayon = "1.10.0"
18
19
 
19
20
  [features]
20
21
  python = []
@@ -12,3 +12,9 @@ test:
12
12
 
13
13
  mypy:
14
14
  uv run mypy .
15
+
16
+ lint:
17
+ uv run ruff check . --fix
18
+
19
+ doc:
20
+ uv run make -C docs html
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocr_stringdist
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python
6
6
  Classifier: Operating System :: OS Independent
@@ -36,7 +36,7 @@ pip install ocr-stringdist
36
36
 
37
37
  ## Features
38
38
 
39
- - **Weighted Levenshtein Distance**: An adaptation of the classic Levenshtein algorithm with custom substitution costs for character pairs that are commonly confused in OCR models.
39
+ - **Weighted Levenshtein Distance**: An adaptation of the classic Levenshtein algorithm with custom substitution costs for character pairs that are commonly confused in OCR models, including efficient batch processing.
40
40
  - **Unicode Support**: Arbitrary unicode strings can be compared.
41
41
  - **Substitution of Multiple Characters**: Not just character pairs, but string pairs may be substituted, for example the Korean syllable "이" for the two letters "OI".
42
42
  - **Pre-defined OCR Distance Map**: A built-in distance map for common OCR confusions (e.g., "0" vs "O", "1" vs "l", "5" vs "S").
@@ -60,7 +60,6 @@ distance = osd.weighted_levenshtein_distance(
60
60
  "hi", "Ini",
61
61
  cost_map=custom_map,
62
62
  symmetric=True,
63
- default_cost=1.0,
64
63
  )
65
64
  print(f"Distance with custom map: {distance}")
66
65
  ```
@@ -21,7 +21,7 @@ pip install ocr-stringdist
21
21
 
22
22
  ## Features
23
23
 
24
- - **Weighted Levenshtein Distance**: An adaptation of the classic Levenshtein algorithm with custom substitution costs for character pairs that are commonly confused in OCR models.
24
+ - **Weighted Levenshtein Distance**: An adaptation of the classic Levenshtein algorithm with custom substitution costs for character pairs that are commonly confused in OCR models, including efficient batch processing.
25
25
  - **Unicode Support**: Arbitrary unicode strings can be compared.
26
26
  - **Substitution of Multiple Characters**: Not just character pairs, but string pairs may be substituted, for example the Korean syllable "이" for the two letters "OI".
27
27
  - **Pre-defined OCR Distance Map**: A built-in distance map for common OCR confusions (e.g., "0" vs "O", "1" vs "l", "5" vs "S").
@@ -45,7 +45,6 @@ distance = osd.weighted_levenshtein_distance(
45
45
  "hi", "Ini",
46
46
  cost_map=custom_map,
47
47
  symmetric=True,
48
- default_cost=1.0,
49
48
  )
50
49
  print(f"Distance with custom map: {distance}")
51
50
  ```
@@ -0,0 +1,23 @@
1
+ .. _api_reference:
2
+
3
+ API Reference
4
+ =============
5
+
6
+ This page contains the auto-generated API reference documentation.
7
+
8
+ .. automodule:: ocr_stringdist.levenshtein
9
+ :members:
10
+ :undoc-members:
11
+ :show-inheritance:
12
+
13
+ .. automodule:: ocr_stringdist.matching
14
+ :members:
15
+ :undoc-members:
16
+ :show-inheritance:
17
+
18
+ .. autodata:: ocr_stringdist.default_ocr_distances.ocr_distance_map
19
+ :annotation:
20
+ .. literalinclude:: ../../../python/ocr_stringdist/default_ocr_distances.py
21
+ :language: python
22
+ :start-after: OCR_DISTANCE_MAP_START
23
+ :end-before: OCR_DISTANCE_MAP_END
@@ -0,0 +1,58 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example demonstrating the usage of the batch processing functions from ocr_stringdist.
4
+ """
5
+
6
+ import time
7
+ from typing import Any, Callable
8
+
9
+ import ocr_stringdist as osd
10
+
11
+
12
+ def benchmark(func: Callable, *args: Any, **kwargs: Any) -> tuple[Any, float]: # type: ignore
13
+ """Run a function and return the execution time in seconds."""
14
+ start = time.time()
15
+ result = func(*args, **kwargs)
16
+ end = time.time()
17
+ return result, end - start
18
+
19
+
20
+ def compare_methods() -> None:
21
+ """
22
+ Compare the performance of different methods for calculating Levenshtein distances.
23
+ """
24
+ # Example data
25
+ source = "recognition"
26
+ candidates = ["recognition", "recogmtion", "recognltlon", "recogrtition", "recognitton"] * 1000
27
+
28
+ print("\nSingle string against multiple candidates:")
29
+ print("-" * 50)
30
+
31
+ # Standard loop approach
32
+ _, time_loop = benchmark(
33
+ lambda: [osd.weighted_levenshtein_distance(source, cand) for cand in candidates]
34
+ )
35
+ print(
36
+ f"Loop of single calls: {time_loop:.6f} seconds "
37
+ f"({1000 * time_loop / len(candidates):.6f}ms each)"
38
+ )
39
+
40
+ # Batch approach
41
+ _, time_batch = benchmark(osd.batch_weighted_levenshtein_distance, source, candidates)
42
+ print(
43
+ f"Batch function: {time_batch:.6f} seconds "
44
+ f"({1000 * time_batch / len(candidates):.6f}ms each)"
45
+ )
46
+ print(f"Speedup: {time_loop / time_batch:.2f}x")
47
+
48
+
49
+ def main() -> None:
50
+ """Main function."""
51
+ print("Demonstrating batch processing functions from ocr_stringdist\n")
52
+
53
+ # Run the benchmarks
54
+ compare_methods()
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
@@ -1,3 +1,4 @@
1
+ #!/usr/bin/env python3
1
2
  from icecream import ic
2
3
  from ocr_stringdist import find_best_candidate, weighted_levenshtein_distance
3
4
 
@@ -0,0 +1,10 @@
1
+ from .default_ocr_distances import ocr_distance_map
2
+ from .levenshtein import batch_weighted_levenshtein_distance, weighted_levenshtein_distance
3
+ from .matching import find_best_candidate
4
+
5
+ __all__ = [
6
+ "ocr_distance_map",
7
+ "weighted_levenshtein_distance",
8
+ "batch_weighted_levenshtein_distance",
9
+ "find_best_candidate",
10
+ ]
@@ -1,3 +1,5 @@
1
+ # Start marker for literalinclude, see docs/source/api/index.rst.
2
+ # OCR_DISTANCE_MAP_START
1
3
  ocr_distance_map: dict[tuple[str, str], float] = {
2
4
  ("O", "0"): 0.1,
3
5
  ("l", "1"): 0.1,
@@ -31,6 +33,8 @@ ocr_distance_map: dict[tuple[str, str], float] = {
31
33
  ("é", "á"): 0.7,
32
34
  ("E", "F"): 0.8,
33
35
  }
36
+ # OCR_DISTANCE_MAP_END
37
+ # End marker for literalinclude
34
38
  """
35
39
  Pre-defined distance map between characters, considering common OCR errors.
36
40
  The distances are between 0 and 1.
@@ -0,0 +1,71 @@
1
+ from typing import Optional
2
+
3
+ from ._rust_stringdist import * # noqa: F403
4
+ from .default_ocr_distances import ocr_distance_map
5
+
6
+
7
+ def weighted_levenshtein_distance(
8
+ s1: str,
9
+ s2: str,
10
+ /,
11
+ cost_map: Optional[dict[tuple[str, str], float]] = None,
12
+ *,
13
+ symmetric: bool = True,
14
+ default_cost: float = 1.0,
15
+ ) -> float:
16
+ """
17
+ Levenshtein distance with custom substitution costs.
18
+ Insertion/deletion costs are 1.
19
+
20
+ The default `cost_map` considers common OCR errors, see
21
+ :py:data:`ocr_stringdist.default_ocr_distances.ocr_distance_map`.
22
+
23
+ :param s1: First string
24
+ :param s2: Second string
25
+ :param cost_map: Dictionary mapping tuples of strings ("substitution tokens") to their
26
+ substitution costs.
27
+ Only one direction needs to be configured unless `symmetric` is False.
28
+ Note that the runtime scales in the length of the longest substitution token.
29
+ Defaults to `ocr_stringdist.ocr_distance_map`.
30
+ :param symmetric: Should the keys of `cost_map` be considered to be symmetric? Defaults to True.
31
+ :param default_cost: The default substitution cost for character pairs not found in `cost_map`.
32
+ """
33
+ if cost_map is None:
34
+ cost_map = ocr_distance_map
35
+ # _weighted_levenshtein_distance is written in Rust, see src/rust_stringdist.rs.
36
+ return _weighted_levenshtein_distance( # type: ignore # noqa: F405
37
+ s1, s2, cost_map=cost_map, symmetric=symmetric, default_cost=default_cost
38
+ )
39
+
40
+
41
+ def batch_weighted_levenshtein_distance(
42
+ s: str,
43
+ candidates: list[str],
44
+ /,
45
+ cost_map: Optional[dict[tuple[str, str], float]] = None,
46
+ *,
47
+ symmetric: bool = True,
48
+ default_cost: float = 1.0,
49
+ ) -> list[float]:
50
+ """
51
+ Calculate weighted Levenshtein distances between a string and multiple candidates.
52
+
53
+ This is more efficient than calling :func:`weighted_levenshtein_distance` multiple times.
54
+
55
+ :param s: The string to compare
56
+ :param candidates: List of candidate strings to compare against
57
+ :param cost_map: Dictionary mapping tuples of strings ("substitution tokens") to their
58
+ substitution costs.
59
+ Only one direction needs to be configured unless `symmetric` is False.
60
+ Note that the runtime scales in the length of the longest substitution token.
61
+ Defaults to `ocr_stringdist.ocr_distance_map`.
62
+ :param symmetric: Should the keys of `cost_map` be considered to be symmetric? Defaults to True.
63
+ :param default_cost: The default substitution cost for character pairs not found in `cost_map`.
64
+ :return: A list of distances corresponding to each candidate
65
+ """
66
+ if cost_map is None:
67
+ cost_map = ocr_distance_map
68
+ # _batch_weighted_levenshtein_distance is written in Rust, see src/rust_stringdist.rs.
69
+ return _batch_weighted_levenshtein_distance( # type: ignore # noqa: F405
70
+ s, candidates, cost_map=cost_map, symmetric=symmetric, default_cost=default_cost
71
+ )
@@ -1,5 +1,7 @@
1
+ mod longest_tokens;
1
2
  mod weighted_levenshtein;
2
3
 
4
+ pub use longest_tokens::longest_key_string_length;
3
5
  pub use weighted_levenshtein::{custom_levenshtein_distance_with_cost_map, OcrCostMap};
4
6
 
5
7
  #[cfg(feature = "python")]
@@ -0,0 +1,48 @@
1
+ use std::collections::HashMap;
2
+
3
+ /// Calculates the length of the longest string found within the key tuples of a HashMap.
4
+ pub fn longest_key_string_length<V>(map: &HashMap<(String, String), V>) -> usize {
5
+ map.keys()
6
+ .flat_map(|(s1, s2)| [s1.len(), s2.len()].into_iter())
7
+ .max()
8
+ .unwrap_or(1)
9
+ }
10
+
11
+ #[cfg(test)]
12
+ mod tests {
13
+ use super::*;
14
+
15
+ #[test]
16
+ fn test_longest_key_string_length_basic() {
17
+ let mut map = HashMap::new();
18
+ map.insert(("apple".to_string(), "banana".to_string()), 1); // 5, 6
19
+ map.insert(("kiwi".to_string(), "grapefruit".to_string()), 2); // 4, 10
20
+ map.insert(("short".to_string(), "tiny".to_string()), 3); // 5, 4
21
+
22
+ assert_eq!(longest_key_string_length(&map), 10); // "grapefruit"
23
+ }
24
+
25
+ #[test]
26
+ fn test_longest_key_string_length_first_element() {
27
+ let mut map = HashMap::new();
28
+ map.insert(("a_very_long_string".to_string(), "short".to_string()), 1); // 18, 5
29
+ map.insert(("medium".to_string(), "small".to_string()), 2); // 6, 5
30
+
31
+ assert_eq!(longest_key_string_length(&map), 18);
32
+ }
33
+
34
+ #[test]
35
+ fn test_longest_key_string_length_empty_map() {
36
+ let map: HashMap<(String, String), bool> = HashMap::new();
37
+ assert_eq!(longest_key_string_length(&map), 1);
38
+ }
39
+
40
+ #[test]
41
+ fn test_longest_key_string_length_empty_strings() {
42
+ let mut map = HashMap::new();
43
+ map.insert(("".to_string(), "".to_string()), 1);
44
+ map.insert(("a".to_string(), "".to_string()), 2);
45
+
46
+ assert_eq!(longest_key_string_length(&map), 1);
47
+ }
48
+ }
@@ -0,0 +1,56 @@
1
+ use crate::custom_levenshtein_distance_with_cost_map as _weighted_lev_with_map;
2
+ use crate::longest_key_string_length;
3
+ use crate::OcrCostMap;
4
+ use pyo3::prelude::*;
5
+ use pyo3::types::PyDict;
6
+ use rayon::prelude::*;
7
+
8
+ // Calculates the weighted Levenshtein distance with a custom cost map from Python.
9
+ #[pyfunction]
10
+ #[pyo3(signature = (a, b, cost_map, symmetric = true, default_cost = 1.0))]
11
+ fn _weighted_levenshtein_distance(
12
+ a: &str,
13
+ b: &str,
14
+ cost_map: &Bound<'_, PyDict>,
15
+ symmetric: bool,
16
+ default_cost: f64,
17
+ ) -> PyResult<f64> {
18
+ let ocr_cost_map = OcrCostMap::from_py_dict(cost_map, default_cost, symmetric);
19
+ let max_token_characters = longest_key_string_length(&ocr_cost_map.costs);
20
+ Ok(_weighted_lev_with_map(
21
+ a,
22
+ b,
23
+ &ocr_cost_map,
24
+ max_token_characters,
25
+ ))
26
+ }
27
+
28
+ // Calculates the weighted Levenshtein distance between a string and a list of candidates.
29
+ #[pyfunction]
30
+ #[pyo3(signature = (s, candidates, cost_map, symmetric = true, default_cost = 1.0))]
31
+ fn _batch_weighted_levenshtein_distance(
32
+ s: &str,
33
+ candidates: Vec<String>,
34
+ cost_map: &Bound<'_, PyDict>,
35
+ symmetric: bool,
36
+ default_cost: f64,
37
+ ) -> PyResult<Vec<f64>> {
38
+ let ocr_cost_map = OcrCostMap::from_py_dict(cost_map, default_cost, symmetric);
39
+ let max_token_characters = longest_key_string_length(&ocr_cost_map.costs);
40
+
41
+ // Calculate distances for each candidate in parallel
42
+ let distances: Vec<f64> = candidates
43
+ .par_iter()
44
+ .map(|candidate| _weighted_lev_with_map(s, candidate, &ocr_cost_map, max_token_characters))
45
+ .collect();
46
+
47
+ Ok(distances)
48
+ }
49
+
50
+ /// A Python module implemented in Rust.
51
+ #[pymodule]
52
+ pub fn _rust_stringdist(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
53
+ m.add_function(wrap_pyfunction!(_weighted_levenshtein_distance, m)?)?;
54
+ m.add_function(wrap_pyfunction!(_batch_weighted_levenshtein_distance, m)?)?;
55
+ Ok(())
56
+ }
@@ -1,9 +1,12 @@
1
1
  use std::collections::HashMap;
2
2
 
3
+ #[cfg(feature = "python")]
4
+ use pyo3::prelude::*;
5
+
3
6
  #[derive(Clone, Debug)]
4
7
  pub struct OcrCostMap {
5
8
  /// Maps pairs of strings to their specific substitution cost.
6
- costs: HashMap<(String, String), f64>,
9
+ pub costs: HashMap<(String, String), f64>,
7
10
  /// Default cost for substitutions not found in the map.
8
11
  default_substitution_cost: f64,
9
12
  }
@@ -31,6 +34,28 @@ impl OcrCostMap {
31
34
  }
32
35
  }
33
36
 
37
+ #[cfg(feature = "python")]
38
+ /// Creates an OcrCostMap from a Python dictionary.
39
+ /// This method is only available when the "python" feature is enabled.
40
+ pub fn from_py_dict<'a, D>(py_dict: &'a D, default_cost: f64, symmetric: bool) -> Self
41
+ where
42
+ D: PyDictMethods<'a>,
43
+ {
44
+ let mut substitution_costs: HashMap<(String, String), f64> = HashMap::new();
45
+
46
+ // Convert Python dictionary to Rust HashMap
47
+ for (key, value) in py_dict.iter() {
48
+ if let Ok(key_tuple) = key.extract::<(String, String)>() {
49
+ if let Ok(cost) = value.extract::<f64>() {
50
+ substitution_costs.insert((key_tuple.0, key_tuple.1), cost);
51
+ }
52
+ }
53
+ }
54
+
55
+ // Create the OcrCostMap
56
+ Self::new(substitution_costs, default_cost, symmetric)
57
+ }
58
+
34
59
  /// Gets the substitution cost between two strings.
35
60
  /// Checks the custom map first, then falls back to the
36
61
  /// default substitution cost configured within this map instance.
@@ -55,7 +80,12 @@ impl OcrCostMap {
55
80
 
56
81
  /// Calculates custom Levenshtein distance between two strings using a provided cost map.
57
82
  /// This implementation considers string-to-string substitutions rather than just characters.
58
- pub fn custom_levenshtein_distance_with_cost_map(s1: &str, s2: &str, cost_map: &OcrCostMap) -> f64 {
83
+ pub fn custom_levenshtein_distance_with_cost_map(
84
+ s1: &str,
85
+ s2: &str,
86
+ cost_map: &OcrCostMap,
87
+ max_token_characters: usize,
88
+ ) -> f64 {
59
89
  if s1 == s2 {
60
90
  return 0.0;
61
91
  }
@@ -86,7 +116,7 @@ pub fn custom_levenshtein_distance_with_cost_map(s1: &str, s2: &str, cost_map: &
86
116
  }
87
117
 
88
118
  // Limit on substring lengths to check
89
- let max_substr_len = 5.min(len1.max(len2));
119
+ let max_substr_len = max_token_characters.min(len1.max(len2));
90
120
 
91
121
  // Fill the dp matrix
92
122
  for i in 1..=len1 {
@@ -176,7 +206,7 @@ mod test {
176
206
  );
177
207
 
178
208
  assert_approx_eq(
179
- custom_levenshtein_distance_with_cost_map("abc", "bbc", &cost_map),
209
+ custom_levenshtein_distance_with_cost_map("abc", "bbc", &cost_map, 3),
180
210
  0.1,
181
211
  1e-9,
182
212
  );
@@ -192,14 +222,14 @@ mod test {
192
222
 
193
223
  // Test that "hi" with "Ini" has a low cost due to the special substitution
194
224
  assert_approx_eq(
195
- custom_levenshtein_distance_with_cost_map("hi", "Ini", &cost_map),
225
+ custom_levenshtein_distance_with_cost_map("hi", "Ini", &cost_map, 2),
196
226
  0.2, // Only the h->In substitution cost
197
227
  1e-9,
198
228
  );
199
229
 
200
230
  // Test another example
201
231
  assert_approx_eq(
202
- custom_levenshtein_distance_with_cost_map("hello", "Inello", &cost_map),
232
+ custom_levenshtein_distance_with_cost_map("hello", "Inello", &cost_map, 2),
203
233
  0.2, // Only the h->In substitution cost
204
234
  1e-9,
205
235
  );
@@ -214,7 +244,7 @@ mod test {
214
244
 
215
245
  // Test multiple substitutions in the same string
216
246
  assert_approx_eq(
217
- custom_levenshtein_distance_with_cost_map("hello", "Ine11o", &cost_map),
247
+ custom_levenshtein_distance_with_cost_map("hello", "Ine11o", &cost_map, 2),
218
248
  0.8, // 0.2 for h->In and 0.3+0.3 for l->1 twice
219
249
  1e-9,
220
250
  );
@@ -229,14 +259,14 @@ mod test {
229
259
 
230
260
  // Test the rn->m substitution
231
261
  assert_approx_eq(
232
- custom_levenshtein_distance_with_cost_map("corner", "comer", &cost_map),
262
+ custom_levenshtein_distance_with_cost_map("corner", "comer", &cost_map, 2),
233
263
  0.1,
234
264
  1e-9,
235
265
  );
236
266
 
237
267
  // Test the cl->d substitution
238
268
  assert_approx_eq(
239
- custom_levenshtein_distance_with_cost_map("class", "dass", &cost_map),
269
+ custom_levenshtein_distance_with_cost_map("class", "dass", &cost_map, 2),
240
270
  0.2,
241
271
  1e-9,
242
272
  );
@@ -253,14 +283,14 @@ mod test {
253
283
 
254
284
  // Test 0->O substitution (lower cost)
255
285
  assert_approx_eq(
256
- custom_levenshtein_distance_with_cost_map("R0AD", "ROAD", &cost_map),
286
+ custom_levenshtein_distance_with_cost_map("R0AD", "ROAD", &cost_map, 1),
257
287
  0.1,
258
288
  1e-9,
259
289
  );
260
290
 
261
291
  // Test O->0 substitution (higher cost)
262
292
  assert_approx_eq(
263
- custom_levenshtein_distance_with_cost_map("rOad", "r0ad", &cost_map),
293
+ custom_levenshtein_distance_with_cost_map("rOad", "r0ad", &cost_map, 1),
264
294
  0.5,
265
295
  1e-9,
266
296
  );
@@ -274,14 +304,14 @@ mod test {
274
304
 
275
305
  // Test substitution at start of word
276
306
  assert_approx_eq(
277
- custom_levenshtein_distance_with_cost_map("rnat", "mat", &cost_map),
307
+ custom_levenshtein_distance_with_cost_map("rnat", "mat", &cost_map, 2),
278
308
  0.1,
279
309
  1e-9,
280
310
  );
281
311
 
282
312
  // Test substitution at end of word
283
313
  assert_approx_eq(
284
- custom_levenshtein_distance_with_cost_map("burn", "bum", &cost_map),
314
+ custom_levenshtein_distance_with_cost_map("burn", "bum", &cost_map, 2),
285
315
  0.1,
286
316
  1e-9,
287
317
  );
@@ -294,13 +324,13 @@ mod test {
294
324
 
295
325
  // Test that "h" -> "In" costs 2.0 (1 deletion + 1 substitution) since there's no custom mapping
296
326
  assert_approx_eq(
297
- custom_levenshtein_distance_with_cost_map("h", "In", &cost_map),
327
+ custom_levenshtein_distance_with_cost_map("h", "In", &cost_map, 1),
298
328
  2.0,
299
329
  1e-9,
300
330
  );
301
331
 
302
332
  assert_approx_eq(
303
- custom_levenshtein_distance_with_cost_map("kitten", "sitting", &cost_map),
333
+ custom_levenshtein_distance_with_cost_map("kitten", "sitting", &cost_map, 1),
304
334
  3.0,
305
335
  1e-9,
306
336
  );
@@ -314,7 +344,7 @@ mod test {
314
344
  // - Insert 'e' (1)
315
345
  // Total: 4 operations
316
346
  assert_approx_eq(
317
- custom_levenshtein_distance_with_cost_map("café", "coffee", &cost_map),
347
+ custom_levenshtein_distance_with_cost_map("café", "coffee", &cost_map, 1),
318
348
  4.0, // 4 edits required
319
349
  1e-9,
320
350
  );
@@ -0,0 +1,131 @@
1
+ """
2
+ Unit tests for the batch processing functions.
3
+ """
4
+
5
+ import pytest
6
+ from ocr_stringdist import batch_weighted_levenshtein_distance, weighted_levenshtein_distance
7
+
8
+ # Define a custom cost map with some OCR confusions for testing
9
+ OCR_COST_MAP = {
10
+ ("l", "1"): 0.2, # l to 1 is a common OCR error
11
+ ("O", "0"): 0.1, # O to 0 is a common OCR error
12
+ ("o", "0"): 0.1,
13
+ ("m", "rn"): 0.3, # m to rn is a common OCR error
14
+ }
15
+
16
+
17
+ @pytest.mark.parametrize(
18
+ ["source", "candidates", "cost_map"],
19
+ [
20
+ (
21
+ "recognition",
22
+ ["recognition", "recogmtion", "recognltlon", "recogrtition", "recognitton"],
23
+ None,
24
+ ),
25
+ (
26
+ "hello",
27
+ ["hello", "he11o", "hell0"],
28
+ OCR_COST_MAP,
29
+ ),
30
+ (
31
+ "algorithm",
32
+ ["algorithm", "algorlthm", "a1gorithm"],
33
+ OCR_COST_MAP,
34
+ ),
35
+ ],
36
+ )
37
+ def test_batch_vs_individual(
38
+ source: str, candidates: list[str], cost_map: dict[tuple[str, str], float]
39
+ ) -> None:
40
+ """Test that batch results match individual function calls."""
41
+ # Individual results
42
+ individual_results = [
43
+ weighted_levenshtein_distance(source, candidate, cost_map=cost_map)
44
+ for candidate in candidates
45
+ ]
46
+
47
+ # Batch results
48
+ batch_results = batch_weighted_levenshtein_distance(source, candidates, cost_map=cost_map)
49
+
50
+ # Compare results
51
+ for ind, batch in zip(individual_results, batch_results):
52
+ assert ind == pytest.approx(batch)
53
+
54
+
55
+ @pytest.mark.parametrize(
56
+ ["source", "candidates", "expected_indices"],
57
+ [
58
+ (
59
+ "hello",
60
+ ["hello", "he11o", "hell0", "hallo", "help"],
61
+ [0], # exact match should be the best
62
+ ),
63
+ (
64
+ "algorithm",
65
+ ["a1gorithm", "algorithm", "algorlthm", "alg0rithm"],
66
+ [1], # exact match should be the best
67
+ ),
68
+ (
69
+ "recognition",
70
+ ["wreck", "cognition", "recogmition", "wreckognition"],
71
+ [2], # "recogmtion" should be closest to "recognition"
72
+ ),
73
+ ],
74
+ )
75
+ def test_batch_finds_best_match(
76
+ source: str, candidates: list[str], expected_indices: list[int]
77
+ ) -> None:
78
+ """Test that batch processing correctly identifies the best match."""
79
+ # Using OCR cost map
80
+ distances = batch_weighted_levenshtein_distance(source, candidates, cost_map=OCR_COST_MAP)
81
+ print(f"------------------------------------distances: {distances}")
82
+
83
+ # Find the index with minimum distance
84
+ min_index = distances.index(min(distances))
85
+
86
+ # Check if the minimum index is in the expected indices
87
+ assert min_index in expected_indices
88
+
89
+
90
+ @pytest.mark.parametrize(
91
+ ["test_string", "expected_distance"],
92
+ [
93
+ ("hello", 0.0), # exact match
94
+ ("he11o", 0.4), # two l->1 substitutions at cost 0.2 each
95
+ ("hell0", 0.1), # one O->0 substitution at cost 0.1
96
+ ],
97
+ )
98
+ def test_custom_cost_map(test_string: str, expected_distance: float) -> None:
99
+ """Test using a custom cost map for specific substitution costs."""
100
+ result = weighted_levenshtein_distance("hello", test_string, cost_map=OCR_COST_MAP)
101
+ assert result == pytest.approx(expected_distance)
102
+
103
+ # Check that batch processing gives the same result
104
+ batch_result = batch_weighted_levenshtein_distance(
105
+ "hello", [test_string], cost_map=OCR_COST_MAP
106
+ )[0]
107
+ assert batch_result == pytest.approx(expected_distance)
108
+
109
+
110
+ @pytest.mark.parametrize(
111
+ ["string1", "string2", "default_map_distance", "custom_map_distance"],
112
+ [
113
+ ("hello", "he11o", 2.0, 0.4), # l->1 costs 0.2 each instead of 1.0 each
114
+ ("hello", "hell0", 1.0, 0.1), # o->0 costs 0.1 instead of 1.0
115
+ ("come", "corne", 2.0, 0.3), # rn->m costs 0.3 instead of 2.0
116
+ ],
117
+ )
118
+ def test_empty_vs_default_cost_map(
119
+ string1: str, string2: str, default_map_distance: float, custom_map_distance: float
120
+ ) -> None:
121
+ """Test that empty cost maps produce different results than default cost maps."""
122
+ # With empty cost map (all costs are 1.0)
123
+ default_result = batch_weighted_levenshtein_distance(string1, [string2], cost_map={})
124
+ assert default_result[0] == pytest.approx(default_map_distance)
125
+
126
+ # With custom cost map (OCR-specific costs)
127
+ custom_result = batch_weighted_levenshtein_distance(string1, [string2], cost_map=OCR_COST_MAP)
128
+ assert custom_result[0] == pytest.approx(custom_map_distance)
129
+
130
+ # Custom map should give lower distance for OCR errors
131
+ assert custom_result[0] < default_result[0]
@@ -1,18 +0,0 @@
1
- .. _api_reference:
2
-
3
- API Reference
4
- =============
5
-
6
- This page contains the auto-generated API reference documentation.
7
-
8
- .. autofunction:: ocr_stringdist.__init__.weighted_levenshtein_distance
9
-
10
- .. automodule:: ocr_stringdist.matching
11
- :members:
12
- :undoc-members:
13
- :show-inheritance:
14
-
15
- .. automodule:: ocr_stringdist.default_ocr_distances
16
- :members:
17
- :undoc-members:
18
- :show-inheritance:
@@ -1,42 +0,0 @@
1
- from typing import Optional
2
-
3
- from ._rust_stringdist import * # noqa: F403
4
- from .default_ocr_distances import ocr_distance_map
5
- from .matching import find_best_candidate
6
-
7
- __all__ = [
8
- "ocr_distance_map",
9
- "weighted_levenshtein_distance", # noqa: F405
10
- "find_best_candidate",
11
- ]
12
-
13
-
14
- def weighted_levenshtein_distance(
15
- s1: str,
16
- s2: str,
17
- /,
18
- cost_map: Optional[dict[tuple[str, str], float]] = None,
19
- *,
20
- symmetric: bool = True,
21
- default_cost: float = 1.0,
22
- ) -> float:
23
- """
24
- Levenshtein distance with custom substitution costs.
25
- Insertion/deletion costs are 1.
26
-
27
- The default `cost_map` considers common OCR errors, see `ocr_stringdist.ocr_distance_map`.
28
-
29
- :param s1: First string
30
- :param s2: Second string
31
- :param cost_map: Dictionary mapping tuples of characters to their substitution cost.
32
- Only one direction needs to be configured unless `symmetric` is False.
33
- Defaults to `ocr_stringdist.ocr_distance_map`.
34
- :param symmetric: Should the keys of `cost_map` be considered to be symmetric? Defaults to True.
35
- :param default_cost: The default substitution cost for character pairs not found in `cost_map`.
36
- """
37
- if cost_map is None:
38
- cost_map = ocr_distance_map
39
- # _weighted_levenshtein_distance is written in Rust, see src/rust_stringdist.rs.
40
- return _weighted_levenshtein_distance( # type: ignore # noqa: F405
41
- s1, s2, cost_map=cost_map, symmetric=symmetric, default_cost=default_cost
42
- )
@@ -1,39 +0,0 @@
1
- use crate::custom_levenshtein_distance_with_cost_map as _weighted_lev_with_map;
2
- use crate::OcrCostMap;
3
- use pyo3::prelude::*;
4
- use pyo3::types::PyDict;
5
- use std::collections::HashMap;
6
-
7
- // Calculates the weighted Levenshtein distance with a custom cost map from Python.
8
- #[pyfunction]
9
- #[pyo3(signature = (a, b, cost_map, symmetric = true, default_cost = 1.0))]
10
- fn _weighted_levenshtein_distance(
11
- a: &str,
12
- b: &str,
13
- cost_map: &Bound<'_, PyDict>,
14
- symmetric: bool,
15
- default_cost: Option<f64>,
16
- ) -> PyResult<f64> {
17
- let default_cost_value = default_cost.unwrap_or(1.0);
18
- let mut substitution_costs: HashMap<(String, String), f64> = HashMap::new();
19
-
20
- // Convert Python dictionary to Rust HashMap
21
- for (key, value) in cost_map.iter() {
22
- if let Ok(key_tuple) = key.extract::<(String, String)>() {
23
- if let Ok(cost) = value.extract::<f64>() {
24
- substitution_costs.insert((key_tuple.0, key_tuple.1), cost);
25
- }
26
- }
27
- }
28
-
29
- // Create a custom cost map and calculate the distance
30
- let custom_cost_map = OcrCostMap::new(substitution_costs, default_cost_value, symmetric);
31
- Ok(_weighted_lev_with_map(a, b, &custom_cost_map))
32
- }
33
-
34
- /// A Python module implemented in Rust.
35
- #[pymodule]
36
- pub fn _rust_stringdist(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
37
- m.add_function(wrap_pyfunction!(_weighted_levenshtein_distance, m)?)?;
38
- Ok(())
39
- }
File without changes
File without changes
File without changes
File without changes