nomenklatura-mpt 4.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. nomenklatura/__init__.py +11 -0
  2. nomenklatura/cache.py +194 -0
  3. nomenklatura/cli.py +260 -0
  4. nomenklatura/conflicting_match.py +80 -0
  5. nomenklatura/data/er-unstable.pkl +0 -0
  6. nomenklatura/data/regression-v1.pkl +0 -0
  7. nomenklatura/db.py +139 -0
  8. nomenklatura/delta.py +4 -0
  9. nomenklatura/enrich/__init__.py +94 -0
  10. nomenklatura/enrich/aleph.py +141 -0
  11. nomenklatura/enrich/common.py +219 -0
  12. nomenklatura/enrich/nominatim.py +72 -0
  13. nomenklatura/enrich/opencorporates.py +233 -0
  14. nomenklatura/enrich/openfigi.py +124 -0
  15. nomenklatura/enrich/permid.py +201 -0
  16. nomenklatura/enrich/wikidata.py +268 -0
  17. nomenklatura/enrich/yente.py +116 -0
  18. nomenklatura/exceptions.py +9 -0
  19. nomenklatura/index/__init__.py +5 -0
  20. nomenklatura/index/common.py +24 -0
  21. nomenklatura/index/entry.py +89 -0
  22. nomenklatura/index/index.py +170 -0
  23. nomenklatura/index/tokenizer.py +92 -0
  24. nomenklatura/judgement.py +21 -0
  25. nomenklatura/kv.py +40 -0
  26. nomenklatura/matching/__init__.py +47 -0
  27. nomenklatura/matching/bench.py +32 -0
  28. nomenklatura/matching/compare/__init__.py +0 -0
  29. nomenklatura/matching/compare/addresses.py +71 -0
  30. nomenklatura/matching/compare/countries.py +15 -0
  31. nomenklatura/matching/compare/dates.py +83 -0
  32. nomenklatura/matching/compare/gender.py +15 -0
  33. nomenklatura/matching/compare/identifiers.py +30 -0
  34. nomenklatura/matching/compare/names.py +157 -0
  35. nomenklatura/matching/compare/util.py +51 -0
  36. nomenklatura/matching/compat.py +66 -0
  37. nomenklatura/matching/erun/__init__.py +0 -0
  38. nomenklatura/matching/erun/countries.py +42 -0
  39. nomenklatura/matching/erun/identifiers.py +64 -0
  40. nomenklatura/matching/erun/misc.py +71 -0
  41. nomenklatura/matching/erun/model.py +110 -0
  42. nomenklatura/matching/erun/names.py +126 -0
  43. nomenklatura/matching/erun/train.py +135 -0
  44. nomenklatura/matching/erun/util.py +28 -0
  45. nomenklatura/matching/logic_v1/__init__.py +0 -0
  46. nomenklatura/matching/logic_v1/identifiers.py +104 -0
  47. nomenklatura/matching/logic_v1/model.py +76 -0
  48. nomenklatura/matching/logic_v1/multi.py +21 -0
  49. nomenklatura/matching/logic_v1/phonetic.py +142 -0
  50. nomenklatura/matching/logic_v2/__init__.py +0 -0
  51. nomenklatura/matching/logic_v2/identifiers.py +124 -0
  52. nomenklatura/matching/logic_v2/model.py +98 -0
  53. nomenklatura/matching/logic_v2/names/__init__.py +3 -0
  54. nomenklatura/matching/logic_v2/names/analysis.py +51 -0
  55. nomenklatura/matching/logic_v2/names/distance.py +181 -0
  56. nomenklatura/matching/logic_v2/names/magic.py +60 -0
  57. nomenklatura/matching/logic_v2/names/match.py +195 -0
  58. nomenklatura/matching/logic_v2/names/pairing.py +81 -0
  59. nomenklatura/matching/logic_v2/names/util.py +89 -0
  60. nomenklatura/matching/name_based/__init__.py +4 -0
  61. nomenklatura/matching/name_based/misc.py +86 -0
  62. nomenklatura/matching/name_based/model.py +59 -0
  63. nomenklatura/matching/name_based/names.py +59 -0
  64. nomenklatura/matching/pairs.py +42 -0
  65. nomenklatura/matching/regression_v1/__init__.py +0 -0
  66. nomenklatura/matching/regression_v1/misc.py +75 -0
  67. nomenklatura/matching/regression_v1/model.py +110 -0
  68. nomenklatura/matching/regression_v1/names.py +63 -0
  69. nomenklatura/matching/regression_v1/train.py +87 -0
  70. nomenklatura/matching/regression_v1/util.py +31 -0
  71. nomenklatura/matching/svm_v1/__init__.py +5 -0
  72. nomenklatura/matching/svm_v1/misc.py +94 -0
  73. nomenklatura/matching/svm_v1/model.py +168 -0
  74. nomenklatura/matching/svm_v1/names.py +81 -0
  75. nomenklatura/matching/svm_v1/train.py +186 -0
  76. nomenklatura/matching/svm_v1/util.py +30 -0
  77. nomenklatura/matching/types.py +227 -0
  78. nomenklatura/matching/util.py +62 -0
  79. nomenklatura/publish/__init__.py +0 -0
  80. nomenklatura/publish/dates.py +49 -0
  81. nomenklatura/publish/edges.py +32 -0
  82. nomenklatura/py.typed +0 -0
  83. nomenklatura/resolver/__init__.py +6 -0
  84. nomenklatura/resolver/common.py +2 -0
  85. nomenklatura/resolver/edge.py +107 -0
  86. nomenklatura/resolver/identifier.py +60 -0
  87. nomenklatura/resolver/linker.py +101 -0
  88. nomenklatura/resolver/resolver.py +565 -0
  89. nomenklatura/settings.py +17 -0
  90. nomenklatura/store/__init__.py +41 -0
  91. nomenklatura/store/base.py +130 -0
  92. nomenklatura/store/level.py +272 -0
  93. nomenklatura/store/memory.py +102 -0
  94. nomenklatura/store/redis_.py +131 -0
  95. nomenklatura/store/sql.py +219 -0
  96. nomenklatura/store/util.py +48 -0
  97. nomenklatura/store/versioned.py +371 -0
  98. nomenklatura/tui/__init__.py +17 -0
  99. nomenklatura/tui/app.py +294 -0
  100. nomenklatura/tui/app.tcss +52 -0
  101. nomenklatura/tui/comparison.py +81 -0
  102. nomenklatura/tui/util.py +35 -0
  103. nomenklatura/util.py +26 -0
  104. nomenklatura/versions.py +119 -0
  105. nomenklatura/wikidata/__init__.py +14 -0
  106. nomenklatura/wikidata/client.py +122 -0
  107. nomenklatura/wikidata/lang.py +94 -0
  108. nomenklatura/wikidata/model.py +139 -0
  109. nomenklatura/wikidata/props.py +70 -0
  110. nomenklatura/wikidata/qualified.py +49 -0
  111. nomenklatura/wikidata/query.py +66 -0
  112. nomenklatura/wikidata/value.py +87 -0
  113. nomenklatura/xref.py +125 -0
  114. nomenklatura_mpt-4.1.9.dist-info/METADATA +159 -0
  115. nomenklatura_mpt-4.1.9.dist-info/RECORD +118 -0
  116. nomenklatura_mpt-4.1.9.dist-info/WHEEL +4 -0
  117. nomenklatura_mpt-4.1.9.dist-info/entry_points.txt +3 -0
  118. nomenklatura_mpt-4.1.9.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,94 @@
1
+ from followthemoney.proxy import E
2
+ from followthemoney.types import registry
3
+
4
+ from .util import tokenize_pair, compare_levenshtein
5
+ from nomenklatura.matching.compare.util import has_overlap, extract_numbers, is_disjoint
6
+ from nomenklatura.matching.util import props_pair, type_pair
7
+ from nomenklatura.matching.util import max_in_sets, has_schema
8
+ from nomenklatura.matching.compat import clean_name_ascii
9
+
10
+
11
+ def birth_place(query: E, result: E) -> float:
12
+ """Same place of birth."""
13
+ lv, rv = tokenize_pair(props_pair(query, result, ["birthPlace"]))
14
+ tokens = min(len(lv), len(rv))
15
+ return float(len(lv.intersection(rv))) / float(max(2.0, tokens))
16
+
17
+
18
+ def address_match(query: E, result: E) -> float:
19
+ """Text similarity between addresses."""
20
+ lv, rv = type_pair(query, result, registry.address)
21
+ lvn = [clean_name_ascii(v) for v in lv if clean_name_ascii(v) is not None]
22
+ rvn = [clean_name_ascii(v) for v in rv if clean_name_ascii(v) is not None]
23
+ return max_in_sets(lvn, rvn, compare_levenshtein)
24
+
25
+
26
+ def address_numbers(query: E, result: E) -> float:
27
+ """Find if names contain numbers, score if the numbers are different."""
28
+ lv, rv = type_pair(query, result, registry.address)
29
+ lvn = extract_numbers(lv)
30
+ rvn = extract_numbers(rv)
31
+ common = len(lvn.intersection(rvn))
32
+ disjoint = len(lvn.difference(rvn))
33
+ return common - disjoint
34
+
35
+
36
+ def phone_match(query: E, result: E) -> float:
37
+ """Matching phone numbers between the two entities."""
38
+ lv, rv = type_pair(query, result, registry.phone)
39
+ return 1.0 if has_overlap(lv, rv) else 0.0
40
+
41
+
42
+ def email_match(query: E, result: E) -> float:
43
+ """Matching email addresses between the two entities."""
44
+ lv, rv = type_pair(query, result, registry.email)
45
+ return 1.0 if has_overlap(lv, rv) else 0.0
46
+
47
+
48
+ def identifier_match(query: E, result: E) -> float:
49
+ """Matching identifiers (e.g. passports, national ID cards, registration or
50
+ tax numbers) between the two entities."""
51
+ if has_schema(query, result, "Organization"):
52
+ return 0.0
53
+ lv, rv = type_pair(query, result, registry.identifier)
54
+ return 1.0 if has_overlap(lv, rv) else 0.0
55
+
56
+
57
+ def org_identifier_match(query: E, result: E) -> float:
58
+ """Matching identifiers (e.g. registration or tax numbers) between two
59
+ organizations or companies."""
60
+ if not has_schema(query, result, "Organization"):
61
+ return 0.0
62
+ lv, rv = type_pair(query, result, registry.identifier)
63
+ return 1.0 if has_overlap(lv, rv) else 0.0
64
+
65
+
66
+ def gender_mismatch(query: E, result: E) -> float:
67
+ """Both entities have a different gender associated with them."""
68
+ qv, rv = props_pair(query, result, ["gender"])
69
+ return 1.0 if is_disjoint(qv, rv) else 0.0
70
+
71
+
72
+ def country_mismatch(query: E, result: E) -> float:
73
+ """Both entities are linked to different countries."""
74
+ qv, rv = type_pair(query, result, registry.country)
75
+ return 1.0 if is_disjoint(qv, rv) else 0.0
76
+
77
+
78
+ def schema_match(query: E, result: E) -> float:
79
+ """Both entities have the same schema."""
80
+ return 1.0 if query.schema == result.schema else 0.0
81
+
82
+
83
+ def property_count_similarity(query: E, result: E) -> float:
84
+ """Similarity in the number of properties."""
85
+ query_props = len([p for p in query.itervalues()])
86
+ result_props = len([p for p in result.itervalues()])
87
+
88
+ if query_props == 0 and result_props == 0:
89
+ return 1.0
90
+
91
+ max_props = max(query_props, result_props)
92
+ min_props = min(query_props, result_props)
93
+
94
+ return float(min_props) / float(max_props) if max_props > 0 else 0.0
@@ -0,0 +1,168 @@
1
+ import pickle
2
+ import numpy as np
3
+ from typing import List, Dict, Tuple, cast, Optional
4
+ from functools import lru_cache as cache
5
+ from sklearn.pipeline import Pipeline # type: ignore
6
+ from sklearn.svm import SVC # type: ignore
7
+ from sklearn.preprocessing import StandardScaler # type: ignore
8
+ from followthemoney.proxy import E
9
+
10
+ from .names import first_name_match
11
+ from .names import family_name_match
12
+ from .names import name_levenshtein, name_match
13
+ from .names import name_token_overlap, name_numbers
14
+ from .names import name_length_similarity
15
+ from .misc import phone_match, email_match
16
+ from .misc import address_match, address_numbers
17
+ from .misc import identifier_match, birth_place
18
+ from .misc import org_identifier_match
19
+ from .misc import gender_mismatch
20
+ from .misc import country_mismatch
21
+ from .misc import schema_match, property_count_similarity
22
+ from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches
23
+ from nomenklatura.matching.compare.dates import dob_year_disjoint
24
+ from nomenklatura.matching.types import (
25
+ FeatureDocs,
26
+ FeatureDoc,
27
+ MatchingResult,
28
+ ScoringConfig,
29
+ )
30
+ from nomenklatura.matching.types import CompareFunction, FtResult
31
+ from nomenklatura.matching.types import Encoded, ScoringAlgorithm
32
+ from nomenklatura.matching.util import make_github_url
33
+ from nomenklatura.util import DATA_PATH
34
+
35
+
36
+ class SVMV1(ScoringAlgorithm):
37
+ """A Support Vector Machine-based matching algorithm with RBF kernel."""
38
+
39
+ NAME = "svm-v1"
40
+ MODEL_PATH = DATA_PATH.joinpath(f"{NAME}.pkl")
41
+ FEATURES: List[CompareFunction] = [
42
+ name_match,
43
+ name_token_overlap,
44
+ name_numbers,
45
+ name_levenshtein,
46
+ name_length_similarity,
47
+ phone_match,
48
+ email_match,
49
+ identifier_match,
50
+ dob_matches,
51
+ dob_year_matches,
52
+ FtResult.unwrap(dob_year_disjoint),
53
+ first_name_match,
54
+ family_name_match,
55
+ birth_place,
56
+ gender_mismatch,
57
+ country_mismatch,
58
+ org_identifier_match,
59
+ address_match,
60
+ address_numbers,
61
+ schema_match,
62
+ property_count_similarity,
63
+ ]
64
+
65
+ @classmethod
66
+ def save(cls,
67
+ pipeline: Pipeline,
68
+ coefficients: Dict[str, float],
69
+ kernel: str = "rbf",
70
+ support_vectors_count: Optional[int] = None) -> None:
71
+ """Store the SVM model after training."""
72
+ mdl = pickle.dumps({
73
+ "pipeline": pipeline,
74
+ "coefficients": coefficients,
75
+ "kernel": kernel,
76
+ "support_vectors_count": support_vectors_count
77
+ })
78
+ with open(cls.MODEL_PATH, "wb") as fh:
79
+ fh.write(mdl)
80
+ cls.load.cache_clear()
81
+
82
+ @classmethod
83
+ @cache
84
+ def load(cls) -> Tuple[Pipeline, Dict[str, float], str, Optional[int]]:
85
+ """Load a pre-trained SVM model for ad-hoc use."""
86
+ with open(cls.MODEL_PATH, "rb") as fh:
87
+ model_data = pickle.loads(fh.read())
88
+
89
+ pipeline = cast(Pipeline, model_data["pipeline"])
90
+ coefficients = cast(Dict[str, float], model_data["coefficients"])
91
+ kernel = cast(str, model_data.get("kernel", "rbf"))
92
+ support_vectors_count = cast(Optional[int], model_data.get("support_vectors_count"))
93
+
94
+ current = [f.__name__ for f in cls.FEATURES]
95
+ if list(coefficients.keys()) != current:
96
+ raise RuntimeError("Model was not trained on identical features!")
97
+
98
+ return pipeline, coefficients, kernel, support_vectors_count
99
+
100
+ @classmethod
101
+ def get_feature_docs(cls) -> FeatureDocs:
102
+ """Return an explanation of the features and their coefficients."""
103
+ features: FeatureDocs = {}
104
+ _, coefficients, _, _ = cls.load()
105
+ for func in cls.FEATURES:
106
+ name = func.__name__
107
+ features[name] = FeatureDoc(
108
+ description=func.__doc__,
109
+ coefficient=float(coefficients.get(name, 0.0)),
110
+ url=make_github_url(func),
111
+ )
112
+ return features
113
+
114
+ @classmethod
115
+ def compare(cls, query: E, result: E, config: ScoringConfig) -> MatchingResult:
116
+ """Use an SVM model to compare two entities."""
117
+ pipeline, coefficients, kernel, _ = cls.load()
118
+
119
+ encoded = cls.encode_pair(query, result)
120
+ npfeat = np.array([encoded])
121
+
122
+ # Get probability predictions
123
+ if hasattr(pipeline.named_steps['svc'], 'predict_proba'):
124
+ pred_proba = pipeline.predict_proba(npfeat)
125
+ score = float(pred_proba[0][1]) # Positive class probability
126
+ else:
127
+ # Fallback to decision function if predict_proba is not available
128
+ decision = pipeline.decision_function(npfeat)
129
+ # Convert decision function to probability-like score
130
+ score = float(1.0 / (1.0 + np.exp(-decision[0]))) # Sigmoid transformation
131
+
132
+ # Create explanations
133
+ explanations: Dict[str, FtResult] = {}
134
+ for feature, coeff in zip(cls.FEATURES, encoded):
135
+ name = feature.__name__
136
+ explanations[name] = FtResult(score=float(coeff), detail=None)
137
+
138
+ # Add model-specific information
139
+ explanations["svm_kernel"] = FtResult(score=0.0, detail=kernel)
140
+
141
+ return MatchingResult.make(score=score, explanations=explanations)
142
+
143
+ @classmethod
144
+ def encode_pair(cls, left: E, right: E) -> Encoded:
145
+ """Encode the comparison between two entities as a set of feature values."""
146
+ return [f(left, right) for f in cls.FEATURES]
147
+
148
+ @classmethod
149
+ def get_decision_function(cls, features: np.ndarray) -> np.ndarray:
150
+ """Get the SVM decision function values."""
151
+ pipeline, _, _, _ = cls.load()
152
+ return pipeline.decision_function(features)
153
+
154
+ @classmethod
155
+ def get_support_vectors_info(cls) -> Dict[str, any]:
156
+ """Get information about the support vectors."""
157
+ pipeline, _, kernel, sv_count = cls.load()
158
+ svm = pipeline.named_steps['svc']
159
+
160
+ info = {
161
+ "kernel": kernel,
162
+ "support_vectors_count": sv_count or len(svm.support_vectors_) if hasattr(svm, 'support_vectors_') else None,
163
+ "n_support": svm.n_support_.tolist() if hasattr(svm, 'n_support_') else None,
164
+ "gamma": svm.gamma if hasattr(svm, 'gamma') else None,
165
+ "C": svm.C if hasattr(svm, 'C') else None,
166
+ }
167
+
168
+ return info
@@ -0,0 +1,81 @@
1
+ from typing import Iterable, Set
2
+ from followthemoney.proxy import E
3
+ from followthemoney.types import registry
4
+
5
+ from .util import tokenize_pair, compare_levenshtein
6
+ from nomenklatura.matching.compare.util import is_disjoint, has_overlap, extract_numbers
7
+ from nomenklatura.matching.util import props_pair, type_pair
8
+ from nomenklatura.matching.util import max_in_sets
9
+ from nomenklatura.matching.compare.names import fingerprint_name
10
+
11
+
12
+ def normalize_names(raws: Iterable[str]) -> Set[str]:
13
+ names = set()
14
+ for raw in raws:
15
+ name = fingerprint_name(raw)
16
+ if name is not None:
17
+ names.add(name[:128])
18
+ return names
19
+
20
+
21
+ def name_levenshtein(left: E, right: E) -> float:
22
+ """Consider the edit distance (as a fraction of name length) between the two most
23
+ similar names linked to both entities."""
24
+ lv, rv = type_pair(left, right, registry.name)
25
+ lvn, rvn = normalize_names(lv), normalize_names(rv)
26
+ return max_in_sets(lvn, rvn, compare_levenshtein)
27
+
28
+
29
+ def first_name_match(left: E, right: E) -> float:
30
+ """Matching first/given name between the two entities."""
31
+ lv, rv = tokenize_pair(props_pair(left, right, ["firstName"]))
32
+ return 1.0 if has_overlap(lv, rv) else 0.0
33
+
34
+
35
+ def family_name_match(left: E, right: E) -> float:
36
+ """Matching family name between the two entities."""
37
+ lv, rv = tokenize_pair(props_pair(left, right, ["lastName"]))
38
+ return 1.0 if has_overlap(lv, rv) else 0.0
39
+
40
+
41
+ def name_match(left: E, right: E) -> float:
42
+ """Check for exact name matches between the two entities."""
43
+ lv, rv = type_pair(left, right, registry.name)
44
+ lvn, rvn = normalize_names(lv), normalize_names(rv)
45
+ common = [len(n) for n in lvn.intersection(rvn)]
46
+ max_common = max(common, default=0)
47
+ if max_common == 0:
48
+ return 0.0
49
+ return float(max_common)
50
+
51
+
52
+ def name_token_overlap(left: E, right: E) -> float:
53
+ """Evaluate the proportion of identical words in each name."""
54
+ lv, rv = tokenize_pair(type_pair(left, right, registry.name))
55
+ common = lv.intersection(rv)
56
+ tokens = min(len(lv), len(rv))
57
+ return float(len(common)) / float(max(2.0, tokens))
58
+
59
+
60
+ def name_numbers(left: E, right: E) -> float:
61
+ """Find if names contain numbers, score if the numbers are different."""
62
+ lv, rv = type_pair(left, right, registry.name)
63
+ return 1.0 if is_disjoint(extract_numbers(lv), extract_numbers(rv)) else 0.0
64
+
65
+
66
+ def name_length_similarity(left: E, right: E) -> float:
67
+ """Similarity in name lengths."""
68
+ lv, rv = type_pair(left, right, registry.name)
69
+ if not lv or not rv:
70
+ return 0.0
71
+
72
+ max_left = max(len(name) for name in lv)
73
+ max_right = max(len(name) for name in rv)
74
+
75
+ if max_left == 0 and max_right == 0:
76
+ return 1.0
77
+
78
+ max_len = max(max_left, max_right)
79
+ min_len = min(max_left, max_right)
80
+
81
+ return float(min_len) / float(max_len) if max_len > 0 else 0.0
@@ -0,0 +1,186 @@
1
+ import logging
2
+ import numpy as np
3
+ import multiprocessing
4
+ from typing import Iterable, List, Tuple, Optional
5
+ from pprint import pprint
6
+ from numpy.typing import NDArray
7
+ from sklearn.pipeline import make_pipeline # type: ignore
8
+ from sklearn.preprocessing import StandardScaler # type: ignore
9
+ from sklearn.model_selection import train_test_split, GridSearchCV # type: ignore
10
+ from sklearn.svm import SVC # type: ignore
11
+ from sklearn import metrics # type: ignore
12
+ from concurrent.futures import ProcessPoolExecutor
13
+ from followthemoney.util import PathLike
14
+
15
+ from nomenklatura.judgement import Judgement
16
+ from nomenklatura.matching.pairs import read_pairs, JudgedPair
17
+ from .model import SVMV1
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ def pair_convert(pair: JudgedPair) -> Tuple[List[float], int]:
23
+ """Encode a pair of training data into features and target."""
24
+ judgement = 1 if pair.judgement == Judgement.POSITIVE else 0
25
+ features = SVMV1.encode_pair(pair.left, pair.right)
26
+ return features, judgement
27
+
28
+
29
+ def pairs_to_arrays(
30
+ pairs: Iterable[JudgedPair],
31
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
32
+ """Parallelize feature computation for training data"""
33
+ xrows = []
34
+ yrows = []
35
+ workers = multiprocessing.cpu_count()
36
+ log.info("Using %d processes for feature computation...", workers)
37
+ with ProcessPoolExecutor(max_workers=workers) as excecutor:
38
+ results = excecutor.map(pair_convert, pairs)
39
+ for idx, (x, y) in enumerate(results):
40
+ if idx > 0 and idx % 10000 == 0:
41
+ log.info("Computing features: %s....", idx)
42
+ xrows.append(x)
43
+ yrows.append(y)
44
+
45
+ return np.array(xrows), np.array(yrows)
46
+
47
+
48
+ def train_svm_matcher(
49
+ pairs_file: PathLike,
50
+ kernel: str = "rbf",
51
+ optimize_hyperparameters: bool = True,
52
+ probability: bool = True
53
+ ) :
54
+ """Train an SVM matching model."""
55
+ pairs = []
56
+ for pair in read_pairs(pairs_file):
57
+ if pair.judgement == Judgement.UNSURE:
58
+ pair.judgement = Judgement.NEGATIVE
59
+ pairs.append(pair)
60
+
61
+ positive = len([p for p in pairs if p.judgement == Judgement.POSITIVE])
62
+ negative = len([p for p in pairs if p.judgement == Judgement.NEGATIVE])
63
+ log.info("Total pairs loaded: %d (%d pos/%d neg)", len(pairs), positive, negative)
64
+
65
+ X, y = pairs_to_arrays(pairs)
66
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
67
+
68
+ log.info("Training SVM model with %s kernel...", kernel)
69
+
70
+ # Create base SVM
71
+ if optimize_hyperparameters:
72
+ log.info("Optimizing hyperparameters with Grid Search...")
73
+
74
+ # Define parameter grid
75
+ if kernel == "rbf":
76
+ param_grid = {
77
+ 'svc__C': [0.1, 1, 10, 100],
78
+ 'svc__gamma': ['scale', 'auto', 0.001, 0.01, 0.1, 1]
79
+ }
80
+ elif kernel == "linear":
81
+ param_grid = {
82
+ 'svc__C': [0.1, 1, 10, 100]
83
+ }
84
+ elif kernel == "poly":
85
+ param_grid = {
86
+ 'svc__C': [0.1, 1, 10, 100],
87
+ 'svc__degree': [2, 3, 4],
88
+ 'svc__gamma': ['scale', 'auto', 0.001, 0.01, 0.1]
89
+ }
90
+ else:
91
+ param_grid = {'svc__C': [0.1, 1, 10, 100]}
92
+
93
+ # Create pipeline
94
+ base_svm = SVC(kernel=kernel, probability=probability, random_state=42)
95
+ pipeline = make_pipeline(StandardScaler(), base_svm)
96
+
97
+ # Grid search
98
+ grid_search = GridSearchCV(
99
+ pipeline,
100
+ param_grid,
101
+ cv=5,
102
+ scoring='roc_auc',
103
+ n_jobs=-1,
104
+ verbose=1
105
+ )
106
+
107
+ grid_search.fit(X_train, y_train)
108
+ pipeline = grid_search.best_estimator_
109
+
110
+ log.info("Best parameters: %s", grid_search.best_params_)
111
+ log.info("Best cross-validation score: %.4f", grid_search.best_score_)
112
+
113
+ else:
114
+ # Use default parameters
115
+ svm = SVC(kernel=kernel, probability=probability, random_state=42)
116
+ pipeline = make_pipeline(StandardScaler(), svm)
117
+ pipeline.fit(X_train, y_train)
118
+
119
+ # Get SVM object for coefficient extraction
120
+ svm_model = pipeline.named_steps['svc']
121
+
122
+ # Create feature coefficients (for linear kernel, use actual coefficients)
123
+ coefficients = {}
124
+ if kernel == "linear" and hasattr(svm_model, 'coef_'):
125
+ # For linear SVM, we have actual feature coefficients
126
+ coef_values = svm_model.coef_[0]
127
+ for i, feature in enumerate(SVMV1.FEATURES):
128
+ coefficients[feature.__name__] = float(coef_values[i])
129
+ else:
130
+ # For non-linear kernels, use feature importance approximation
131
+ # This is a simplified approach - in practice, you might use SHAP or similar
132
+ for feature in SVMV1.FEATURES:
133
+ coefficients[feature.__name__] = 1.0 # Placeholder
134
+
135
+ # Get support vector count
136
+ support_vectors_count = len(svm_model.support_vectors_) if hasattr(svm_model, 'support_vectors_') else None
137
+
138
+ # Save the model
139
+ SVMV1.save(pipeline, coefficients, kernel, support_vectors_count)
140
+ path = SVMV1.MODEL_PATH
141
+ # Evaluation
142
+ log.info("Evaluating SVM model...")
143
+
144
+ # Predictions
145
+ y_pred = pipeline.predict(X_test)
146
+
147
+ if probability:
148
+ y_pred_proba = pipeline.predict_proba(X_test)[:, 1]
149
+ else:
150
+ # Use decision function and apply sigmoid
151
+ decision_scores = pipeline.decision_function(X_test)
152
+ y_pred_proba = 1.0 / (1.0 + np.exp(-decision_scores))
153
+
154
+ # Print results
155
+ print("\nSVM Results:")
156
+ print(f"Kernel: {kernel}")
157
+ print(f"Support vectors count: {support_vectors_count}")
158
+
159
+ if kernel == "linear":
160
+ print("Feature Coefficients:")
161
+ pprint(coefficients)
162
+
163
+ cnf_matrix = metrics.confusion_matrix(y_test, y_pred)
164
+ print("Confusion matrix:\n", cnf_matrix)
165
+ print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
166
+ print("Precision:", metrics.precision_score(y_test, y_pred))
167
+ print("Recall:", metrics.recall_score(y_test, y_pred))
168
+ print("F1-score:", metrics.f1_score(y_test, y_pred))
169
+
170
+ auc = metrics.roc_auc_score(y_test, y_pred_proba)
171
+ print("Area under curve:", auc)
172
+
173
+ # Additional SVM-specific metrics
174
+ if hasattr(svm_model, 'support_'):
175
+ print(f"Number of support vectors: {len(svm_model.support_)}")
176
+ print(f"Support vectors per class: {svm_model.n_support_}")
177
+ return SVMV1.MODEL_PATH
178
+
179
+
180
+ def train_matcher(
181
+ pairs_file: PathLike,
182
+ kernel: str = "rbf",
183
+ optimize: bool = True
184
+ ) -> None:
185
+ """Wrapper function to maintain compatibility with existing training interface."""
186
+ train_svm_matcher(pairs_file, kernel=kernel, optimize_hyperparameters=optimize)
@@ -0,0 +1,30 @@
1
+ from normality.constants import WS
2
+ from typing import Iterable, Set, Tuple
3
+ from rigour.text.distance import levenshtein
4
+
5
+ from nomenklatura.matching.compat import clean_name_ascii
6
+
7
+
8
+ def tokenize(texts: Iterable[str]) -> Set[str]:
9
+ tokens: Set[str] = set()
10
+ for text in texts:
11
+ cleaned = clean_name_ascii(text)
12
+ if cleaned is None:
13
+ continue
14
+ for token in cleaned.split(WS):
15
+ token = token.strip()
16
+ if len(token) > 2:
17
+ tokens.add(token)
18
+ return tokens
19
+
20
+
21
+ def tokenize_pair(
22
+ pair: Tuple[Iterable[str], Iterable[str]],
23
+ ) -> Tuple[Set[str], Set[str]]:
24
+ return tokenize(pair[0]), tokenize(pair[1])
25
+
26
+
27
+ def compare_levenshtein(left: str, right: str) -> float:
28
+ distance = levenshtein(left, right)
29
+ base = max((1, len(left), len(right)))
30
+ return 1.0 - (distance / float(base))