ocr-stringdist 0.0.7__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocr_stringdist-0.1.0/CHANGELOG.md +16 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/Cargo.lock +1 -1
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/Cargo.toml +1 -1
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/Justfile +3 -2
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/PKG-INFO +1 -2
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/README.md +0 -1
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/docs/source/api/index.rst +0 -2
- ocr_stringdist-0.1.0/docs/source/changelog.rst +1 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/docs/source/conf.py +3 -0
- ocr_stringdist-0.1.0/docs/source/examples.rst +20 -0
- ocr_stringdist-0.1.0/docs/source/getting-started.rst +10 -0
- ocr_stringdist-0.1.0/docs/source/index.rst +36 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/examples/weighted_levenshtein.py +4 -4
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/pyproject.toml +1 -0
- ocr_stringdist-0.1.0/python/ocr_stringdist/levenshtein.py +118 -0
- ocr_stringdist-0.1.0/src/cost_map.rs +306 -0
- ocr_stringdist-0.1.0/src/lib.rs +12 -0
- ocr_stringdist-0.1.0/src/rust_stringdist.rs +135 -0
- ocr_stringdist-0.1.0/src/types.rs +13 -0
- ocr_stringdist-0.1.0/src/weighted_levenshtein.rs +992 -0
- ocr_stringdist-0.0.7/tests/test_batch_functions.py → ocr_stringdist-0.1.0/tests/test_batch_weighted_levenshtein.py +13 -7
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/tests/test_matching.py +3 -1
- ocr_stringdist-0.1.0/tests/test_weighted_levenshtein.py +545 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/uv.lock +90 -0
- ocr_stringdist-0.0.7/docs/source/index.rst +0 -10
- ocr_stringdist-0.0.7/python/ocr_stringdist/levenshtein.py +0 -71
- ocr_stringdist-0.0.7/src/lib.rs +0 -10
- ocr_stringdist-0.0.7/src/longest_tokens.rs +0 -48
- ocr_stringdist-0.0.7/src/rust_stringdist.rs +0 -56
- ocr_stringdist-0.0.7/src/weighted_levenshtein.rs +0 -352
- ocr_stringdist-0.0.7/tests/test_ocr_stringdist.py +0 -106
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/.github/workflows/CI.yml +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/.github/workflows/docs.yml +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/.gitignore +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/LICENSE +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/docs/Makefile +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/docs/make.bat +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/examples/batch_processing.py +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/mypy.ini +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/python/ocr_stringdist/__init__.py +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/python/ocr_stringdist/default_ocr_distances.py +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/python/ocr_stringdist/matching.py +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/python/ocr_stringdist/py.typed +0 -0
- {ocr_stringdist-0.0.7 → ocr_stringdist-0.1.0}/ruff.toml +0 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
# Changelog
|
2
|
+
|
3
|
+
All notable changes to this project will be documented in this file.
|
4
|
+
|
5
|
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
6
|
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
7
|
+
|
8
|
+
## [0.1.0] - 2025-04-26
|
9
|
+
|
10
|
+
### Added
|
11
|
+
|
12
|
+
- Custom insertion and deletion costs for weighted Levenshtein distance.
|
13
|
+
|
14
|
+
### Changed
|
15
|
+
|
16
|
+
- Breaking changes to Levenshtein distance functions signatures.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ocr_stringdist
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.1.0
|
4
4
|
Classifier: Programming Language :: Rust
|
5
5
|
Classifier: Programming Language :: Python
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -40,7 +40,6 @@ pip install ocr-stringdist
|
|
40
40
|
- **Unicode Support**: Arbitrary unicode strings can be compared.
|
41
41
|
- **Substitution of Multiple Characters**: Not just character pairs, but string pairs may be substituted, for example the Korean syllable "이" for the two letters "OI".
|
42
42
|
- **Pre-defined OCR Distance Map**: A built-in distance map for common OCR confusions (e.g., "0" vs "O", "1" vs "l", "5" vs "S").
|
43
|
-
- **Customizable Cost Maps**: Create your own substitution cost maps for specific OCR systems or domains.
|
44
43
|
- **Best Match Finder**: Utility function `find_best_candidate` to efficiently find the best matching string from a collection of candidates using any specified distance function (including the library's OCR-aware ones).
|
45
44
|
|
46
45
|
## Usage
|
@@ -25,7 +25,6 @@ pip install ocr-stringdist
|
|
25
25
|
- **Unicode Support**: Arbitrary unicode strings can be compared.
|
26
26
|
- **Substitution of Multiple Characters**: Not just character pairs, but string pairs may be substituted, for example the Korean syllable "이" for the two letters "OI".
|
27
27
|
- **Pre-defined OCR Distance Map**: A built-in distance map for common OCR confusions (e.g., "0" vs "O", "1" vs "l", "5" vs "S").
|
28
|
-
- **Customizable Cost Maps**: Create your own substitution cost maps for specific OCR systems or domains.
|
29
28
|
- **Best Match Finder**: Utility function `find_best_candidate` to efficiently find the best matching string from a collection of candidates using any specified distance function (including the library's OCR-aware ones).
|
30
29
|
|
31
30
|
## Usage
|
@@ -0,0 +1 @@
|
|
1
|
+
.. mdinclude:: ../../CHANGELOG.md
|
@@ -4,6 +4,7 @@
|
|
4
4
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
5
5
|
|
6
6
|
|
7
|
+
import importlib.metadata
|
7
8
|
import os
|
8
9
|
import sys
|
9
10
|
|
@@ -17,6 +18,8 @@ sys.path.insert(0, os.path.abspath("../../python"))
|
|
17
18
|
project = "OCR-StringDist"
|
18
19
|
copyright = "2025, Niklas von Moers"
|
19
20
|
author = "Niklas von Moers"
|
21
|
+
release = importlib.metadata.version("ocr_stringdist")
|
22
|
+
version = release
|
20
23
|
|
21
24
|
# -- General configuration ---------------------------------------------------
|
22
25
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
@@ -0,0 +1,20 @@
|
|
1
|
+
==========
|
2
|
+
Examples
|
3
|
+
==========
|
4
|
+
|
5
|
+
.. code-block:: python
|
6
|
+
|
7
|
+
import ocr_stringdist as osd
|
8
|
+
|
9
|
+
# Using default OCR distance map
|
10
|
+
distance = osd.weighted_levenshtein_distance("OCR5", "OCRS")
|
11
|
+
print(f"Distance between 'OCR5' and 'OCRS': {distance}") # Will be less than 1.0
|
12
|
+
|
13
|
+
# Custom cost map
|
14
|
+
custom_map = {("In", "h"): 0.5}
|
15
|
+
distance = osd.weighted_levenshtein_distance(
|
16
|
+
"hi", "Ini",
|
17
|
+
cost_map=custom_map,
|
18
|
+
symmetric=True,
|
19
|
+
)
|
20
|
+
print(f"Distance with custom map: {distance}")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
================
|
2
|
+
OCR-StringDist
|
3
|
+
================
|
4
|
+
|
5
|
+
A Python library for string distance calculations that account for common OCR (optical character recognition) errors, written in Rust.
|
6
|
+
|
7
|
+
:Repository: https://niklasvonm.github.io/ocr-stringdist/
|
8
|
+
:Current version: |release|
|
9
|
+
|
10
|
+
.. image:: https://img.shields.io/badge/PyPI-Package-blue
|
11
|
+
:target: https://pypi.org/project/ocr-stringdist/
|
12
|
+
:alt: PyPI
|
13
|
+
|
14
|
+
.. image:: https://img.shields.io/badge/License-MIT-green
|
15
|
+
:target: LICENSE
|
16
|
+
:alt: License
|
17
|
+
|
18
|
+
Features
|
19
|
+
========
|
20
|
+
|
21
|
+
- **Weighted Levenshtein Distance**: An adaptation of the classic Levenshtein algorithm with custom substitution costs for character pairs that are commonly confused in OCR models, including efficient batch processing.
|
22
|
+
- **Unicode Support**: Arbitrary unicode strings can be compared.
|
23
|
+
- **Substitution of Multiple Characters**: Not just character pairs, but string pairs may be substituted, for example the Korean syllable "이" for the two letters "OI".
|
24
|
+
- **Pre-defined OCR Distance Map**: A built-in distance map for common OCR confusions (e.g., "0" vs "O", "1" vs "l", "5" vs "S").
|
25
|
+
- **Best Match Finder**: Utility function ``find_best_candidate`` to efficiently find the best matching string from a collection of candidates using any specified distance function (including the library's OCR-aware ones).
|
26
|
+
|
27
|
+
Contents
|
28
|
+
========
|
29
|
+
|
30
|
+
.. toctree::
|
31
|
+
:maxdepth: 1
|
32
|
+
|
33
|
+
getting-started
|
34
|
+
examples
|
35
|
+
api/index
|
36
|
+
changelog
|
@@ -31,13 +31,13 @@ ic(
|
|
31
31
|
weighted_levenshtein_distance(
|
32
32
|
"ABCDE",
|
33
33
|
"XBCDE",
|
34
|
-
|
35
|
-
|
34
|
+
substitution_costs={},
|
35
|
+
default_substitution_cost=0.8, # Lower default substitution cost (default is 1.0)
|
36
36
|
)
|
37
37
|
)
|
38
38
|
|
39
|
-
ic(weighted_levenshtein_distance("A", "B", {("A", "B"): 0.0},
|
40
|
-
ic(weighted_levenshtein_distance("A", "B", {("B", "A"): 0.0},
|
39
|
+
ic(weighted_levenshtein_distance("A", "B", {("A", "B"): 0.0}, symmetric_substitution=False))
|
40
|
+
ic(weighted_levenshtein_distance("A", "B", {("B", "A"): 0.0}, symmetric_substitution=False))
|
41
41
|
|
42
42
|
ic(
|
43
43
|
find_best_candidate(
|
@@ -0,0 +1,118 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from ._rust_stringdist import * # noqa: F403
|
4
|
+
from .default_ocr_distances import ocr_distance_map
|
5
|
+
|
6
|
+
|
7
|
+
def weighted_levenshtein_distance(
|
8
|
+
s1: str,
|
9
|
+
s2: str,
|
10
|
+
/,
|
11
|
+
substitution_costs: Optional[dict[tuple[str, str], float]] = None,
|
12
|
+
insertion_costs: Optional[dict[str, float]] = None,
|
13
|
+
deletion_costs: Optional[dict[str, float]] = None,
|
14
|
+
*,
|
15
|
+
symmetric_substitution: bool = True,
|
16
|
+
default_substitution_cost: float = 1.0,
|
17
|
+
default_insertion_cost: float = 1.0,
|
18
|
+
default_deletion_cost: float = 1.0,
|
19
|
+
) -> float:
|
20
|
+
"""
|
21
|
+
Levenshtein distance with custom substitution, insertion and deletion costs.
|
22
|
+
|
23
|
+
The default `substitution_costs` considers common OCR errors, see
|
24
|
+
:py:data:`ocr_stringdist.default_ocr_distances.ocr_distance_map`.
|
25
|
+
|
26
|
+
:param s1: First string (interpreted as the string read via OCR)
|
27
|
+
:param s2: Second string
|
28
|
+
:param substitution_costs: Dictionary mapping tuples of strings ("substitution tokens") to their
|
29
|
+
substitution costs. Only one direction needs to be configured unless
|
30
|
+
`symmetric_substitution` is False.
|
31
|
+
Note that the runtime scales in the length of the longest substitution token.
|
32
|
+
Defaults to `ocr_stringdist.ocr_distance_map`.
|
33
|
+
:param insertion_costs: Dictionary mapping strings to their insertion costs.
|
34
|
+
:param deletion_costs: Dictionary mapping strings to their deletion costs.
|
35
|
+
:param symmetric_substitution: Should the keys of `substitution_costs` be considered to be
|
36
|
+
symmetric? Defaults to True.
|
37
|
+
:param default_substitution_cost: The default substitution cost for character pairs not found
|
38
|
+
in `substitution_costs`.
|
39
|
+
:param default_insertion_cost: The default insertion cost for characters not found in
|
40
|
+
`insertion_costs`.
|
41
|
+
:param default_deletion_cost: The default deletion cost for characters not found in
|
42
|
+
`deletion_costs`.
|
43
|
+
"""
|
44
|
+
if substitution_costs is None:
|
45
|
+
substitution_costs = ocr_distance_map
|
46
|
+
if insertion_costs is None:
|
47
|
+
insertion_costs = {}
|
48
|
+
if deletion_costs is None:
|
49
|
+
deletion_costs = {}
|
50
|
+
# _weighted_levenshtein_distance is written in Rust, see src/rust_stringdist.rs.
|
51
|
+
return _weighted_levenshtein_distance( # type: ignore # noqa: F405
|
52
|
+
s1,
|
53
|
+
s2,
|
54
|
+
substitution_costs=substitution_costs,
|
55
|
+
insertion_costs=insertion_costs,
|
56
|
+
deletion_costs=deletion_costs,
|
57
|
+
symmetric_substitution=symmetric_substitution,
|
58
|
+
default_substitution_cost=default_substitution_cost,
|
59
|
+
default_insertion_cost=default_insertion_cost,
|
60
|
+
default_deletion_cost=default_deletion_cost,
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
def batch_weighted_levenshtein_distance(
|
65
|
+
s: str,
|
66
|
+
candidates: list[str],
|
67
|
+
/,
|
68
|
+
substitution_costs: Optional[dict[tuple[str, str], float]] = None,
|
69
|
+
insertion_costs: Optional[dict[str, float]] = None,
|
70
|
+
deletion_costs: Optional[dict[str, float]] = None,
|
71
|
+
*,
|
72
|
+
symmetric_substitution: bool = True,
|
73
|
+
default_substitution_cost: float = 1.0,
|
74
|
+
default_insertion_cost: float = 1.0,
|
75
|
+
default_deletion_cost: float = 1.0,
|
76
|
+
) -> list[float]:
|
77
|
+
"""
|
78
|
+
Calculate weighted Levenshtein distances between a string and multiple candidates.
|
79
|
+
|
80
|
+
This is more efficient than calling :func:`weighted_levenshtein_distance` multiple times.
|
81
|
+
|
82
|
+
:param s: The string to compare (interpreted as the string read via OCR)
|
83
|
+
:param candidates: List of candidate strings to compare against
|
84
|
+
:param substitution_costs: Dictionary mapping tuples of strings ("substitution tokens") to their
|
85
|
+
substitution costs. Only one direction needs to be configured unless
|
86
|
+
`symmetric_substitution` is False.
|
87
|
+
Note that the runtime scales in the length of the longest substitution token.
|
88
|
+
Defaults to `ocr_stringdist.ocr_distance_map`.
|
89
|
+
:param insertion_costs: Dictionary mapping strings to their insertion costs.
|
90
|
+
:param deletion_costs: Dictionary mapping strings to their deletion costs.
|
91
|
+
:param symmetric_substitution: Should the keys of `substitution_costs` be considered to be
|
92
|
+
symmetric? Defaults to True.
|
93
|
+
:param default_substitution_cost: The default substitution cost for character pairs not found
|
94
|
+
in `substitution_costs`.
|
95
|
+
:param default_insertion_cost: The default insertion cost for characters not found in
|
96
|
+
`insertion_costs`.
|
97
|
+
:param default_deletion_cost: The default deletion cost for characters not found in
|
98
|
+
`deletion_costs`.
|
99
|
+
:return: A list of distances corresponding to each candidate
|
100
|
+
"""
|
101
|
+
if substitution_costs is None:
|
102
|
+
substitution_costs = ocr_distance_map
|
103
|
+
if insertion_costs is None:
|
104
|
+
insertion_costs = {}
|
105
|
+
if deletion_costs is None:
|
106
|
+
deletion_costs = {}
|
107
|
+
# _batch_weighted_levenshtein_distance is written in Rust, see src/rust_stringdist.rs.
|
108
|
+
return _batch_weighted_levenshtein_distance( # type: ignore # noqa: F405
|
109
|
+
s,
|
110
|
+
candidates,
|
111
|
+
substitution_costs=substitution_costs,
|
112
|
+
insertion_costs=insertion_costs,
|
113
|
+
deletion_costs=deletion_costs,
|
114
|
+
symmetric_substitution=symmetric_substitution,
|
115
|
+
default_substitution_cost=default_substitution_cost,
|
116
|
+
default_insertion_cost=default_insertion_cost,
|
117
|
+
default_deletion_cost=default_deletion_cost,
|
118
|
+
)
|
@@ -0,0 +1,306 @@
|
|
1
|
+
use crate::types::{SingleTokenCostMap, SingleTokenKey, SubstitutionCostMap, SubstitutionKey};
|
2
|
+
use std::collections::HashMap;
|
3
|
+
use std::fmt::Debug;
|
4
|
+
use std::hash::Hash;
|
5
|
+
|
6
|
+
#[cfg(feature = "python")]
|
7
|
+
use pyo3::prelude::*;
|
8
|
+
|
9
|
+
/// A trait for cost map keys, allowing us to constrain the generic parameter
|
10
|
+
pub trait CostKey: Clone + Debug + Eq + Hash {}
|
11
|
+
|
12
|
+
// Implement the trait for both key types
|
13
|
+
impl CostKey for SingleTokenKey {}
|
14
|
+
impl CostKey for SubstitutionKey {}
|
15
|
+
|
16
|
+
/// Generic cost map structure that works with different key types
|
17
|
+
#[derive(Clone, Debug)]
|
18
|
+
pub struct CostMap<K: CostKey> {
|
19
|
+
/// The costs map
|
20
|
+
pub costs: HashMap<K, f64>,
|
21
|
+
/// Default cost for operations not found in the map
|
22
|
+
default_cost: f64,
|
23
|
+
/// Maximum token length in the map
|
24
|
+
pub max_token_length: usize,
|
25
|
+
}
|
26
|
+
|
27
|
+
impl<K: CostKey> Default for CostMap<K>
|
28
|
+
where
|
29
|
+
K: Default,
|
30
|
+
{
|
31
|
+
/// Creates a new CostMap with default values
|
32
|
+
fn default() -> Self {
|
33
|
+
Self {
|
34
|
+
costs: HashMap::new(),
|
35
|
+
default_cost: 1.0,
|
36
|
+
max_token_length: 1,
|
37
|
+
}
|
38
|
+
}
|
39
|
+
}
|
40
|
+
|
41
|
+
// Implementation for SubstitutionKey (pair of strings)
|
42
|
+
impl CostMap<SubstitutionKey> {
|
43
|
+
/// Creates a new substitution CostMap with specified costs.
|
44
|
+
/// Ensures symmetry by adding both (a, b) and (b, a) if only one is provided when symmetric is true.
|
45
|
+
pub fn new(
|
46
|
+
custom_costs_input: SubstitutionCostMap,
|
47
|
+
default_cost: f64,
|
48
|
+
symmetric: bool,
|
49
|
+
) -> Self {
|
50
|
+
let mut costs = HashMap::with_capacity(custom_costs_input.len() * 2);
|
51
|
+
let mut max_length = 1;
|
52
|
+
|
53
|
+
for ((s1, s2), cost) in custom_costs_input {
|
54
|
+
costs.entry((s1.clone(), s2.clone())).or_insert(cost);
|
55
|
+
if symmetric {
|
56
|
+
costs.entry((s2.clone(), s1.clone())).or_insert(cost);
|
57
|
+
}
|
58
|
+
|
59
|
+
// Update max token length
|
60
|
+
max_length = max_length.max(s1.chars().count()).max(s2.chars().count());
|
61
|
+
}
|
62
|
+
|
63
|
+
CostMap {
|
64
|
+
costs,
|
65
|
+
default_cost,
|
66
|
+
max_token_length: max_length,
|
67
|
+
}
|
68
|
+
}
|
69
|
+
|
70
|
+
/// Creates a new substitution CostMap with the specified custom costs.
|
71
|
+
/// Uses default values for other parameters.
|
72
|
+
pub fn with_costs(custom_costs: SubstitutionCostMap) -> Self {
|
73
|
+
Self::new(custom_costs, 1.0, true)
|
74
|
+
}
|
75
|
+
|
76
|
+
#[cfg(feature = "python")]
|
77
|
+
/// Creates a substitution CostMap from a Python dictionary.
|
78
|
+
/// This method is only available when the "python" feature is enabled.
|
79
|
+
pub fn from_py_dict<'a, D>(py_dict: &'a D, default_cost: f64, symmetric: bool) -> Self
|
80
|
+
where
|
81
|
+
D: PyDictMethods<'a>,
|
82
|
+
{
|
83
|
+
let mut substitution_costs = SubstitutionCostMap::new();
|
84
|
+
let mut max_length = 1;
|
85
|
+
|
86
|
+
// Convert Python dictionary to Rust HashMap
|
87
|
+
for (key, value) in py_dict.iter() {
|
88
|
+
if let Ok(key_tuple) = key.extract::<(String, String)>() {
|
89
|
+
if let Ok(cost) = value.extract::<f64>() {
|
90
|
+
substitution_costs.insert((key_tuple.0.clone(), key_tuple.1.clone()), cost);
|
91
|
+
|
92
|
+
// Update max token length
|
93
|
+
max_length = max_length
|
94
|
+
.max(key_tuple.0.chars().count())
|
95
|
+
.max(key_tuple.1.chars().count());
|
96
|
+
}
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
// Create the CostMap
|
101
|
+
Self::new(substitution_costs, default_cost, symmetric)
|
102
|
+
}
|
103
|
+
|
104
|
+
/// Gets the substitution cost between two strings.
|
105
|
+
pub fn get_cost(&self, s1: &str, s2: &str) -> f64 {
|
106
|
+
if s1 == s2 {
|
107
|
+
0.0 // No cost if strings are identical
|
108
|
+
} else {
|
109
|
+
let key_pair = (s1.to_string(), s2.to_string());
|
110
|
+
|
111
|
+
// Lookup the pair (symmetry is handled by storage in `new`)
|
112
|
+
// Use the map's configured default_cost as the fallback.
|
113
|
+
self.costs
|
114
|
+
.get(&key_pair)
|
115
|
+
.copied()
|
116
|
+
.unwrap_or(self.default_cost)
|
117
|
+
}
|
118
|
+
}
|
119
|
+
|
120
|
+
/// Checks if the cost map contains a specific substitution
|
121
|
+
pub fn has_key(&self, s1: &str, s2: &str) -> bool {
|
122
|
+
let key_pair = (s1.to_string(), s2.to_string());
|
123
|
+
self.costs.contains_key(&key_pair)
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
// Implementation for SingleTokenKey (single string)
|
128
|
+
impl CostMap<SingleTokenKey> {
|
129
|
+
/// Creates a new single token CostMap for insertion or deletion operations
|
130
|
+
pub fn new(custom_costs_input: SingleTokenCostMap, default_cost: f64) -> Self {
|
131
|
+
let mut max_length = 1;
|
132
|
+
|
133
|
+
// Calculate max token length
|
134
|
+
for key in custom_costs_input.keys() {
|
135
|
+
max_length = max_length.max(key.chars().count());
|
136
|
+
}
|
137
|
+
|
138
|
+
CostMap {
|
139
|
+
costs: custom_costs_input,
|
140
|
+
default_cost,
|
141
|
+
max_token_length: max_length,
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
/// Creates a new single token CostMap with the specified custom costs.
|
146
|
+
/// Uses default value for default cost.
|
147
|
+
pub fn with_costs(custom_costs: SingleTokenCostMap) -> Self {
|
148
|
+
Self::new(custom_costs, 1.0)
|
149
|
+
}
|
150
|
+
|
151
|
+
#[cfg(feature = "python")]
|
152
|
+
/// Creates a single token CostMap from a Python dictionary.
|
153
|
+
/// This method is only available when the "python" feature is enabled.
|
154
|
+
pub fn from_py_dict<'a, D>(py_dict: &'a D, default_cost: f64) -> Self
|
155
|
+
where
|
156
|
+
D: PyDictMethods<'a>,
|
157
|
+
{
|
158
|
+
let mut single_token_costs = SingleTokenCostMap::new();
|
159
|
+
let mut max_length = 1;
|
160
|
+
|
161
|
+
// Convert Python dictionary to Rust HashMap
|
162
|
+
for (key, value) in py_dict.iter() {
|
163
|
+
if let Ok(token) = key.extract::<String>() {
|
164
|
+
if let Ok(cost) = value.extract::<f64>() {
|
165
|
+
single_token_costs.insert(token.clone(), cost);
|
166
|
+
|
167
|
+
// Update max token length
|
168
|
+
max_length = max_length.max(token.chars().count());
|
169
|
+
}
|
170
|
+
}
|
171
|
+
}
|
172
|
+
|
173
|
+
// Create the CostMap
|
174
|
+
Self::new(single_token_costs, default_cost)
|
175
|
+
}
|
176
|
+
|
177
|
+
/// Gets the cost for a single token (insertion or deletion).
|
178
|
+
pub fn get_cost(&self, token: &str) -> f64 {
|
179
|
+
self.costs.get(token).copied().unwrap_or(self.default_cost)
|
180
|
+
}
|
181
|
+
|
182
|
+
/// Checks if the cost map contains a specific single token
|
183
|
+
pub fn has_key(&self, token: &str) -> bool {
|
184
|
+
self.costs.contains_key(token)
|
185
|
+
}
|
186
|
+
}
|
187
|
+
|
188
|
+
// Common methods for any type of CostMap
|
189
|
+
impl<K: CostKey> CostMap<K> {
|
190
|
+
/// Returns the default cost for this cost map
|
191
|
+
pub fn default_cost(&self) -> f64 {
|
192
|
+
self.default_cost
|
193
|
+
}
|
194
|
+
}
|
195
|
+
|
196
|
+
#[cfg(test)]
|
197
|
+
mod tests {
|
198
|
+
use super::*;
|
199
|
+
|
200
|
+
#[test]
|
201
|
+
fn test_single_token_map_default() {
|
202
|
+
// Test with default initialization
|
203
|
+
let cost_map: CostMap<SingleTokenKey> = CostMap::default();
|
204
|
+
assert_eq!(cost_map.default_cost(), 1.0);
|
205
|
+
assert_eq!(cost_map.get_cost("any_token"), 1.0);
|
206
|
+
assert!(!cost_map.has_key("any_token"));
|
207
|
+
}
|
208
|
+
|
209
|
+
#[test]
|
210
|
+
fn test_single_token_map_with_costs() {
|
211
|
+
let mut custom_costs = SingleTokenCostMap::new();
|
212
|
+
custom_costs.insert("a".to_string(), 0.5);
|
213
|
+
custom_costs.insert("b".to_string(), 0.8);
|
214
|
+
|
215
|
+
// Test with_costs constructor (default cost 1.0)
|
216
|
+
let cost_map = CostMap::<SingleTokenKey>::with_costs(custom_costs);
|
217
|
+
|
218
|
+
// Test getting costs for tokens
|
219
|
+
assert_eq!(cost_map.get_cost("a"), 0.5);
|
220
|
+
assert_eq!(cost_map.get_cost("b"), 0.8);
|
221
|
+
assert_eq!(cost_map.get_cost("c"), 1.0); // Default cost
|
222
|
+
|
223
|
+
// Test has_key
|
224
|
+
assert!(cost_map.has_key("a"));
|
225
|
+
assert!(cost_map.has_key("b"));
|
226
|
+
assert!(!cost_map.has_key("c"));
|
227
|
+
}
|
228
|
+
|
229
|
+
#[test]
|
230
|
+
fn test_single_token_map_with_custom_default() {
|
231
|
+
let mut custom_costs = SingleTokenCostMap::new();
|
232
|
+
custom_costs.insert("test".to_string(), 0.3);
|
233
|
+
|
234
|
+
// Test new constructor with custom default cost
|
235
|
+
let cost_map = CostMap::<SingleTokenKey>::new(custom_costs, 2.0);
|
236
|
+
|
237
|
+
assert_eq!(cost_map.default_cost(), 2.0);
|
238
|
+
assert_eq!(cost_map.get_cost("test"), 0.3);
|
239
|
+
assert_eq!(cost_map.get_cost("unknown"), 2.0);
|
240
|
+
}
|
241
|
+
|
242
|
+
#[test]
|
243
|
+
fn test_substitution_map_default() {
|
244
|
+
let cost_map: CostMap<SubstitutionKey> = CostMap {
|
245
|
+
costs: HashMap::new(),
|
246
|
+
default_cost: 1.0,
|
247
|
+
max_token_length: 1,
|
248
|
+
};
|
249
|
+
|
250
|
+
assert_eq!(cost_map.default_cost(), 1.0);
|
251
|
+
assert_eq!(cost_map.get_cost("a", "b"), 1.0);
|
252
|
+
assert!(!cost_map.has_key("a", "b"));
|
253
|
+
}
|
254
|
+
|
255
|
+
#[test]
|
256
|
+
fn test_substitution_map_with_costs() {
|
257
|
+
let mut custom_costs = SubstitutionCostMap::new();
|
258
|
+
custom_costs.insert(("0".to_string(), "o".to_string()), 0.2);
|
259
|
+
custom_costs.insert(("l".to_string(), "1".to_string()), 0.3);
|
260
|
+
|
261
|
+
// Test with_costs constructor (symmetric by default)
|
262
|
+
let cost_map = CostMap::<SubstitutionKey>::with_costs(custom_costs);
|
263
|
+
|
264
|
+
// Test getting costs
|
265
|
+
assert_eq!(cost_map.get_cost("0", "o"), 0.2);
|
266
|
+
assert_eq!(cost_map.get_cost("o", "0"), 0.2); // Symmetry check
|
267
|
+
assert_eq!(cost_map.get_cost("l", "1"), 0.3);
|
268
|
+
assert_eq!(cost_map.get_cost("1", "l"), 0.3); // Symmetry check
|
269
|
+
assert_eq!(cost_map.get_cost("a", "b"), 1.0); // Default
|
270
|
+
|
271
|
+
// Test same character
|
272
|
+
assert_eq!(cost_map.get_cost("a", "a"), 0.0); // Same char = 0 cost
|
273
|
+
|
274
|
+
// Test has_key
|
275
|
+
assert!(cost_map.has_key("0", "o"));
|
276
|
+
assert!(cost_map.has_key("o", "0")); // Symmetry check
|
277
|
+
assert!(!cost_map.has_key("a", "b"));
|
278
|
+
}
|
279
|
+
|
280
|
+
#[test]
|
281
|
+
fn test_substitution_map_asymmetric() {
|
282
|
+
let mut custom_costs = SubstitutionCostMap::new();
|
283
|
+
custom_costs.insert(("a".to_string(), "b".to_string()), 0.4);
|
284
|
+
|
285
|
+
// Create with symmetric=false
|
286
|
+
let cost_map = CostMap::<SubstitutionKey>::new(custom_costs, 1.5, false);
|
287
|
+
|
288
|
+
// Test asymmetry
|
289
|
+
assert_eq!(cost_map.get_cost("a", "b"), 0.4);
|
290
|
+
assert_eq!(cost_map.get_cost("b", "a"), 1.5); // Should be default cost
|
291
|
+
|
292
|
+
assert!(cost_map.has_key("a", "b"));
|
293
|
+
assert!(!cost_map.has_key("b", "a")); // Should not exist
|
294
|
+
}
|
295
|
+
|
296
|
+
#[test]
|
297
|
+
fn test_default_cost_accessor() {
|
298
|
+
// Test for SubstitutionKey
|
299
|
+
let sub_map = CostMap::<SubstitutionKey>::new(HashMap::new(), 2.5, true);
|
300
|
+
assert_eq!(sub_map.default_cost(), 2.5);
|
301
|
+
|
302
|
+
// Test for SingleTokenKey
|
303
|
+
let single_map = CostMap::<SingleTokenKey>::new(HashMap::new(), 3.0);
|
304
|
+
assert_eq!(single_map.default_cost(), 3.0);
|
305
|
+
}
|
306
|
+
}
|
@@ -0,0 +1,12 @@
|
|
1
|
+
mod cost_map;
|
2
|
+
mod types;
|
3
|
+
mod weighted_levenshtein;
|
4
|
+
|
5
|
+
pub use cost_map::CostMap;
|
6
|
+
pub use types::*;
|
7
|
+
pub use weighted_levenshtein::custom_levenshtein_distance_with_cost_maps;
|
8
|
+
|
9
|
+
#[cfg(feature = "python")]
|
10
|
+
mod rust_stringdist;
|
11
|
+
#[cfg(feature = "python")]
|
12
|
+
pub use rust_stringdist::_rust_stringdist;
|