nomenklatura-mpt 4.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nomenklatura/__init__.py +11 -0
- nomenklatura/cache.py +194 -0
- nomenklatura/cli.py +260 -0
- nomenklatura/conflicting_match.py +80 -0
- nomenklatura/data/er-unstable.pkl +0 -0
- nomenklatura/data/regression-v1.pkl +0 -0
- nomenklatura/db.py +139 -0
- nomenklatura/delta.py +4 -0
- nomenklatura/enrich/__init__.py +94 -0
- nomenklatura/enrich/aleph.py +141 -0
- nomenklatura/enrich/common.py +219 -0
- nomenklatura/enrich/nominatim.py +72 -0
- nomenklatura/enrich/opencorporates.py +233 -0
- nomenklatura/enrich/openfigi.py +124 -0
- nomenklatura/enrich/permid.py +201 -0
- nomenklatura/enrich/wikidata.py +268 -0
- nomenklatura/enrich/yente.py +116 -0
- nomenklatura/exceptions.py +9 -0
- nomenklatura/index/__init__.py +5 -0
- nomenklatura/index/common.py +24 -0
- nomenklatura/index/entry.py +89 -0
- nomenklatura/index/index.py +170 -0
- nomenklatura/index/tokenizer.py +92 -0
- nomenklatura/judgement.py +21 -0
- nomenklatura/kv.py +40 -0
- nomenklatura/matching/__init__.py +47 -0
- nomenklatura/matching/bench.py +32 -0
- nomenklatura/matching/compare/__init__.py +0 -0
- nomenklatura/matching/compare/addresses.py +71 -0
- nomenklatura/matching/compare/countries.py +15 -0
- nomenklatura/matching/compare/dates.py +83 -0
- nomenklatura/matching/compare/gender.py +15 -0
- nomenklatura/matching/compare/identifiers.py +30 -0
- nomenklatura/matching/compare/names.py +157 -0
- nomenklatura/matching/compare/util.py +51 -0
- nomenklatura/matching/compat.py +66 -0
- nomenklatura/matching/erun/__init__.py +0 -0
- nomenklatura/matching/erun/countries.py +42 -0
- nomenklatura/matching/erun/identifiers.py +64 -0
- nomenklatura/matching/erun/misc.py +71 -0
- nomenklatura/matching/erun/model.py +110 -0
- nomenklatura/matching/erun/names.py +126 -0
- nomenklatura/matching/erun/train.py +135 -0
- nomenklatura/matching/erun/util.py +28 -0
- nomenklatura/matching/logic_v1/__init__.py +0 -0
- nomenklatura/matching/logic_v1/identifiers.py +104 -0
- nomenklatura/matching/logic_v1/model.py +76 -0
- nomenklatura/matching/logic_v1/multi.py +21 -0
- nomenklatura/matching/logic_v1/phonetic.py +142 -0
- nomenklatura/matching/logic_v2/__init__.py +0 -0
- nomenklatura/matching/logic_v2/identifiers.py +124 -0
- nomenklatura/matching/logic_v2/model.py +98 -0
- nomenklatura/matching/logic_v2/names/__init__.py +3 -0
- nomenklatura/matching/logic_v2/names/analysis.py +51 -0
- nomenklatura/matching/logic_v2/names/distance.py +181 -0
- nomenklatura/matching/logic_v2/names/magic.py +60 -0
- nomenklatura/matching/logic_v2/names/match.py +195 -0
- nomenklatura/matching/logic_v2/names/pairing.py +81 -0
- nomenklatura/matching/logic_v2/names/util.py +89 -0
- nomenklatura/matching/name_based/__init__.py +4 -0
- nomenklatura/matching/name_based/misc.py +86 -0
- nomenklatura/matching/name_based/model.py +59 -0
- nomenklatura/matching/name_based/names.py +59 -0
- nomenklatura/matching/pairs.py +42 -0
- nomenklatura/matching/regression_v1/__init__.py +0 -0
- nomenklatura/matching/regression_v1/misc.py +75 -0
- nomenklatura/matching/regression_v1/model.py +110 -0
- nomenklatura/matching/regression_v1/names.py +63 -0
- nomenklatura/matching/regression_v1/train.py +87 -0
- nomenklatura/matching/regression_v1/util.py +31 -0
- nomenklatura/matching/svm_v1/__init__.py +5 -0
- nomenklatura/matching/svm_v1/misc.py +94 -0
- nomenklatura/matching/svm_v1/model.py +168 -0
- nomenklatura/matching/svm_v1/names.py +81 -0
- nomenklatura/matching/svm_v1/train.py +186 -0
- nomenklatura/matching/svm_v1/util.py +30 -0
- nomenklatura/matching/types.py +227 -0
- nomenklatura/matching/util.py +62 -0
- nomenklatura/publish/__init__.py +0 -0
- nomenklatura/publish/dates.py +49 -0
- nomenklatura/publish/edges.py +32 -0
- nomenklatura/py.typed +0 -0
- nomenklatura/resolver/__init__.py +6 -0
- nomenklatura/resolver/common.py +2 -0
- nomenklatura/resolver/edge.py +107 -0
- nomenklatura/resolver/identifier.py +60 -0
- nomenklatura/resolver/linker.py +101 -0
- nomenklatura/resolver/resolver.py +565 -0
- nomenklatura/settings.py +17 -0
- nomenklatura/store/__init__.py +41 -0
- nomenklatura/store/base.py +130 -0
- nomenklatura/store/level.py +272 -0
- nomenklatura/store/memory.py +102 -0
- nomenklatura/store/redis_.py +131 -0
- nomenklatura/store/sql.py +219 -0
- nomenklatura/store/util.py +48 -0
- nomenklatura/store/versioned.py +371 -0
- nomenklatura/tui/__init__.py +17 -0
- nomenklatura/tui/app.py +294 -0
- nomenklatura/tui/app.tcss +52 -0
- nomenklatura/tui/comparison.py +81 -0
- nomenklatura/tui/util.py +35 -0
- nomenklatura/util.py +26 -0
- nomenklatura/versions.py +119 -0
- nomenklatura/wikidata/__init__.py +14 -0
- nomenklatura/wikidata/client.py +122 -0
- nomenklatura/wikidata/lang.py +94 -0
- nomenklatura/wikidata/model.py +139 -0
- nomenklatura/wikidata/props.py +70 -0
- nomenklatura/wikidata/qualified.py +49 -0
- nomenklatura/wikidata/query.py +66 -0
- nomenklatura/wikidata/value.py +87 -0
- nomenklatura/xref.py +125 -0
- nomenklatura_mpt-4.1.9.dist-info/METADATA +159 -0
- nomenklatura_mpt-4.1.9.dist-info/RECORD +118 -0
- nomenklatura_mpt-4.1.9.dist-info/WHEEL +4 -0
- nomenklatura_mpt-4.1.9.dist-info/entry_points.txt +3 -0
- nomenklatura_mpt-4.1.9.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
from followthemoney.proxy import E
|
2
|
+
from followthemoney.types import registry
|
3
|
+
|
4
|
+
from .util import tokenize_pair, compare_levenshtein
|
5
|
+
from nomenklatura.matching.compare.util import has_overlap, extract_numbers, is_disjoint
|
6
|
+
from nomenklatura.matching.util import props_pair, type_pair
|
7
|
+
from nomenklatura.matching.util import max_in_sets, has_schema
|
8
|
+
from nomenklatura.matching.compat import clean_name_ascii
|
9
|
+
|
10
|
+
|
11
|
+
def birth_place(query: E, result: E) -> float:
|
12
|
+
"""Same place of birth."""
|
13
|
+
lv, rv = tokenize_pair(props_pair(query, result, ["birthPlace"]))
|
14
|
+
tokens = min(len(lv), len(rv))
|
15
|
+
return float(len(lv.intersection(rv))) / float(max(2.0, tokens))
|
16
|
+
|
17
|
+
|
18
|
+
def address_match(query: E, result: E) -> float:
|
19
|
+
"""Text similarity between addresses."""
|
20
|
+
lv, rv = type_pair(query, result, registry.address)
|
21
|
+
lvn = [clean_name_ascii(v) for v in lv if clean_name_ascii(v) is not None]
|
22
|
+
rvn = [clean_name_ascii(v) for v in rv if clean_name_ascii(v) is not None]
|
23
|
+
return max_in_sets(lvn, rvn, compare_levenshtein)
|
24
|
+
|
25
|
+
|
26
|
+
def address_numbers(query: E, result: E) -> float:
|
27
|
+
"""Find if names contain numbers, score if the numbers are different."""
|
28
|
+
lv, rv = type_pair(query, result, registry.address)
|
29
|
+
lvn = extract_numbers(lv)
|
30
|
+
rvn = extract_numbers(rv)
|
31
|
+
common = len(lvn.intersection(rvn))
|
32
|
+
disjoint = len(lvn.difference(rvn))
|
33
|
+
return common - disjoint
|
34
|
+
|
35
|
+
|
36
|
+
def phone_match(query: E, result: E) -> float:
|
37
|
+
"""Matching phone numbers between the two entities."""
|
38
|
+
lv, rv = type_pair(query, result, registry.phone)
|
39
|
+
return 1.0 if has_overlap(lv, rv) else 0.0
|
40
|
+
|
41
|
+
|
42
|
+
def email_match(query: E, result: E) -> float:
|
43
|
+
"""Matching email addresses between the two entities."""
|
44
|
+
lv, rv = type_pair(query, result, registry.email)
|
45
|
+
return 1.0 if has_overlap(lv, rv) else 0.0
|
46
|
+
|
47
|
+
|
48
|
+
def identifier_match(query: E, result: E) -> float:
|
49
|
+
"""Matching identifiers (e.g. passports, national ID cards, registration or
|
50
|
+
tax numbers) between the two entities."""
|
51
|
+
if has_schema(query, result, "Organization"):
|
52
|
+
return 0.0
|
53
|
+
lv, rv = type_pair(query, result, registry.identifier)
|
54
|
+
return 1.0 if has_overlap(lv, rv) else 0.0
|
55
|
+
|
56
|
+
|
57
|
+
def org_identifier_match(query: E, result: E) -> float:
|
58
|
+
"""Matching identifiers (e.g. registration or tax numbers) between two
|
59
|
+
organizations or companies."""
|
60
|
+
if not has_schema(query, result, "Organization"):
|
61
|
+
return 0.0
|
62
|
+
lv, rv = type_pair(query, result, registry.identifier)
|
63
|
+
return 1.0 if has_overlap(lv, rv) else 0.0
|
64
|
+
|
65
|
+
|
66
|
+
def gender_mismatch(query: E, result: E) -> float:
|
67
|
+
"""Both entities have a different gender associated with them."""
|
68
|
+
qv, rv = props_pair(query, result, ["gender"])
|
69
|
+
return 1.0 if is_disjoint(qv, rv) else 0.0
|
70
|
+
|
71
|
+
|
72
|
+
def country_mismatch(query: E, result: E) -> float:
|
73
|
+
"""Both entities are linked to different countries."""
|
74
|
+
qv, rv = type_pair(query, result, registry.country)
|
75
|
+
return 1.0 if is_disjoint(qv, rv) else 0.0
|
76
|
+
|
77
|
+
|
78
|
+
def schema_match(query: E, result: E) -> float:
|
79
|
+
"""Both entities have the same schema."""
|
80
|
+
return 1.0 if query.schema == result.schema else 0.0
|
81
|
+
|
82
|
+
|
83
|
+
def property_count_similarity(query: E, result: E) -> float:
|
84
|
+
"""Similarity in the number of properties."""
|
85
|
+
query_props = len([p for p in query.itervalues()])
|
86
|
+
result_props = len([p for p in result.itervalues()])
|
87
|
+
|
88
|
+
if query_props == 0 and result_props == 0:
|
89
|
+
return 1.0
|
90
|
+
|
91
|
+
max_props = max(query_props, result_props)
|
92
|
+
min_props = min(query_props, result_props)
|
93
|
+
|
94
|
+
return float(min_props) / float(max_props) if max_props > 0 else 0.0
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import pickle
|
2
|
+
import numpy as np
|
3
|
+
from typing import List, Dict, Tuple, cast, Optional
|
4
|
+
from functools import lru_cache as cache
|
5
|
+
from sklearn.pipeline import Pipeline # type: ignore
|
6
|
+
from sklearn.svm import SVC # type: ignore
|
7
|
+
from sklearn.preprocessing import StandardScaler # type: ignore
|
8
|
+
from followthemoney.proxy import E
|
9
|
+
|
10
|
+
from .names import first_name_match
|
11
|
+
from .names import family_name_match
|
12
|
+
from .names import name_levenshtein, name_match
|
13
|
+
from .names import name_token_overlap, name_numbers
|
14
|
+
from .names import name_length_similarity
|
15
|
+
from .misc import phone_match, email_match
|
16
|
+
from .misc import address_match, address_numbers
|
17
|
+
from .misc import identifier_match, birth_place
|
18
|
+
from .misc import org_identifier_match
|
19
|
+
from .misc import gender_mismatch
|
20
|
+
from .misc import country_mismatch
|
21
|
+
from .misc import schema_match, property_count_similarity
|
22
|
+
from nomenklatura.matching.compare.dates import dob_matches, dob_year_matches
|
23
|
+
from nomenklatura.matching.compare.dates import dob_year_disjoint
|
24
|
+
from nomenklatura.matching.types import (
|
25
|
+
FeatureDocs,
|
26
|
+
FeatureDoc,
|
27
|
+
MatchingResult,
|
28
|
+
ScoringConfig,
|
29
|
+
)
|
30
|
+
from nomenklatura.matching.types import CompareFunction, FtResult
|
31
|
+
from nomenklatura.matching.types import Encoded, ScoringAlgorithm
|
32
|
+
from nomenklatura.matching.util import make_github_url
|
33
|
+
from nomenklatura.util import DATA_PATH
|
34
|
+
|
35
|
+
|
36
|
+
class SVMV1(ScoringAlgorithm):
|
37
|
+
"""A Support Vector Machine-based matching algorithm with RBF kernel."""
|
38
|
+
|
39
|
+
NAME = "svm-v1"
|
40
|
+
MODEL_PATH = DATA_PATH.joinpath(f"{NAME}.pkl")
|
41
|
+
FEATURES: List[CompareFunction] = [
|
42
|
+
name_match,
|
43
|
+
name_token_overlap,
|
44
|
+
name_numbers,
|
45
|
+
name_levenshtein,
|
46
|
+
name_length_similarity,
|
47
|
+
phone_match,
|
48
|
+
email_match,
|
49
|
+
identifier_match,
|
50
|
+
dob_matches,
|
51
|
+
dob_year_matches,
|
52
|
+
FtResult.unwrap(dob_year_disjoint),
|
53
|
+
first_name_match,
|
54
|
+
family_name_match,
|
55
|
+
birth_place,
|
56
|
+
gender_mismatch,
|
57
|
+
country_mismatch,
|
58
|
+
org_identifier_match,
|
59
|
+
address_match,
|
60
|
+
address_numbers,
|
61
|
+
schema_match,
|
62
|
+
property_count_similarity,
|
63
|
+
]
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def save(cls,
|
67
|
+
pipeline: Pipeline,
|
68
|
+
coefficients: Dict[str, float],
|
69
|
+
kernel: str = "rbf",
|
70
|
+
support_vectors_count: Optional[int] = None) -> None:
|
71
|
+
"""Store the SVM model after training."""
|
72
|
+
mdl = pickle.dumps({
|
73
|
+
"pipeline": pipeline,
|
74
|
+
"coefficients": coefficients,
|
75
|
+
"kernel": kernel,
|
76
|
+
"support_vectors_count": support_vectors_count
|
77
|
+
})
|
78
|
+
with open(cls.MODEL_PATH, "wb") as fh:
|
79
|
+
fh.write(mdl)
|
80
|
+
cls.load.cache_clear()
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
@cache
|
84
|
+
def load(cls) -> Tuple[Pipeline, Dict[str, float], str, Optional[int]]:
|
85
|
+
"""Load a pre-trained SVM model for ad-hoc use."""
|
86
|
+
with open(cls.MODEL_PATH, "rb") as fh:
|
87
|
+
model_data = pickle.loads(fh.read())
|
88
|
+
|
89
|
+
pipeline = cast(Pipeline, model_data["pipeline"])
|
90
|
+
coefficients = cast(Dict[str, float], model_data["coefficients"])
|
91
|
+
kernel = cast(str, model_data.get("kernel", "rbf"))
|
92
|
+
support_vectors_count = cast(Optional[int], model_data.get("support_vectors_count"))
|
93
|
+
|
94
|
+
current = [f.__name__ for f in cls.FEATURES]
|
95
|
+
if list(coefficients.keys()) != current:
|
96
|
+
raise RuntimeError("Model was not trained on identical features!")
|
97
|
+
|
98
|
+
return pipeline, coefficients, kernel, support_vectors_count
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def get_feature_docs(cls) -> FeatureDocs:
|
102
|
+
"""Return an explanation of the features and their coefficients."""
|
103
|
+
features: FeatureDocs = {}
|
104
|
+
_, coefficients, _, _ = cls.load()
|
105
|
+
for func in cls.FEATURES:
|
106
|
+
name = func.__name__
|
107
|
+
features[name] = FeatureDoc(
|
108
|
+
description=func.__doc__,
|
109
|
+
coefficient=float(coefficients.get(name, 0.0)),
|
110
|
+
url=make_github_url(func),
|
111
|
+
)
|
112
|
+
return features
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def compare(cls, query: E, result: E, config: ScoringConfig) -> MatchingResult:
|
116
|
+
"""Use an SVM model to compare two entities."""
|
117
|
+
pipeline, coefficients, kernel, _ = cls.load()
|
118
|
+
|
119
|
+
encoded = cls.encode_pair(query, result)
|
120
|
+
npfeat = np.array([encoded])
|
121
|
+
|
122
|
+
# Get probability predictions
|
123
|
+
if hasattr(pipeline.named_steps['svc'], 'predict_proba'):
|
124
|
+
pred_proba = pipeline.predict_proba(npfeat)
|
125
|
+
score = float(pred_proba[0][1]) # Positive class probability
|
126
|
+
else:
|
127
|
+
# Fallback to decision function if predict_proba is not available
|
128
|
+
decision = pipeline.decision_function(npfeat)
|
129
|
+
# Convert decision function to probability-like score
|
130
|
+
score = float(1.0 / (1.0 + np.exp(-decision[0]))) # Sigmoid transformation
|
131
|
+
|
132
|
+
# Create explanations
|
133
|
+
explanations: Dict[str, FtResult] = {}
|
134
|
+
for feature, coeff in zip(cls.FEATURES, encoded):
|
135
|
+
name = feature.__name__
|
136
|
+
explanations[name] = FtResult(score=float(coeff), detail=None)
|
137
|
+
|
138
|
+
# Add model-specific information
|
139
|
+
explanations["svm_kernel"] = FtResult(score=0.0, detail=kernel)
|
140
|
+
|
141
|
+
return MatchingResult.make(score=score, explanations=explanations)
|
142
|
+
|
143
|
+
@classmethod
|
144
|
+
def encode_pair(cls, left: E, right: E) -> Encoded:
|
145
|
+
"""Encode the comparison between two entities as a set of feature values."""
|
146
|
+
return [f(left, right) for f in cls.FEATURES]
|
147
|
+
|
148
|
+
@classmethod
|
149
|
+
def get_decision_function(cls, features: np.ndarray) -> np.ndarray:
|
150
|
+
"""Get the SVM decision function values."""
|
151
|
+
pipeline, _, _, _ = cls.load()
|
152
|
+
return pipeline.decision_function(features)
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def get_support_vectors_info(cls) -> Dict[str, any]:
|
156
|
+
"""Get information about the support vectors."""
|
157
|
+
pipeline, _, kernel, sv_count = cls.load()
|
158
|
+
svm = pipeline.named_steps['svc']
|
159
|
+
|
160
|
+
info = {
|
161
|
+
"kernel": kernel,
|
162
|
+
"support_vectors_count": sv_count or len(svm.support_vectors_) if hasattr(svm, 'support_vectors_') else None,
|
163
|
+
"n_support": svm.n_support_.tolist() if hasattr(svm, 'n_support_') else None,
|
164
|
+
"gamma": svm.gamma if hasattr(svm, 'gamma') else None,
|
165
|
+
"C": svm.C if hasattr(svm, 'C') else None,
|
166
|
+
}
|
167
|
+
|
168
|
+
return info
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from typing import Iterable, Set
|
2
|
+
from followthemoney.proxy import E
|
3
|
+
from followthemoney.types import registry
|
4
|
+
|
5
|
+
from .util import tokenize_pair, compare_levenshtein
|
6
|
+
from nomenklatura.matching.compare.util import is_disjoint, has_overlap, extract_numbers
|
7
|
+
from nomenklatura.matching.util import props_pair, type_pair
|
8
|
+
from nomenklatura.matching.util import max_in_sets
|
9
|
+
from nomenklatura.matching.compare.names import fingerprint_name
|
10
|
+
|
11
|
+
|
12
|
+
def normalize_names(raws: Iterable[str]) -> Set[str]:
|
13
|
+
names = set()
|
14
|
+
for raw in raws:
|
15
|
+
name = fingerprint_name(raw)
|
16
|
+
if name is not None:
|
17
|
+
names.add(name[:128])
|
18
|
+
return names
|
19
|
+
|
20
|
+
|
21
|
+
def name_levenshtein(left: E, right: E) -> float:
|
22
|
+
"""Consider the edit distance (as a fraction of name length) between the two most
|
23
|
+
similar names linked to both entities."""
|
24
|
+
lv, rv = type_pair(left, right, registry.name)
|
25
|
+
lvn, rvn = normalize_names(lv), normalize_names(rv)
|
26
|
+
return max_in_sets(lvn, rvn, compare_levenshtein)
|
27
|
+
|
28
|
+
|
29
|
+
def first_name_match(left: E, right: E) -> float:
|
30
|
+
"""Matching first/given name between the two entities."""
|
31
|
+
lv, rv = tokenize_pair(props_pair(left, right, ["firstName"]))
|
32
|
+
return 1.0 if has_overlap(lv, rv) else 0.0
|
33
|
+
|
34
|
+
|
35
|
+
def family_name_match(left: E, right: E) -> float:
|
36
|
+
"""Matching family name between the two entities."""
|
37
|
+
lv, rv = tokenize_pair(props_pair(left, right, ["lastName"]))
|
38
|
+
return 1.0 if has_overlap(lv, rv) else 0.0
|
39
|
+
|
40
|
+
|
41
|
+
def name_match(left: E, right: E) -> float:
|
42
|
+
"""Check for exact name matches between the two entities."""
|
43
|
+
lv, rv = type_pair(left, right, registry.name)
|
44
|
+
lvn, rvn = normalize_names(lv), normalize_names(rv)
|
45
|
+
common = [len(n) for n in lvn.intersection(rvn)]
|
46
|
+
max_common = max(common, default=0)
|
47
|
+
if max_common == 0:
|
48
|
+
return 0.0
|
49
|
+
return float(max_common)
|
50
|
+
|
51
|
+
|
52
|
+
def name_token_overlap(left: E, right: E) -> float:
|
53
|
+
"""Evaluate the proportion of identical words in each name."""
|
54
|
+
lv, rv = tokenize_pair(type_pair(left, right, registry.name))
|
55
|
+
common = lv.intersection(rv)
|
56
|
+
tokens = min(len(lv), len(rv))
|
57
|
+
return float(len(common)) / float(max(2.0, tokens))
|
58
|
+
|
59
|
+
|
60
|
+
def name_numbers(left: E, right: E) -> float:
|
61
|
+
"""Find if names contain numbers, score if the numbers are different."""
|
62
|
+
lv, rv = type_pair(left, right, registry.name)
|
63
|
+
return 1.0 if is_disjoint(extract_numbers(lv), extract_numbers(rv)) else 0.0
|
64
|
+
|
65
|
+
|
66
|
+
def name_length_similarity(left: E, right: E) -> float:
|
67
|
+
"""Similarity in name lengths."""
|
68
|
+
lv, rv = type_pair(left, right, registry.name)
|
69
|
+
if not lv or not rv:
|
70
|
+
return 0.0
|
71
|
+
|
72
|
+
max_left = max(len(name) for name in lv)
|
73
|
+
max_right = max(len(name) for name in rv)
|
74
|
+
|
75
|
+
if max_left == 0 and max_right == 0:
|
76
|
+
return 1.0
|
77
|
+
|
78
|
+
max_len = max(max_left, max_right)
|
79
|
+
min_len = min(max_left, max_right)
|
80
|
+
|
81
|
+
return float(min_len) / float(max_len) if max_len > 0 else 0.0
|
@@ -0,0 +1,186 @@
|
|
1
|
+
import logging
|
2
|
+
import numpy as np
|
3
|
+
import multiprocessing
|
4
|
+
from typing import Iterable, List, Tuple, Optional
|
5
|
+
from pprint import pprint
|
6
|
+
from numpy.typing import NDArray
|
7
|
+
from sklearn.pipeline import make_pipeline # type: ignore
|
8
|
+
from sklearn.preprocessing import StandardScaler # type: ignore
|
9
|
+
from sklearn.model_selection import train_test_split, GridSearchCV # type: ignore
|
10
|
+
from sklearn.svm import SVC # type: ignore
|
11
|
+
from sklearn import metrics # type: ignore
|
12
|
+
from concurrent.futures import ProcessPoolExecutor
|
13
|
+
from followthemoney.util import PathLike
|
14
|
+
|
15
|
+
from nomenklatura.judgement import Judgement
|
16
|
+
from nomenklatura.matching.pairs import read_pairs, JudgedPair
|
17
|
+
from .model import SVMV1
|
18
|
+
|
19
|
+
log = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
def pair_convert(pair: JudgedPair) -> Tuple[List[float], int]:
|
23
|
+
"""Encode a pair of training data into features and target."""
|
24
|
+
judgement = 1 if pair.judgement == Judgement.POSITIVE else 0
|
25
|
+
features = SVMV1.encode_pair(pair.left, pair.right)
|
26
|
+
return features, judgement
|
27
|
+
|
28
|
+
|
29
|
+
def pairs_to_arrays(
|
30
|
+
pairs: Iterable[JudgedPair],
|
31
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
32
|
+
"""Parallelize feature computation for training data"""
|
33
|
+
xrows = []
|
34
|
+
yrows = []
|
35
|
+
workers = multiprocessing.cpu_count()
|
36
|
+
log.info("Using %d processes for feature computation...", workers)
|
37
|
+
with ProcessPoolExecutor(max_workers=workers) as excecutor:
|
38
|
+
results = excecutor.map(pair_convert, pairs)
|
39
|
+
for idx, (x, y) in enumerate(results):
|
40
|
+
if idx > 0 and idx % 10000 == 0:
|
41
|
+
log.info("Computing features: %s....", idx)
|
42
|
+
xrows.append(x)
|
43
|
+
yrows.append(y)
|
44
|
+
|
45
|
+
return np.array(xrows), np.array(yrows)
|
46
|
+
|
47
|
+
|
48
|
+
def train_svm_matcher(
|
49
|
+
pairs_file: PathLike,
|
50
|
+
kernel: str = "rbf",
|
51
|
+
optimize_hyperparameters: bool = True,
|
52
|
+
probability: bool = True
|
53
|
+
) :
|
54
|
+
"""Train an SVM matching model."""
|
55
|
+
pairs = []
|
56
|
+
for pair in read_pairs(pairs_file):
|
57
|
+
if pair.judgement == Judgement.UNSURE:
|
58
|
+
pair.judgement = Judgement.NEGATIVE
|
59
|
+
pairs.append(pair)
|
60
|
+
|
61
|
+
positive = len([p for p in pairs if p.judgement == Judgement.POSITIVE])
|
62
|
+
negative = len([p for p in pairs if p.judgement == Judgement.NEGATIVE])
|
63
|
+
log.info("Total pairs loaded: %d (%d pos/%d neg)", len(pairs), positive, negative)
|
64
|
+
|
65
|
+
X, y = pairs_to_arrays(pairs)
|
66
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
|
67
|
+
|
68
|
+
log.info("Training SVM model with %s kernel...", kernel)
|
69
|
+
|
70
|
+
# Create base SVM
|
71
|
+
if optimize_hyperparameters:
|
72
|
+
log.info("Optimizing hyperparameters with Grid Search...")
|
73
|
+
|
74
|
+
# Define parameter grid
|
75
|
+
if kernel == "rbf":
|
76
|
+
param_grid = {
|
77
|
+
'svc__C': [0.1, 1, 10, 100],
|
78
|
+
'svc__gamma': ['scale', 'auto', 0.001, 0.01, 0.1, 1]
|
79
|
+
}
|
80
|
+
elif kernel == "linear":
|
81
|
+
param_grid = {
|
82
|
+
'svc__C': [0.1, 1, 10, 100]
|
83
|
+
}
|
84
|
+
elif kernel == "poly":
|
85
|
+
param_grid = {
|
86
|
+
'svc__C': [0.1, 1, 10, 100],
|
87
|
+
'svc__degree': [2, 3, 4],
|
88
|
+
'svc__gamma': ['scale', 'auto', 0.001, 0.01, 0.1]
|
89
|
+
}
|
90
|
+
else:
|
91
|
+
param_grid = {'svc__C': [0.1, 1, 10, 100]}
|
92
|
+
|
93
|
+
# Create pipeline
|
94
|
+
base_svm = SVC(kernel=kernel, probability=probability, random_state=42)
|
95
|
+
pipeline = make_pipeline(StandardScaler(), base_svm)
|
96
|
+
|
97
|
+
# Grid search
|
98
|
+
grid_search = GridSearchCV(
|
99
|
+
pipeline,
|
100
|
+
param_grid,
|
101
|
+
cv=5,
|
102
|
+
scoring='roc_auc',
|
103
|
+
n_jobs=-1,
|
104
|
+
verbose=1
|
105
|
+
)
|
106
|
+
|
107
|
+
grid_search.fit(X_train, y_train)
|
108
|
+
pipeline = grid_search.best_estimator_
|
109
|
+
|
110
|
+
log.info("Best parameters: %s", grid_search.best_params_)
|
111
|
+
log.info("Best cross-validation score: %.4f", grid_search.best_score_)
|
112
|
+
|
113
|
+
else:
|
114
|
+
# Use default parameters
|
115
|
+
svm = SVC(kernel=kernel, probability=probability, random_state=42)
|
116
|
+
pipeline = make_pipeline(StandardScaler(), svm)
|
117
|
+
pipeline.fit(X_train, y_train)
|
118
|
+
|
119
|
+
# Get SVM object for coefficient extraction
|
120
|
+
svm_model = pipeline.named_steps['svc']
|
121
|
+
|
122
|
+
# Create feature coefficients (for linear kernel, use actual coefficients)
|
123
|
+
coefficients = {}
|
124
|
+
if kernel == "linear" and hasattr(svm_model, 'coef_'):
|
125
|
+
# For linear SVM, we have actual feature coefficients
|
126
|
+
coef_values = svm_model.coef_[0]
|
127
|
+
for i, feature in enumerate(SVMV1.FEATURES):
|
128
|
+
coefficients[feature.__name__] = float(coef_values[i])
|
129
|
+
else:
|
130
|
+
# For non-linear kernels, use feature importance approximation
|
131
|
+
# This is a simplified approach - in practice, you might use SHAP or similar
|
132
|
+
for feature in SVMV1.FEATURES:
|
133
|
+
coefficients[feature.__name__] = 1.0 # Placeholder
|
134
|
+
|
135
|
+
# Get support vector count
|
136
|
+
support_vectors_count = len(svm_model.support_vectors_) if hasattr(svm_model, 'support_vectors_') else None
|
137
|
+
|
138
|
+
# Save the model
|
139
|
+
SVMV1.save(pipeline, coefficients, kernel, support_vectors_count)
|
140
|
+
path = SVMV1.MODEL_PATH
|
141
|
+
# Evaluation
|
142
|
+
log.info("Evaluating SVM model...")
|
143
|
+
|
144
|
+
# Predictions
|
145
|
+
y_pred = pipeline.predict(X_test)
|
146
|
+
|
147
|
+
if probability:
|
148
|
+
y_pred_proba = pipeline.predict_proba(X_test)[:, 1]
|
149
|
+
else:
|
150
|
+
# Use decision function and apply sigmoid
|
151
|
+
decision_scores = pipeline.decision_function(X_test)
|
152
|
+
y_pred_proba = 1.0 / (1.0 + np.exp(-decision_scores))
|
153
|
+
|
154
|
+
# Print results
|
155
|
+
print("\nSVM Results:")
|
156
|
+
print(f"Kernel: {kernel}")
|
157
|
+
print(f"Support vectors count: {support_vectors_count}")
|
158
|
+
|
159
|
+
if kernel == "linear":
|
160
|
+
print("Feature Coefficients:")
|
161
|
+
pprint(coefficients)
|
162
|
+
|
163
|
+
cnf_matrix = metrics.confusion_matrix(y_test, y_pred)
|
164
|
+
print("Confusion matrix:\n", cnf_matrix)
|
165
|
+
print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
|
166
|
+
print("Precision:", metrics.precision_score(y_test, y_pred))
|
167
|
+
print("Recall:", metrics.recall_score(y_test, y_pred))
|
168
|
+
print("F1-score:", metrics.f1_score(y_test, y_pred))
|
169
|
+
|
170
|
+
auc = metrics.roc_auc_score(y_test, y_pred_proba)
|
171
|
+
print("Area under curve:", auc)
|
172
|
+
|
173
|
+
# Additional SVM-specific metrics
|
174
|
+
if hasattr(svm_model, 'support_'):
|
175
|
+
print(f"Number of support vectors: {len(svm_model.support_)}")
|
176
|
+
print(f"Support vectors per class: {svm_model.n_support_}")
|
177
|
+
return SVMV1.MODEL_PATH
|
178
|
+
|
179
|
+
|
180
|
+
def train_matcher(
|
181
|
+
pairs_file: PathLike,
|
182
|
+
kernel: str = "rbf",
|
183
|
+
optimize: bool = True
|
184
|
+
) -> None:
|
185
|
+
"""Wrapper function to maintain compatibility with existing training interface."""
|
186
|
+
train_svm_matcher(pairs_file, kernel=kernel, optimize_hyperparameters=optimize)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from normality.constants import WS
|
2
|
+
from typing import Iterable, Set, Tuple
|
3
|
+
from rigour.text.distance import levenshtein
|
4
|
+
|
5
|
+
from nomenklatura.matching.compat import clean_name_ascii
|
6
|
+
|
7
|
+
|
8
|
+
def tokenize(texts: Iterable[str]) -> Set[str]:
|
9
|
+
tokens: Set[str] = set()
|
10
|
+
for text in texts:
|
11
|
+
cleaned = clean_name_ascii(text)
|
12
|
+
if cleaned is None:
|
13
|
+
continue
|
14
|
+
for token in cleaned.split(WS):
|
15
|
+
token = token.strip()
|
16
|
+
if len(token) > 2:
|
17
|
+
tokens.add(token)
|
18
|
+
return tokens
|
19
|
+
|
20
|
+
|
21
|
+
def tokenize_pair(
|
22
|
+
pair: Tuple[Iterable[str], Iterable[str]],
|
23
|
+
) -> Tuple[Set[str], Set[str]]:
|
24
|
+
return tokenize(pair[0]), tokenize(pair[1])
|
25
|
+
|
26
|
+
|
27
|
+
def compare_levenshtein(left: str, right: str) -> float:
|
28
|
+
distance = levenshtein(left, right)
|
29
|
+
base = max((1, len(left), len(right)))
|
30
|
+
return 1.0 - (distance / float(base))
|