corp-extractor 0.9.0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/METADATA +40 -9
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/RECORD +29 -26
- statement_extractor/cli.py +866 -77
- statement_extractor/database/hub.py +35 -127
- statement_extractor/database/importers/__init__.py +10 -2
- statement_extractor/database/importers/companies_house.py +16 -2
- statement_extractor/database/importers/companies_house_officers.py +431 -0
- statement_extractor/database/importers/gleif.py +23 -0
- statement_extractor/database/importers/sec_edgar.py +17 -0
- statement_extractor/database/importers/sec_form4.py +512 -0
- statement_extractor/database/importers/wikidata.py +151 -43
- statement_extractor/database/importers/wikidata_dump.py +1951 -0
- statement_extractor/database/importers/wikidata_people.py +823 -325
- statement_extractor/database/models.py +30 -6
- statement_extractor/database/store.py +1485 -60
- statement_extractor/document/deduplicator.py +10 -12
- statement_extractor/extractor.py +1 -1
- statement_extractor/models/__init__.py +3 -2
- statement_extractor/models/statement.py +15 -17
- statement_extractor/models.py +1 -1
- statement_extractor/pipeline/context.py +5 -5
- statement_extractor/pipeline/orchestrator.py +12 -12
- statement_extractor/plugins/base.py +17 -17
- statement_extractor/plugins/extractors/gliner2.py +28 -28
- statement_extractor/plugins/qualifiers/embedding_company.py +7 -5
- statement_extractor/plugins/qualifiers/person.py +11 -1
- statement_extractor/plugins/splitters/t5_gemma.py +35 -39
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/WHEEL +0 -0
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/entry_points.txt +0 -0
|
@@ -15,6 +15,7 @@ from pathlib import Path
|
|
|
15
15
|
from typing import Iterator, Optional
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
|
+
import pycountry
|
|
18
19
|
import sqlite_vec
|
|
19
20
|
|
|
20
21
|
from .models import CompanyRecord, DatabaseStats, EntityType, PersonRecord, PersonType
|
|
@@ -24,12 +25,45 @@ logger = logging.getLogger(__name__)
|
|
|
24
25
|
# Default database location
|
|
25
26
|
DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
|
|
26
27
|
|
|
28
|
+
# Module-level shared connections by path (both databases share the same connection)
|
|
29
|
+
_shared_connections: dict[str, sqlite3.Connection] = {}
|
|
30
|
+
|
|
27
31
|
# Module-level singleton for OrganizationDatabase to prevent multiple loads
|
|
28
32
|
_database_instances: dict[str, "OrganizationDatabase"] = {}
|
|
29
33
|
|
|
30
34
|
# Module-level singleton for PersonDatabase
|
|
31
35
|
_person_database_instances: dict[str, "PersonDatabase"] = {}
|
|
32
36
|
|
|
37
|
+
|
|
38
|
+
def _get_shared_connection(db_path: Path, embedding_dim: int = 768) -> sqlite3.Connection:
|
|
39
|
+
"""Get or create a shared database connection for the given path."""
|
|
40
|
+
path_key = str(db_path)
|
|
41
|
+
if path_key not in _shared_connections:
|
|
42
|
+
# Ensure directory exists
|
|
43
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
conn = sqlite3.connect(str(db_path))
|
|
46
|
+
conn.row_factory = sqlite3.Row
|
|
47
|
+
|
|
48
|
+
# Load sqlite-vec extension
|
|
49
|
+
conn.enable_load_extension(True)
|
|
50
|
+
sqlite_vec.load(conn)
|
|
51
|
+
conn.enable_load_extension(False)
|
|
52
|
+
|
|
53
|
+
_shared_connections[path_key] = conn
|
|
54
|
+
logger.debug(f"Created shared database connection for {path_key}")
|
|
55
|
+
|
|
56
|
+
return _shared_connections[path_key]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def close_shared_connection(db_path: Optional[Path] = None) -> None:
|
|
60
|
+
"""Close a shared database connection."""
|
|
61
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
62
|
+
if path_key in _shared_connections:
|
|
63
|
+
_shared_connections[path_key].close()
|
|
64
|
+
del _shared_connections[path_key]
|
|
65
|
+
logger.debug(f"Closed shared database connection for {path_key}")
|
|
66
|
+
|
|
33
67
|
# Comprehensive set of corporate legal suffixes (international)
|
|
34
68
|
COMPANY_SUFFIXES: set[str] = {
|
|
35
69
|
'A/S', 'AB', 'AG', 'AO', 'AG & Co', 'AG &', 'AG & CO.', 'AG & CO. KG', 'AG & CO. KGaA',
|
|
@@ -43,6 +77,222 @@ COMPANY_SUFFIXES: set[str] = {
|
|
|
43
77
|
'Group', 'Holdings', 'Holding', 'Partners', 'Trust', 'Fund', 'Bank', 'N.A.', 'The',
|
|
44
78
|
}
|
|
45
79
|
|
|
80
|
+
# Source priority for organization canonicalization (lower = higher priority)
|
|
81
|
+
SOURCE_PRIORITY: dict[str, int] = {
|
|
82
|
+
"gleif": 1, # Gold standard LEI - globally unique legal entity identifier
|
|
83
|
+
"sec_edgar": 2, # Vetted US filers with CIK + ticker
|
|
84
|
+
"companies_house": 3, # Official UK registry
|
|
85
|
+
"wikipedia": 4, # Crowdsourced, less authoritative
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Source priority for people canonicalization (lower = higher priority)
|
|
89
|
+
PERSON_SOURCE_PRIORITY: dict[str, int] = {
|
|
90
|
+
"wikidata": 1, # Curated, has rich biographical data and Q codes
|
|
91
|
+
"sec_edgar": 2, # Vetted US filers (Form 4 officers/directors)
|
|
92
|
+
"companies_house": 3, # UK company officers
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Suffix expansions for canonical name matching
|
|
96
|
+
SUFFIX_EXPANSIONS: dict[str, str] = {
|
|
97
|
+
" ltd": " limited",
|
|
98
|
+
" corp": " corporation",
|
|
99
|
+
" inc": " incorporated",
|
|
100
|
+
" co": " company",
|
|
101
|
+
" intl": " international",
|
|
102
|
+
" natl": " national",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class UnionFind:
|
|
107
|
+
"""Simple Union-Find (Disjoint Set Union) data structure for canonicalization."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, elements: list[int]):
|
|
110
|
+
"""Initialize with list of element IDs."""
|
|
111
|
+
self.parent: dict[int, int] = {e: e for e in elements}
|
|
112
|
+
self.rank: dict[int, int] = {e: 0 for e in elements}
|
|
113
|
+
|
|
114
|
+
def find(self, x: int) -> int:
|
|
115
|
+
"""Find with path compression."""
|
|
116
|
+
if self.parent[x] != x:
|
|
117
|
+
self.parent[x] = self.find(self.parent[x])
|
|
118
|
+
return self.parent[x]
|
|
119
|
+
|
|
120
|
+
def union(self, x: int, y: int) -> None:
|
|
121
|
+
"""Union by rank."""
|
|
122
|
+
px, py = self.find(x), self.find(y)
|
|
123
|
+
if px == py:
|
|
124
|
+
return
|
|
125
|
+
if self.rank[px] < self.rank[py]:
|
|
126
|
+
px, py = py, px
|
|
127
|
+
self.parent[py] = px
|
|
128
|
+
if self.rank[px] == self.rank[py]:
|
|
129
|
+
self.rank[px] += 1
|
|
130
|
+
|
|
131
|
+
def groups(self) -> dict[int, list[int]]:
|
|
132
|
+
"""Return dict of root -> list of members."""
|
|
133
|
+
result: dict[int, list[int]] = {}
|
|
134
|
+
for e in self.parent:
|
|
135
|
+
root = self.find(e)
|
|
136
|
+
result.setdefault(root, []).append(e)
|
|
137
|
+
return result
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# Common region aliases not handled well by pycountry fuzzy search
|
|
141
|
+
REGION_ALIASES: dict[str, str] = {
|
|
142
|
+
"uk": "GB",
|
|
143
|
+
"u.k.": "GB",
|
|
144
|
+
"england": "GB",
|
|
145
|
+
"scotland": "GB",
|
|
146
|
+
"wales": "GB",
|
|
147
|
+
"northern ireland": "GB",
|
|
148
|
+
"usa": "US",
|
|
149
|
+
"u.s.a.": "US",
|
|
150
|
+
"u.s.": "US",
|
|
151
|
+
"united states of america": "US",
|
|
152
|
+
"america": "US",
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
# Cache for region normalization lookups
|
|
156
|
+
_region_cache: dict[str, str] = {}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _normalize_region(region: str) -> str:
|
|
160
|
+
"""
|
|
161
|
+
Normalize a region string to ISO 3166-1 alpha-2 country code.
|
|
162
|
+
|
|
163
|
+
Handles:
|
|
164
|
+
- Country codes (2-letter, 3-letter)
|
|
165
|
+
- Country names (with fuzzy matching)
|
|
166
|
+
- US state codes (CA, NY) -> US
|
|
167
|
+
- US state names (California, New York) -> US
|
|
168
|
+
- Common aliases (UK, USA, England) -> proper codes
|
|
169
|
+
|
|
170
|
+
Returns empty string if region cannot be normalized.
|
|
171
|
+
"""
|
|
172
|
+
if not region:
|
|
173
|
+
return ""
|
|
174
|
+
|
|
175
|
+
# Check cache first
|
|
176
|
+
cache_key = region.lower().strip()
|
|
177
|
+
if cache_key in _region_cache:
|
|
178
|
+
return _region_cache[cache_key]
|
|
179
|
+
|
|
180
|
+
result = _normalize_region_uncached(region)
|
|
181
|
+
_region_cache[cache_key] = result
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _normalize_region_uncached(region: str) -> str:
|
|
186
|
+
"""Uncached region normalization logic."""
|
|
187
|
+
region_clean = region.strip()
|
|
188
|
+
|
|
189
|
+
# Empty after stripping = empty result
|
|
190
|
+
if not region_clean:
|
|
191
|
+
return ""
|
|
192
|
+
|
|
193
|
+
region_lower = region_clean.lower()
|
|
194
|
+
region_upper = region_clean.upper()
|
|
195
|
+
|
|
196
|
+
# Check common aliases first
|
|
197
|
+
if region_lower in REGION_ALIASES:
|
|
198
|
+
return REGION_ALIASES[region_lower]
|
|
199
|
+
|
|
200
|
+
# For 2-letter codes, check country first, then US state
|
|
201
|
+
# This means ambiguous codes like "CA" (Canada vs California) prefer country
|
|
202
|
+
# But unambiguous codes like "NY" (not a country) will match as US state
|
|
203
|
+
if len(region_clean) == 2:
|
|
204
|
+
# Try as country alpha-2 first
|
|
205
|
+
country = pycountry.countries.get(alpha_2=region_upper)
|
|
206
|
+
if country:
|
|
207
|
+
return country.alpha_2
|
|
208
|
+
|
|
209
|
+
# If not a country, try as US state code
|
|
210
|
+
subdivision = pycountry.subdivisions.get(code=f"US-{region_upper}")
|
|
211
|
+
if subdivision:
|
|
212
|
+
return "US"
|
|
213
|
+
|
|
214
|
+
# Try alpha-3 lookup
|
|
215
|
+
if len(region_clean) == 3:
|
|
216
|
+
country = pycountry.countries.get(alpha_3=region_upper)
|
|
217
|
+
if country:
|
|
218
|
+
return country.alpha_2
|
|
219
|
+
|
|
220
|
+
# Try as US state name (e.g., "California", "New York")
|
|
221
|
+
try:
|
|
222
|
+
subdivisions = list(pycountry.subdivisions.search_fuzzy(region_clean))
|
|
223
|
+
if subdivisions:
|
|
224
|
+
# Check if it's a US state
|
|
225
|
+
if subdivisions[0].code.startswith("US-"):
|
|
226
|
+
return "US"
|
|
227
|
+
# Return the parent country code
|
|
228
|
+
return subdivisions[0].country_code
|
|
229
|
+
except LookupError:
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
# Try country fuzzy search
|
|
233
|
+
try:
|
|
234
|
+
countries = pycountry.countries.search_fuzzy(region_clean)
|
|
235
|
+
if countries:
|
|
236
|
+
return countries[0].alpha_2
|
|
237
|
+
except LookupError:
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
# Return empty if we can't normalize
|
|
241
|
+
return ""
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _regions_match(region1: str, region2: str) -> bool:
|
|
245
|
+
"""
|
|
246
|
+
Check if two regions match after normalization.
|
|
247
|
+
|
|
248
|
+
Empty regions match anything (lenient matching for incomplete data).
|
|
249
|
+
"""
|
|
250
|
+
norm1 = _normalize_region(region1)
|
|
251
|
+
norm2 = _normalize_region(region2)
|
|
252
|
+
|
|
253
|
+
# Empty regions match anything
|
|
254
|
+
if not norm1 or not norm2:
|
|
255
|
+
return True
|
|
256
|
+
|
|
257
|
+
return norm1 == norm2
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _normalize_for_canon(name: str) -> str:
|
|
261
|
+
"""Normalize name for canonical matching (simpler than search normalization)."""
|
|
262
|
+
# Lowercase
|
|
263
|
+
result = name.lower()
|
|
264
|
+
# Remove trailing dots
|
|
265
|
+
result = result.rstrip(".")
|
|
266
|
+
# Remove extra whitespace
|
|
267
|
+
result = " ".join(result.split())
|
|
268
|
+
return result
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _expand_suffix(name: str) -> str:
|
|
272
|
+
"""Expand known suffix abbreviations."""
|
|
273
|
+
result = name.lower().rstrip(".")
|
|
274
|
+
for abbrev, full in SUFFIX_EXPANSIONS.items():
|
|
275
|
+
if result.endswith(abbrev):
|
|
276
|
+
result = result[:-len(abbrev)] + full
|
|
277
|
+
break # Only expand one suffix
|
|
278
|
+
return result
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _names_match_for_canon(name1: str, name2: str) -> bool:
|
|
282
|
+
"""Check if two names match for canonicalization."""
|
|
283
|
+
n1 = _normalize_for_canon(name1)
|
|
284
|
+
n2 = _normalize_for_canon(name2)
|
|
285
|
+
|
|
286
|
+
# Exact match after normalization
|
|
287
|
+
if n1 == n2:
|
|
288
|
+
return True
|
|
289
|
+
|
|
290
|
+
# Try with suffix expansion
|
|
291
|
+
if _expand_suffix(n1) == _expand_suffix(n2):
|
|
292
|
+
return True
|
|
293
|
+
|
|
294
|
+
return False
|
|
295
|
+
|
|
46
296
|
# Pre-compile the suffix pattern for performance
|
|
47
297
|
_SUFFIX_PATTERN = re.compile(
|
|
48
298
|
r'\s+(' + '|'.join(re.escape(suffix) for suffix in COMPANY_SUFFIXES) + r')\.?$',
|
|
@@ -256,20 +506,13 @@ class OrganizationDatabase:
|
|
|
256
506
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
257
507
|
|
|
258
508
|
def _connect(self) -> sqlite3.Connection:
|
|
259
|
-
"""Get or create database connection
|
|
509
|
+
"""Get or create database connection using shared connection pool."""
|
|
260
510
|
if self._conn is not None:
|
|
261
511
|
return self._conn
|
|
262
512
|
|
|
263
|
-
self.
|
|
264
|
-
self._conn = sqlite3.connect(str(self._db_path))
|
|
265
|
-
self._conn.row_factory = sqlite3.Row
|
|
266
|
-
|
|
267
|
-
# Load sqlite-vec extension
|
|
268
|
-
self._conn.enable_load_extension(True)
|
|
269
|
-
sqlite_vec.load(self._conn)
|
|
270
|
-
self._conn.enable_load_extension(False)
|
|
513
|
+
self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
|
|
271
514
|
|
|
272
|
-
# Create tables
|
|
515
|
+
# Create tables (idempotent)
|
|
273
516
|
self._create_tables()
|
|
274
517
|
|
|
275
518
|
return self._conn
|
|
@@ -289,6 +532,8 @@ class OrganizationDatabase:
|
|
|
289
532
|
source_id TEXT NOT NULL,
|
|
290
533
|
region TEXT NOT NULL DEFAULT '',
|
|
291
534
|
entity_type TEXT NOT NULL DEFAULT 'unknown',
|
|
535
|
+
from_date TEXT NOT NULL DEFAULT '',
|
|
536
|
+
to_date TEXT NOT NULL DEFAULT '',
|
|
292
537
|
record TEXT NOT NULL,
|
|
293
538
|
UNIQUE(source, source_id)
|
|
294
539
|
)
|
|
@@ -308,6 +553,34 @@ class OrganizationDatabase:
|
|
|
308
553
|
except sqlite3.OperationalError:
|
|
309
554
|
pass # Column already exists
|
|
310
555
|
|
|
556
|
+
# Add from_date column if it doesn't exist (migration for existing DBs)
|
|
557
|
+
try:
|
|
558
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
|
|
559
|
+
logger.info("Added from_date column to organizations table")
|
|
560
|
+
except sqlite3.OperationalError:
|
|
561
|
+
pass # Column already exists
|
|
562
|
+
|
|
563
|
+
# Add to_date column if it doesn't exist (migration for existing DBs)
|
|
564
|
+
try:
|
|
565
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
|
|
566
|
+
logger.info("Added to_date column to organizations table")
|
|
567
|
+
except sqlite3.OperationalError:
|
|
568
|
+
pass # Column already exists
|
|
569
|
+
|
|
570
|
+
# Add canon_id column if it doesn't exist (migration for canonicalization)
|
|
571
|
+
try:
|
|
572
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN canon_id INTEGER DEFAULT NULL")
|
|
573
|
+
logger.info("Added canon_id column to organizations table")
|
|
574
|
+
except sqlite3.OperationalError:
|
|
575
|
+
pass # Column already exists
|
|
576
|
+
|
|
577
|
+
# Add canon_size column if it doesn't exist (migration for canonicalization)
|
|
578
|
+
try:
|
|
579
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN canon_size INTEGER DEFAULT 1")
|
|
580
|
+
logger.info("Added canon_size column to organizations table")
|
|
581
|
+
except sqlite3.OperationalError:
|
|
582
|
+
pass # Column already exists
|
|
583
|
+
|
|
311
584
|
# Create indexes on main table
|
|
312
585
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
|
|
313
586
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
|
|
@@ -316,6 +589,7 @@ class OrganizationDatabase:
|
|
|
316
589
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
|
|
317
590
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
|
|
318
591
|
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
|
|
592
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_canon_id ON organizations(canon_id)")
|
|
319
593
|
|
|
320
594
|
# Create sqlite-vec virtual table for embeddings
|
|
321
595
|
# vec0 is the recommended virtual table type
|
|
@@ -329,10 +603,8 @@ class OrganizationDatabase:
|
|
|
329
603
|
conn.commit()
|
|
330
604
|
|
|
331
605
|
def close(self) -> None:
|
|
332
|
-
"""
|
|
333
|
-
|
|
334
|
-
self._conn.close()
|
|
335
|
-
self._conn = None
|
|
606
|
+
"""Clear connection reference (shared connection remains open)."""
|
|
607
|
+
self._conn = None
|
|
336
608
|
|
|
337
609
|
def insert(self, record: CompanyRecord, embedding: np.ndarray) -> int:
|
|
338
610
|
"""
|
|
@@ -353,8 +625,8 @@ class OrganizationDatabase:
|
|
|
353
625
|
|
|
354
626
|
cursor = conn.execute("""
|
|
355
627
|
INSERT OR REPLACE INTO organizations
|
|
356
|
-
(name, name_normalized, source, source_id, region, entity_type, record)
|
|
357
|
-
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
628
|
+
(name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
|
|
629
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
358
630
|
""", (
|
|
359
631
|
record.name,
|
|
360
632
|
name_normalized,
|
|
@@ -362,6 +634,8 @@ class OrganizationDatabase:
|
|
|
362
634
|
record.source_id,
|
|
363
635
|
record.region,
|
|
364
636
|
record.entity_type.value,
|
|
637
|
+
record.from_date or "",
|
|
638
|
+
record.to_date or "",
|
|
365
639
|
record_json,
|
|
366
640
|
))
|
|
367
641
|
|
|
@@ -369,10 +643,11 @@ class OrganizationDatabase:
|
|
|
369
643
|
assert row_id is not None
|
|
370
644
|
|
|
371
645
|
# Insert embedding into vec table
|
|
372
|
-
# sqlite-vec
|
|
646
|
+
# sqlite-vec virtual tables don't support INSERT OR REPLACE, so delete first
|
|
373
647
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
648
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
|
|
374
649
|
conn.execute("""
|
|
375
|
-
INSERT
|
|
650
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
376
651
|
VALUES (?, ?)
|
|
377
652
|
""", (row_id, embedding_blob))
|
|
378
653
|
|
|
@@ -405,8 +680,8 @@ class OrganizationDatabase:
|
|
|
405
680
|
|
|
406
681
|
cursor = conn.execute("""
|
|
407
682
|
INSERT OR REPLACE INTO organizations
|
|
408
|
-
(name, name_normalized, source, source_id, region, entity_type, record)
|
|
409
|
-
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
683
|
+
(name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
|
|
684
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
410
685
|
""", (
|
|
411
686
|
record.name,
|
|
412
687
|
name_normalized,
|
|
@@ -414,16 +689,19 @@ class OrganizationDatabase:
|
|
|
414
689
|
record.source_id,
|
|
415
690
|
record.region,
|
|
416
691
|
record.entity_type.value,
|
|
692
|
+
record.from_date or "",
|
|
693
|
+
record.to_date or "",
|
|
417
694
|
record_json,
|
|
418
695
|
))
|
|
419
696
|
|
|
420
697
|
row_id = cursor.lastrowid
|
|
421
698
|
assert row_id is not None
|
|
422
699
|
|
|
423
|
-
# Insert embedding
|
|
700
|
+
# Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
|
|
424
701
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
702
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
|
|
425
703
|
conn.execute("""
|
|
426
|
-
INSERT
|
|
704
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
427
705
|
VALUES (?, ?)
|
|
428
706
|
""", (row_id, embedding_blob))
|
|
429
707
|
|
|
@@ -443,13 +721,15 @@ class OrganizationDatabase:
|
|
|
443
721
|
source_filter: Optional[str] = None,
|
|
444
722
|
query_text: Optional[str] = None,
|
|
445
723
|
max_text_candidates: int = 5000,
|
|
724
|
+
rerank_min_candidates: int = 500,
|
|
446
725
|
) -> list[tuple[CompanyRecord, float]]:
|
|
447
726
|
"""
|
|
448
727
|
Search for similar organizations using hybrid text + vector search.
|
|
449
728
|
|
|
450
|
-
|
|
729
|
+
Three-stage approach:
|
|
451
730
|
1. If query_text provided, use SQL LIKE to find candidates containing search terms
|
|
452
731
|
2. Use sqlite-vec for vector similarity ranking on filtered candidates
|
|
732
|
+
3. Apply prominence-based re-ranking to boost major companies (SEC filers, tickers)
|
|
453
733
|
|
|
454
734
|
Args:
|
|
455
735
|
query_embedding: Query embedding vector
|
|
@@ -457,9 +737,10 @@ class OrganizationDatabase:
|
|
|
457
737
|
source_filter: Optional filter by source (gleif, sec_edgar, etc.)
|
|
458
738
|
query_text: Optional query text for text-based pre-filtering
|
|
459
739
|
max_text_candidates: Max candidates to keep after text filtering
|
|
740
|
+
rerank_min_candidates: Minimum candidates to fetch for re-ranking (default 500)
|
|
460
741
|
|
|
461
742
|
Returns:
|
|
462
|
-
List of (CompanyRecord,
|
|
743
|
+
List of (CompanyRecord, adjusted_score) tuples sorted by prominence-adjusted score
|
|
463
744
|
"""
|
|
464
745
|
start = time.time()
|
|
465
746
|
self._connect()
|
|
@@ -473,6 +754,7 @@ class OrganizationDatabase:
|
|
|
473
754
|
|
|
474
755
|
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
475
756
|
candidate_ids: Optional[set[int]] = None
|
|
757
|
+
query_normalized_text = ""
|
|
476
758
|
if query_text:
|
|
477
759
|
query_normalized_text = _normalize_name(query_text)
|
|
478
760
|
if query_normalized_text:
|
|
@@ -483,24 +765,168 @@ class OrganizationDatabase:
|
|
|
483
765
|
)
|
|
484
766
|
logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
|
|
485
767
|
|
|
486
|
-
# Stage 2: Vector search
|
|
768
|
+
# Stage 2: Vector search - fetch more candidates for re-ranking
|
|
487
769
|
if candidate_ids is not None and len(candidate_ids) == 0:
|
|
488
770
|
# No text matches, return empty
|
|
489
771
|
return []
|
|
490
772
|
|
|
773
|
+
# Fetch enough candidates for prominence re-ranking to be effective
|
|
774
|
+
# Use at least rerank_min_candidates, or all text-filtered candidates if fewer
|
|
775
|
+
if candidate_ids is not None:
|
|
776
|
+
fetch_k = min(len(candidate_ids), max(rerank_min_candidates, top_k * 5))
|
|
777
|
+
else:
|
|
778
|
+
fetch_k = max(rerank_min_candidates, top_k * 5)
|
|
779
|
+
|
|
491
780
|
if candidate_ids is not None:
|
|
492
781
|
# Search within text-filtered candidates
|
|
493
782
|
results = self._vector_search_filtered(
|
|
494
|
-
query_blob, candidate_ids,
|
|
783
|
+
query_blob, candidate_ids, fetch_k, source_filter
|
|
495
784
|
)
|
|
496
785
|
else:
|
|
497
786
|
# Full vector search
|
|
498
|
-
results = self._vector_search_full(query_blob,
|
|
787
|
+
results = self._vector_search_full(query_blob, fetch_k, source_filter)
|
|
788
|
+
|
|
789
|
+
# Stage 3: Prominence-based re-ranking
|
|
790
|
+
if results and query_normalized_text:
|
|
791
|
+
results = self._apply_prominence_reranking(results, query_normalized_text, top_k)
|
|
792
|
+
else:
|
|
793
|
+
# No re-ranking, just trim to top_k
|
|
794
|
+
results = results[:top_k]
|
|
499
795
|
|
|
500
796
|
elapsed = time.time() - start
|
|
501
797
|
logger.debug(f"Hybrid search took {elapsed:.3f}s (results={len(results)})")
|
|
502
798
|
return results
|
|
503
799
|
|
|
800
|
+
def _calculate_prominence_boost(
|
|
801
|
+
self,
|
|
802
|
+
record: CompanyRecord,
|
|
803
|
+
query_normalized: str,
|
|
804
|
+
canon_sources: Optional[set[str]] = None,
|
|
805
|
+
) -> float:
|
|
806
|
+
"""
|
|
807
|
+
Calculate prominence boost for re-ranking search results.
|
|
808
|
+
|
|
809
|
+
Boosts scores based on signals that indicate a major/prominent company:
|
|
810
|
+
- Has ticker symbol (publicly traded)
|
|
811
|
+
- GLEIF source (has LEI)
|
|
812
|
+
- SEC source (vetted US filers)
|
|
813
|
+
- Wikidata source (Wikipedia-notable)
|
|
814
|
+
- Exact normalized name match
|
|
815
|
+
|
|
816
|
+
When canon_sources is provided (from a canonical group), boosts are
|
|
817
|
+
applied for ALL sources in the canon group, not just this record's source.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
record: The company record to evaluate
|
|
821
|
+
query_normalized: Normalized query text for exact match check
|
|
822
|
+
canon_sources: Optional set of sources in this record's canonical group
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
Boost value to add to embedding similarity (0.0 to ~0.21)
|
|
826
|
+
"""
|
|
827
|
+
boost = 0.0
|
|
828
|
+
|
|
829
|
+
# Get all sources to consider (canon group or just this record)
|
|
830
|
+
sources_to_check = canon_sources or {record.source}
|
|
831
|
+
|
|
832
|
+
# Has ticker symbol = publicly traded major company
|
|
833
|
+
# Check if ANY record in canon group has ticker
|
|
834
|
+
if record.record.get("ticker") or (canon_sources and "sec_edgar" in canon_sources):
|
|
835
|
+
boost += 0.08
|
|
836
|
+
|
|
837
|
+
# Source-based boosts - accumulate for all sources in canon group
|
|
838
|
+
if "gleif" in sources_to_check:
|
|
839
|
+
boost += 0.05 # Has LEI = verified legal entity
|
|
840
|
+
if "sec_edgar" in sources_to_check:
|
|
841
|
+
boost += 0.03 # SEC filer
|
|
842
|
+
if "wikipedia" in sources_to_check:
|
|
843
|
+
boost += 0.02 # Wikipedia notable
|
|
844
|
+
|
|
845
|
+
# Exact normalized name match bonus
|
|
846
|
+
record_normalized = _normalize_name(record.name)
|
|
847
|
+
if query_normalized == record_normalized:
|
|
848
|
+
boost += 0.05
|
|
849
|
+
|
|
850
|
+
return boost
|
|
851
|
+
|
|
852
|
+
def _apply_prominence_reranking(
|
|
853
|
+
self,
|
|
854
|
+
results: list[tuple[CompanyRecord, float]],
|
|
855
|
+
query_normalized: str,
|
|
856
|
+
top_k: int,
|
|
857
|
+
similarity_weight: float = 0.3,
|
|
858
|
+
) -> list[tuple[CompanyRecord, float]]:
|
|
859
|
+
"""
|
|
860
|
+
Apply prominence-based re-ranking to search results with canon group awareness.
|
|
861
|
+
|
|
862
|
+
When records have been canonicalized, boosts are applied based on ALL sources
|
|
863
|
+
in the canonical group, not just the matched record's source.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
results: List of (record, similarity) from vector search
|
|
867
|
+
query_normalized: Normalized query text
|
|
868
|
+
top_k: Number of results to return after re-ranking
|
|
869
|
+
similarity_weight: Weight for similarity score (0-1), lower = prominence matters more
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
Re-ranked list of (record, adjusted_score) tuples
|
|
873
|
+
"""
|
|
874
|
+
conn = self._conn
|
|
875
|
+
assert conn is not None
|
|
876
|
+
|
|
877
|
+
# Build canon_id -> sources mapping for all results that have canon_id
|
|
878
|
+
canon_sources_map: dict[int, set[str]] = {}
|
|
879
|
+
canon_ids = [
|
|
880
|
+
r.record.get("canon_id")
|
|
881
|
+
for r, _ in results
|
|
882
|
+
if r.record.get("canon_id") is not None
|
|
883
|
+
]
|
|
884
|
+
|
|
885
|
+
if canon_ids:
|
|
886
|
+
# Fetch all sources for each canon_id in one query
|
|
887
|
+
unique_canon_ids = list(set(canon_ids))
|
|
888
|
+
placeholders = ",".join("?" * len(unique_canon_ids))
|
|
889
|
+
rows = conn.execute(f"""
|
|
890
|
+
SELECT canon_id, source
|
|
891
|
+
FROM organizations
|
|
892
|
+
WHERE canon_id IN ({placeholders})
|
|
893
|
+
""", unique_canon_ids).fetchall()
|
|
894
|
+
|
|
895
|
+
for row in rows:
|
|
896
|
+
canon_id = row["canon_id"]
|
|
897
|
+
canon_sources_map.setdefault(canon_id, set()).add(row["source"])
|
|
898
|
+
|
|
899
|
+
# Calculate boosted scores with canon group awareness
|
|
900
|
+
# Formula: adjusted = (similarity * weight) + boost
|
|
901
|
+
# With weight=0.3, a sim=0.65 SEC+ticker (boost=0.11) beats sim=0.75 no-boost
|
|
902
|
+
boosted_results: list[tuple[CompanyRecord, float, float, float]] = []
|
|
903
|
+
for record, similarity in results:
|
|
904
|
+
canon_id = record.record.get("canon_id")
|
|
905
|
+
# Get all sources in this record's canon group (if any)
|
|
906
|
+
canon_sources = canon_sources_map.get(canon_id) if canon_id else None
|
|
907
|
+
|
|
908
|
+
boost = self._calculate_prominence_boost(record, query_normalized, canon_sources)
|
|
909
|
+
adjusted_score = (similarity * similarity_weight) + boost
|
|
910
|
+
boosted_results.append((record, similarity, boost, adjusted_score))
|
|
911
|
+
|
|
912
|
+
# Sort by adjusted score (descending)
|
|
913
|
+
boosted_results.sort(key=lambda x: x[3], reverse=True)
|
|
914
|
+
|
|
915
|
+
# Log re-ranking details for top results
|
|
916
|
+
logger.debug(f"Prominence re-ranking for '{query_normalized}':")
|
|
917
|
+
for record, sim, boost, adj in boosted_results[:10]:
|
|
918
|
+
ticker = record.record.get("ticker", "")
|
|
919
|
+
ticker_str = f" ticker={ticker}" if ticker else ""
|
|
920
|
+
canon_id = record.record.get("canon_id")
|
|
921
|
+
canon_str = f" canon={canon_id}" if canon_id else ""
|
|
922
|
+
logger.debug(
|
|
923
|
+
f" {record.name}: sim={sim:.3f} + boost={boost:.3f} = {adj:.3f} "
|
|
924
|
+
f"[{record.source}{ticker_str}{canon_str}]"
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Return top_k with adjusted scores
|
|
928
|
+
return [(r, adj) for r, _, _, adj in boosted_results[:top_k]]
|
|
929
|
+
|
|
504
930
|
def _text_filter_candidates(
|
|
505
931
|
self,
|
|
506
932
|
query_normalized: str,
|
|
@@ -651,24 +1077,28 @@ class OrganizationDatabase:
|
|
|
651
1077
|
return results
|
|
652
1078
|
|
|
653
1079
|
def _get_record_by_id(self, org_id: int) -> Optional[CompanyRecord]:
|
|
654
|
-
"""Get an organization record by ID."""
|
|
1080
|
+
"""Get an organization record by ID, including db_id and canon_id in record dict."""
|
|
655
1081
|
conn = self._conn
|
|
656
1082
|
assert conn is not None
|
|
657
1083
|
|
|
658
1084
|
cursor = conn.execute("""
|
|
659
|
-
SELECT name, source, source_id, region, entity_type, record
|
|
1085
|
+
SELECT id, name, source, source_id, region, entity_type, record, canon_id
|
|
660
1086
|
FROM organizations WHERE id = ?
|
|
661
1087
|
""", (org_id,))
|
|
662
1088
|
|
|
663
1089
|
row = cursor.fetchone()
|
|
664
1090
|
if row:
|
|
1091
|
+
record_data = json.loads(row["record"])
|
|
1092
|
+
# Add db_id and canon_id to record dict for canon-aware search
|
|
1093
|
+
record_data["db_id"] = row["id"]
|
|
1094
|
+
record_data["canon_id"] = row["canon_id"]
|
|
665
1095
|
return CompanyRecord(
|
|
666
1096
|
name=row["name"],
|
|
667
1097
|
source=row["source"],
|
|
668
1098
|
source_id=row["source_id"],
|
|
669
1099
|
region=row["region"] or "",
|
|
670
1100
|
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
671
|
-
record=
|
|
1101
|
+
record=record_data,
|
|
672
1102
|
)
|
|
673
1103
|
return None
|
|
674
1104
|
|
|
@@ -694,6 +1124,20 @@ class OrganizationDatabase:
|
|
|
694
1124
|
)
|
|
695
1125
|
return None
|
|
696
1126
|
|
|
1127
|
+
def get_id_by_source_id(self, source: str, source_id: str) -> Optional[int]:
|
|
1128
|
+
"""Get the internal database ID for an organization by source and source_id."""
|
|
1129
|
+
conn = self._connect()
|
|
1130
|
+
|
|
1131
|
+
cursor = conn.execute("""
|
|
1132
|
+
SELECT id FROM organizations
|
|
1133
|
+
WHERE source = ? AND source_id = ?
|
|
1134
|
+
""", (source, source_id))
|
|
1135
|
+
|
|
1136
|
+
row = cursor.fetchone()
|
|
1137
|
+
if row:
|
|
1138
|
+
return row["id"]
|
|
1139
|
+
return None
|
|
1140
|
+
|
|
697
1141
|
def get_stats(self) -> DatabaseStats:
|
|
698
1142
|
"""Get database statistics."""
|
|
699
1143
|
conn = self._connect()
|
|
@@ -716,6 +1160,30 @@ class OrganizationDatabase:
|
|
|
716
1160
|
database_size_bytes=db_size,
|
|
717
1161
|
)
|
|
718
1162
|
|
|
1163
|
+
def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
|
|
1164
|
+
"""
|
|
1165
|
+
Get all source_ids from the organizations table.
|
|
1166
|
+
|
|
1167
|
+
Useful for resume operations to skip already-imported records.
|
|
1168
|
+
|
|
1169
|
+
Args:
|
|
1170
|
+
source: Optional source filter (e.g., "wikipedia" for Wikidata orgs)
|
|
1171
|
+
|
|
1172
|
+
Returns:
|
|
1173
|
+
Set of source_id strings (e.g., Q codes for Wikidata)
|
|
1174
|
+
"""
|
|
1175
|
+
conn = self._connect()
|
|
1176
|
+
|
|
1177
|
+
if source:
|
|
1178
|
+
cursor = conn.execute(
|
|
1179
|
+
"SELECT DISTINCT source_id FROM organizations WHERE source = ?",
|
|
1180
|
+
(source,)
|
|
1181
|
+
)
|
|
1182
|
+
else:
|
|
1183
|
+
cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
|
|
1184
|
+
|
|
1185
|
+
return {row[0] for row in cursor}
|
|
1186
|
+
|
|
719
1187
|
def iter_records(self, source: Optional[str] = None) -> Iterator[CompanyRecord]:
|
|
720
1188
|
"""Iterate over all records, optionally filtered by source."""
|
|
721
1189
|
conn = self._connect()
|
|
@@ -742,6 +1210,245 @@ class OrganizationDatabase:
|
|
|
742
1210
|
record=json.loads(row["record"]),
|
|
743
1211
|
)
|
|
744
1212
|
|
|
1213
|
+
def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
|
|
1214
|
+
"""
|
|
1215
|
+
Canonicalize all organizations by linking equivalent records.
|
|
1216
|
+
|
|
1217
|
+
Records are considered equivalent if they match by:
|
|
1218
|
+
1. Same LEI (GLEIF source_id or Wikidata P1278) - globally unique, no region check
|
|
1219
|
+
2. Same ticker symbol - globally unique, no region check
|
|
1220
|
+
3. Same CIK - globally unique, no region check
|
|
1221
|
+
4. Same normalized name AND same normalized region
|
|
1222
|
+
5. Name match with suffix expansion AND same region
|
|
1223
|
+
|
|
1224
|
+
Region normalization uses pycountry to handle:
|
|
1225
|
+
- Country codes/names (GB, United Kingdom, Great Britain -> GB)
|
|
1226
|
+
- US state codes/names (CA, California -> US)
|
|
1227
|
+
- Common aliases (UK -> GB, USA -> US)
|
|
1228
|
+
|
|
1229
|
+
For each group of equivalent records, the highest-priority source
|
|
1230
|
+
(gleif > sec_edgar > companies_house > wikipedia) becomes canonical.
|
|
1231
|
+
|
|
1232
|
+
Args:
|
|
1233
|
+
batch_size: Commit batch size for updates
|
|
1234
|
+
|
|
1235
|
+
Returns:
|
|
1236
|
+
Dict with stats: total_records, groups_found, records_updated
|
|
1237
|
+
"""
|
|
1238
|
+
conn = self._connect()
|
|
1239
|
+
logger.info("Starting canonicalization...")
|
|
1240
|
+
|
|
1241
|
+
# Phase 1: Load all organization data and build indexes
|
|
1242
|
+
logger.info("Phase 1: Building indexes...")
|
|
1243
|
+
|
|
1244
|
+
lei_index: dict[str, list[int]] = {}
|
|
1245
|
+
ticker_index: dict[str, list[int]] = {}
|
|
1246
|
+
cik_index: dict[str, list[int]] = {}
|
|
1247
|
+
# Name indexes now keyed by (normalized_name, normalized_region)
|
|
1248
|
+
# Region-less matching only applies for identifier-based matching
|
|
1249
|
+
name_region_index: dict[tuple[str, str], list[int]] = {}
|
|
1250
|
+
expanded_name_region_index: dict[tuple[str, str], list[int]] = {}
|
|
1251
|
+
|
|
1252
|
+
sources: dict[int, str] = {} # org_id -> source
|
|
1253
|
+
all_org_ids: list[int] = []
|
|
1254
|
+
|
|
1255
|
+
cursor = conn.execute("""
|
|
1256
|
+
SELECT id, source, source_id, name, region, record
|
|
1257
|
+
FROM organizations
|
|
1258
|
+
""")
|
|
1259
|
+
|
|
1260
|
+
count = 0
|
|
1261
|
+
for row in cursor:
|
|
1262
|
+
org_id = row["id"]
|
|
1263
|
+
source = row["source"]
|
|
1264
|
+
name = row["name"]
|
|
1265
|
+
region = row["region"] or ""
|
|
1266
|
+
record = json.loads(row["record"])
|
|
1267
|
+
|
|
1268
|
+
all_org_ids.append(org_id)
|
|
1269
|
+
sources[org_id] = source
|
|
1270
|
+
|
|
1271
|
+
# Index by LEI (GLEIF source_id or Wikidata's P1278)
|
|
1272
|
+
# LEI is globally unique - no region check needed
|
|
1273
|
+
if source == "gleif":
|
|
1274
|
+
lei = row["source_id"]
|
|
1275
|
+
else:
|
|
1276
|
+
lei = record.get("lei")
|
|
1277
|
+
if lei:
|
|
1278
|
+
lei_index.setdefault(lei.upper(), []).append(org_id)
|
|
1279
|
+
|
|
1280
|
+
# Index by ticker - globally unique, no region check
|
|
1281
|
+
ticker = record.get("ticker")
|
|
1282
|
+
if ticker:
|
|
1283
|
+
ticker_index.setdefault(ticker.upper(), []).append(org_id)
|
|
1284
|
+
|
|
1285
|
+
# Index by CIK - globally unique, no region check
|
|
1286
|
+
if source == "sec_edgar":
|
|
1287
|
+
cik = row["source_id"]
|
|
1288
|
+
else:
|
|
1289
|
+
cik = record.get("cik")
|
|
1290
|
+
if cik:
|
|
1291
|
+
cik_index.setdefault(str(cik), []).append(org_id)
|
|
1292
|
+
|
|
1293
|
+
# Index by (normalized_name, normalized_region)
|
|
1294
|
+
# Same name in different regions = different legal entities
|
|
1295
|
+
norm_name = _normalize_for_canon(name)
|
|
1296
|
+
norm_region = _normalize_region(region)
|
|
1297
|
+
if norm_name:
|
|
1298
|
+
key = (norm_name, norm_region)
|
|
1299
|
+
name_region_index.setdefault(key, []).append(org_id)
|
|
1300
|
+
|
|
1301
|
+
# Index by (expanded_name, normalized_region)
|
|
1302
|
+
expanded_name = _expand_suffix(name)
|
|
1303
|
+
if expanded_name and expanded_name != norm_name:
|
|
1304
|
+
key = (expanded_name, norm_region)
|
|
1305
|
+
expanded_name_region_index.setdefault(key, []).append(org_id)
|
|
1306
|
+
|
|
1307
|
+
count += 1
|
|
1308
|
+
if count % 100000 == 0:
|
|
1309
|
+
logger.info(f" Indexed {count} organizations...")
|
|
1310
|
+
|
|
1311
|
+
logger.info(f" Indexed {count} organizations total")
|
|
1312
|
+
logger.info(f" LEI index: {len(lei_index)} unique LEIs")
|
|
1313
|
+
logger.info(f" Ticker index: {len(ticker_index)} unique tickers")
|
|
1314
|
+
logger.info(f" CIK index: {len(cik_index)} unique CIKs")
|
|
1315
|
+
logger.info(f" Name+region index: {len(name_region_index)} unique (name, region) pairs")
|
|
1316
|
+
logger.info(f" Expanded name+region index: {len(expanded_name_region_index)} unique pairs")
|
|
1317
|
+
|
|
1318
|
+
# Phase 2: Build equivalence groups using Union-Find
|
|
1319
|
+
logger.info("Phase 2: Building equivalence groups...")
|
|
1320
|
+
|
|
1321
|
+
uf = UnionFind(all_org_ids)
|
|
1322
|
+
|
|
1323
|
+
# Merge by LEI (globally unique identifier)
|
|
1324
|
+
for _lei, ids in lei_index.items():
|
|
1325
|
+
for i in range(1, len(ids)):
|
|
1326
|
+
uf.union(ids[0], ids[i])
|
|
1327
|
+
|
|
1328
|
+
# Merge by ticker (globally unique identifier)
|
|
1329
|
+
for _ticker, ids in ticker_index.items():
|
|
1330
|
+
for i in range(1, len(ids)):
|
|
1331
|
+
uf.union(ids[0], ids[i])
|
|
1332
|
+
|
|
1333
|
+
# Merge by CIK (globally unique identifier)
|
|
1334
|
+
for _cik, ids in cik_index.items():
|
|
1335
|
+
for i in range(1, len(ids)):
|
|
1336
|
+
uf.union(ids[0], ids[i])
|
|
1337
|
+
|
|
1338
|
+
# Merge by (normalized_name, normalized_region)
|
|
1339
|
+
for _name_region, ids in name_region_index.items():
|
|
1340
|
+
for i in range(1, len(ids)):
|
|
1341
|
+
uf.union(ids[0], ids[i])
|
|
1342
|
+
|
|
1343
|
+
# Merge by (expanded_name, normalized_region)
|
|
1344
|
+
# This connects "Amazon Ltd" with "Amazon Limited" in same region
|
|
1345
|
+
for key, expanded_ids in expanded_name_region_index.items():
|
|
1346
|
+
# Find org_ids with the expanded form as their normalized name in same region
|
|
1347
|
+
if key in name_region_index:
|
|
1348
|
+
# Link first expanded_id to first name_id
|
|
1349
|
+
uf.union(expanded_ids[0], name_region_index[key][0])
|
|
1350
|
+
|
|
1351
|
+
groups = uf.groups()
|
|
1352
|
+
logger.info(f" Found {len(groups)} equivalence groups")
|
|
1353
|
+
|
|
1354
|
+
# Count groups with multiple records
|
|
1355
|
+
multi_record_groups = sum(1 for ids in groups.values() if len(ids) > 1)
|
|
1356
|
+
logger.info(f" Groups with multiple records: {multi_record_groups}")
|
|
1357
|
+
|
|
1358
|
+
# Phase 3: Select canonical record for each group and update database
|
|
1359
|
+
logger.info("Phase 3: Updating database...")
|
|
1360
|
+
|
|
1361
|
+
updated_count = 0
|
|
1362
|
+
batch_updates: list[tuple[int, int, int]] = [] # (org_id, canon_id, canon_size)
|
|
1363
|
+
|
|
1364
|
+
for _root, group_ids in groups.items():
|
|
1365
|
+
if len(group_ids) == 1:
|
|
1366
|
+
# Single record - canonical to itself
|
|
1367
|
+
batch_updates.append((group_ids[0], group_ids[0], 1))
|
|
1368
|
+
else:
|
|
1369
|
+
# Multiple records - find highest priority source
|
|
1370
|
+
best_id = min(
|
|
1371
|
+
group_ids,
|
|
1372
|
+
key=lambda oid: (SOURCE_PRIORITY.get(sources[oid], 99), oid)
|
|
1373
|
+
)
|
|
1374
|
+
group_size = len(group_ids)
|
|
1375
|
+
|
|
1376
|
+
# All records in group point to the best one
|
|
1377
|
+
for oid in group_ids:
|
|
1378
|
+
# canon_size is only set on the canonical record
|
|
1379
|
+
size = group_size if oid == best_id else 1
|
|
1380
|
+
batch_updates.append((oid, best_id, size))
|
|
1381
|
+
|
|
1382
|
+
# Commit batch
|
|
1383
|
+
if len(batch_updates) >= batch_size:
|
|
1384
|
+
self._apply_canon_updates(batch_updates)
|
|
1385
|
+
updated_count += len(batch_updates)
|
|
1386
|
+
logger.info(f" Updated {updated_count} records...")
|
|
1387
|
+
batch_updates = []
|
|
1388
|
+
|
|
1389
|
+
# Final batch
|
|
1390
|
+
if batch_updates:
|
|
1391
|
+
self._apply_canon_updates(batch_updates)
|
|
1392
|
+
updated_count += len(batch_updates)
|
|
1393
|
+
|
|
1394
|
+
conn.commit()
|
|
1395
|
+
logger.info(f"Canonicalization complete: {updated_count} records updated, {multi_record_groups} multi-record groups")
|
|
1396
|
+
|
|
1397
|
+
return {
|
|
1398
|
+
"total_records": count,
|
|
1399
|
+
"groups_found": len(groups),
|
|
1400
|
+
"multi_record_groups": multi_record_groups,
|
|
1401
|
+
"records_updated": updated_count,
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
def _apply_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
|
|
1405
|
+
"""Apply batch of canon updates: (org_id, canon_id, canon_size)."""
|
|
1406
|
+
conn = self._conn
|
|
1407
|
+
assert conn is not None
|
|
1408
|
+
|
|
1409
|
+
for org_id, canon_id, canon_size in updates:
|
|
1410
|
+
conn.execute(
|
|
1411
|
+
"UPDATE organizations SET canon_id = ?, canon_size = ? WHERE id = ?",
|
|
1412
|
+
(canon_id, canon_size, org_id)
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
conn.commit()
|
|
1416
|
+
|
|
1417
|
+
def get_canon_stats(self) -> dict[str, int]:
|
|
1418
|
+
"""Get statistics about canonicalization status."""
|
|
1419
|
+
conn = self._connect()
|
|
1420
|
+
|
|
1421
|
+
# Total records
|
|
1422
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations")
|
|
1423
|
+
total = cursor.fetchone()[0]
|
|
1424
|
+
|
|
1425
|
+
# Records with canon_id set
|
|
1426
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_id IS NOT NULL")
|
|
1427
|
+
canonicalized = cursor.fetchone()[0]
|
|
1428
|
+
|
|
1429
|
+
# Number of canonical groups (unique canon_ids)
|
|
1430
|
+
cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM organizations WHERE canon_id IS NOT NULL")
|
|
1431
|
+
groups = cursor.fetchone()[0]
|
|
1432
|
+
|
|
1433
|
+
# Multi-record groups (canon_size > 1)
|
|
1434
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_size > 1")
|
|
1435
|
+
multi_record_groups = cursor.fetchone()[0]
|
|
1436
|
+
|
|
1437
|
+
# Records in multi-record groups
|
|
1438
|
+
cursor = conn.execute("""
|
|
1439
|
+
SELECT COUNT(*) FROM organizations o1
|
|
1440
|
+
WHERE EXISTS (SELECT 1 FROM organizations o2 WHERE o2.id = o1.canon_id AND o2.canon_size > 1)
|
|
1441
|
+
""")
|
|
1442
|
+
records_in_multi = cursor.fetchone()[0]
|
|
1443
|
+
|
|
1444
|
+
return {
|
|
1445
|
+
"total_records": total,
|
|
1446
|
+
"canonicalized_records": canonicalized,
|
|
1447
|
+
"canonical_groups": groups,
|
|
1448
|
+
"multi_record_groups": multi_record_groups,
|
|
1449
|
+
"records_in_multi_groups": records_in_multi,
|
|
1450
|
+
}
|
|
1451
|
+
|
|
745
1452
|
def migrate_name_normalized(self, batch_size: int = 50000) -> int:
|
|
746
1453
|
"""
|
|
747
1454
|
Populate the name_normalized column for all records.
|
|
@@ -867,8 +1574,9 @@ class OrganizationDatabase:
|
|
|
867
1574
|
assert conn is not None
|
|
868
1575
|
|
|
869
1576
|
for org_id, embedding_blob in batch:
|
|
1577
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
|
|
870
1578
|
conn.execute("""
|
|
871
|
-
INSERT
|
|
1579
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
872
1580
|
VALUES (?, ?)
|
|
873
1581
|
""", (org_id, embedding_blob))
|
|
874
1582
|
|
|
@@ -1107,8 +1815,9 @@ class OrganizationDatabase:
|
|
|
1107
1815
|
|
|
1108
1816
|
for org_id, embedding in zip(org_ids, embeddings):
|
|
1109
1817
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
1818
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
|
|
1110
1819
|
conn.execute("""
|
|
1111
|
-
INSERT
|
|
1820
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
1112
1821
|
VALUES (?, ?)
|
|
1113
1822
|
""", (org_id, embedding_blob))
|
|
1114
1823
|
count += 1
|
|
@@ -1116,6 +1825,68 @@ class OrganizationDatabase:
|
|
|
1116
1825
|
conn.commit()
|
|
1117
1826
|
return count
|
|
1118
1827
|
|
|
1828
|
+
def resolve_qid_labels(
|
|
1829
|
+
self,
|
|
1830
|
+
label_map: dict[str, str],
|
|
1831
|
+
batch_size: int = 1000,
|
|
1832
|
+
) -> int:
|
|
1833
|
+
"""
|
|
1834
|
+
Update organization records that have QIDs instead of labels in region field.
|
|
1835
|
+
|
|
1836
|
+
Args:
|
|
1837
|
+
label_map: Mapping of QID -> label for resolution
|
|
1838
|
+
batch_size: Commit batch size
|
|
1839
|
+
|
|
1840
|
+
Returns:
|
|
1841
|
+
Number of records updated
|
|
1842
|
+
"""
|
|
1843
|
+
conn = self._connect()
|
|
1844
|
+
|
|
1845
|
+
# Find records with QIDs in region field (starts with 'Q' followed by digits)
|
|
1846
|
+
region_updates = 0
|
|
1847
|
+
cursor = conn.execute("""
|
|
1848
|
+
SELECT id, region FROM organizations
|
|
1849
|
+
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
1850
|
+
""")
|
|
1851
|
+
rows = cursor.fetchall()
|
|
1852
|
+
|
|
1853
|
+
for row in rows:
|
|
1854
|
+
org_id = row["id"]
|
|
1855
|
+
qid = row["region"]
|
|
1856
|
+
if qid in label_map:
|
|
1857
|
+
conn.execute(
|
|
1858
|
+
"UPDATE organizations SET region = ? WHERE id = ?",
|
|
1859
|
+
(label_map[qid], org_id)
|
|
1860
|
+
)
|
|
1861
|
+
region_updates += 1
|
|
1862
|
+
|
|
1863
|
+
if region_updates % batch_size == 0:
|
|
1864
|
+
conn.commit()
|
|
1865
|
+
logger.info(f"Updated {region_updates} organization region labels...")
|
|
1866
|
+
|
|
1867
|
+
conn.commit()
|
|
1868
|
+
logger.info(f"Resolved QID labels: {region_updates} organization regions")
|
|
1869
|
+
return region_updates
|
|
1870
|
+
|
|
1871
|
+
def get_unresolved_qids(self) -> set[str]:
|
|
1872
|
+
"""
|
|
1873
|
+
Get all QIDs that still need resolution in the organizations table.
|
|
1874
|
+
|
|
1875
|
+
Returns:
|
|
1876
|
+
Set of QIDs (starting with 'Q') found in region field
|
|
1877
|
+
"""
|
|
1878
|
+
conn = self._connect()
|
|
1879
|
+
qids: set[str] = set()
|
|
1880
|
+
|
|
1881
|
+
cursor = conn.execute("""
|
|
1882
|
+
SELECT DISTINCT region FROM organizations
|
|
1883
|
+
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
1884
|
+
""")
|
|
1885
|
+
for row in cursor:
|
|
1886
|
+
qids.add(row["region"])
|
|
1887
|
+
|
|
1888
|
+
return qids
|
|
1889
|
+
|
|
1119
1890
|
|
|
1120
1891
|
def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
|
|
1121
1892
|
"""
|
|
@@ -1167,20 +1938,13 @@ class PersonDatabase:
|
|
|
1167
1938
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1168
1939
|
|
|
1169
1940
|
def _connect(self) -> sqlite3.Connection:
|
|
1170
|
-
"""Get or create database connection
|
|
1941
|
+
"""Get or create database connection using shared connection pool."""
|
|
1171
1942
|
if self._conn is not None:
|
|
1172
1943
|
return self._conn
|
|
1173
1944
|
|
|
1174
|
-
self.
|
|
1175
|
-
self._conn = sqlite3.connect(str(self._db_path))
|
|
1176
|
-
self._conn.row_factory = sqlite3.Row
|
|
1177
|
-
|
|
1178
|
-
# Load sqlite-vec extension
|
|
1179
|
-
self._conn.enable_load_extension(True)
|
|
1180
|
-
sqlite_vec.load(self._conn)
|
|
1181
|
-
self._conn.enable_load_extension(False)
|
|
1945
|
+
self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
|
|
1182
1946
|
|
|
1183
|
-
# Create tables
|
|
1947
|
+
# Create tables (idempotent)
|
|
1184
1948
|
self._create_tables()
|
|
1185
1949
|
|
|
1186
1950
|
return self._conn
|
|
@@ -1190,7 +1954,12 @@ class PersonDatabase:
|
|
|
1190
1954
|
conn = self._conn
|
|
1191
1955
|
assert conn is not None
|
|
1192
1956
|
|
|
1957
|
+
# Check if we need to migrate from old schema (unique on source+source_id only)
|
|
1958
|
+
self._migrate_people_schema_if_needed(conn)
|
|
1959
|
+
|
|
1193
1960
|
# Main people records table
|
|
1961
|
+
# Unique constraint on source+source_id+role+org allows multiple records
|
|
1962
|
+
# for the same person with different role/org combinations
|
|
1194
1963
|
conn.execute("""
|
|
1195
1964
|
CREATE TABLE IF NOT EXISTS people (
|
|
1196
1965
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
@@ -1202,8 +1971,14 @@ class PersonDatabase:
|
|
|
1202
1971
|
person_type TEXT NOT NULL DEFAULT 'unknown',
|
|
1203
1972
|
known_for_role TEXT NOT NULL DEFAULT '',
|
|
1204
1973
|
known_for_org TEXT NOT NULL DEFAULT '',
|
|
1974
|
+
known_for_org_id INTEGER DEFAULT NULL,
|
|
1975
|
+
from_date TEXT NOT NULL DEFAULT '',
|
|
1976
|
+
to_date TEXT NOT NULL DEFAULT '',
|
|
1977
|
+
birth_date TEXT NOT NULL DEFAULT '',
|
|
1978
|
+
death_date TEXT NOT NULL DEFAULT '',
|
|
1205
1979
|
record TEXT NOT NULL,
|
|
1206
|
-
UNIQUE(source, source_id)
|
|
1980
|
+
UNIQUE(source, source_id, known_for_role, known_for_org),
|
|
1981
|
+
FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
|
|
1207
1982
|
)
|
|
1208
1983
|
""")
|
|
1209
1984
|
|
|
@@ -1211,9 +1986,71 @@ class PersonDatabase:
|
|
|
1211
1986
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name ON people(name)")
|
|
1212
1987
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name_normalized ON people(name_normalized)")
|
|
1213
1988
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source ON people(source)")
|
|
1214
|
-
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id)")
|
|
1989
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id, known_for_role, known_for_org)")
|
|
1215
1990
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org ON people(known_for_org)")
|
|
1216
1991
|
|
|
1992
|
+
# Add from_date column if it doesn't exist (migration for existing DBs)
|
|
1993
|
+
try:
|
|
1994
|
+
conn.execute("ALTER TABLE people ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
|
|
1995
|
+
logger.info("Added from_date column to people table")
|
|
1996
|
+
except sqlite3.OperationalError:
|
|
1997
|
+
pass # Column already exists
|
|
1998
|
+
|
|
1999
|
+
# Add to_date column if it doesn't exist (migration for existing DBs)
|
|
2000
|
+
try:
|
|
2001
|
+
conn.execute("ALTER TABLE people ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
|
|
2002
|
+
logger.info("Added to_date column to people table")
|
|
2003
|
+
except sqlite3.OperationalError:
|
|
2004
|
+
pass # Column already exists
|
|
2005
|
+
|
|
2006
|
+
# Add known_for_org_id column if it doesn't exist (migration for existing DBs)
|
|
2007
|
+
# This is a foreign key to the organizations table (nullable)
|
|
2008
|
+
try:
|
|
2009
|
+
conn.execute("ALTER TABLE people ADD COLUMN known_for_org_id INTEGER DEFAULT NULL")
|
|
2010
|
+
logger.info("Added known_for_org_id column to people table")
|
|
2011
|
+
except sqlite3.OperationalError:
|
|
2012
|
+
pass # Column already exists
|
|
2013
|
+
|
|
2014
|
+
# Create index on known_for_org_id for joins (only if column exists)
|
|
2015
|
+
try:
|
|
2016
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org_id ON people(known_for_org_id)")
|
|
2017
|
+
except sqlite3.OperationalError:
|
|
2018
|
+
pass # Column doesn't exist yet (will be added on next connection)
|
|
2019
|
+
|
|
2020
|
+
# Add birth_date column if it doesn't exist (migration for existing DBs)
|
|
2021
|
+
try:
|
|
2022
|
+
conn.execute("ALTER TABLE people ADD COLUMN birth_date TEXT NOT NULL DEFAULT ''")
|
|
2023
|
+
logger.info("Added birth_date column to people table")
|
|
2024
|
+
except sqlite3.OperationalError:
|
|
2025
|
+
pass # Column already exists
|
|
2026
|
+
|
|
2027
|
+
# Add death_date column if it doesn't exist (migration for existing DBs)
|
|
2028
|
+
try:
|
|
2029
|
+
conn.execute("ALTER TABLE people ADD COLUMN death_date TEXT NOT NULL DEFAULT ''")
|
|
2030
|
+
logger.info("Added death_date column to people table")
|
|
2031
|
+
except sqlite3.OperationalError:
|
|
2032
|
+
pass # Column already exists
|
|
2033
|
+
|
|
2034
|
+
# Add canon_id column if it doesn't exist (migration for canonicalization)
|
|
2035
|
+
try:
|
|
2036
|
+
conn.execute("ALTER TABLE people ADD COLUMN canon_id INTEGER DEFAULT NULL")
|
|
2037
|
+
logger.info("Added canon_id column to people table")
|
|
2038
|
+
except sqlite3.OperationalError:
|
|
2039
|
+
pass # Column already exists
|
|
2040
|
+
|
|
2041
|
+
# Add canon_size column if it doesn't exist (migration for canonicalization)
|
|
2042
|
+
try:
|
|
2043
|
+
conn.execute("ALTER TABLE people ADD COLUMN canon_size INTEGER DEFAULT 1")
|
|
2044
|
+
logger.info("Added canon_size column to people table")
|
|
2045
|
+
except sqlite3.OperationalError:
|
|
2046
|
+
pass # Column already exists
|
|
2047
|
+
|
|
2048
|
+
# Create index on canon_id for joins
|
|
2049
|
+
try:
|
|
2050
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_canon_id ON people(canon_id)")
|
|
2051
|
+
except sqlite3.OperationalError:
|
|
2052
|
+
pass # Column doesn't exist yet
|
|
2053
|
+
|
|
1217
2054
|
# Create sqlite-vec virtual table for embeddings
|
|
1218
2055
|
conn.execute(f"""
|
|
1219
2056
|
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
|
|
@@ -1222,13 +2059,94 @@ class PersonDatabase:
|
|
|
1222
2059
|
)
|
|
1223
2060
|
""")
|
|
1224
2061
|
|
|
2062
|
+
# Create QID labels lookup table for Wikidata QID -> label mappings
|
|
2063
|
+
conn.execute("""
|
|
2064
|
+
CREATE TABLE IF NOT EXISTS qid_labels (
|
|
2065
|
+
qid TEXT PRIMARY KEY,
|
|
2066
|
+
label TEXT NOT NULL
|
|
2067
|
+
)
|
|
2068
|
+
""")
|
|
2069
|
+
|
|
1225
2070
|
conn.commit()
|
|
1226
2071
|
|
|
2072
|
+
def _migrate_people_schema_if_needed(self, conn: sqlite3.Connection) -> None:
|
|
2073
|
+
"""Migrate people table from old schema if needed."""
|
|
2074
|
+
# Check if people table exists
|
|
2075
|
+
cursor = conn.execute(
|
|
2076
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='people'"
|
|
2077
|
+
)
|
|
2078
|
+
if not cursor.fetchone():
|
|
2079
|
+
return # Table doesn't exist, no migration needed
|
|
2080
|
+
|
|
2081
|
+
# Check the unique constraint - look at index info
|
|
2082
|
+
# Old schema: UNIQUE(source, source_id)
|
|
2083
|
+
# New schema: UNIQUE(source, source_id, known_for_role, known_for_org)
|
|
2084
|
+
cursor = conn.execute("PRAGMA index_list(people)")
|
|
2085
|
+
indexes = cursor.fetchall()
|
|
2086
|
+
|
|
2087
|
+
needs_migration = False
|
|
2088
|
+
for idx in indexes:
|
|
2089
|
+
idx_name = idx[1]
|
|
2090
|
+
if "sqlite_autoindex_people" in idx_name:
|
|
2091
|
+
# Check columns in this unique index
|
|
2092
|
+
cursor = conn.execute(f"PRAGMA index_info('{idx_name}')")
|
|
2093
|
+
cols = [row[2] for row in cursor.fetchall()]
|
|
2094
|
+
# Old schema has only 2 columns in unique constraint
|
|
2095
|
+
if cols == ["source", "source_id"]:
|
|
2096
|
+
needs_migration = True
|
|
2097
|
+
logger.info("Detected old people schema, migrating to new unique constraint...")
|
|
2098
|
+
break
|
|
2099
|
+
|
|
2100
|
+
if not needs_migration:
|
|
2101
|
+
return
|
|
2102
|
+
|
|
2103
|
+
# Migrate: create new table, copy data, drop old, rename new
|
|
2104
|
+
logger.info("Migrating people table to new schema with (source, source_id, role, org) unique constraint...")
|
|
2105
|
+
|
|
2106
|
+
conn.execute("""
|
|
2107
|
+
CREATE TABLE people_new (
|
|
2108
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
2109
|
+
name TEXT NOT NULL,
|
|
2110
|
+
name_normalized TEXT NOT NULL,
|
|
2111
|
+
source TEXT NOT NULL DEFAULT 'wikidata',
|
|
2112
|
+
source_id TEXT NOT NULL,
|
|
2113
|
+
country TEXT NOT NULL DEFAULT '',
|
|
2114
|
+
person_type TEXT NOT NULL DEFAULT 'unknown',
|
|
2115
|
+
known_for_role TEXT NOT NULL DEFAULT '',
|
|
2116
|
+
known_for_org TEXT NOT NULL DEFAULT '',
|
|
2117
|
+
known_for_org_id INTEGER DEFAULT NULL,
|
|
2118
|
+
from_date TEXT NOT NULL DEFAULT '',
|
|
2119
|
+
to_date TEXT NOT NULL DEFAULT '',
|
|
2120
|
+
record TEXT NOT NULL,
|
|
2121
|
+
UNIQUE(source, source_id, known_for_role, known_for_org),
|
|
2122
|
+
FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
|
|
2123
|
+
)
|
|
2124
|
+
""")
|
|
2125
|
+
|
|
2126
|
+
# Copy data (old IDs will change, but embeddings table references them)
|
|
2127
|
+
# Note: old table may not have from_date/to_date columns, so use defaults
|
|
2128
|
+
conn.execute("""
|
|
2129
|
+
INSERT INTO people_new (name, name_normalized, source, source_id, country,
|
|
2130
|
+
person_type, known_for_role, known_for_org, record)
|
|
2131
|
+
SELECT name, name_normalized, source, source_id, country,
|
|
2132
|
+
person_type, known_for_role, known_for_org, record
|
|
2133
|
+
FROM people
|
|
2134
|
+
""")
|
|
2135
|
+
|
|
2136
|
+
# Drop old table and embeddings (IDs changed, embeddings are invalid)
|
|
2137
|
+
conn.execute("DROP TABLE IF EXISTS person_embeddings")
|
|
2138
|
+
conn.execute("DROP TABLE people")
|
|
2139
|
+
conn.execute("ALTER TABLE people_new RENAME TO people")
|
|
2140
|
+
|
|
2141
|
+
# Drop old index if it exists
|
|
2142
|
+
conn.execute("DROP INDEX IF EXISTS idx_people_source_id")
|
|
2143
|
+
|
|
2144
|
+
conn.commit()
|
|
2145
|
+
logger.info("Migration complete. Note: person embeddings were cleared and need to be regenerated.")
|
|
2146
|
+
|
|
1227
2147
|
def close(self) -> None:
|
|
1228
|
-
"""
|
|
1229
|
-
|
|
1230
|
-
self._conn.close()
|
|
1231
|
-
self._conn = None
|
|
2148
|
+
"""Clear connection reference (shared connection remains open)."""
|
|
2149
|
+
self._conn = None
|
|
1232
2150
|
|
|
1233
2151
|
def insert(self, record: PersonRecord, embedding: np.ndarray) -> int:
|
|
1234
2152
|
"""
|
|
@@ -1249,8 +2167,10 @@ class PersonDatabase:
|
|
|
1249
2167
|
|
|
1250
2168
|
cursor = conn.execute("""
|
|
1251
2169
|
INSERT OR REPLACE INTO people
|
|
1252
|
-
(name, name_normalized, source, source_id, country, person_type,
|
|
1253
|
-
|
|
2170
|
+
(name, name_normalized, source, source_id, country, person_type,
|
|
2171
|
+
known_for_role, known_for_org, known_for_org_id, from_date, to_date,
|
|
2172
|
+
birth_date, death_date, record)
|
|
2173
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
1254
2174
|
""", (
|
|
1255
2175
|
record.name,
|
|
1256
2176
|
name_normalized,
|
|
@@ -1260,16 +2180,22 @@ class PersonDatabase:
|
|
|
1260
2180
|
record.person_type.value,
|
|
1261
2181
|
record.known_for_role,
|
|
1262
2182
|
record.known_for_org,
|
|
2183
|
+
record.known_for_org_id, # Can be None
|
|
2184
|
+
record.from_date or "",
|
|
2185
|
+
record.to_date or "",
|
|
2186
|
+
record.birth_date or "",
|
|
2187
|
+
record.death_date or "",
|
|
1263
2188
|
record_json,
|
|
1264
2189
|
))
|
|
1265
2190
|
|
|
1266
2191
|
row_id = cursor.lastrowid
|
|
1267
2192
|
assert row_id is not None
|
|
1268
2193
|
|
|
1269
|
-
# Insert embedding into vec table
|
|
2194
|
+
# Insert embedding into vec table (delete first since sqlite-vec doesn't support REPLACE)
|
|
1270
2195
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
2196
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
|
|
1271
2197
|
conn.execute("""
|
|
1272
|
-
INSERT
|
|
2198
|
+
INSERT INTO person_embeddings (person_id, embedding)
|
|
1273
2199
|
VALUES (?, ?)
|
|
1274
2200
|
""", (row_id, embedding_blob))
|
|
1275
2201
|
|
|
@@ -1302,8 +2228,10 @@ class PersonDatabase:
|
|
|
1302
2228
|
|
|
1303
2229
|
cursor = conn.execute("""
|
|
1304
2230
|
INSERT OR REPLACE INTO people
|
|
1305
|
-
(name, name_normalized, source, source_id, country, person_type,
|
|
1306
|
-
|
|
2231
|
+
(name, name_normalized, source, source_id, country, person_type,
|
|
2232
|
+
known_for_role, known_for_org, known_for_org_id, from_date, to_date,
|
|
2233
|
+
birth_date, death_date, record)
|
|
2234
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
1307
2235
|
""", (
|
|
1308
2236
|
record.name,
|
|
1309
2237
|
name_normalized,
|
|
@@ -1313,16 +2241,22 @@ class PersonDatabase:
|
|
|
1313
2241
|
record.person_type.value,
|
|
1314
2242
|
record.known_for_role,
|
|
1315
2243
|
record.known_for_org,
|
|
2244
|
+
record.known_for_org_id, # Can be None
|
|
2245
|
+
record.from_date or "",
|
|
2246
|
+
record.to_date or "",
|
|
2247
|
+
record.birth_date or "",
|
|
2248
|
+
record.death_date or "",
|
|
1316
2249
|
record_json,
|
|
1317
2250
|
))
|
|
1318
2251
|
|
|
1319
2252
|
row_id = cursor.lastrowid
|
|
1320
2253
|
assert row_id is not None
|
|
1321
2254
|
|
|
1322
|
-
# Insert embedding
|
|
2255
|
+
# Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
|
|
1323
2256
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
2257
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
|
|
1324
2258
|
conn.execute("""
|
|
1325
|
-
INSERT
|
|
2259
|
+
INSERT INTO person_embeddings (person_id, embedding)
|
|
1326
2260
|
VALUES (?, ?)
|
|
1327
2261
|
""", (row_id, embedding_blob))
|
|
1328
2262
|
|
|
@@ -1335,6 +2269,88 @@ class PersonDatabase:
|
|
|
1335
2269
|
conn.commit()
|
|
1336
2270
|
return count
|
|
1337
2271
|
|
|
2272
|
+
def update_dates(self, source: str, source_id: str, from_date: Optional[str], to_date: Optional[str]) -> bool:
|
|
2273
|
+
"""
|
|
2274
|
+
Update the from_date and to_date for a person record.
|
|
2275
|
+
|
|
2276
|
+
Args:
|
|
2277
|
+
source: Data source (e.g., 'wikidata')
|
|
2278
|
+
source_id: Source identifier (e.g., QID)
|
|
2279
|
+
from_date: Start date in ISO format or None
|
|
2280
|
+
to_date: End date in ISO format or None
|
|
2281
|
+
|
|
2282
|
+
Returns:
|
|
2283
|
+
True if record was updated, False if not found
|
|
2284
|
+
"""
|
|
2285
|
+
conn = self._connect()
|
|
2286
|
+
|
|
2287
|
+
cursor = conn.execute("""
|
|
2288
|
+
UPDATE people SET from_date = ?, to_date = ?
|
|
2289
|
+
WHERE source = ? AND source_id = ?
|
|
2290
|
+
""", (from_date or "", to_date or "", source, source_id))
|
|
2291
|
+
|
|
2292
|
+
conn.commit()
|
|
2293
|
+
return cursor.rowcount > 0
|
|
2294
|
+
|
|
2295
|
+
def update_role_org(
|
|
2296
|
+
self,
|
|
2297
|
+
source: str,
|
|
2298
|
+
source_id: str,
|
|
2299
|
+
known_for_role: str,
|
|
2300
|
+
known_for_org: str,
|
|
2301
|
+
known_for_org_id: Optional[int],
|
|
2302
|
+
new_embedding: np.ndarray,
|
|
2303
|
+
from_date: Optional[str] = None,
|
|
2304
|
+
to_date: Optional[str] = None,
|
|
2305
|
+
) -> bool:
|
|
2306
|
+
"""
|
|
2307
|
+
Update the role/org/dates data for a person record and re-embed.
|
|
2308
|
+
|
|
2309
|
+
Args:
|
|
2310
|
+
source: Data source (e.g., 'wikidata')
|
|
2311
|
+
source_id: Source identifier (e.g., QID)
|
|
2312
|
+
known_for_role: Role/position title
|
|
2313
|
+
known_for_org: Organization name
|
|
2314
|
+
known_for_org_id: Organization internal ID (FK) or None
|
|
2315
|
+
new_embedding: New embedding vector based on updated data
|
|
2316
|
+
from_date: Start date in ISO format or None
|
|
2317
|
+
to_date: End date in ISO format or None
|
|
2318
|
+
|
|
2319
|
+
Returns:
|
|
2320
|
+
True if record was updated, False if not found
|
|
2321
|
+
"""
|
|
2322
|
+
conn = self._connect()
|
|
2323
|
+
|
|
2324
|
+
# First get the person's internal ID
|
|
2325
|
+
row = conn.execute(
|
|
2326
|
+
"SELECT id FROM people WHERE source = ? AND source_id = ?",
|
|
2327
|
+
(source, source_id)
|
|
2328
|
+
).fetchone()
|
|
2329
|
+
|
|
2330
|
+
if not row:
|
|
2331
|
+
return False
|
|
2332
|
+
|
|
2333
|
+
person_id = row[0]
|
|
2334
|
+
|
|
2335
|
+
# Update the person record (including dates)
|
|
2336
|
+
conn.execute("""
|
|
2337
|
+
UPDATE people SET
|
|
2338
|
+
known_for_role = ?, known_for_org = ?, known_for_org_id = ?,
|
|
2339
|
+
from_date = COALESCE(?, from_date, ''),
|
|
2340
|
+
to_date = COALESCE(?, to_date, '')
|
|
2341
|
+
WHERE id = ?
|
|
2342
|
+
""", (known_for_role, known_for_org, known_for_org_id, from_date, to_date, person_id))
|
|
2343
|
+
|
|
2344
|
+
# Update the embedding
|
|
2345
|
+
embedding_bytes = new_embedding.astype(np.float32).tobytes()
|
|
2346
|
+
conn.execute("""
|
|
2347
|
+
UPDATE people_vec SET embedding = ?
|
|
2348
|
+
WHERE rowid = ?
|
|
2349
|
+
""", (embedding_bytes, person_id))
|
|
2350
|
+
|
|
2351
|
+
conn.commit()
|
|
2352
|
+
return True
|
|
2353
|
+
|
|
1338
2354
|
def search(
|
|
1339
2355
|
self,
|
|
1340
2356
|
query_embedding: np.ndarray,
|
|
@@ -1516,7 +2532,7 @@ class PersonDatabase:
|
|
|
1516
2532
|
assert conn is not None
|
|
1517
2533
|
|
|
1518
2534
|
cursor = conn.execute("""
|
|
1519
|
-
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
2535
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
1520
2536
|
FROM people WHERE id = ?
|
|
1521
2537
|
""", (person_id,))
|
|
1522
2538
|
|
|
@@ -1530,6 +2546,9 @@ class PersonDatabase:
|
|
|
1530
2546
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1531
2547
|
known_for_role=row["known_for_role"] or "",
|
|
1532
2548
|
known_for_org=row["known_for_org"] or "",
|
|
2549
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
2550
|
+
birth_date=row["birth_date"] or "",
|
|
2551
|
+
death_date=row["death_date"] or "",
|
|
1533
2552
|
record=json.loads(row["record"]),
|
|
1534
2553
|
)
|
|
1535
2554
|
return None
|
|
@@ -1539,7 +2558,7 @@ class PersonDatabase:
|
|
|
1539
2558
|
conn = self._connect()
|
|
1540
2559
|
|
|
1541
2560
|
cursor = conn.execute("""
|
|
1542
|
-
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
2561
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
1543
2562
|
FROM people
|
|
1544
2563
|
WHERE source = ? AND source_id = ?
|
|
1545
2564
|
""", (source, source_id))
|
|
@@ -1554,6 +2573,9 @@ class PersonDatabase:
|
|
|
1554
2573
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1555
2574
|
known_for_role=row["known_for_role"] or "",
|
|
1556
2575
|
known_for_org=row["known_for_org"] or "",
|
|
2576
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
2577
|
+
birth_date=row["birth_date"] or "",
|
|
2578
|
+
death_date=row["death_date"] or "",
|
|
1557
2579
|
record=json.loads(row["record"]),
|
|
1558
2580
|
)
|
|
1559
2581
|
return None
|
|
@@ -1580,19 +2602,43 @@ class PersonDatabase:
|
|
|
1580
2602
|
"by_source": by_source,
|
|
1581
2603
|
}
|
|
1582
2604
|
|
|
2605
|
+
def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
|
|
2606
|
+
"""
|
|
2607
|
+
Get all source_ids from the people table.
|
|
2608
|
+
|
|
2609
|
+
Useful for resume operations to skip already-imported records.
|
|
2610
|
+
|
|
2611
|
+
Args:
|
|
2612
|
+
source: Optional source filter (e.g., "wikidata")
|
|
2613
|
+
|
|
2614
|
+
Returns:
|
|
2615
|
+
Set of source_id strings (e.g., Q codes for Wikidata)
|
|
2616
|
+
"""
|
|
2617
|
+
conn = self._connect()
|
|
2618
|
+
|
|
2619
|
+
if source:
|
|
2620
|
+
cursor = conn.execute(
|
|
2621
|
+
"SELECT DISTINCT source_id FROM people WHERE source = ?",
|
|
2622
|
+
(source,)
|
|
2623
|
+
)
|
|
2624
|
+
else:
|
|
2625
|
+
cursor = conn.execute("SELECT DISTINCT source_id FROM people")
|
|
2626
|
+
|
|
2627
|
+
return {row[0] for row in cursor}
|
|
2628
|
+
|
|
1583
2629
|
def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
|
|
1584
2630
|
"""Iterate over all person records, optionally filtered by source."""
|
|
1585
2631
|
conn = self._connect()
|
|
1586
2632
|
|
|
1587
2633
|
if source:
|
|
1588
2634
|
cursor = conn.execute("""
|
|
1589
|
-
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
2635
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
1590
2636
|
FROM people
|
|
1591
2637
|
WHERE source = ?
|
|
1592
2638
|
""", (source,))
|
|
1593
2639
|
else:
|
|
1594
2640
|
cursor = conn.execute("""
|
|
1595
|
-
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
2641
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
1596
2642
|
FROM people
|
|
1597
2643
|
""")
|
|
1598
2644
|
|
|
@@ -1605,5 +2651,384 @@ class PersonDatabase:
|
|
|
1605
2651
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1606
2652
|
known_for_role=row["known_for_role"] or "",
|
|
1607
2653
|
known_for_org=row["known_for_org"] or "",
|
|
2654
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
2655
|
+
birth_date=row["birth_date"] or "",
|
|
2656
|
+
death_date=row["death_date"] or "",
|
|
1608
2657
|
record=json.loads(row["record"]),
|
|
1609
2658
|
)
|
|
2659
|
+
|
|
2660
|
+
def resolve_qid_labels(
|
|
2661
|
+
self,
|
|
2662
|
+
label_map: dict[str, str],
|
|
2663
|
+
batch_size: int = 1000,
|
|
2664
|
+
) -> tuple[int, int]:
|
|
2665
|
+
"""
|
|
2666
|
+
Update records that have QIDs instead of labels.
|
|
2667
|
+
|
|
2668
|
+
This is called after dump import to resolve any QIDs that were
|
|
2669
|
+
stored because labels weren't available in the cache at import time.
|
|
2670
|
+
|
|
2671
|
+
If resolving would create a duplicate of an existing record with
|
|
2672
|
+
resolved labels, the QID version is deleted instead.
|
|
2673
|
+
|
|
2674
|
+
Args:
|
|
2675
|
+
label_map: Mapping of QID -> label for resolution
|
|
2676
|
+
batch_size: Commit batch size
|
|
2677
|
+
|
|
2678
|
+
Returns:
|
|
2679
|
+
Tuple of (updates, deletes)
|
|
2680
|
+
"""
|
|
2681
|
+
conn = self._connect()
|
|
2682
|
+
|
|
2683
|
+
# Find all records with QIDs in any field (role or org - these are in unique constraint)
|
|
2684
|
+
# Country is not part of unique constraint so can be updated directly
|
|
2685
|
+
cursor = conn.execute("""
|
|
2686
|
+
SELECT id, source, source_id, country, known_for_role, known_for_org
|
|
2687
|
+
FROM people
|
|
2688
|
+
WHERE (country LIKE 'Q%' AND country GLOB 'Q[0-9]*')
|
|
2689
|
+
OR (known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*')
|
|
2690
|
+
OR (known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*')
|
|
2691
|
+
""")
|
|
2692
|
+
rows = cursor.fetchall()
|
|
2693
|
+
|
|
2694
|
+
updates = 0
|
|
2695
|
+
deletes = 0
|
|
2696
|
+
|
|
2697
|
+
for row in rows:
|
|
2698
|
+
person_id = row["id"]
|
|
2699
|
+
source = row["source"]
|
|
2700
|
+
source_id = row["source_id"]
|
|
2701
|
+
country = row["country"]
|
|
2702
|
+
role = row["known_for_role"]
|
|
2703
|
+
org = row["known_for_org"]
|
|
2704
|
+
|
|
2705
|
+
# Resolve QIDs to labels
|
|
2706
|
+
new_country = label_map.get(country, country) if country.startswith("Q") and country[1:].isdigit() else country
|
|
2707
|
+
new_role = label_map.get(role, role) if role.startswith("Q") and role[1:].isdigit() else role
|
|
2708
|
+
new_org = label_map.get(org, org) if org.startswith("Q") and org[1:].isdigit() else org
|
|
2709
|
+
|
|
2710
|
+
# Skip if nothing changed
|
|
2711
|
+
if new_country == country and new_role == role and new_org == org:
|
|
2712
|
+
continue
|
|
2713
|
+
|
|
2714
|
+
# Check if resolved values would duplicate an existing record
|
|
2715
|
+
# (unique constraint is on source, source_id, known_for_role, known_for_org)
|
|
2716
|
+
if new_role != role or new_org != org:
|
|
2717
|
+
cursor2 = conn.execute("""
|
|
2718
|
+
SELECT id FROM people
|
|
2719
|
+
WHERE source = ? AND source_id = ? AND known_for_role = ? AND known_for_org = ?
|
|
2720
|
+
AND id != ?
|
|
2721
|
+
""", (source, source_id, new_role, new_org, person_id))
|
|
2722
|
+
existing = cursor2.fetchone()
|
|
2723
|
+
|
|
2724
|
+
if existing:
|
|
2725
|
+
# Duplicate would exist - delete the QID version
|
|
2726
|
+
conn.execute("DELETE FROM people WHERE id = ?", (person_id,))
|
|
2727
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
|
|
2728
|
+
deletes += 1
|
|
2729
|
+
logger.debug(f"Deleted duplicate QID record {person_id} (source_id={source_id})")
|
|
2730
|
+
continue
|
|
2731
|
+
|
|
2732
|
+
# No duplicate - update in place
|
|
2733
|
+
conn.execute("""
|
|
2734
|
+
UPDATE people SET country = ?, known_for_role = ?, known_for_org = ?
|
|
2735
|
+
WHERE id = ?
|
|
2736
|
+
""", (new_country, new_role, new_org, person_id))
|
|
2737
|
+
updates += 1
|
|
2738
|
+
|
|
2739
|
+
if (updates + deletes) % batch_size == 0:
|
|
2740
|
+
conn.commit()
|
|
2741
|
+
logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes...")
|
|
2742
|
+
|
|
2743
|
+
conn.commit()
|
|
2744
|
+
logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes")
|
|
2745
|
+
return updates, deletes
|
|
2746
|
+
|
|
2747
|
+
def get_unresolved_qids(self) -> set[str]:
|
|
2748
|
+
"""
|
|
2749
|
+
Get all QIDs that still need resolution in the database.
|
|
2750
|
+
|
|
2751
|
+
Returns:
|
|
2752
|
+
Set of QIDs (starting with 'Q') found in country, role, or org fields
|
|
2753
|
+
"""
|
|
2754
|
+
conn = self._connect()
|
|
2755
|
+
qids: set[str] = set()
|
|
2756
|
+
|
|
2757
|
+
# Get QIDs from country field
|
|
2758
|
+
cursor = conn.execute("""
|
|
2759
|
+
SELECT DISTINCT country FROM people
|
|
2760
|
+
WHERE country LIKE 'Q%' AND country GLOB 'Q[0-9]*'
|
|
2761
|
+
""")
|
|
2762
|
+
for row in cursor:
|
|
2763
|
+
qids.add(row["country"])
|
|
2764
|
+
|
|
2765
|
+
# Get QIDs from known_for_role field
|
|
2766
|
+
cursor = conn.execute("""
|
|
2767
|
+
SELECT DISTINCT known_for_role FROM people
|
|
2768
|
+
WHERE known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*'
|
|
2769
|
+
""")
|
|
2770
|
+
for row in cursor:
|
|
2771
|
+
qids.add(row["known_for_role"])
|
|
2772
|
+
|
|
2773
|
+
# Get QIDs from known_for_org field
|
|
2774
|
+
cursor = conn.execute("""
|
|
2775
|
+
SELECT DISTINCT known_for_org FROM people
|
|
2776
|
+
WHERE known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*'
|
|
2777
|
+
""")
|
|
2778
|
+
for row in cursor:
|
|
2779
|
+
qids.add(row["known_for_org"])
|
|
2780
|
+
|
|
2781
|
+
return qids
|
|
2782
|
+
|
|
2783
|
+
def insert_qid_labels(
|
|
2784
|
+
self,
|
|
2785
|
+
label_map: dict[str, str],
|
|
2786
|
+
batch_size: int = 1000,
|
|
2787
|
+
) -> int:
|
|
2788
|
+
"""
|
|
2789
|
+
Insert QID -> label mappings into the lookup table.
|
|
2790
|
+
|
|
2791
|
+
Args:
|
|
2792
|
+
label_map: Mapping of QID -> label
|
|
2793
|
+
batch_size: Commit batch size
|
|
2794
|
+
|
|
2795
|
+
Returns:
|
|
2796
|
+
Number of labels inserted/updated
|
|
2797
|
+
"""
|
|
2798
|
+
conn = self._connect()
|
|
2799
|
+
count = 0
|
|
2800
|
+
|
|
2801
|
+
for qid, label in label_map.items():
|
|
2802
|
+
conn.execute(
|
|
2803
|
+
"INSERT OR REPLACE INTO qid_labels (qid, label) VALUES (?, ?)",
|
|
2804
|
+
(qid, label)
|
|
2805
|
+
)
|
|
2806
|
+
count += 1
|
|
2807
|
+
|
|
2808
|
+
if count % batch_size == 0:
|
|
2809
|
+
conn.commit()
|
|
2810
|
+
logger.debug(f"Inserted {count} QID labels...")
|
|
2811
|
+
|
|
2812
|
+
conn.commit()
|
|
2813
|
+
logger.info(f"Inserted {count} QID labels into lookup table")
|
|
2814
|
+
return count
|
|
2815
|
+
|
|
2816
|
+
def get_qid_label(self, qid: str) -> Optional[str]:
|
|
2817
|
+
"""
|
|
2818
|
+
Get the label for a QID from the lookup table.
|
|
2819
|
+
|
|
2820
|
+
Args:
|
|
2821
|
+
qid: Wikidata QID (e.g., 'Q30')
|
|
2822
|
+
|
|
2823
|
+
Returns:
|
|
2824
|
+
Label string or None if not found
|
|
2825
|
+
"""
|
|
2826
|
+
conn = self._connect()
|
|
2827
|
+
cursor = conn.execute(
|
|
2828
|
+
"SELECT label FROM qid_labels WHERE qid = ?",
|
|
2829
|
+
(qid,)
|
|
2830
|
+
)
|
|
2831
|
+
row = cursor.fetchone()
|
|
2832
|
+
return row["label"] if row else None
|
|
2833
|
+
|
|
2834
|
+
def get_all_qid_labels(self) -> dict[str, str]:
|
|
2835
|
+
"""
|
|
2836
|
+
Get all QID -> label mappings from the lookup table.
|
|
2837
|
+
|
|
2838
|
+
Returns:
|
|
2839
|
+
Dict mapping QID -> label
|
|
2840
|
+
"""
|
|
2841
|
+
conn = self._connect()
|
|
2842
|
+
cursor = conn.execute("SELECT qid, label FROM qid_labels")
|
|
2843
|
+
return {row["qid"]: row["label"] for row in cursor}
|
|
2844
|
+
|
|
2845
|
+
def get_qid_labels_count(self) -> int:
|
|
2846
|
+
"""Get the number of QID labels in the lookup table."""
|
|
2847
|
+
conn = self._connect()
|
|
2848
|
+
cursor = conn.execute("SELECT COUNT(*) FROM qid_labels")
|
|
2849
|
+
return cursor.fetchone()[0]
|
|
2850
|
+
|
|
2851
|
+
def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
|
|
2852
|
+
"""
|
|
2853
|
+
Canonicalize person records by linking equivalent entries across sources.
|
|
2854
|
+
|
|
2855
|
+
Uses a multi-phase approach:
|
|
2856
|
+
1. Match by normalized name + same organization (org canonical group)
|
|
2857
|
+
2. Match by normalized name + overlapping date ranges
|
|
2858
|
+
|
|
2859
|
+
Source priority (lower = more authoritative):
|
|
2860
|
+
- wikidata: 1 (curated, has Q codes)
|
|
2861
|
+
- sec_edgar: 2 (US insider filings)
|
|
2862
|
+
- companies_house: 3 (UK officers)
|
|
2863
|
+
|
|
2864
|
+
Args:
|
|
2865
|
+
batch_size: Number of records to process before committing
|
|
2866
|
+
|
|
2867
|
+
Returns:
|
|
2868
|
+
Stats dict with counts for each matching type
|
|
2869
|
+
"""
|
|
2870
|
+
conn = self._connect()
|
|
2871
|
+
stats = {
|
|
2872
|
+
"total_records": 0,
|
|
2873
|
+
"matched_by_org": 0,
|
|
2874
|
+
"matched_by_date": 0,
|
|
2875
|
+
"canonical_groups": 0,
|
|
2876
|
+
"records_in_groups": 0,
|
|
2877
|
+
}
|
|
2878
|
+
|
|
2879
|
+
logger.info("Phase 1: Building person index...")
|
|
2880
|
+
|
|
2881
|
+
# Load all people with their normalized names and org info
|
|
2882
|
+
cursor = conn.execute("""
|
|
2883
|
+
SELECT id, name, name_normalized, source, source_id,
|
|
2884
|
+
known_for_org, known_for_org_id, from_date, to_date
|
|
2885
|
+
FROM people
|
|
2886
|
+
""")
|
|
2887
|
+
|
|
2888
|
+
people: list[dict] = []
|
|
2889
|
+
for row in cursor:
|
|
2890
|
+
people.append({
|
|
2891
|
+
"id": row["id"],
|
|
2892
|
+
"name": row["name"],
|
|
2893
|
+
"name_normalized": row["name_normalized"],
|
|
2894
|
+
"source": row["source"],
|
|
2895
|
+
"source_id": row["source_id"],
|
|
2896
|
+
"known_for_org": row["known_for_org"],
|
|
2897
|
+
"known_for_org_id": row["known_for_org_id"],
|
|
2898
|
+
"from_date": row["from_date"],
|
|
2899
|
+
"to_date": row["to_date"],
|
|
2900
|
+
})
|
|
2901
|
+
|
|
2902
|
+
stats["total_records"] = len(people)
|
|
2903
|
+
logger.info(f"Loaded {len(people)} person records")
|
|
2904
|
+
|
|
2905
|
+
if len(people) == 0:
|
|
2906
|
+
return stats
|
|
2907
|
+
|
|
2908
|
+
# Initialize Union-Find
|
|
2909
|
+
person_ids = [p["id"] for p in people]
|
|
2910
|
+
uf = UnionFind(person_ids)
|
|
2911
|
+
|
|
2912
|
+
# Build indexes for efficient matching
|
|
2913
|
+
# Index by normalized name
|
|
2914
|
+
name_to_people: dict[str, list[dict]] = {}
|
|
2915
|
+
for p in people:
|
|
2916
|
+
name_norm = p["name_normalized"]
|
|
2917
|
+
name_to_people.setdefault(name_norm, []).append(p)
|
|
2918
|
+
|
|
2919
|
+
logger.info("Phase 2: Matching by normalized name + organization...")
|
|
2920
|
+
|
|
2921
|
+
# Match people with same normalized name and same organization
|
|
2922
|
+
for name_norm, same_name in name_to_people.items():
|
|
2923
|
+
if len(same_name) < 2:
|
|
2924
|
+
continue
|
|
2925
|
+
|
|
2926
|
+
# Group by organization (using known_for_org_id if available, else known_for_org)
|
|
2927
|
+
org_groups: dict[str, list[dict]] = {}
|
|
2928
|
+
for p in same_name:
|
|
2929
|
+
org_key = str(p["known_for_org_id"]) if p["known_for_org_id"] else p["known_for_org"]
|
|
2930
|
+
if org_key: # Only group if they have an org
|
|
2931
|
+
org_groups.setdefault(org_key, []).append(p)
|
|
2932
|
+
|
|
2933
|
+
# Union people with same name + same org
|
|
2934
|
+
for org_key, org_people in org_groups.items():
|
|
2935
|
+
if len(org_people) >= 2:
|
|
2936
|
+
first_id = org_people[0]["id"]
|
|
2937
|
+
for p in org_people[1:]:
|
|
2938
|
+
uf.union(first_id, p["id"])
|
|
2939
|
+
stats["matched_by_org"] += 1
|
|
2940
|
+
|
|
2941
|
+
logger.info(f"Phase 2 complete: {stats['matched_by_org']} matches by org")
|
|
2942
|
+
|
|
2943
|
+
logger.info("Phase 3: Matching by normalized name + overlapping dates...")
|
|
2944
|
+
|
|
2945
|
+
# Match people with same normalized name and overlapping date ranges
|
|
2946
|
+
for name_norm, same_name in name_to_people.items():
|
|
2947
|
+
if len(same_name) < 2:
|
|
2948
|
+
continue
|
|
2949
|
+
|
|
2950
|
+
# Skip if already all unified
|
|
2951
|
+
roots = set(uf.find(p["id"]) for p in same_name)
|
|
2952
|
+
if len(roots) == 1:
|
|
2953
|
+
continue
|
|
2954
|
+
|
|
2955
|
+
# Check for overlapping date ranges
|
|
2956
|
+
for i, p1 in enumerate(same_name):
|
|
2957
|
+
for p2 in same_name[i+1:]:
|
|
2958
|
+
# Skip if already in same group
|
|
2959
|
+
if uf.find(p1["id"]) == uf.find(p2["id"]):
|
|
2960
|
+
continue
|
|
2961
|
+
|
|
2962
|
+
# Check date overlap (if both have dates)
|
|
2963
|
+
if p1["from_date"] and p2["from_date"]:
|
|
2964
|
+
# Simple overlap check: if either from_date is before other's to_date
|
|
2965
|
+
p1_from = p1["from_date"]
|
|
2966
|
+
p1_to = p1["to_date"] or "9999-12-31"
|
|
2967
|
+
p2_from = p2["from_date"]
|
|
2968
|
+
p2_to = p2["to_date"] or "9999-12-31"
|
|
2969
|
+
|
|
2970
|
+
# Overlap if: p1_from <= p2_to AND p2_from <= p1_to
|
|
2971
|
+
if p1_from <= p2_to and p2_from <= p1_to:
|
|
2972
|
+
uf.union(p1["id"], p2["id"])
|
|
2973
|
+
stats["matched_by_date"] += 1
|
|
2974
|
+
|
|
2975
|
+
logger.info(f"Phase 3 complete: {stats['matched_by_date']} matches by date")
|
|
2976
|
+
|
|
2977
|
+
logger.info("Phase 4: Applying canonical updates...")
|
|
2978
|
+
|
|
2979
|
+
# Get all groups and select canonical record for each
|
|
2980
|
+
groups = uf.groups()
|
|
2981
|
+
|
|
2982
|
+
# Build id -> source mapping
|
|
2983
|
+
id_to_source = {p["id"]: p["source"] for p in people}
|
|
2984
|
+
|
|
2985
|
+
batch_updates: list[tuple[int, int, int]] = [] # (person_id, canon_id, canon_size)
|
|
2986
|
+
|
|
2987
|
+
for _root, group_ids in groups.items():
|
|
2988
|
+
group_size = len(group_ids)
|
|
2989
|
+
|
|
2990
|
+
if group_size == 1:
|
|
2991
|
+
# Single record is its own canonical
|
|
2992
|
+
person_id = group_ids[0]
|
|
2993
|
+
batch_updates.append((person_id, person_id, 1))
|
|
2994
|
+
else:
|
|
2995
|
+
# Multiple records - pick highest priority source as canonical
|
|
2996
|
+
# Sort by source priority, then by id (for stability)
|
|
2997
|
+
sorted_ids = sorted(
|
|
2998
|
+
group_ids,
|
|
2999
|
+
key=lambda pid: (PERSON_SOURCE_PRIORITY.get(id_to_source[pid], 99), pid)
|
|
3000
|
+
)
|
|
3001
|
+
canon_id = sorted_ids[0]
|
|
3002
|
+
stats["canonical_groups"] += 1
|
|
3003
|
+
stats["records_in_groups"] += group_size
|
|
3004
|
+
|
|
3005
|
+
for person_id in group_ids:
|
|
3006
|
+
batch_updates.append((person_id, canon_id, group_size if person_id == canon_id else 1))
|
|
3007
|
+
|
|
3008
|
+
# Commit in batches
|
|
3009
|
+
if len(batch_updates) >= batch_size:
|
|
3010
|
+
self._apply_person_canon_updates(batch_updates)
|
|
3011
|
+
conn.commit()
|
|
3012
|
+
logger.info(f"Applied {len(batch_updates)} canon updates...")
|
|
3013
|
+
batch_updates = []
|
|
3014
|
+
|
|
3015
|
+
# Final batch
|
|
3016
|
+
if batch_updates:
|
|
3017
|
+
self._apply_person_canon_updates(batch_updates)
|
|
3018
|
+
conn.commit()
|
|
3019
|
+
|
|
3020
|
+
logger.info(f"Canonicalization complete: {stats['canonical_groups']} groups, "
|
|
3021
|
+
f"{stats['records_in_groups']} records in multi-record groups")
|
|
3022
|
+
|
|
3023
|
+
return stats
|
|
3024
|
+
|
|
3025
|
+
def _apply_person_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
|
|
3026
|
+
"""Apply batch of canon updates: (person_id, canon_id, canon_size)."""
|
|
3027
|
+
conn = self._conn
|
|
3028
|
+
assert conn is not None
|
|
3029
|
+
|
|
3030
|
+
for person_id, canon_id, canon_size in updates:
|
|
3031
|
+
conn.execute(
|
|
3032
|
+
"UPDATE people SET canon_id = ?, canon_size = ? WHERE id = ?",
|
|
3033
|
+
(canon_id, canon_size, person_id)
|
|
3034
|
+
)
|