corp-extractor 0.9.0__py3-none-any.whl → 0.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/METADATA +72 -11
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/RECORD +34 -27
- statement_extractor/cli.py +1317 -101
- statement_extractor/database/embeddings.py +45 -0
- statement_extractor/database/hub.py +86 -136
- statement_extractor/database/importers/__init__.py +10 -2
- statement_extractor/database/importers/companies_house.py +16 -2
- statement_extractor/database/importers/companies_house_officers.py +431 -0
- statement_extractor/database/importers/gleif.py +23 -0
- statement_extractor/database/importers/import_utils.py +264 -0
- statement_extractor/database/importers/sec_edgar.py +17 -0
- statement_extractor/database/importers/sec_form4.py +512 -0
- statement_extractor/database/importers/wikidata.py +151 -43
- statement_extractor/database/importers/wikidata_dump.py +2282 -0
- statement_extractor/database/importers/wikidata_people.py +867 -325
- statement_extractor/database/migrate_v2.py +852 -0
- statement_extractor/database/models.py +155 -7
- statement_extractor/database/schema_v2.py +409 -0
- statement_extractor/database/seed_data.py +359 -0
- statement_extractor/database/store.py +3449 -233
- statement_extractor/document/deduplicator.py +10 -12
- statement_extractor/extractor.py +1 -1
- statement_extractor/models/__init__.py +3 -2
- statement_extractor/models/statement.py +15 -17
- statement_extractor/models.py +1 -1
- statement_extractor/pipeline/context.py +5 -5
- statement_extractor/pipeline/orchestrator.py +12 -12
- statement_extractor/plugins/base.py +17 -17
- statement_extractor/plugins/extractors/gliner2.py +28 -28
- statement_extractor/plugins/qualifiers/embedding_company.py +7 -5
- statement_extractor/plugins/qualifiers/person.py +120 -53
- statement_extractor/plugins/splitters/t5_gemma.py +35 -39
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/WHEEL +0 -0
- {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/entry_points.txt +0 -0
|
@@ -12,17 +12,44 @@ import re
|
|
|
12
12
|
import sqlite3
|
|
13
13
|
import time
|
|
14
14
|
from pathlib import Path
|
|
15
|
-
from typing import Iterator, Optional
|
|
15
|
+
from typing import Any, Iterator, Optional
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
|
+
import pycountry
|
|
18
19
|
import sqlite_vec
|
|
19
20
|
|
|
20
|
-
from .models import
|
|
21
|
+
from .models import (
|
|
22
|
+
CompanyRecord,
|
|
23
|
+
DatabaseStats,
|
|
24
|
+
EntityType,
|
|
25
|
+
LocationRecord,
|
|
26
|
+
PersonRecord,
|
|
27
|
+
PersonType,
|
|
28
|
+
RoleRecord,
|
|
29
|
+
SimplifiedLocationType,
|
|
30
|
+
)
|
|
31
|
+
from .seed_data import (
|
|
32
|
+
LOCATION_TYPE_NAME_TO_ID,
|
|
33
|
+
LOCATION_TYPE_QID_TO_ID,
|
|
34
|
+
LOCATION_TYPE_TO_SIMPLIFIED,
|
|
35
|
+
ORG_TYPE_ID_TO_NAME,
|
|
36
|
+
ORG_TYPE_NAME_TO_ID,
|
|
37
|
+
PEOPLE_TYPE_ID_TO_NAME,
|
|
38
|
+
PEOPLE_TYPE_NAME_TO_ID,
|
|
39
|
+
SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME,
|
|
40
|
+
SOURCE_ID_TO_NAME,
|
|
41
|
+
SOURCE_NAME_TO_ID,
|
|
42
|
+
seed_all_enums,
|
|
43
|
+
seed_pycountry_locations,
|
|
44
|
+
)
|
|
21
45
|
|
|
22
46
|
logger = logging.getLogger(__name__)
|
|
23
47
|
|
|
24
48
|
# Default database location
|
|
25
|
-
DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
|
|
49
|
+
DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities-v2.db"
|
|
50
|
+
|
|
51
|
+
# Module-level shared connections by path (both databases share the same connection)
|
|
52
|
+
_shared_connections: dict[str, sqlite3.Connection] = {}
|
|
26
53
|
|
|
27
54
|
# Module-level singleton for OrganizationDatabase to prevent multiple loads
|
|
28
55
|
_database_instances: dict[str, "OrganizationDatabase"] = {}
|
|
@@ -30,6 +57,36 @@ _database_instances: dict[str, "OrganizationDatabase"] = {}
|
|
|
30
57
|
# Module-level singleton for PersonDatabase
|
|
31
58
|
_person_database_instances: dict[str, "PersonDatabase"] = {}
|
|
32
59
|
|
|
60
|
+
|
|
61
|
+
def _get_shared_connection(db_path: Path, embedding_dim: int = 768) -> sqlite3.Connection:
|
|
62
|
+
"""Get or create a shared database connection for the given path."""
|
|
63
|
+
path_key = str(db_path)
|
|
64
|
+
if path_key not in _shared_connections:
|
|
65
|
+
# Ensure directory exists
|
|
66
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
|
|
68
|
+
conn = sqlite3.connect(str(db_path))
|
|
69
|
+
conn.row_factory = sqlite3.Row
|
|
70
|
+
|
|
71
|
+
# Load sqlite-vec extension
|
|
72
|
+
conn.enable_load_extension(True)
|
|
73
|
+
sqlite_vec.load(conn)
|
|
74
|
+
conn.enable_load_extension(False)
|
|
75
|
+
|
|
76
|
+
_shared_connections[path_key] = conn
|
|
77
|
+
logger.debug(f"Created shared database connection for {path_key}")
|
|
78
|
+
|
|
79
|
+
return _shared_connections[path_key]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def close_shared_connection(db_path: Optional[Path] = None) -> None:
|
|
83
|
+
"""Close a shared database connection."""
|
|
84
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
85
|
+
if path_key in _shared_connections:
|
|
86
|
+
_shared_connections[path_key].close()
|
|
87
|
+
del _shared_connections[path_key]
|
|
88
|
+
logger.debug(f"Closed shared database connection for {path_key}")
|
|
89
|
+
|
|
33
90
|
# Comprehensive set of corporate legal suffixes (international)
|
|
34
91
|
COMPANY_SUFFIXES: set[str] = {
|
|
35
92
|
'A/S', 'AB', 'AG', 'AO', 'AG & Co', 'AG &', 'AG & CO.', 'AG & CO. KG', 'AG & CO. KGaA',
|
|
@@ -43,6 +100,222 @@ COMPANY_SUFFIXES: set[str] = {
|
|
|
43
100
|
'Group', 'Holdings', 'Holding', 'Partners', 'Trust', 'Fund', 'Bank', 'N.A.', 'The',
|
|
44
101
|
}
|
|
45
102
|
|
|
103
|
+
# Source priority for organization canonicalization (lower = higher priority)
|
|
104
|
+
SOURCE_PRIORITY: dict[str, int] = {
|
|
105
|
+
"gleif": 1, # Gold standard LEI - globally unique legal entity identifier
|
|
106
|
+
"sec_edgar": 2, # Vetted US filers with CIK + ticker
|
|
107
|
+
"companies_house": 3, # Official UK registry
|
|
108
|
+
"wikipedia": 4, # Crowdsourced, less authoritative
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
# Source priority for people canonicalization (lower = higher priority)
|
|
112
|
+
PERSON_SOURCE_PRIORITY: dict[str, int] = {
|
|
113
|
+
"wikidata": 1, # Curated, has rich biographical data and Q codes
|
|
114
|
+
"sec_edgar": 2, # Vetted US filers (Form 4 officers/directors)
|
|
115
|
+
"companies_house": 3, # UK company officers
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
# Suffix expansions for canonical name matching
|
|
119
|
+
SUFFIX_EXPANSIONS: dict[str, str] = {
|
|
120
|
+
" ltd": " limited",
|
|
121
|
+
" corp": " corporation",
|
|
122
|
+
" inc": " incorporated",
|
|
123
|
+
" co": " company",
|
|
124
|
+
" intl": " international",
|
|
125
|
+
" natl": " national",
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class UnionFind:
|
|
130
|
+
"""Simple Union-Find (Disjoint Set Union) data structure for canonicalization."""
|
|
131
|
+
|
|
132
|
+
def __init__(self, elements: list[int]):
|
|
133
|
+
"""Initialize with list of element IDs."""
|
|
134
|
+
self.parent: dict[int, int] = {e: e for e in elements}
|
|
135
|
+
self.rank: dict[int, int] = {e: 0 for e in elements}
|
|
136
|
+
|
|
137
|
+
def find(self, x: int) -> int:
|
|
138
|
+
"""Find with path compression."""
|
|
139
|
+
if self.parent[x] != x:
|
|
140
|
+
self.parent[x] = self.find(self.parent[x])
|
|
141
|
+
return self.parent[x]
|
|
142
|
+
|
|
143
|
+
def union(self, x: int, y: int) -> None:
|
|
144
|
+
"""Union by rank."""
|
|
145
|
+
px, py = self.find(x), self.find(y)
|
|
146
|
+
if px == py:
|
|
147
|
+
return
|
|
148
|
+
if self.rank[px] < self.rank[py]:
|
|
149
|
+
px, py = py, px
|
|
150
|
+
self.parent[py] = px
|
|
151
|
+
if self.rank[px] == self.rank[py]:
|
|
152
|
+
self.rank[px] += 1
|
|
153
|
+
|
|
154
|
+
def groups(self) -> dict[int, list[int]]:
|
|
155
|
+
"""Return dict of root -> list of members."""
|
|
156
|
+
result: dict[int, list[int]] = {}
|
|
157
|
+
for e in self.parent:
|
|
158
|
+
root = self.find(e)
|
|
159
|
+
result.setdefault(root, []).append(e)
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# Common region aliases not handled well by pycountry fuzzy search
|
|
164
|
+
REGION_ALIASES: dict[str, str] = {
|
|
165
|
+
"uk": "GB",
|
|
166
|
+
"u.k.": "GB",
|
|
167
|
+
"england": "GB",
|
|
168
|
+
"scotland": "GB",
|
|
169
|
+
"wales": "GB",
|
|
170
|
+
"northern ireland": "GB",
|
|
171
|
+
"usa": "US",
|
|
172
|
+
"u.s.a.": "US",
|
|
173
|
+
"u.s.": "US",
|
|
174
|
+
"united states of america": "US",
|
|
175
|
+
"america": "US",
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# Cache for region normalization lookups
|
|
179
|
+
_region_cache: dict[str, str] = {}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _normalize_region(region: str) -> str:
|
|
183
|
+
"""
|
|
184
|
+
Normalize a region string to ISO 3166-1 alpha-2 country code.
|
|
185
|
+
|
|
186
|
+
Handles:
|
|
187
|
+
- Country codes (2-letter, 3-letter)
|
|
188
|
+
- Country names (with fuzzy matching)
|
|
189
|
+
- US state codes (CA, NY) -> US
|
|
190
|
+
- US state names (California, New York) -> US
|
|
191
|
+
- Common aliases (UK, USA, England) -> proper codes
|
|
192
|
+
|
|
193
|
+
Returns empty string if region cannot be normalized.
|
|
194
|
+
"""
|
|
195
|
+
if not region:
|
|
196
|
+
return ""
|
|
197
|
+
|
|
198
|
+
# Check cache first
|
|
199
|
+
cache_key = region.lower().strip()
|
|
200
|
+
if cache_key in _region_cache:
|
|
201
|
+
return _region_cache[cache_key]
|
|
202
|
+
|
|
203
|
+
result = _normalize_region_uncached(region)
|
|
204
|
+
_region_cache[cache_key] = result
|
|
205
|
+
return result
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _normalize_region_uncached(region: str) -> str:
|
|
209
|
+
"""Uncached region normalization logic."""
|
|
210
|
+
region_clean = region.strip()
|
|
211
|
+
|
|
212
|
+
# Empty after stripping = empty result
|
|
213
|
+
if not region_clean:
|
|
214
|
+
return ""
|
|
215
|
+
|
|
216
|
+
region_lower = region_clean.lower()
|
|
217
|
+
region_upper = region_clean.upper()
|
|
218
|
+
|
|
219
|
+
# Check common aliases first
|
|
220
|
+
if region_lower in REGION_ALIASES:
|
|
221
|
+
return REGION_ALIASES[region_lower]
|
|
222
|
+
|
|
223
|
+
# For 2-letter codes, check country first, then US state
|
|
224
|
+
# This means ambiguous codes like "CA" (Canada vs California) prefer country
|
|
225
|
+
# But unambiguous codes like "NY" (not a country) will match as US state
|
|
226
|
+
if len(region_clean) == 2:
|
|
227
|
+
# Try as country alpha-2 first
|
|
228
|
+
country = pycountry.countries.get(alpha_2=region_upper)
|
|
229
|
+
if country:
|
|
230
|
+
return country.alpha_2
|
|
231
|
+
|
|
232
|
+
# If not a country, try as US state code
|
|
233
|
+
subdivision = pycountry.subdivisions.get(code=f"US-{region_upper}")
|
|
234
|
+
if subdivision:
|
|
235
|
+
return "US"
|
|
236
|
+
|
|
237
|
+
# Try alpha-3 lookup
|
|
238
|
+
if len(region_clean) == 3:
|
|
239
|
+
country = pycountry.countries.get(alpha_3=region_upper)
|
|
240
|
+
if country:
|
|
241
|
+
return country.alpha_2
|
|
242
|
+
|
|
243
|
+
# Try as US state name (e.g., "California", "New York")
|
|
244
|
+
try:
|
|
245
|
+
subdivisions = list(pycountry.subdivisions.search_fuzzy(region_clean))
|
|
246
|
+
if subdivisions:
|
|
247
|
+
# Check if it's a US state
|
|
248
|
+
if subdivisions[0].code.startswith("US-"):
|
|
249
|
+
return "US"
|
|
250
|
+
# Return the parent country code
|
|
251
|
+
return subdivisions[0].country_code
|
|
252
|
+
except LookupError:
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
# Try country fuzzy search
|
|
256
|
+
try:
|
|
257
|
+
countries = pycountry.countries.search_fuzzy(region_clean)
|
|
258
|
+
if countries:
|
|
259
|
+
return countries[0].alpha_2
|
|
260
|
+
except LookupError:
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
# Return empty if we can't normalize
|
|
264
|
+
return ""
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _regions_match(region1: str, region2: str) -> bool:
|
|
268
|
+
"""
|
|
269
|
+
Check if two regions match after normalization.
|
|
270
|
+
|
|
271
|
+
Empty regions match anything (lenient matching for incomplete data).
|
|
272
|
+
"""
|
|
273
|
+
norm1 = _normalize_region(region1)
|
|
274
|
+
norm2 = _normalize_region(region2)
|
|
275
|
+
|
|
276
|
+
# Empty regions match anything
|
|
277
|
+
if not norm1 or not norm2:
|
|
278
|
+
return True
|
|
279
|
+
|
|
280
|
+
return norm1 == norm2
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _normalize_for_canon(name: str) -> str:
|
|
284
|
+
"""Normalize name for canonical matching (simpler than search normalization)."""
|
|
285
|
+
# Lowercase
|
|
286
|
+
result = name.lower()
|
|
287
|
+
# Remove trailing dots
|
|
288
|
+
result = result.rstrip(".")
|
|
289
|
+
# Remove extra whitespace
|
|
290
|
+
result = " ".join(result.split())
|
|
291
|
+
return result
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _expand_suffix(name: str) -> str:
|
|
295
|
+
"""Expand known suffix abbreviations."""
|
|
296
|
+
result = name.lower().rstrip(".")
|
|
297
|
+
for abbrev, full in SUFFIX_EXPANSIONS.items():
|
|
298
|
+
if result.endswith(abbrev):
|
|
299
|
+
result = result[:-len(abbrev)] + full
|
|
300
|
+
break # Only expand one suffix
|
|
301
|
+
return result
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _names_match_for_canon(name1: str, name2: str) -> bool:
|
|
305
|
+
"""Check if two names match for canonicalization."""
|
|
306
|
+
n1 = _normalize_for_canon(name1)
|
|
307
|
+
n2 = _normalize_for_canon(name2)
|
|
308
|
+
|
|
309
|
+
# Exact match after normalization
|
|
310
|
+
if n1 == n2:
|
|
311
|
+
return True
|
|
312
|
+
|
|
313
|
+
# Try with suffix expansion
|
|
314
|
+
if _expand_suffix(n1) == _expand_suffix(n2):
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
return False
|
|
318
|
+
|
|
46
319
|
# Pre-compile the suffix pattern for performance
|
|
47
320
|
_SUFFIX_PATTERN = re.compile(
|
|
48
321
|
r'\s+(' + '|'.join(re.escape(suffix) for suffix in COMPANY_SUFFIXES) + r')\.?$',
|
|
@@ -250,30 +523,40 @@ class OrganizationDatabase:
|
|
|
250
523
|
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
251
524
|
self._embedding_dim = embedding_dim
|
|
252
525
|
self._conn: Optional[sqlite3.Connection] = None
|
|
526
|
+
self._is_v2: Optional[bool] = None # Detected on first connect
|
|
253
527
|
|
|
254
528
|
def _ensure_dir(self) -> None:
|
|
255
529
|
"""Ensure database directory exists."""
|
|
256
530
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
257
531
|
|
|
258
532
|
def _connect(self) -> sqlite3.Connection:
|
|
259
|
-
"""Get or create database connection
|
|
533
|
+
"""Get or create database connection using shared connection pool."""
|
|
260
534
|
if self._conn is not None:
|
|
261
535
|
return self._conn
|
|
262
536
|
|
|
263
|
-
self.
|
|
264
|
-
self._conn = sqlite3.connect(str(self._db_path))
|
|
265
|
-
self._conn.row_factory = sqlite3.Row
|
|
537
|
+
self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
|
|
266
538
|
|
|
267
|
-
#
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
539
|
+
# Detect schema version BEFORE creating tables
|
|
540
|
+
# v2 has entity_type_id (FK) instead of entity_type (TEXT)
|
|
541
|
+
if self._is_v2 is None:
|
|
542
|
+
cursor = self._conn.execute("PRAGMA table_info(organizations)")
|
|
543
|
+
columns = {row["name"] for row in cursor}
|
|
544
|
+
self._is_v2 = "entity_type_id" in columns
|
|
545
|
+
if self._is_v2:
|
|
546
|
+
logger.debug("Detected v2 schema for organizations")
|
|
271
547
|
|
|
272
|
-
# Create tables
|
|
273
|
-
|
|
548
|
+
# Create tables (idempotent) - only for v1 schema or fresh databases
|
|
549
|
+
# v2 databases already have their schema from migration
|
|
550
|
+
if not self._is_v2:
|
|
551
|
+
self._create_tables()
|
|
274
552
|
|
|
275
553
|
return self._conn
|
|
276
554
|
|
|
555
|
+
@property
|
|
556
|
+
def _org_table(self) -> str:
|
|
557
|
+
"""Return table/view name for organization queries needing text fields."""
|
|
558
|
+
return "organizations_view" if self._is_v2 else "organizations"
|
|
559
|
+
|
|
277
560
|
def _create_tables(self) -> None:
|
|
278
561
|
"""Create database tables including sqlite-vec virtual table."""
|
|
279
562
|
conn = self._conn
|
|
@@ -289,6 +572,8 @@ class OrganizationDatabase:
|
|
|
289
572
|
source_id TEXT NOT NULL,
|
|
290
573
|
region TEXT NOT NULL DEFAULT '',
|
|
291
574
|
entity_type TEXT NOT NULL DEFAULT 'unknown',
|
|
575
|
+
from_date TEXT NOT NULL DEFAULT '',
|
|
576
|
+
to_date TEXT NOT NULL DEFAULT '',
|
|
292
577
|
record TEXT NOT NULL,
|
|
293
578
|
UNIQUE(source, source_id)
|
|
294
579
|
)
|
|
@@ -308,6 +593,34 @@ class OrganizationDatabase:
|
|
|
308
593
|
except sqlite3.OperationalError:
|
|
309
594
|
pass # Column already exists
|
|
310
595
|
|
|
596
|
+
# Add from_date column if it doesn't exist (migration for existing DBs)
|
|
597
|
+
try:
|
|
598
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
|
|
599
|
+
logger.info("Added from_date column to organizations table")
|
|
600
|
+
except sqlite3.OperationalError:
|
|
601
|
+
pass # Column already exists
|
|
602
|
+
|
|
603
|
+
# Add to_date column if it doesn't exist (migration for existing DBs)
|
|
604
|
+
try:
|
|
605
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
|
|
606
|
+
logger.info("Added to_date column to organizations table")
|
|
607
|
+
except sqlite3.OperationalError:
|
|
608
|
+
pass # Column already exists
|
|
609
|
+
|
|
610
|
+
# Add canon_id column if it doesn't exist (migration for canonicalization)
|
|
611
|
+
try:
|
|
612
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN canon_id INTEGER DEFAULT NULL")
|
|
613
|
+
logger.info("Added canon_id column to organizations table")
|
|
614
|
+
except sqlite3.OperationalError:
|
|
615
|
+
pass # Column already exists
|
|
616
|
+
|
|
617
|
+
# Add canon_size column if it doesn't exist (migration for canonicalization)
|
|
618
|
+
try:
|
|
619
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN canon_size INTEGER DEFAULT 1")
|
|
620
|
+
logger.info("Added canon_size column to organizations table")
|
|
621
|
+
except sqlite3.OperationalError:
|
|
622
|
+
pass # Column already exists
|
|
623
|
+
|
|
311
624
|
# Create indexes on main table
|
|
312
625
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
|
|
313
626
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
|
|
@@ -316,8 +629,9 @@ class OrganizationDatabase:
|
|
|
316
629
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
|
|
317
630
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
|
|
318
631
|
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
|
|
632
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_canon_id ON organizations(canon_id)")
|
|
319
633
|
|
|
320
|
-
# Create sqlite-vec virtual table for embeddings
|
|
634
|
+
# Create sqlite-vec virtual table for embeddings (float32)
|
|
321
635
|
# vec0 is the recommended virtual table type
|
|
322
636
|
conn.execute(f"""
|
|
323
637
|
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings USING vec0(
|
|
@@ -326,21 +640,34 @@ class OrganizationDatabase:
|
|
|
326
640
|
)
|
|
327
641
|
""")
|
|
328
642
|
|
|
643
|
+
# Create sqlite-vec virtual table for scalar embeddings (int8)
|
|
644
|
+
# Provides 75% storage reduction with ~92% recall at top-100
|
|
645
|
+
conn.execute(f"""
|
|
646
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
|
|
647
|
+
org_id INTEGER PRIMARY KEY,
|
|
648
|
+
embedding int8[{self._embedding_dim}]
|
|
649
|
+
)
|
|
650
|
+
""")
|
|
651
|
+
|
|
329
652
|
conn.commit()
|
|
330
653
|
|
|
331
654
|
def close(self) -> None:
|
|
332
|
-
"""
|
|
333
|
-
|
|
334
|
-
self._conn.close()
|
|
335
|
-
self._conn = None
|
|
655
|
+
"""Clear connection reference (shared connection remains open)."""
|
|
656
|
+
self._conn = None
|
|
336
657
|
|
|
337
|
-
def insert(
|
|
658
|
+
def insert(
|
|
659
|
+
self,
|
|
660
|
+
record: CompanyRecord,
|
|
661
|
+
embedding: np.ndarray,
|
|
662
|
+
scalar_embedding: Optional[np.ndarray] = None,
|
|
663
|
+
) -> int:
|
|
338
664
|
"""
|
|
339
665
|
Insert an organization record with its embedding.
|
|
340
666
|
|
|
341
667
|
Args:
|
|
342
668
|
record: Organization record to insert
|
|
343
|
-
embedding: Embedding vector for the organization name
|
|
669
|
+
embedding: Embedding vector for the organization name (float32)
|
|
670
|
+
scalar_embedding: Optional int8 scalar embedding for compact storage
|
|
344
671
|
|
|
345
672
|
Returns:
|
|
346
673
|
Row ID of inserted record
|
|
@@ -353,8 +680,8 @@ class OrganizationDatabase:
|
|
|
353
680
|
|
|
354
681
|
cursor = conn.execute("""
|
|
355
682
|
INSERT OR REPLACE INTO organizations
|
|
356
|
-
(name, name_normalized, source, source_id, region, entity_type, record)
|
|
357
|
-
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
683
|
+
(name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
|
|
684
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
358
685
|
""", (
|
|
359
686
|
record.name,
|
|
360
687
|
name_normalized,
|
|
@@ -362,20 +689,32 @@ class OrganizationDatabase:
|
|
|
362
689
|
record.source_id,
|
|
363
690
|
record.region,
|
|
364
691
|
record.entity_type.value,
|
|
692
|
+
record.from_date or "",
|
|
693
|
+
record.to_date or "",
|
|
365
694
|
record_json,
|
|
366
695
|
))
|
|
367
696
|
|
|
368
697
|
row_id = cursor.lastrowid
|
|
369
698
|
assert row_id is not None
|
|
370
699
|
|
|
371
|
-
# Insert embedding into vec table
|
|
372
|
-
# sqlite-vec
|
|
700
|
+
# Insert embedding into vec table (float32)
|
|
701
|
+
# sqlite-vec virtual tables don't support INSERT OR REPLACE, so delete first
|
|
373
702
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
703
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
|
|
374
704
|
conn.execute("""
|
|
375
|
-
INSERT
|
|
705
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
376
706
|
VALUES (?, ?)
|
|
377
707
|
""", (row_id, embedding_blob))
|
|
378
708
|
|
|
709
|
+
# Insert scalar embedding if provided (int8)
|
|
710
|
+
if scalar_embedding is not None:
|
|
711
|
+
scalar_blob = scalar_embedding.astype(np.int8).tobytes()
|
|
712
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
|
|
713
|
+
conn.execute("""
|
|
714
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
715
|
+
VALUES (?, vec_int8(?))
|
|
716
|
+
""", (row_id, scalar_blob))
|
|
717
|
+
|
|
379
718
|
conn.commit()
|
|
380
719
|
return row_id
|
|
381
720
|
|
|
@@ -384,14 +723,16 @@ class OrganizationDatabase:
|
|
|
384
723
|
records: list[CompanyRecord],
|
|
385
724
|
embeddings: np.ndarray,
|
|
386
725
|
batch_size: int = 1000,
|
|
726
|
+
scalar_embeddings: Optional[np.ndarray] = None,
|
|
387
727
|
) -> int:
|
|
388
728
|
"""
|
|
389
729
|
Insert multiple organization records with embeddings.
|
|
390
730
|
|
|
391
731
|
Args:
|
|
392
732
|
records: List of organization records
|
|
393
|
-
embeddings: Matrix of embeddings (N x dim)
|
|
733
|
+
embeddings: Matrix of embeddings (N x dim) - float32
|
|
394
734
|
batch_size: Commit batch size
|
|
735
|
+
scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
|
|
395
736
|
|
|
396
737
|
Returns:
|
|
397
738
|
Number of records inserted
|
|
@@ -399,34 +740,75 @@ class OrganizationDatabase:
|
|
|
399
740
|
conn = self._connect()
|
|
400
741
|
count = 0
|
|
401
742
|
|
|
402
|
-
for record, embedding in zip(records, embeddings):
|
|
743
|
+
for i, (record, embedding) in enumerate(zip(records, embeddings)):
|
|
403
744
|
record_json = json.dumps(record.record)
|
|
404
745
|
name_normalized = _normalize_name(record.name)
|
|
405
746
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
record.
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
747
|
+
if self._is_v2:
|
|
748
|
+
# v2 schema: use FK IDs instead of TEXT columns
|
|
749
|
+
source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
|
|
750
|
+
entity_type_id = ORG_TYPE_NAME_TO_ID.get(record.entity_type.value, 17) # 17 = unknown
|
|
751
|
+
|
|
752
|
+
# Resolve region to location_id if provided
|
|
753
|
+
region_id = None
|
|
754
|
+
if record.region:
|
|
755
|
+
# Use locations database to resolve region
|
|
756
|
+
locations_db = get_locations_database(db_path=self._db_path)
|
|
757
|
+
region_id = locations_db.resolve_region_text(record.region)
|
|
758
|
+
|
|
759
|
+
cursor = conn.execute("""
|
|
760
|
+
INSERT OR REPLACE INTO organizations
|
|
761
|
+
(name, name_normalized, source_id, source_identifier, region_id, entity_type_id, from_date, to_date, record)
|
|
762
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
763
|
+
""", (
|
|
764
|
+
record.name,
|
|
765
|
+
name_normalized,
|
|
766
|
+
source_type_id,
|
|
767
|
+
record.source_id,
|
|
768
|
+
region_id,
|
|
769
|
+
entity_type_id,
|
|
770
|
+
record.from_date or "",
|
|
771
|
+
record.to_date or "",
|
|
772
|
+
record_json,
|
|
773
|
+
))
|
|
774
|
+
else:
|
|
775
|
+
# v1 schema: use TEXT columns
|
|
776
|
+
cursor = conn.execute("""
|
|
777
|
+
INSERT OR REPLACE INTO organizations
|
|
778
|
+
(name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
|
|
779
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
780
|
+
""", (
|
|
781
|
+
record.name,
|
|
782
|
+
name_normalized,
|
|
783
|
+
record.source,
|
|
784
|
+
record.source_id,
|
|
785
|
+
record.region,
|
|
786
|
+
record.entity_type.value,
|
|
787
|
+
record.from_date or "",
|
|
788
|
+
record.to_date or "",
|
|
789
|
+
record_json,
|
|
790
|
+
))
|
|
419
791
|
|
|
420
792
|
row_id = cursor.lastrowid
|
|
421
793
|
assert row_id is not None
|
|
422
794
|
|
|
423
|
-
# Insert embedding
|
|
795
|
+
# Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
|
|
424
796
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
797
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
|
|
425
798
|
conn.execute("""
|
|
426
|
-
INSERT
|
|
799
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
427
800
|
VALUES (?, ?)
|
|
428
801
|
""", (row_id, embedding_blob))
|
|
429
802
|
|
|
803
|
+
# Insert scalar embedding if provided (int8)
|
|
804
|
+
if scalar_embeddings is not None:
|
|
805
|
+
scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
|
|
806
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
|
|
807
|
+
conn.execute("""
|
|
808
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
809
|
+
VALUES (?, vec_int8(?))
|
|
810
|
+
""", (row_id, scalar_blob))
|
|
811
|
+
|
|
430
812
|
count += 1
|
|
431
813
|
|
|
432
814
|
if count % batch_size == 0:
|
|
@@ -443,13 +825,15 @@ class OrganizationDatabase:
|
|
|
443
825
|
source_filter: Optional[str] = None,
|
|
444
826
|
query_text: Optional[str] = None,
|
|
445
827
|
max_text_candidates: int = 5000,
|
|
828
|
+
rerank_min_candidates: int = 500,
|
|
446
829
|
) -> list[tuple[CompanyRecord, float]]:
|
|
447
830
|
"""
|
|
448
831
|
Search for similar organizations using hybrid text + vector search.
|
|
449
832
|
|
|
450
|
-
|
|
833
|
+
Three-stage approach:
|
|
451
834
|
1. If query_text provided, use SQL LIKE to find candidates containing search terms
|
|
452
835
|
2. Use sqlite-vec for vector similarity ranking on filtered candidates
|
|
836
|
+
3. Apply prominence-based re-ranking to boost major companies (SEC filers, tickers)
|
|
453
837
|
|
|
454
838
|
Args:
|
|
455
839
|
query_embedding: Query embedding vector
|
|
@@ -457,9 +841,10 @@ class OrganizationDatabase:
|
|
|
457
841
|
source_filter: Optional filter by source (gleif, sec_edgar, etc.)
|
|
458
842
|
query_text: Optional query text for text-based pre-filtering
|
|
459
843
|
max_text_candidates: Max candidates to keep after text filtering
|
|
844
|
+
rerank_min_candidates: Minimum candidates to fetch for re-ranking (default 500)
|
|
460
845
|
|
|
461
846
|
Returns:
|
|
462
|
-
List of (CompanyRecord,
|
|
847
|
+
List of (CompanyRecord, adjusted_score) tuples sorted by prominence-adjusted score
|
|
463
848
|
"""
|
|
464
849
|
start = time.time()
|
|
465
850
|
self._connect()
|
|
@@ -469,10 +854,17 @@ class OrganizationDatabase:
|
|
|
469
854
|
if query_norm == 0:
|
|
470
855
|
return []
|
|
471
856
|
query_normalized = query_embedding / query_norm
|
|
472
|
-
|
|
857
|
+
|
|
858
|
+
# Use int8 quantized query if scalar table is available (75% storage savings)
|
|
859
|
+
if self._has_scalar_table():
|
|
860
|
+
query_int8 = self._quantize_query(query_normalized)
|
|
861
|
+
query_blob = query_int8.tobytes()
|
|
862
|
+
else:
|
|
863
|
+
query_blob = query_normalized.astype(np.float32).tobytes()
|
|
473
864
|
|
|
474
865
|
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
475
866
|
candidate_ids: Optional[set[int]] = None
|
|
867
|
+
query_normalized_text = ""
|
|
476
868
|
if query_text:
|
|
477
869
|
query_normalized_text = _normalize_name(query_text)
|
|
478
870
|
if query_normalized_text:
|
|
@@ -483,24 +875,168 @@ class OrganizationDatabase:
|
|
|
483
875
|
)
|
|
484
876
|
logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
|
|
485
877
|
|
|
486
|
-
# Stage 2: Vector search
|
|
878
|
+
# Stage 2: Vector search - fetch more candidates for re-ranking
|
|
487
879
|
if candidate_ids is not None and len(candidate_ids) == 0:
|
|
488
880
|
# No text matches, return empty
|
|
489
881
|
return []
|
|
490
882
|
|
|
883
|
+
# Fetch enough candidates for prominence re-ranking to be effective
|
|
884
|
+
# Use at least rerank_min_candidates, or all text-filtered candidates if fewer
|
|
885
|
+
if candidate_ids is not None:
|
|
886
|
+
fetch_k = min(len(candidate_ids), max(rerank_min_candidates, top_k * 5))
|
|
887
|
+
else:
|
|
888
|
+
fetch_k = max(rerank_min_candidates, top_k * 5)
|
|
889
|
+
|
|
491
890
|
if candidate_ids is not None:
|
|
492
891
|
# Search within text-filtered candidates
|
|
493
892
|
results = self._vector_search_filtered(
|
|
494
|
-
query_blob, candidate_ids,
|
|
893
|
+
query_blob, candidate_ids, fetch_k, source_filter
|
|
495
894
|
)
|
|
496
895
|
else:
|
|
497
896
|
# Full vector search
|
|
498
|
-
results = self._vector_search_full(query_blob,
|
|
897
|
+
results = self._vector_search_full(query_blob, fetch_k, source_filter)
|
|
898
|
+
|
|
899
|
+
# Stage 3: Prominence-based re-ranking
|
|
900
|
+
if results and query_normalized_text:
|
|
901
|
+
results = self._apply_prominence_reranking(results, query_normalized_text, top_k)
|
|
902
|
+
else:
|
|
903
|
+
# No re-ranking, just trim to top_k
|
|
904
|
+
results = results[:top_k]
|
|
499
905
|
|
|
500
906
|
elapsed = time.time() - start
|
|
501
907
|
logger.debug(f"Hybrid search took {elapsed:.3f}s (results={len(results)})")
|
|
502
908
|
return results
|
|
503
909
|
|
|
910
|
+
def _calculate_prominence_boost(
|
|
911
|
+
self,
|
|
912
|
+
record: CompanyRecord,
|
|
913
|
+
query_normalized: str,
|
|
914
|
+
canon_sources: Optional[set[str]] = None,
|
|
915
|
+
) -> float:
|
|
916
|
+
"""
|
|
917
|
+
Calculate prominence boost for re-ranking search results.
|
|
918
|
+
|
|
919
|
+
Boosts scores based on signals that indicate a major/prominent company:
|
|
920
|
+
- Has ticker symbol (publicly traded)
|
|
921
|
+
- GLEIF source (has LEI)
|
|
922
|
+
- SEC source (vetted US filers)
|
|
923
|
+
- Wikidata source (Wikipedia-notable)
|
|
924
|
+
- Exact normalized name match
|
|
925
|
+
|
|
926
|
+
When canon_sources is provided (from a canonical group), boosts are
|
|
927
|
+
applied for ALL sources in the canon group, not just this record's source.
|
|
928
|
+
|
|
929
|
+
Args:
|
|
930
|
+
record: The company record to evaluate
|
|
931
|
+
query_normalized: Normalized query text for exact match check
|
|
932
|
+
canon_sources: Optional set of sources in this record's canonical group
|
|
933
|
+
|
|
934
|
+
Returns:
|
|
935
|
+
Boost value to add to embedding similarity (0.0 to ~0.21)
|
|
936
|
+
"""
|
|
937
|
+
boost = 0.0
|
|
938
|
+
|
|
939
|
+
# Get all sources to consider (canon group or just this record)
|
|
940
|
+
sources_to_check = canon_sources or {record.source}
|
|
941
|
+
|
|
942
|
+
# Has ticker symbol = publicly traded major company
|
|
943
|
+
# Check if ANY record in canon group has ticker
|
|
944
|
+
if record.record.get("ticker") or (canon_sources and "sec_edgar" in canon_sources):
|
|
945
|
+
boost += 0.08
|
|
946
|
+
|
|
947
|
+
# Source-based boosts - accumulate for all sources in canon group
|
|
948
|
+
if "gleif" in sources_to_check:
|
|
949
|
+
boost += 0.05 # Has LEI = verified legal entity
|
|
950
|
+
if "sec_edgar" in sources_to_check:
|
|
951
|
+
boost += 0.03 # SEC filer
|
|
952
|
+
if "wikipedia" in sources_to_check:
|
|
953
|
+
boost += 0.02 # Wikipedia notable
|
|
954
|
+
|
|
955
|
+
# Exact normalized name match bonus
|
|
956
|
+
record_normalized = _normalize_name(record.name)
|
|
957
|
+
if query_normalized == record_normalized:
|
|
958
|
+
boost += 0.05
|
|
959
|
+
|
|
960
|
+
return boost
|
|
961
|
+
|
|
962
|
+
def _apply_prominence_reranking(
|
|
963
|
+
self,
|
|
964
|
+
results: list[tuple[CompanyRecord, float]],
|
|
965
|
+
query_normalized: str,
|
|
966
|
+
top_k: int,
|
|
967
|
+
similarity_weight: float = 0.3,
|
|
968
|
+
) -> list[tuple[CompanyRecord, float]]:
|
|
969
|
+
"""
|
|
970
|
+
Apply prominence-based re-ranking to search results with canon group awareness.
|
|
971
|
+
|
|
972
|
+
When records have been canonicalized, boosts are applied based on ALL sources
|
|
973
|
+
in the canonical group, not just the matched record's source.
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
results: List of (record, similarity) from vector search
|
|
977
|
+
query_normalized: Normalized query text
|
|
978
|
+
top_k: Number of results to return after re-ranking
|
|
979
|
+
similarity_weight: Weight for similarity score (0-1), lower = prominence matters more
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
Re-ranked list of (record, adjusted_score) tuples
|
|
983
|
+
"""
|
|
984
|
+
conn = self._conn
|
|
985
|
+
assert conn is not None
|
|
986
|
+
|
|
987
|
+
# Build canon_id -> sources mapping for all results that have canon_id
|
|
988
|
+
canon_sources_map: dict[int, set[str]] = {}
|
|
989
|
+
canon_ids = [
|
|
990
|
+
r.record.get("canon_id")
|
|
991
|
+
for r, _ in results
|
|
992
|
+
if r.record.get("canon_id") is not None
|
|
993
|
+
]
|
|
994
|
+
|
|
995
|
+
if canon_ids:
|
|
996
|
+
# Fetch all sources for each canon_id in one query
|
|
997
|
+
unique_canon_ids = list(set(canon_ids))
|
|
998
|
+
placeholders = ",".join("?" * len(unique_canon_ids))
|
|
999
|
+
rows = conn.execute(f"""
|
|
1000
|
+
SELECT canon_id, source
|
|
1001
|
+
FROM organizations
|
|
1002
|
+
WHERE canon_id IN ({placeholders})
|
|
1003
|
+
""", unique_canon_ids).fetchall()
|
|
1004
|
+
|
|
1005
|
+
for row in rows:
|
|
1006
|
+
canon_id = row["canon_id"]
|
|
1007
|
+
canon_sources_map.setdefault(canon_id, set()).add(row["source"])
|
|
1008
|
+
|
|
1009
|
+
# Calculate boosted scores with canon group awareness
|
|
1010
|
+
# Formula: adjusted = (similarity * weight) + boost
|
|
1011
|
+
# With weight=0.3, a sim=0.65 SEC+ticker (boost=0.11) beats sim=0.75 no-boost
|
|
1012
|
+
boosted_results: list[tuple[CompanyRecord, float, float, float]] = []
|
|
1013
|
+
for record, similarity in results:
|
|
1014
|
+
canon_id = record.record.get("canon_id")
|
|
1015
|
+
# Get all sources in this record's canon group (if any)
|
|
1016
|
+
canon_sources = canon_sources_map.get(canon_id) if canon_id else None
|
|
1017
|
+
|
|
1018
|
+
boost = self._calculate_prominence_boost(record, query_normalized, canon_sources)
|
|
1019
|
+
adjusted_score = (similarity * similarity_weight) + boost
|
|
1020
|
+
boosted_results.append((record, similarity, boost, adjusted_score))
|
|
1021
|
+
|
|
1022
|
+
# Sort by adjusted score (descending)
|
|
1023
|
+
boosted_results.sort(key=lambda x: x[3], reverse=True)
|
|
1024
|
+
|
|
1025
|
+
# Log re-ranking details for top results
|
|
1026
|
+
logger.debug(f"Prominence re-ranking for '{query_normalized}':")
|
|
1027
|
+
for record, sim, boost, adj in boosted_results[:10]:
|
|
1028
|
+
ticker = record.record.get("ticker", "")
|
|
1029
|
+
ticker_str = f" ticker={ticker}" if ticker else ""
|
|
1030
|
+
canon_id = record.record.get("canon_id")
|
|
1031
|
+
canon_str = f" canon={canon_id}" if canon_id else ""
|
|
1032
|
+
logger.debug(
|
|
1033
|
+
f" {record.name}: sim={sim:.3f} + boost={boost:.3f} = {adj:.3f} "
|
|
1034
|
+
f"[{record.source}{ticker_str}{canon_str}]"
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
# Return top_k with adjusted scores
|
|
1038
|
+
return [(r, adj) for r, _, _, adj in boosted_results[:top_k]]
|
|
1039
|
+
|
|
504
1040
|
def _text_filter_candidates(
|
|
505
1041
|
self,
|
|
506
1042
|
query_normalized: str,
|
|
@@ -554,6 +1090,19 @@ class OrganizationDatabase:
|
|
|
554
1090
|
cursor = conn.execute(query, params)
|
|
555
1091
|
return set(row["id"] for row in cursor)
|
|
556
1092
|
|
|
1093
|
+
def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
|
|
1094
|
+
"""Quantize query embedding to int8 for scalar search."""
|
|
1095
|
+
return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
|
|
1096
|
+
|
|
1097
|
+
def _has_scalar_table(self) -> bool:
|
|
1098
|
+
"""Check if scalar embedding table exists."""
|
|
1099
|
+
conn = self._conn
|
|
1100
|
+
assert conn is not None
|
|
1101
|
+
cursor = conn.execute(
|
|
1102
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='organization_embeddings_scalar'"
|
|
1103
|
+
)
|
|
1104
|
+
return cursor.fetchone() is not None
|
|
1105
|
+
|
|
557
1106
|
def _vector_search_filtered(
|
|
558
1107
|
self,
|
|
559
1108
|
query_blob: bytes,
|
|
@@ -561,7 +1110,7 @@ class OrganizationDatabase:
|
|
|
561
1110
|
top_k: int,
|
|
562
1111
|
source_filter: Optional[str],
|
|
563
1112
|
) -> list[tuple[CompanyRecord, float]]:
|
|
564
|
-
"""Vector search within a filtered set of candidates."""
|
|
1113
|
+
"""Vector search within a filtered set of candidates using scalar (int8) embeddings."""
|
|
565
1114
|
conn = self._conn
|
|
566
1115
|
assert conn is not None
|
|
567
1116
|
|
|
@@ -571,18 +1120,29 @@ class OrganizationDatabase:
|
|
|
571
1120
|
# Build IN clause for candidate IDs
|
|
572
1121
|
placeholders = ",".join("?" * len(candidate_ids))
|
|
573
1122
|
|
|
574
|
-
#
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
1123
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
1124
|
+
if self._has_scalar_table():
|
|
1125
|
+
# Query uses int8 embeddings with vec_int8() wrapper
|
|
1126
|
+
query = f"""
|
|
1127
|
+
SELECT
|
|
1128
|
+
e.org_id,
|
|
1129
|
+
vec_distance_cosine(e.embedding, vec_int8(?)) as distance
|
|
1130
|
+
FROM organization_embeddings_scalar e
|
|
1131
|
+
WHERE e.org_id IN ({placeholders})
|
|
1132
|
+
ORDER BY distance
|
|
1133
|
+
LIMIT ?
|
|
1134
|
+
"""
|
|
1135
|
+
else:
|
|
1136
|
+
# Fall back to float32 embeddings
|
|
1137
|
+
query = f"""
|
|
1138
|
+
SELECT
|
|
1139
|
+
e.org_id,
|
|
1140
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
1141
|
+
FROM organization_embeddings e
|
|
1142
|
+
WHERE e.org_id IN ({placeholders})
|
|
1143
|
+
ORDER BY distance
|
|
1144
|
+
LIMIT ?
|
|
1145
|
+
"""
|
|
586
1146
|
|
|
587
1147
|
cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
|
|
588
1148
|
|
|
@@ -609,33 +1169,58 @@ class OrganizationDatabase:
|
|
|
609
1169
|
top_k: int,
|
|
610
1170
|
source_filter: Optional[str],
|
|
611
1171
|
) -> list[tuple[CompanyRecord, float]]:
|
|
612
|
-
"""Full vector search without text pre-filtering."""
|
|
1172
|
+
"""Full vector search without text pre-filtering using scalar (int8) embeddings."""
|
|
613
1173
|
conn = self._conn
|
|
614
1174
|
assert conn is not None
|
|
615
1175
|
|
|
1176
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
1177
|
+
use_scalar = self._has_scalar_table()
|
|
1178
|
+
|
|
616
1179
|
# KNN search with sqlite-vec
|
|
617
1180
|
if source_filter:
|
|
618
1181
|
# Need to join with organizations table for source filter
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
1182
|
+
if use_scalar:
|
|
1183
|
+
query = """
|
|
1184
|
+
SELECT
|
|
1185
|
+
e.org_id,
|
|
1186
|
+
vec_distance_cosine(e.embedding, vec_int8(?)) as distance
|
|
1187
|
+
FROM organization_embeddings_scalar e
|
|
1188
|
+
JOIN organizations c ON e.org_id = c.id
|
|
1189
|
+
WHERE c.source = ?
|
|
1190
|
+
ORDER BY distance
|
|
1191
|
+
LIMIT ?
|
|
1192
|
+
"""
|
|
1193
|
+
else:
|
|
1194
|
+
query = """
|
|
1195
|
+
SELECT
|
|
1196
|
+
e.org_id,
|
|
1197
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
1198
|
+
FROM organization_embeddings e
|
|
1199
|
+
JOIN organizations c ON e.org_id = c.id
|
|
1200
|
+
WHERE c.source = ?
|
|
1201
|
+
ORDER BY distance
|
|
1202
|
+
LIMIT ?
|
|
1203
|
+
"""
|
|
629
1204
|
cursor = conn.execute(query, (query_blob, source_filter, top_k))
|
|
630
1205
|
else:
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
1206
|
+
if use_scalar:
|
|
1207
|
+
query = """
|
|
1208
|
+
SELECT
|
|
1209
|
+
org_id,
|
|
1210
|
+
vec_distance_cosine(embedding, vec_int8(?)) as distance
|
|
1211
|
+
FROM organization_embeddings_scalar
|
|
1212
|
+
ORDER BY distance
|
|
1213
|
+
LIMIT ?
|
|
1214
|
+
"""
|
|
1215
|
+
else:
|
|
1216
|
+
query = """
|
|
1217
|
+
SELECT
|
|
1218
|
+
org_id,
|
|
1219
|
+
vec_distance_cosine(embedding, ?) as distance
|
|
1220
|
+
FROM organization_embeddings
|
|
1221
|
+
ORDER BY distance
|
|
1222
|
+
LIMIT ?
|
|
1223
|
+
"""
|
|
639
1224
|
cursor = conn.execute(query, (query_blob, top_k))
|
|
640
1225
|
|
|
641
1226
|
results = []
|
|
@@ -651,24 +1236,38 @@ class OrganizationDatabase:
|
|
|
651
1236
|
return results
|
|
652
1237
|
|
|
653
1238
|
def _get_record_by_id(self, org_id: int) -> Optional[CompanyRecord]:
|
|
654
|
-
"""Get an organization record by ID."""
|
|
1239
|
+
"""Get an organization record by ID, including db_id and canon_id in record dict."""
|
|
655
1240
|
conn = self._conn
|
|
656
1241
|
assert conn is not None
|
|
657
1242
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
1243
|
+
if self._is_v2:
|
|
1244
|
+
# v2 schema: use view for text fields, but need record from base table
|
|
1245
|
+
cursor = conn.execute("""
|
|
1246
|
+
SELECT v.id, v.name, v.source, v.source_identifier, v.region, v.entity_type, v.canon_id, o.record
|
|
1247
|
+
FROM organizations_view v
|
|
1248
|
+
JOIN organizations o ON v.id = o.id
|
|
1249
|
+
WHERE v.id = ?
|
|
1250
|
+
""", (org_id,))
|
|
1251
|
+
else:
|
|
1252
|
+
cursor = conn.execute("""
|
|
1253
|
+
SELECT id, name, source, source_id, region, entity_type, record, canon_id
|
|
1254
|
+
FROM organizations WHERE id = ?
|
|
1255
|
+
""", (org_id,))
|
|
662
1256
|
|
|
663
1257
|
row = cursor.fetchone()
|
|
664
1258
|
if row:
|
|
1259
|
+
record_data = json.loads(row["record"])
|
|
1260
|
+
# Add db_id and canon_id to record dict for canon-aware search
|
|
1261
|
+
record_data["db_id"] = row["id"]
|
|
1262
|
+
record_data["canon_id"] = row["canon_id"]
|
|
1263
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
665
1264
|
return CompanyRecord(
|
|
666
1265
|
name=row["name"],
|
|
667
1266
|
source=row["source"],
|
|
668
|
-
source_id=row[
|
|
1267
|
+
source_id=row[source_id_field],
|
|
669
1268
|
region=row["region"] or "",
|
|
670
1269
|
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
671
|
-
record=
|
|
1270
|
+
record=record_data,
|
|
672
1271
|
)
|
|
673
1272
|
return None
|
|
674
1273
|
|
|
@@ -676,24 +1275,56 @@ class OrganizationDatabase:
|
|
|
676
1275
|
"""Get an organization record by source and source_id."""
|
|
677
1276
|
conn = self._connect()
|
|
678
1277
|
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
1278
|
+
if self._is_v2:
|
|
1279
|
+
# v2 schema: join view with base table for record
|
|
1280
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1281
|
+
cursor = conn.execute("""
|
|
1282
|
+
SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
|
|
1283
|
+
FROM organizations_view v
|
|
1284
|
+
JOIN organizations o ON v.id = o.id
|
|
1285
|
+
WHERE o.source_id = ? AND o.source_identifier = ?
|
|
1286
|
+
""", (source_type_id, source_id))
|
|
1287
|
+
else:
|
|
1288
|
+
cursor = conn.execute("""
|
|
1289
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
1290
|
+
FROM organizations
|
|
1291
|
+
WHERE source = ? AND source_id = ?
|
|
1292
|
+
""", (source, source_id))
|
|
684
1293
|
|
|
685
1294
|
row = cursor.fetchone()
|
|
686
1295
|
if row:
|
|
1296
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
687
1297
|
return CompanyRecord(
|
|
688
1298
|
name=row["name"],
|
|
689
1299
|
source=row["source"],
|
|
690
|
-
source_id=row[
|
|
1300
|
+
source_id=row[source_id_field],
|
|
691
1301
|
region=row["region"] or "",
|
|
692
1302
|
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
693
1303
|
record=json.loads(row["record"]),
|
|
694
1304
|
)
|
|
695
1305
|
return None
|
|
696
1306
|
|
|
1307
|
+
def get_id_by_source_id(self, source: str, source_id: str) -> Optional[int]:
|
|
1308
|
+
"""Get the internal database ID for an organization by source and source_id."""
|
|
1309
|
+
conn = self._connect()
|
|
1310
|
+
|
|
1311
|
+
if self._is_v2:
|
|
1312
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1313
|
+
cursor = conn.execute("""
|
|
1314
|
+
SELECT id FROM organizations
|
|
1315
|
+
WHERE source_id = ? AND source_identifier = ?
|
|
1316
|
+
""", (source_type_id, source_id))
|
|
1317
|
+
else:
|
|
1318
|
+
cursor = conn.execute("""
|
|
1319
|
+
SELECT id FROM organizations
|
|
1320
|
+
WHERE source = ? AND source_id = ?
|
|
1321
|
+
""", (source, source_id))
|
|
1322
|
+
|
|
1323
|
+
row = cursor.fetchone()
|
|
1324
|
+
if row:
|
|
1325
|
+
return row["id"]
|
|
1326
|
+
return None
|
|
1327
|
+
|
|
697
1328
|
def get_stats(self) -> DatabaseStats:
|
|
698
1329
|
"""Get database statistics."""
|
|
699
1330
|
conn = self._connect()
|
|
@@ -702,8 +1333,18 @@ class OrganizationDatabase:
|
|
|
702
1333
|
cursor = conn.execute("SELECT COUNT(*) FROM organizations")
|
|
703
1334
|
total = cursor.fetchone()[0]
|
|
704
1335
|
|
|
705
|
-
# Count by source
|
|
706
|
-
|
|
1336
|
+
# Count by source - handle both v1 and v2 schema
|
|
1337
|
+
if self._is_v2:
|
|
1338
|
+
# v2 schema - join with source_types
|
|
1339
|
+
cursor = conn.execute("""
|
|
1340
|
+
SELECT st.name as source, COUNT(*) as cnt
|
|
1341
|
+
FROM organizations o
|
|
1342
|
+
JOIN source_types st ON o.source_id = st.id
|
|
1343
|
+
GROUP BY o.source_id
|
|
1344
|
+
""")
|
|
1345
|
+
else:
|
|
1346
|
+
# v1 schema
|
|
1347
|
+
cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
|
|
707
1348
|
by_source = {row["source"]: row["cnt"] for row in cursor}
|
|
708
1349
|
|
|
709
1350
|
# Database file size
|
|
@@ -716,44 +1357,350 @@ class OrganizationDatabase:
|
|
|
716
1357
|
database_size_bytes=db_size,
|
|
717
1358
|
)
|
|
718
1359
|
|
|
1360
|
+
def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
|
|
1361
|
+
"""
|
|
1362
|
+
Get all source_ids from the organizations table.
|
|
1363
|
+
|
|
1364
|
+
Useful for resume operations to skip already-imported records.
|
|
1365
|
+
|
|
1366
|
+
Args:
|
|
1367
|
+
source: Optional source filter (e.g., "wikidata" for Wikidata orgs)
|
|
1368
|
+
|
|
1369
|
+
Returns:
|
|
1370
|
+
Set of source_id strings (e.g., Q codes for Wikidata)
|
|
1371
|
+
"""
|
|
1372
|
+
conn = self._connect()
|
|
1373
|
+
|
|
1374
|
+
if self._is_v2:
|
|
1375
|
+
id_col = "source_identifier"
|
|
1376
|
+
if source:
|
|
1377
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1378
|
+
cursor = conn.execute(
|
|
1379
|
+
f"SELECT DISTINCT {id_col} FROM organizations WHERE source_id = ?",
|
|
1380
|
+
(source_type_id,)
|
|
1381
|
+
)
|
|
1382
|
+
else:
|
|
1383
|
+
cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM organizations")
|
|
1384
|
+
else:
|
|
1385
|
+
if source:
|
|
1386
|
+
cursor = conn.execute(
|
|
1387
|
+
"SELECT DISTINCT source_id FROM organizations WHERE source = ?",
|
|
1388
|
+
(source,)
|
|
1389
|
+
)
|
|
1390
|
+
else:
|
|
1391
|
+
cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
|
|
1392
|
+
|
|
1393
|
+
return {row[0] for row in cursor}
|
|
1394
|
+
|
|
719
1395
|
def iter_records(self, source: Optional[str] = None) -> Iterator[CompanyRecord]:
|
|
720
1396
|
"""Iterate over all records, optionally filtered by source."""
|
|
721
1397
|
conn = self._connect()
|
|
722
1398
|
|
|
723
|
-
if
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
1399
|
+
if self._is_v2:
|
|
1400
|
+
if source:
|
|
1401
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1402
|
+
cursor = conn.execute("""
|
|
1403
|
+
SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
|
|
1404
|
+
FROM organizations_view v
|
|
1405
|
+
JOIN organizations o ON v.id = o.id
|
|
1406
|
+
WHERE o.source_id = ?
|
|
1407
|
+
""", (source_type_id,))
|
|
1408
|
+
else:
|
|
1409
|
+
cursor = conn.execute("""
|
|
1410
|
+
SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
|
|
1411
|
+
FROM organizations_view v
|
|
1412
|
+
JOIN organizations o ON v.id = o.id
|
|
1413
|
+
""")
|
|
1414
|
+
for row in cursor:
|
|
1415
|
+
yield CompanyRecord(
|
|
1416
|
+
name=row["name"],
|
|
1417
|
+
source=row["source"],
|
|
1418
|
+
source_id=row["source_identifier"],
|
|
1419
|
+
region=row["region"] or "",
|
|
1420
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
1421
|
+
record=json.loads(row["record"]),
|
|
1422
|
+
)
|
|
729
1423
|
else:
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
1424
|
+
if source:
|
|
1425
|
+
cursor = conn.execute("""
|
|
1426
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
1427
|
+
FROM organizations
|
|
1428
|
+
WHERE source = ?
|
|
1429
|
+
""", (source,))
|
|
1430
|
+
else:
|
|
1431
|
+
cursor = conn.execute("""
|
|
1432
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
1433
|
+
FROM organizations
|
|
1434
|
+
""")
|
|
1435
|
+
for row in cursor:
|
|
1436
|
+
yield CompanyRecord(
|
|
1437
|
+
name=row["name"],
|
|
1438
|
+
source=row["source"],
|
|
1439
|
+
source_id=row["source_id"],
|
|
1440
|
+
region=row["region"] or "",
|
|
1441
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
1442
|
+
record=json.loads(row["record"]),
|
|
1443
|
+
)
|
|
744
1444
|
|
|
745
|
-
def
|
|
1445
|
+
def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
|
|
746
1446
|
"""
|
|
747
|
-
|
|
1447
|
+
Canonicalize all organizations by linking equivalent records.
|
|
748
1448
|
|
|
749
|
-
|
|
750
|
-
|
|
1449
|
+
Records are considered equivalent if they match by:
|
|
1450
|
+
1. Same LEI (GLEIF source_id or Wikidata P1278) - globally unique, no region check
|
|
1451
|
+
2. Same ticker symbol - globally unique, no region check
|
|
1452
|
+
3. Same CIK - globally unique, no region check
|
|
1453
|
+
4. Same normalized name AND same normalized region
|
|
1454
|
+
5. Name match with suffix expansion AND same region
|
|
1455
|
+
|
|
1456
|
+
Region normalization uses pycountry to handle:
|
|
1457
|
+
- Country codes/names (GB, United Kingdom, Great Britain -> GB)
|
|
1458
|
+
- US state codes/names (CA, California -> US)
|
|
1459
|
+
- Common aliases (UK -> GB, USA -> US)
|
|
1460
|
+
|
|
1461
|
+
For each group of equivalent records, the highest-priority source
|
|
1462
|
+
(gleif > sec_edgar > companies_house > wikipedia) becomes canonical.
|
|
751
1463
|
|
|
752
1464
|
Args:
|
|
753
|
-
batch_size:
|
|
1465
|
+
batch_size: Commit batch size for updates
|
|
754
1466
|
|
|
755
1467
|
Returns:
|
|
756
|
-
|
|
1468
|
+
Dict with stats: total_records, groups_found, records_updated
|
|
1469
|
+
"""
|
|
1470
|
+
conn = self._connect()
|
|
1471
|
+
logger.info("Starting canonicalization...")
|
|
1472
|
+
|
|
1473
|
+
# Phase 1: Load all organization data and build indexes
|
|
1474
|
+
logger.info("Phase 1: Building indexes...")
|
|
1475
|
+
|
|
1476
|
+
lei_index: dict[str, list[int]] = {}
|
|
1477
|
+
ticker_index: dict[str, list[int]] = {}
|
|
1478
|
+
cik_index: dict[str, list[int]] = {}
|
|
1479
|
+
# Name indexes now keyed by (normalized_name, normalized_region)
|
|
1480
|
+
# Region-less matching only applies for identifier-based matching
|
|
1481
|
+
name_region_index: dict[tuple[str, str], list[int]] = {}
|
|
1482
|
+
expanded_name_region_index: dict[tuple[str, str], list[int]] = {}
|
|
1483
|
+
|
|
1484
|
+
sources: dict[int, str] = {} # org_id -> source
|
|
1485
|
+
all_org_ids: list[int] = []
|
|
1486
|
+
|
|
1487
|
+
if self._is_v2:
|
|
1488
|
+
cursor = conn.execute("""
|
|
1489
|
+
SELECT o.id, s.name as source, o.source_identifier as source_id, o.name, l.name as region, o.record
|
|
1490
|
+
FROM organizations o
|
|
1491
|
+
JOIN source_types s ON o.source_id = s.id
|
|
1492
|
+
LEFT JOIN locations l ON o.region_id = l.id
|
|
1493
|
+
""")
|
|
1494
|
+
else:
|
|
1495
|
+
cursor = conn.execute("""
|
|
1496
|
+
SELECT id, source, source_id, name, region, record
|
|
1497
|
+
FROM organizations
|
|
1498
|
+
""")
|
|
1499
|
+
|
|
1500
|
+
count = 0
|
|
1501
|
+
for row in cursor:
|
|
1502
|
+
org_id = row["id"]
|
|
1503
|
+
source = row["source"]
|
|
1504
|
+
name = row["name"]
|
|
1505
|
+
region = row["region"] or ""
|
|
1506
|
+
record = json.loads(row["record"])
|
|
1507
|
+
|
|
1508
|
+
all_org_ids.append(org_id)
|
|
1509
|
+
sources[org_id] = source
|
|
1510
|
+
|
|
1511
|
+
# Index by LEI (GLEIF source_id or Wikidata's P1278)
|
|
1512
|
+
# LEI is globally unique - no region check needed
|
|
1513
|
+
if source == "gleif":
|
|
1514
|
+
lei = row["source_id"]
|
|
1515
|
+
else:
|
|
1516
|
+
lei = record.get("lei")
|
|
1517
|
+
if lei:
|
|
1518
|
+
lei_index.setdefault(lei.upper(), []).append(org_id)
|
|
1519
|
+
|
|
1520
|
+
# Index by ticker - globally unique, no region check
|
|
1521
|
+
ticker = record.get("ticker")
|
|
1522
|
+
if ticker:
|
|
1523
|
+
ticker_index.setdefault(ticker.upper(), []).append(org_id)
|
|
1524
|
+
|
|
1525
|
+
# Index by CIK - globally unique, no region check
|
|
1526
|
+
if source == "sec_edgar":
|
|
1527
|
+
cik = row["source_id"]
|
|
1528
|
+
else:
|
|
1529
|
+
cik = record.get("cik")
|
|
1530
|
+
if cik:
|
|
1531
|
+
cik_index.setdefault(str(cik), []).append(org_id)
|
|
1532
|
+
|
|
1533
|
+
# Index by (normalized_name, normalized_region)
|
|
1534
|
+
# Same name in different regions = different legal entities
|
|
1535
|
+
norm_name = _normalize_for_canon(name)
|
|
1536
|
+
norm_region = _normalize_region(region)
|
|
1537
|
+
if norm_name:
|
|
1538
|
+
key = (norm_name, norm_region)
|
|
1539
|
+
name_region_index.setdefault(key, []).append(org_id)
|
|
1540
|
+
|
|
1541
|
+
# Index by (expanded_name, normalized_region)
|
|
1542
|
+
expanded_name = _expand_suffix(name)
|
|
1543
|
+
if expanded_name and expanded_name != norm_name:
|
|
1544
|
+
key = (expanded_name, norm_region)
|
|
1545
|
+
expanded_name_region_index.setdefault(key, []).append(org_id)
|
|
1546
|
+
|
|
1547
|
+
count += 1
|
|
1548
|
+
if count % 100000 == 0:
|
|
1549
|
+
logger.info(f" Indexed {count} organizations...")
|
|
1550
|
+
|
|
1551
|
+
logger.info(f" Indexed {count} organizations total")
|
|
1552
|
+
logger.info(f" LEI index: {len(lei_index)} unique LEIs")
|
|
1553
|
+
logger.info(f" Ticker index: {len(ticker_index)} unique tickers")
|
|
1554
|
+
logger.info(f" CIK index: {len(cik_index)} unique CIKs")
|
|
1555
|
+
logger.info(f" Name+region index: {len(name_region_index)} unique (name, region) pairs")
|
|
1556
|
+
logger.info(f" Expanded name+region index: {len(expanded_name_region_index)} unique pairs")
|
|
1557
|
+
|
|
1558
|
+
# Phase 2: Build equivalence groups using Union-Find
|
|
1559
|
+
logger.info("Phase 2: Building equivalence groups...")
|
|
1560
|
+
|
|
1561
|
+
uf = UnionFind(all_org_ids)
|
|
1562
|
+
|
|
1563
|
+
# Merge by LEI (globally unique identifier)
|
|
1564
|
+
for _lei, ids in lei_index.items():
|
|
1565
|
+
for i in range(1, len(ids)):
|
|
1566
|
+
uf.union(ids[0], ids[i])
|
|
1567
|
+
|
|
1568
|
+
# Merge by ticker (globally unique identifier)
|
|
1569
|
+
for _ticker, ids in ticker_index.items():
|
|
1570
|
+
for i in range(1, len(ids)):
|
|
1571
|
+
uf.union(ids[0], ids[i])
|
|
1572
|
+
|
|
1573
|
+
# Merge by CIK (globally unique identifier)
|
|
1574
|
+
for _cik, ids in cik_index.items():
|
|
1575
|
+
for i in range(1, len(ids)):
|
|
1576
|
+
uf.union(ids[0], ids[i])
|
|
1577
|
+
|
|
1578
|
+
# Merge by (normalized_name, normalized_region)
|
|
1579
|
+
for _name_region, ids in name_region_index.items():
|
|
1580
|
+
for i in range(1, len(ids)):
|
|
1581
|
+
uf.union(ids[0], ids[i])
|
|
1582
|
+
|
|
1583
|
+
# Merge by (expanded_name, normalized_region)
|
|
1584
|
+
# This connects "Amazon Ltd" with "Amazon Limited" in same region
|
|
1585
|
+
for key, expanded_ids in expanded_name_region_index.items():
|
|
1586
|
+
# Find org_ids with the expanded form as their normalized name in same region
|
|
1587
|
+
if key in name_region_index:
|
|
1588
|
+
# Link first expanded_id to first name_id
|
|
1589
|
+
uf.union(expanded_ids[0], name_region_index[key][0])
|
|
1590
|
+
|
|
1591
|
+
groups = uf.groups()
|
|
1592
|
+
logger.info(f" Found {len(groups)} equivalence groups")
|
|
1593
|
+
|
|
1594
|
+
# Count groups with multiple records
|
|
1595
|
+
multi_record_groups = sum(1 for ids in groups.values() if len(ids) > 1)
|
|
1596
|
+
logger.info(f" Groups with multiple records: {multi_record_groups}")
|
|
1597
|
+
|
|
1598
|
+
# Phase 3: Select canonical record for each group and update database
|
|
1599
|
+
logger.info("Phase 3: Updating database...")
|
|
1600
|
+
|
|
1601
|
+
updated_count = 0
|
|
1602
|
+
batch_updates: list[tuple[int, int, int]] = [] # (org_id, canon_id, canon_size)
|
|
1603
|
+
|
|
1604
|
+
for _root, group_ids in groups.items():
|
|
1605
|
+
if len(group_ids) == 1:
|
|
1606
|
+
# Single record - canonical to itself
|
|
1607
|
+
batch_updates.append((group_ids[0], group_ids[0], 1))
|
|
1608
|
+
else:
|
|
1609
|
+
# Multiple records - find highest priority source
|
|
1610
|
+
best_id = min(
|
|
1611
|
+
group_ids,
|
|
1612
|
+
key=lambda oid: (SOURCE_PRIORITY.get(sources[oid], 99), oid)
|
|
1613
|
+
)
|
|
1614
|
+
group_size = len(group_ids)
|
|
1615
|
+
|
|
1616
|
+
# All records in group point to the best one
|
|
1617
|
+
for oid in group_ids:
|
|
1618
|
+
# canon_size is only set on the canonical record
|
|
1619
|
+
size = group_size if oid == best_id else 1
|
|
1620
|
+
batch_updates.append((oid, best_id, size))
|
|
1621
|
+
|
|
1622
|
+
# Commit batch
|
|
1623
|
+
if len(batch_updates) >= batch_size:
|
|
1624
|
+
self._apply_canon_updates(batch_updates)
|
|
1625
|
+
updated_count += len(batch_updates)
|
|
1626
|
+
logger.info(f" Updated {updated_count} records...")
|
|
1627
|
+
batch_updates = []
|
|
1628
|
+
|
|
1629
|
+
# Final batch
|
|
1630
|
+
if batch_updates:
|
|
1631
|
+
self._apply_canon_updates(batch_updates)
|
|
1632
|
+
updated_count += len(batch_updates)
|
|
1633
|
+
|
|
1634
|
+
conn.commit()
|
|
1635
|
+
logger.info(f"Canonicalization complete: {updated_count} records updated, {multi_record_groups} multi-record groups")
|
|
1636
|
+
|
|
1637
|
+
return {
|
|
1638
|
+
"total_records": count,
|
|
1639
|
+
"groups_found": len(groups),
|
|
1640
|
+
"multi_record_groups": multi_record_groups,
|
|
1641
|
+
"records_updated": updated_count,
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
def _apply_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
|
|
1645
|
+
"""Apply batch of canon updates: (org_id, canon_id, canon_size)."""
|
|
1646
|
+
conn = self._conn
|
|
1647
|
+
assert conn is not None
|
|
1648
|
+
|
|
1649
|
+
for org_id, canon_id, canon_size in updates:
|
|
1650
|
+
conn.execute(
|
|
1651
|
+
"UPDATE organizations SET canon_id = ?, canon_size = ? WHERE id = ?",
|
|
1652
|
+
(canon_id, canon_size, org_id)
|
|
1653
|
+
)
|
|
1654
|
+
|
|
1655
|
+
conn.commit()
|
|
1656
|
+
|
|
1657
|
+
def get_canon_stats(self) -> dict[str, int]:
|
|
1658
|
+
"""Get statistics about canonicalization status."""
|
|
1659
|
+
conn = self._connect()
|
|
1660
|
+
|
|
1661
|
+
# Total records
|
|
1662
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations")
|
|
1663
|
+
total = cursor.fetchone()[0]
|
|
1664
|
+
|
|
1665
|
+
# Records with canon_id set
|
|
1666
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_id IS NOT NULL")
|
|
1667
|
+
canonicalized = cursor.fetchone()[0]
|
|
1668
|
+
|
|
1669
|
+
# Number of canonical groups (unique canon_ids)
|
|
1670
|
+
cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM organizations WHERE canon_id IS NOT NULL")
|
|
1671
|
+
groups = cursor.fetchone()[0]
|
|
1672
|
+
|
|
1673
|
+
# Multi-record groups (canon_size > 1)
|
|
1674
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_size > 1")
|
|
1675
|
+
multi_record_groups = cursor.fetchone()[0]
|
|
1676
|
+
|
|
1677
|
+
# Records in multi-record groups
|
|
1678
|
+
cursor = conn.execute("""
|
|
1679
|
+
SELECT COUNT(*) FROM organizations o1
|
|
1680
|
+
WHERE EXISTS (SELECT 1 FROM organizations o2 WHERE o2.id = o1.canon_id AND o2.canon_size > 1)
|
|
1681
|
+
""")
|
|
1682
|
+
records_in_multi = cursor.fetchone()[0]
|
|
1683
|
+
|
|
1684
|
+
return {
|
|
1685
|
+
"total_records": total,
|
|
1686
|
+
"canonicalized_records": canonicalized,
|
|
1687
|
+
"canonical_groups": groups,
|
|
1688
|
+
"multi_record_groups": multi_record_groups,
|
|
1689
|
+
"records_in_multi_groups": records_in_multi,
|
|
1690
|
+
}
|
|
1691
|
+
|
|
1692
|
+
def migrate_name_normalized(self, batch_size: int = 50000) -> int:
|
|
1693
|
+
"""
|
|
1694
|
+
Populate the name_normalized column for all records.
|
|
1695
|
+
|
|
1696
|
+
This is a one-time migration for databases that don't have
|
|
1697
|
+
normalized names populated.
|
|
1698
|
+
|
|
1699
|
+
Args:
|
|
1700
|
+
batch_size: Number of records to process per batch
|
|
1701
|
+
|
|
1702
|
+
Returns:
|
|
1703
|
+
Number of records updated
|
|
757
1704
|
"""
|
|
758
1705
|
conn = self._connect()
|
|
759
1706
|
|
|
@@ -867,8 +1814,9 @@ class OrganizationDatabase:
|
|
|
867
1814
|
assert conn is not None
|
|
868
1815
|
|
|
869
1816
|
for org_id, embedding_blob in batch:
|
|
1817
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
|
|
870
1818
|
conn.execute("""
|
|
871
|
-
INSERT
|
|
1819
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
872
1820
|
VALUES (?, ?)
|
|
873
1821
|
""", (org_id, embedding_blob))
|
|
874
1822
|
|
|
@@ -878,17 +1826,32 @@ class OrganizationDatabase:
|
|
|
878
1826
|
"""Delete all records from a specific source."""
|
|
879
1827
|
conn = self._connect()
|
|
880
1828
|
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
1829
|
+
if self._is_v2:
|
|
1830
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1831
|
+
# First get IDs to delete from vec table
|
|
1832
|
+
cursor = conn.execute("SELECT id FROM organizations WHERE source_id = ?", (source_type_id,))
|
|
1833
|
+
ids_to_delete = [row["id"] for row in cursor]
|
|
1834
|
+
|
|
1835
|
+
# Delete from vec table
|
|
1836
|
+
if ids_to_delete:
|
|
1837
|
+
placeholders = ",".join("?" * len(ids_to_delete))
|
|
1838
|
+
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
1839
|
+
|
|
1840
|
+
# Delete from main table
|
|
1841
|
+
cursor = conn.execute("DELETE FROM organizations WHERE source_id = ?", (source_type_id,))
|
|
1842
|
+
else:
|
|
1843
|
+
# First get IDs to delete from vec table
|
|
1844
|
+
cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
|
|
1845
|
+
ids_to_delete = [row["id"] for row in cursor]
|
|
1846
|
+
|
|
1847
|
+
# Delete from vec table
|
|
1848
|
+
if ids_to_delete:
|
|
1849
|
+
placeholders = ",".join("?" * len(ids_to_delete))
|
|
1850
|
+
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
884
1851
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
placeholders = ",".join("?" * len(ids_to_delete))
|
|
888
|
-
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
1852
|
+
# Delete from main table
|
|
1853
|
+
cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
|
|
889
1854
|
|
|
890
|
-
# Delete from main table
|
|
891
|
-
cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
|
|
892
1855
|
deleted = cursor.rowcount
|
|
893
1856
|
|
|
894
1857
|
conn.commit()
|
|
@@ -1107,8 +2070,9 @@ class OrganizationDatabase:
|
|
|
1107
2070
|
|
|
1108
2071
|
for org_id, embedding in zip(org_ids, embeddings):
|
|
1109
2072
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
2073
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
|
|
1110
2074
|
conn.execute("""
|
|
1111
|
-
INSERT
|
|
2075
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
1112
2076
|
VALUES (?, ?)
|
|
1113
2077
|
""", (org_id, embedding_blob))
|
|
1114
2078
|
count += 1
|
|
@@ -1116,6 +2080,287 @@ class OrganizationDatabase:
|
|
|
1116
2080
|
conn.commit()
|
|
1117
2081
|
return count
|
|
1118
2082
|
|
|
2083
|
+
def ensure_scalar_table_exists(self) -> None:
|
|
2084
|
+
"""Create scalar embedding table if it doesn't exist."""
|
|
2085
|
+
conn = self._connect()
|
|
2086
|
+
conn.execute(f"""
|
|
2087
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
|
|
2088
|
+
org_id INTEGER PRIMARY KEY,
|
|
2089
|
+
embedding int8[{self._embedding_dim}]
|
|
2090
|
+
)
|
|
2091
|
+
""")
|
|
2092
|
+
conn.commit()
|
|
2093
|
+
logger.info("Ensured organization_embeddings_scalar table exists")
|
|
2094
|
+
|
|
2095
|
+
def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
|
|
2096
|
+
"""
|
|
2097
|
+
Yield batches of org IDs that have float32 but missing scalar embeddings.
|
|
2098
|
+
|
|
2099
|
+
Args:
|
|
2100
|
+
batch_size: Number of IDs per batch
|
|
2101
|
+
|
|
2102
|
+
Yields:
|
|
2103
|
+
Lists of org_ids needing scalar embeddings
|
|
2104
|
+
"""
|
|
2105
|
+
conn = self._connect()
|
|
2106
|
+
|
|
2107
|
+
# Ensure scalar table exists before querying
|
|
2108
|
+
self.ensure_scalar_table_exists()
|
|
2109
|
+
|
|
2110
|
+
last_id = 0
|
|
2111
|
+
while True:
|
|
2112
|
+
cursor = conn.execute("""
|
|
2113
|
+
SELECT e.org_id FROM organization_embeddings e
|
|
2114
|
+
LEFT JOIN organization_embeddings_scalar s ON e.org_id = s.org_id
|
|
2115
|
+
WHERE s.org_id IS NULL AND e.org_id > ?
|
|
2116
|
+
ORDER BY e.org_id
|
|
2117
|
+
LIMIT ?
|
|
2118
|
+
""", (last_id, batch_size))
|
|
2119
|
+
|
|
2120
|
+
rows = cursor.fetchall()
|
|
2121
|
+
if not rows:
|
|
2122
|
+
break
|
|
2123
|
+
|
|
2124
|
+
ids = [row["org_id"] for row in rows]
|
|
2125
|
+
yield ids
|
|
2126
|
+
last_id = ids[-1]
|
|
2127
|
+
|
|
2128
|
+
def get_embeddings_by_ids(self, org_ids: list[int]) -> dict[int, np.ndarray]:
|
|
2129
|
+
"""
|
|
2130
|
+
Fetch float32 embeddings for given org IDs.
|
|
2131
|
+
|
|
2132
|
+
Args:
|
|
2133
|
+
org_ids: List of organization IDs
|
|
2134
|
+
|
|
2135
|
+
Returns:
|
|
2136
|
+
Dict mapping org_id to float32 embedding array
|
|
2137
|
+
"""
|
|
2138
|
+
conn = self._connect()
|
|
2139
|
+
|
|
2140
|
+
if not org_ids:
|
|
2141
|
+
return {}
|
|
2142
|
+
|
|
2143
|
+
placeholders = ",".join("?" * len(org_ids))
|
|
2144
|
+
cursor = conn.execute(f"""
|
|
2145
|
+
SELECT org_id, embedding FROM organization_embeddings
|
|
2146
|
+
WHERE org_id IN ({placeholders})
|
|
2147
|
+
""", org_ids)
|
|
2148
|
+
|
|
2149
|
+
result = {}
|
|
2150
|
+
for row in cursor:
|
|
2151
|
+
embedding_blob = row["embedding"]
|
|
2152
|
+
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
2153
|
+
result[row["org_id"]] = embedding
|
|
2154
|
+
return result
|
|
2155
|
+
|
|
2156
|
+
def insert_scalar_embeddings_batch(self, org_ids: list[int], embeddings: np.ndarray) -> int:
|
|
2157
|
+
"""
|
|
2158
|
+
Insert scalar (int8) embeddings for existing orgs.
|
|
2159
|
+
|
|
2160
|
+
Args:
|
|
2161
|
+
org_ids: List of organization IDs
|
|
2162
|
+
embeddings: Matrix of int8 embeddings (N x dim)
|
|
2163
|
+
|
|
2164
|
+
Returns:
|
|
2165
|
+
Number of embeddings inserted
|
|
2166
|
+
"""
|
|
2167
|
+
conn = self._connect()
|
|
2168
|
+
count = 0
|
|
2169
|
+
|
|
2170
|
+
for org_id, embedding in zip(org_ids, embeddings):
|
|
2171
|
+
scalar_blob = embedding.astype(np.int8).tobytes()
|
|
2172
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
|
|
2173
|
+
conn.execute("""
|
|
2174
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
2175
|
+
VALUES (?, vec_int8(?))
|
|
2176
|
+
""", (org_id, scalar_blob))
|
|
2177
|
+
count += 1
|
|
2178
|
+
|
|
2179
|
+
conn.commit()
|
|
2180
|
+
return count
|
|
2181
|
+
|
|
2182
|
+
def get_scalar_embedding_count(self) -> int:
|
|
2183
|
+
"""Get count of scalar embeddings."""
|
|
2184
|
+
conn = self._connect()
|
|
2185
|
+
if not self._has_scalar_table():
|
|
2186
|
+
return 0
|
|
2187
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings_scalar")
|
|
2188
|
+
return cursor.fetchone()[0]
|
|
2189
|
+
|
|
2190
|
+
def get_float32_embedding_count(self) -> int:
|
|
2191
|
+
"""Get count of float32 embeddings."""
|
|
2192
|
+
conn = self._connect()
|
|
2193
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings")
|
|
2194
|
+
return cursor.fetchone()[0]
|
|
2195
|
+
|
|
2196
|
+
def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
|
|
2197
|
+
"""
|
|
2198
|
+
Yield batches of (org_id, name) tuples for records missing both float32 and scalar embeddings.
|
|
2199
|
+
|
|
2200
|
+
Args:
|
|
2201
|
+
batch_size: Number of IDs per batch
|
|
2202
|
+
|
|
2203
|
+
Yields:
|
|
2204
|
+
Lists of (org_id, name) tuples needing embeddings generated from scratch
|
|
2205
|
+
"""
|
|
2206
|
+
conn = self._connect()
|
|
2207
|
+
|
|
2208
|
+
# Ensure scalar table exists
|
|
2209
|
+
self.ensure_scalar_table_exists()
|
|
2210
|
+
|
|
2211
|
+
last_id = 0
|
|
2212
|
+
while True:
|
|
2213
|
+
cursor = conn.execute("""
|
|
2214
|
+
SELECT o.id, o.name FROM organizations o
|
|
2215
|
+
LEFT JOIN organization_embeddings e ON o.id = e.org_id
|
|
2216
|
+
WHERE e.org_id IS NULL AND o.id > ?
|
|
2217
|
+
ORDER BY o.id
|
|
2218
|
+
LIMIT ?
|
|
2219
|
+
""", (last_id, batch_size))
|
|
2220
|
+
|
|
2221
|
+
rows = cursor.fetchall()
|
|
2222
|
+
if not rows:
|
|
2223
|
+
break
|
|
2224
|
+
|
|
2225
|
+
results = [(row["id"], row["name"]) for row in rows]
|
|
2226
|
+
yield results
|
|
2227
|
+
last_id = results[-1][0]
|
|
2228
|
+
|
|
2229
|
+
def insert_both_embeddings_batch(
|
|
2230
|
+
self,
|
|
2231
|
+
org_ids: list[int],
|
|
2232
|
+
fp32_embeddings: np.ndarray,
|
|
2233
|
+
int8_embeddings: np.ndarray,
|
|
2234
|
+
) -> int:
|
|
2235
|
+
"""
|
|
2236
|
+
Insert both float32 and int8 embeddings for existing orgs.
|
|
2237
|
+
|
|
2238
|
+
Args:
|
|
2239
|
+
org_ids: List of organization IDs
|
|
2240
|
+
fp32_embeddings: Matrix of float32 embeddings (N x dim)
|
|
2241
|
+
int8_embeddings: Matrix of int8 embeddings (N x dim)
|
|
2242
|
+
|
|
2243
|
+
Returns:
|
|
2244
|
+
Number of embeddings inserted
|
|
2245
|
+
"""
|
|
2246
|
+
conn = self._connect()
|
|
2247
|
+
count = 0
|
|
2248
|
+
|
|
2249
|
+
for org_id, fp32, int8 in zip(org_ids, fp32_embeddings, int8_embeddings):
|
|
2250
|
+
# Insert float32
|
|
2251
|
+
fp32_blob = fp32.astype(np.float32).tobytes()
|
|
2252
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
|
|
2253
|
+
conn.execute("""
|
|
2254
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
2255
|
+
VALUES (?, ?)
|
|
2256
|
+
""", (org_id, fp32_blob))
|
|
2257
|
+
|
|
2258
|
+
# Insert int8
|
|
2259
|
+
int8_blob = int8.astype(np.int8).tobytes()
|
|
2260
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
|
|
2261
|
+
conn.execute("""
|
|
2262
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
2263
|
+
VALUES (?, vec_int8(?))
|
|
2264
|
+
""", (org_id, int8_blob))
|
|
2265
|
+
|
|
2266
|
+
count += 1
|
|
2267
|
+
|
|
2268
|
+
conn.commit()
|
|
2269
|
+
return count
|
|
2270
|
+
|
|
2271
|
+
def resolve_qid_labels(
|
|
2272
|
+
self,
|
|
2273
|
+
label_map: dict[str, str],
|
|
2274
|
+
batch_size: int = 1000,
|
|
2275
|
+
) -> tuple[int, int]:
|
|
2276
|
+
"""
|
|
2277
|
+
Update organization records that have QIDs instead of labels in region field.
|
|
2278
|
+
|
|
2279
|
+
If resolving would create a duplicate of an existing record with
|
|
2280
|
+
resolved labels, the QID version is deleted instead.
|
|
2281
|
+
|
|
2282
|
+
Args:
|
|
2283
|
+
label_map: Mapping of QID -> label for resolution
|
|
2284
|
+
batch_size: Commit batch size
|
|
2285
|
+
|
|
2286
|
+
Returns:
|
|
2287
|
+
Tuple of (records updated, duplicates deleted)
|
|
2288
|
+
"""
|
|
2289
|
+
conn = self._connect()
|
|
2290
|
+
|
|
2291
|
+
# Find records with QIDs in region field (starts with 'Q' followed by digits)
|
|
2292
|
+
region_updates = 0
|
|
2293
|
+
cursor = conn.execute("""
|
|
2294
|
+
SELECT id, region FROM organizations
|
|
2295
|
+
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
2296
|
+
""")
|
|
2297
|
+
rows = cursor.fetchall()
|
|
2298
|
+
|
|
2299
|
+
duplicates_deleted = 0
|
|
2300
|
+
for row in rows:
|
|
2301
|
+
org_id = row["id"]
|
|
2302
|
+
qid = row["region"]
|
|
2303
|
+
if qid in label_map:
|
|
2304
|
+
resolved_region = label_map[qid]
|
|
2305
|
+
# Check if this update would create a duplicate
|
|
2306
|
+
# Get the name and source of the current record
|
|
2307
|
+
org_cursor = conn.execute(
|
|
2308
|
+
"SELECT name, source FROM organizations WHERE id = ?",
|
|
2309
|
+
(org_id,)
|
|
2310
|
+
)
|
|
2311
|
+
org_row = org_cursor.fetchone()
|
|
2312
|
+
if org_row is None:
|
|
2313
|
+
continue
|
|
2314
|
+
|
|
2315
|
+
org_name = org_row["name"]
|
|
2316
|
+
org_source = org_row["source"]
|
|
2317
|
+
|
|
2318
|
+
# Check if a record with the resolved region already exists
|
|
2319
|
+
existing_cursor = conn.execute(
|
|
2320
|
+
"SELECT id FROM organizations WHERE name = ? AND region = ? AND source = ? AND id != ?",
|
|
2321
|
+
(org_name, resolved_region, org_source, org_id)
|
|
2322
|
+
)
|
|
2323
|
+
existing = existing_cursor.fetchone()
|
|
2324
|
+
|
|
2325
|
+
if existing is not None:
|
|
2326
|
+
# Duplicate would be created - delete the QID-based record
|
|
2327
|
+
conn.execute("DELETE FROM organizations WHERE id = ?", (org_id,))
|
|
2328
|
+
duplicates_deleted += 1
|
|
2329
|
+
else:
|
|
2330
|
+
# Safe to update
|
|
2331
|
+
conn.execute(
|
|
2332
|
+
"UPDATE organizations SET region = ? WHERE id = ?",
|
|
2333
|
+
(resolved_region, org_id)
|
|
2334
|
+
)
|
|
2335
|
+
region_updates += 1
|
|
2336
|
+
|
|
2337
|
+
if (region_updates + duplicates_deleted) % batch_size == 0:
|
|
2338
|
+
conn.commit()
|
|
2339
|
+
logger.info(f"Resolved QID labels: {region_updates} updates, {duplicates_deleted} deletes...")
|
|
2340
|
+
|
|
2341
|
+
conn.commit()
|
|
2342
|
+
logger.info(f"Resolved QID labels: {region_updates} organization regions, {duplicates_deleted} duplicates deleted")
|
|
2343
|
+
return region_updates, duplicates_deleted
|
|
2344
|
+
|
|
2345
|
+
def get_unresolved_qids(self) -> set[str]:
|
|
2346
|
+
"""
|
|
2347
|
+
Get all QIDs that still need resolution in the organizations table.
|
|
2348
|
+
|
|
2349
|
+
Returns:
|
|
2350
|
+
Set of QIDs (starting with 'Q') found in region field
|
|
2351
|
+
"""
|
|
2352
|
+
conn = self._connect()
|
|
2353
|
+
qids: set[str] = set()
|
|
2354
|
+
|
|
2355
|
+
cursor = conn.execute("""
|
|
2356
|
+
SELECT DISTINCT region FROM organizations
|
|
2357
|
+
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
2358
|
+
""")
|
|
2359
|
+
for row in cursor:
|
|
2360
|
+
qids.add(row["region"])
|
|
2361
|
+
|
|
2362
|
+
return qids
|
|
2363
|
+
|
|
1119
2364
|
|
|
1120
2365
|
def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
|
|
1121
2366
|
"""
|
|
@@ -1161,36 +2406,51 @@ class PersonDatabase:
|
|
|
1161
2406
|
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
1162
2407
|
self._embedding_dim = embedding_dim
|
|
1163
2408
|
self._conn: Optional[sqlite3.Connection] = None
|
|
2409
|
+
self._is_v2: Optional[bool] = None # Detected on first connect
|
|
1164
2410
|
|
|
1165
2411
|
def _ensure_dir(self) -> None:
|
|
1166
2412
|
"""Ensure database directory exists."""
|
|
1167
2413
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1168
2414
|
|
|
1169
2415
|
def _connect(self) -> sqlite3.Connection:
|
|
1170
|
-
"""Get or create database connection
|
|
2416
|
+
"""Get or create database connection using shared connection pool."""
|
|
1171
2417
|
if self._conn is not None:
|
|
1172
2418
|
return self._conn
|
|
1173
2419
|
|
|
1174
|
-
self.
|
|
1175
|
-
self._conn = sqlite3.connect(str(self._db_path))
|
|
1176
|
-
self._conn.row_factory = sqlite3.Row
|
|
2420
|
+
self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
|
|
1177
2421
|
|
|
1178
|
-
#
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
2422
|
+
# Detect schema version BEFORE creating tables
|
|
2423
|
+
# v2 has person_type_id (FK) instead of person_type (TEXT)
|
|
2424
|
+
if self._is_v2 is None:
|
|
2425
|
+
cursor = self._conn.execute("PRAGMA table_info(people)")
|
|
2426
|
+
columns = {row["name"] for row in cursor}
|
|
2427
|
+
self._is_v2 = "person_type_id" in columns
|
|
2428
|
+
if self._is_v2:
|
|
2429
|
+
logger.debug("Detected v2 schema for people")
|
|
1182
2430
|
|
|
1183
|
-
# Create tables
|
|
1184
|
-
|
|
2431
|
+
# Create tables (idempotent) - only for v1 schema or fresh databases
|
|
2432
|
+
# v2 databases already have their schema from migration
|
|
2433
|
+
if not self._is_v2:
|
|
2434
|
+
self._create_tables()
|
|
1185
2435
|
|
|
1186
2436
|
return self._conn
|
|
1187
2437
|
|
|
2438
|
+
@property
|
|
2439
|
+
def _people_table(self) -> str:
|
|
2440
|
+
"""Return table/view name for people queries needing text fields."""
|
|
2441
|
+
return "people_view" if self._is_v2 else "people"
|
|
2442
|
+
|
|
1188
2443
|
def _create_tables(self) -> None:
|
|
1189
2444
|
"""Create database tables including sqlite-vec virtual table."""
|
|
1190
2445
|
conn = self._conn
|
|
1191
2446
|
assert conn is not None
|
|
1192
2447
|
|
|
2448
|
+
# Check if we need to migrate from old schema (unique on source+source_id only)
|
|
2449
|
+
self._migrate_people_schema_if_needed(conn)
|
|
2450
|
+
|
|
1193
2451
|
# Main people records table
|
|
2452
|
+
# Unique constraint on source+source_id+role+org allows multiple records
|
|
2453
|
+
# for the same person with different role/org combinations
|
|
1194
2454
|
conn.execute("""
|
|
1195
2455
|
CREATE TABLE IF NOT EXISTS people (
|
|
1196
2456
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
@@ -1202,8 +2462,14 @@ class PersonDatabase:
|
|
|
1202
2462
|
person_type TEXT NOT NULL DEFAULT 'unknown',
|
|
1203
2463
|
known_for_role TEXT NOT NULL DEFAULT '',
|
|
1204
2464
|
known_for_org TEXT NOT NULL DEFAULT '',
|
|
2465
|
+
known_for_org_id INTEGER DEFAULT NULL,
|
|
2466
|
+
from_date TEXT NOT NULL DEFAULT '',
|
|
2467
|
+
to_date TEXT NOT NULL DEFAULT '',
|
|
2468
|
+
birth_date TEXT NOT NULL DEFAULT '',
|
|
2469
|
+
death_date TEXT NOT NULL DEFAULT '',
|
|
1205
2470
|
record TEXT NOT NULL,
|
|
1206
|
-
UNIQUE(source, source_id)
|
|
2471
|
+
UNIQUE(source, source_id, known_for_role, known_for_org),
|
|
2472
|
+
FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
|
|
1207
2473
|
)
|
|
1208
2474
|
""")
|
|
1209
2475
|
|
|
@@ -1211,10 +2477,72 @@ class PersonDatabase:
|
|
|
1211
2477
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name ON people(name)")
|
|
1212
2478
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name_normalized ON people(name_normalized)")
|
|
1213
2479
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source ON people(source)")
|
|
1214
|
-
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id)")
|
|
2480
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id, known_for_role, known_for_org)")
|
|
1215
2481
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org ON people(known_for_org)")
|
|
1216
2482
|
|
|
1217
|
-
#
|
|
2483
|
+
# Add from_date column if it doesn't exist (migration for existing DBs)
|
|
2484
|
+
try:
|
|
2485
|
+
conn.execute("ALTER TABLE people ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
|
|
2486
|
+
logger.info("Added from_date column to people table")
|
|
2487
|
+
except sqlite3.OperationalError:
|
|
2488
|
+
pass # Column already exists
|
|
2489
|
+
|
|
2490
|
+
# Add to_date column if it doesn't exist (migration for existing DBs)
|
|
2491
|
+
try:
|
|
2492
|
+
conn.execute("ALTER TABLE people ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
|
|
2493
|
+
logger.info("Added to_date column to people table")
|
|
2494
|
+
except sqlite3.OperationalError:
|
|
2495
|
+
pass # Column already exists
|
|
2496
|
+
|
|
2497
|
+
# Add known_for_org_id column if it doesn't exist (migration for existing DBs)
|
|
2498
|
+
# This is a foreign key to the organizations table (nullable)
|
|
2499
|
+
try:
|
|
2500
|
+
conn.execute("ALTER TABLE people ADD COLUMN known_for_org_id INTEGER DEFAULT NULL")
|
|
2501
|
+
logger.info("Added known_for_org_id column to people table")
|
|
2502
|
+
except sqlite3.OperationalError:
|
|
2503
|
+
pass # Column already exists
|
|
2504
|
+
|
|
2505
|
+
# Create index on known_for_org_id for joins (only if column exists)
|
|
2506
|
+
try:
|
|
2507
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org_id ON people(known_for_org_id)")
|
|
2508
|
+
except sqlite3.OperationalError:
|
|
2509
|
+
pass # Column doesn't exist yet (will be added on next connection)
|
|
2510
|
+
|
|
2511
|
+
# Add birth_date column if it doesn't exist (migration for existing DBs)
|
|
2512
|
+
try:
|
|
2513
|
+
conn.execute("ALTER TABLE people ADD COLUMN birth_date TEXT NOT NULL DEFAULT ''")
|
|
2514
|
+
logger.info("Added birth_date column to people table")
|
|
2515
|
+
except sqlite3.OperationalError:
|
|
2516
|
+
pass # Column already exists
|
|
2517
|
+
|
|
2518
|
+
# Add death_date column if it doesn't exist (migration for existing DBs)
|
|
2519
|
+
try:
|
|
2520
|
+
conn.execute("ALTER TABLE people ADD COLUMN death_date TEXT NOT NULL DEFAULT ''")
|
|
2521
|
+
logger.info("Added death_date column to people table")
|
|
2522
|
+
except sqlite3.OperationalError:
|
|
2523
|
+
pass # Column already exists
|
|
2524
|
+
|
|
2525
|
+
# Add canon_id column if it doesn't exist (migration for canonicalization)
|
|
2526
|
+
try:
|
|
2527
|
+
conn.execute("ALTER TABLE people ADD COLUMN canon_id INTEGER DEFAULT NULL")
|
|
2528
|
+
logger.info("Added canon_id column to people table")
|
|
2529
|
+
except sqlite3.OperationalError:
|
|
2530
|
+
pass # Column already exists
|
|
2531
|
+
|
|
2532
|
+
# Add canon_size column if it doesn't exist (migration for canonicalization)
|
|
2533
|
+
try:
|
|
2534
|
+
conn.execute("ALTER TABLE people ADD COLUMN canon_size INTEGER DEFAULT 1")
|
|
2535
|
+
logger.info("Added canon_size column to people table")
|
|
2536
|
+
except sqlite3.OperationalError:
|
|
2537
|
+
pass # Column already exists
|
|
2538
|
+
|
|
2539
|
+
# Create index on canon_id for joins
|
|
2540
|
+
try:
|
|
2541
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_canon_id ON people(canon_id)")
|
|
2542
|
+
except sqlite3.OperationalError:
|
|
2543
|
+
pass # Column doesn't exist yet
|
|
2544
|
+
|
|
2545
|
+
# Create sqlite-vec virtual table for embeddings (float32)
|
|
1218
2546
|
conn.execute(f"""
|
|
1219
2547
|
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
|
|
1220
2548
|
person_id INTEGER PRIMARY KEY,
|
|
@@ -1222,21 +2550,117 @@ class PersonDatabase:
|
|
|
1222
2550
|
)
|
|
1223
2551
|
""")
|
|
1224
2552
|
|
|
1225
|
-
|
|
1226
|
-
|
|
2553
|
+
# Create sqlite-vec virtual table for scalar embeddings (int8)
|
|
2554
|
+
# Provides 75% storage reduction with ~92% recall at top-100
|
|
2555
|
+
conn.execute(f"""
|
|
2556
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
|
|
2557
|
+
person_id INTEGER PRIMARY KEY,
|
|
2558
|
+
embedding int8[{self._embedding_dim}]
|
|
2559
|
+
)
|
|
2560
|
+
""")
|
|
2561
|
+
|
|
2562
|
+
# Create QID labels lookup table for Wikidata QID -> label mappings
|
|
2563
|
+
conn.execute("""
|
|
2564
|
+
CREATE TABLE IF NOT EXISTS qid_labels (
|
|
2565
|
+
qid TEXT PRIMARY KEY,
|
|
2566
|
+
label TEXT NOT NULL
|
|
2567
|
+
)
|
|
2568
|
+
""")
|
|
2569
|
+
|
|
2570
|
+
conn.commit()
|
|
2571
|
+
|
|
2572
|
+
def _migrate_people_schema_if_needed(self, conn: sqlite3.Connection) -> None:
|
|
2573
|
+
"""Migrate people table from old schema if needed."""
|
|
2574
|
+
# Check if people table exists
|
|
2575
|
+
cursor = conn.execute(
|
|
2576
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='people'"
|
|
2577
|
+
)
|
|
2578
|
+
if not cursor.fetchone():
|
|
2579
|
+
return # Table doesn't exist, no migration needed
|
|
2580
|
+
|
|
2581
|
+
# Check the unique constraint - look at index info
|
|
2582
|
+
# Old schema: UNIQUE(source, source_id)
|
|
2583
|
+
# New schema: UNIQUE(source, source_id, known_for_role, known_for_org)
|
|
2584
|
+
cursor = conn.execute("PRAGMA index_list(people)")
|
|
2585
|
+
indexes = cursor.fetchall()
|
|
2586
|
+
|
|
2587
|
+
needs_migration = False
|
|
2588
|
+
for idx in indexes:
|
|
2589
|
+
idx_name = idx[1]
|
|
2590
|
+
if "sqlite_autoindex_people" in idx_name:
|
|
2591
|
+
# Check columns in this unique index
|
|
2592
|
+
cursor = conn.execute(f"PRAGMA index_info('{idx_name}')")
|
|
2593
|
+
cols = [row[2] for row in cursor.fetchall()]
|
|
2594
|
+
# Old schema has only 2 columns in unique constraint
|
|
2595
|
+
if cols == ["source", "source_id"]:
|
|
2596
|
+
needs_migration = True
|
|
2597
|
+
logger.info("Detected old people schema, migrating to new unique constraint...")
|
|
2598
|
+
break
|
|
2599
|
+
|
|
2600
|
+
if not needs_migration:
|
|
2601
|
+
return
|
|
2602
|
+
|
|
2603
|
+
# Migrate: create new table, copy data, drop old, rename new
|
|
2604
|
+
logger.info("Migrating people table to new schema with (source, source_id, role, org) unique constraint...")
|
|
2605
|
+
|
|
2606
|
+
conn.execute("""
|
|
2607
|
+
CREATE TABLE people_new (
|
|
2608
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
2609
|
+
name TEXT NOT NULL,
|
|
2610
|
+
name_normalized TEXT NOT NULL,
|
|
2611
|
+
source TEXT NOT NULL DEFAULT 'wikidata',
|
|
2612
|
+
source_id TEXT NOT NULL,
|
|
2613
|
+
country TEXT NOT NULL DEFAULT '',
|
|
2614
|
+
person_type TEXT NOT NULL DEFAULT 'unknown',
|
|
2615
|
+
known_for_role TEXT NOT NULL DEFAULT '',
|
|
2616
|
+
known_for_org TEXT NOT NULL DEFAULT '',
|
|
2617
|
+
known_for_org_id INTEGER DEFAULT NULL,
|
|
2618
|
+
from_date TEXT NOT NULL DEFAULT '',
|
|
2619
|
+
to_date TEXT NOT NULL DEFAULT '',
|
|
2620
|
+
record TEXT NOT NULL,
|
|
2621
|
+
UNIQUE(source, source_id, known_for_role, known_for_org),
|
|
2622
|
+
FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
|
|
2623
|
+
)
|
|
2624
|
+
""")
|
|
2625
|
+
|
|
2626
|
+
# Copy data (old IDs will change, but embeddings table references them)
|
|
2627
|
+
# Note: old table may not have from_date/to_date columns, so use defaults
|
|
2628
|
+
conn.execute("""
|
|
2629
|
+
INSERT INTO people_new (name, name_normalized, source, source_id, country,
|
|
2630
|
+
person_type, known_for_role, known_for_org, record)
|
|
2631
|
+
SELECT name, name_normalized, source, source_id, country,
|
|
2632
|
+
person_type, known_for_role, known_for_org, record
|
|
2633
|
+
FROM people
|
|
2634
|
+
""")
|
|
2635
|
+
|
|
2636
|
+
# Drop old table and embeddings (IDs changed, embeddings are invalid)
|
|
2637
|
+
conn.execute("DROP TABLE IF EXISTS person_embeddings")
|
|
2638
|
+
conn.execute("DROP TABLE people")
|
|
2639
|
+
conn.execute("ALTER TABLE people_new RENAME TO people")
|
|
2640
|
+
|
|
2641
|
+
# Drop old index if it exists
|
|
2642
|
+
conn.execute("DROP INDEX IF EXISTS idx_people_source_id")
|
|
2643
|
+
|
|
2644
|
+
conn.commit()
|
|
2645
|
+
logger.info("Migration complete. Note: person embeddings were cleared and need to be regenerated.")
|
|
2646
|
+
|
|
1227
2647
|
def close(self) -> None:
|
|
1228
|
-
"""
|
|
1229
|
-
|
|
1230
|
-
self._conn.close()
|
|
1231
|
-
self._conn = None
|
|
2648
|
+
"""Clear connection reference (shared connection remains open)."""
|
|
2649
|
+
self._conn = None
|
|
1232
2650
|
|
|
1233
|
-
def insert(
|
|
2651
|
+
def insert(
|
|
2652
|
+
self,
|
|
2653
|
+
record: PersonRecord,
|
|
2654
|
+
embedding: np.ndarray,
|
|
2655
|
+
scalar_embedding: Optional[np.ndarray] = None,
|
|
2656
|
+
) -> int:
|
|
1234
2657
|
"""
|
|
1235
2658
|
Insert a person record with its embedding.
|
|
1236
2659
|
|
|
1237
2660
|
Args:
|
|
1238
2661
|
record: Person record to insert
|
|
1239
|
-
embedding: Embedding vector for the person name
|
|
2662
|
+
embedding: Embedding vector for the person name (float32)
|
|
2663
|
+
scalar_embedding: Optional int8 scalar embedding for compact storage
|
|
1240
2664
|
|
|
1241
2665
|
Returns:
|
|
1242
2666
|
Row ID of inserted record
|
|
@@ -1249,8 +2673,10 @@ class PersonDatabase:
|
|
|
1249
2673
|
|
|
1250
2674
|
cursor = conn.execute("""
|
|
1251
2675
|
INSERT OR REPLACE INTO people
|
|
1252
|
-
(name, name_normalized, source, source_id, country, person_type,
|
|
1253
|
-
|
|
2676
|
+
(name, name_normalized, source, source_id, country, person_type,
|
|
2677
|
+
known_for_role, known_for_org, known_for_org_id, from_date, to_date,
|
|
2678
|
+
birth_date, death_date, record)
|
|
2679
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
1254
2680
|
""", (
|
|
1255
2681
|
record.name,
|
|
1256
2682
|
name_normalized,
|
|
@@ -1260,19 +2686,34 @@ class PersonDatabase:
|
|
|
1260
2686
|
record.person_type.value,
|
|
1261
2687
|
record.known_for_role,
|
|
1262
2688
|
record.known_for_org,
|
|
2689
|
+
record.known_for_org_id, # Can be None
|
|
2690
|
+
record.from_date or "",
|
|
2691
|
+
record.to_date or "",
|
|
2692
|
+
record.birth_date or "",
|
|
2693
|
+
record.death_date or "",
|
|
1263
2694
|
record_json,
|
|
1264
2695
|
))
|
|
1265
2696
|
|
|
1266
2697
|
row_id = cursor.lastrowid
|
|
1267
2698
|
assert row_id is not None
|
|
1268
2699
|
|
|
1269
|
-
# Insert embedding into vec table
|
|
2700
|
+
# Insert embedding into vec table (float32)
|
|
1270
2701
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
2702
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
|
|
1271
2703
|
conn.execute("""
|
|
1272
|
-
INSERT
|
|
2704
|
+
INSERT INTO person_embeddings (person_id, embedding)
|
|
1273
2705
|
VALUES (?, ?)
|
|
1274
2706
|
""", (row_id, embedding_blob))
|
|
1275
2707
|
|
|
2708
|
+
# Insert scalar embedding if provided (int8)
|
|
2709
|
+
if scalar_embedding is not None:
|
|
2710
|
+
scalar_blob = scalar_embedding.astype(np.int8).tobytes()
|
|
2711
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
|
|
2712
|
+
conn.execute("""
|
|
2713
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
2714
|
+
VALUES (?, vec_int8(?))
|
|
2715
|
+
""", (row_id, scalar_blob))
|
|
2716
|
+
|
|
1276
2717
|
conn.commit()
|
|
1277
2718
|
return row_id
|
|
1278
2719
|
|
|
@@ -1281,14 +2722,16 @@ class PersonDatabase:
|
|
|
1281
2722
|
records: list[PersonRecord],
|
|
1282
2723
|
embeddings: np.ndarray,
|
|
1283
2724
|
batch_size: int = 1000,
|
|
2725
|
+
scalar_embeddings: Optional[np.ndarray] = None,
|
|
1284
2726
|
) -> int:
|
|
1285
2727
|
"""
|
|
1286
2728
|
Insert multiple person records with embeddings.
|
|
1287
2729
|
|
|
1288
2730
|
Args:
|
|
1289
2731
|
records: List of person records
|
|
1290
|
-
embeddings: Matrix of embeddings (N x dim)
|
|
2732
|
+
embeddings: Matrix of embeddings (N x dim) - float32
|
|
1291
2733
|
batch_size: Commit batch size
|
|
2734
|
+
scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
|
|
1292
2735
|
|
|
1293
2736
|
Returns:
|
|
1294
2737
|
Number of records inserted
|
|
@@ -1296,36 +2739,87 @@ class PersonDatabase:
|
|
|
1296
2739
|
conn = self._connect()
|
|
1297
2740
|
count = 0
|
|
1298
2741
|
|
|
1299
|
-
for record, embedding in zip(records, embeddings):
|
|
2742
|
+
for i, (record, embedding) in enumerate(zip(records, embeddings)):
|
|
1300
2743
|
record_json = json.dumps(record.record)
|
|
1301
2744
|
name_normalized = _normalize_person_name(record.name)
|
|
1302
2745
|
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
record.
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
2746
|
+
if self._is_v2:
|
|
2747
|
+
# v2 schema: use FK IDs instead of TEXT columns
|
|
2748
|
+
source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
|
|
2749
|
+
person_type_id = PEOPLE_TYPE_NAME_TO_ID.get(record.person_type.value, 15) # 15 = unknown
|
|
2750
|
+
|
|
2751
|
+
# Resolve country to location_id if provided
|
|
2752
|
+
country_id = None
|
|
2753
|
+
if record.country:
|
|
2754
|
+
locations_db = get_locations_database(db_path=self._db_path)
|
|
2755
|
+
country_id = locations_db.resolve_region_text(record.country)
|
|
2756
|
+
|
|
2757
|
+
cursor = conn.execute("""
|
|
2758
|
+
INSERT OR REPLACE INTO people
|
|
2759
|
+
(name, name_normalized, source_id, source_identifier, country_id, person_type_id,
|
|
2760
|
+
known_for_org, known_for_org_id, from_date, to_date,
|
|
2761
|
+
birth_date, death_date, record)
|
|
2762
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
2763
|
+
""", (
|
|
2764
|
+
record.name,
|
|
2765
|
+
name_normalized,
|
|
2766
|
+
source_type_id,
|
|
2767
|
+
record.source_id,
|
|
2768
|
+
country_id,
|
|
2769
|
+
person_type_id,
|
|
2770
|
+
record.known_for_org,
|
|
2771
|
+
record.known_for_org_id, # Can be None
|
|
2772
|
+
record.from_date or "",
|
|
2773
|
+
record.to_date or "",
|
|
2774
|
+
record.birth_date or "",
|
|
2775
|
+
record.death_date or "",
|
|
2776
|
+
record_json,
|
|
2777
|
+
))
|
|
2778
|
+
else:
|
|
2779
|
+
# v1 schema: use TEXT columns
|
|
2780
|
+
cursor = conn.execute("""
|
|
2781
|
+
INSERT OR REPLACE INTO people
|
|
2782
|
+
(name, name_normalized, source, source_id, country, person_type,
|
|
2783
|
+
known_for_role, known_for_org, known_for_org_id, from_date, to_date,
|
|
2784
|
+
birth_date, death_date, record)
|
|
2785
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
2786
|
+
""", (
|
|
2787
|
+
record.name,
|
|
2788
|
+
name_normalized,
|
|
2789
|
+
record.source,
|
|
2790
|
+
record.source_id,
|
|
2791
|
+
record.country,
|
|
2792
|
+
record.person_type.value,
|
|
2793
|
+
record.known_for_role,
|
|
2794
|
+
record.known_for_org,
|
|
2795
|
+
record.known_for_org_id, # Can be None
|
|
2796
|
+
record.from_date or "",
|
|
2797
|
+
record.to_date or "",
|
|
2798
|
+
record.birth_date or "",
|
|
2799
|
+
record.death_date or "",
|
|
2800
|
+
record_json,
|
|
2801
|
+
))
|
|
1318
2802
|
|
|
1319
2803
|
row_id = cursor.lastrowid
|
|
1320
2804
|
assert row_id is not None
|
|
1321
2805
|
|
|
1322
|
-
# Insert embedding
|
|
2806
|
+
# Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
|
|
1323
2807
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
2808
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
|
|
1324
2809
|
conn.execute("""
|
|
1325
|
-
INSERT
|
|
2810
|
+
INSERT INTO person_embeddings (person_id, embedding)
|
|
1326
2811
|
VALUES (?, ?)
|
|
1327
2812
|
""", (row_id, embedding_blob))
|
|
1328
2813
|
|
|
2814
|
+
# Insert scalar embedding if provided (int8)
|
|
2815
|
+
if scalar_embeddings is not None:
|
|
2816
|
+
scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
|
|
2817
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
|
|
2818
|
+
conn.execute("""
|
|
2819
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
2820
|
+
VALUES (?, vec_int8(?))
|
|
2821
|
+
""", (row_id, scalar_blob))
|
|
2822
|
+
|
|
1329
2823
|
count += 1
|
|
1330
2824
|
|
|
1331
2825
|
if count % batch_size == 0:
|
|
@@ -1335,6 +2829,102 @@ class PersonDatabase:
|
|
|
1335
2829
|
conn.commit()
|
|
1336
2830
|
return count
|
|
1337
2831
|
|
|
2832
|
+
def update_dates(self, source: str, source_id: str, from_date: Optional[str], to_date: Optional[str]) -> bool:
|
|
2833
|
+
"""
|
|
2834
|
+
Update the from_date and to_date for a person record.
|
|
2835
|
+
|
|
2836
|
+
Args:
|
|
2837
|
+
source: Data source (e.g., 'wikidata')
|
|
2838
|
+
source_id: Source identifier (e.g., QID)
|
|
2839
|
+
from_date: Start date in ISO format or None
|
|
2840
|
+
to_date: End date in ISO format or None
|
|
2841
|
+
|
|
2842
|
+
Returns:
|
|
2843
|
+
True if record was updated, False if not found
|
|
2844
|
+
"""
|
|
2845
|
+
conn = self._connect()
|
|
2846
|
+
|
|
2847
|
+
if self._is_v2:
|
|
2848
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
2849
|
+
cursor = conn.execute("""
|
|
2850
|
+
UPDATE people SET from_date = ?, to_date = ?
|
|
2851
|
+
WHERE source_id = ? AND source_identifier = ?
|
|
2852
|
+
""", (from_date or "", to_date or "", source_type_id, source_id))
|
|
2853
|
+
else:
|
|
2854
|
+
cursor = conn.execute("""
|
|
2855
|
+
UPDATE people SET from_date = ?, to_date = ?
|
|
2856
|
+
WHERE source = ? AND source_id = ?
|
|
2857
|
+
""", (from_date or "", to_date or "", source, source_id))
|
|
2858
|
+
|
|
2859
|
+
conn.commit()
|
|
2860
|
+
return cursor.rowcount > 0
|
|
2861
|
+
|
|
2862
|
+
def update_role_org(
|
|
2863
|
+
self,
|
|
2864
|
+
source: str,
|
|
2865
|
+
source_id: str,
|
|
2866
|
+
known_for_role: str,
|
|
2867
|
+
known_for_org: str,
|
|
2868
|
+
known_for_org_id: Optional[int],
|
|
2869
|
+
new_embedding: np.ndarray,
|
|
2870
|
+
from_date: Optional[str] = None,
|
|
2871
|
+
to_date: Optional[str] = None,
|
|
2872
|
+
) -> bool:
|
|
2873
|
+
"""
|
|
2874
|
+
Update the role/org/dates data for a person record and re-embed.
|
|
2875
|
+
|
|
2876
|
+
Args:
|
|
2877
|
+
source: Data source (e.g., 'wikidata')
|
|
2878
|
+
source_id: Source identifier (e.g., QID)
|
|
2879
|
+
known_for_role: Role/position title
|
|
2880
|
+
known_for_org: Organization name
|
|
2881
|
+
known_for_org_id: Organization internal ID (FK) or None
|
|
2882
|
+
new_embedding: New embedding vector based on updated data
|
|
2883
|
+
from_date: Start date in ISO format or None
|
|
2884
|
+
to_date: End date in ISO format or None
|
|
2885
|
+
|
|
2886
|
+
Returns:
|
|
2887
|
+
True if record was updated, False if not found
|
|
2888
|
+
"""
|
|
2889
|
+
conn = self._connect()
|
|
2890
|
+
|
|
2891
|
+
# First get the person's internal ID
|
|
2892
|
+
if self._is_v2:
|
|
2893
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
2894
|
+
row = conn.execute(
|
|
2895
|
+
"SELECT id FROM people WHERE source_id = ? AND source_identifier = ?",
|
|
2896
|
+
(source_type_id, source_id)
|
|
2897
|
+
).fetchone()
|
|
2898
|
+
else:
|
|
2899
|
+
row = conn.execute(
|
|
2900
|
+
"SELECT id FROM people WHERE source = ? AND source_id = ?",
|
|
2901
|
+
(source, source_id)
|
|
2902
|
+
).fetchone()
|
|
2903
|
+
|
|
2904
|
+
if not row:
|
|
2905
|
+
return False
|
|
2906
|
+
|
|
2907
|
+
person_id = row[0]
|
|
2908
|
+
|
|
2909
|
+
# Update the person record (including dates)
|
|
2910
|
+
conn.execute("""
|
|
2911
|
+
UPDATE people SET
|
|
2912
|
+
known_for_role = ?, known_for_org = ?, known_for_org_id = ?,
|
|
2913
|
+
from_date = COALESCE(?, from_date, ''),
|
|
2914
|
+
to_date = COALESCE(?, to_date, '')
|
|
2915
|
+
WHERE id = ?
|
|
2916
|
+
""", (known_for_role, known_for_org, known_for_org_id, from_date, to_date, person_id))
|
|
2917
|
+
|
|
2918
|
+
# Update the embedding
|
|
2919
|
+
embedding_bytes = new_embedding.astype(np.float32).tobytes()
|
|
2920
|
+
conn.execute("""
|
|
2921
|
+
UPDATE people_vec SET embedding = ?
|
|
2922
|
+
WHERE rowid = ?
|
|
2923
|
+
""", (embedding_bytes, person_id))
|
|
2924
|
+
|
|
2925
|
+
conn.commit()
|
|
2926
|
+
return True
|
|
2927
|
+
|
|
1338
2928
|
def search(
|
|
1339
2929
|
self,
|
|
1340
2930
|
query_embedding: np.ndarray,
|
|
@@ -1366,7 +2956,13 @@ class PersonDatabase:
|
|
|
1366
2956
|
if query_norm == 0:
|
|
1367
2957
|
return []
|
|
1368
2958
|
query_normalized = query_embedding / query_norm
|
|
1369
|
-
|
|
2959
|
+
|
|
2960
|
+
# Use int8 quantized query if scalar table is available (75% storage savings)
|
|
2961
|
+
if self._has_scalar_table():
|
|
2962
|
+
query_int8 = self._quantize_query(query_normalized)
|
|
2963
|
+
query_blob = query_int8.tobytes()
|
|
2964
|
+
else:
|
|
2965
|
+
query_blob = query_normalized.astype(np.float32).tobytes()
|
|
1370
2966
|
|
|
1371
2967
|
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
1372
2968
|
candidate_ids: Optional[set[int]] = None
|
|
@@ -1437,13 +3033,26 @@ class PersonDatabase:
|
|
|
1437
3033
|
cursor = conn.execute(query, params)
|
|
1438
3034
|
return set(row["id"] for row in cursor)
|
|
1439
3035
|
|
|
3036
|
+
def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
|
|
3037
|
+
"""Quantize query embedding to int8 for scalar search."""
|
|
3038
|
+
return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
|
|
3039
|
+
|
|
3040
|
+
def _has_scalar_table(self) -> bool:
|
|
3041
|
+
"""Check if scalar embedding table exists."""
|
|
3042
|
+
conn = self._conn
|
|
3043
|
+
assert conn is not None
|
|
3044
|
+
cursor = conn.execute(
|
|
3045
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='person_embeddings_scalar'"
|
|
3046
|
+
)
|
|
3047
|
+
return cursor.fetchone() is not None
|
|
3048
|
+
|
|
1440
3049
|
def _vector_search_filtered(
|
|
1441
3050
|
self,
|
|
1442
3051
|
query_blob: bytes,
|
|
1443
3052
|
candidate_ids: set[int],
|
|
1444
3053
|
top_k: int,
|
|
1445
3054
|
) -> list[tuple[PersonRecord, float]]:
|
|
1446
|
-
"""Vector search within a filtered set of candidates."""
|
|
3055
|
+
"""Vector search within a filtered set of candidates using scalar (int8) embeddings."""
|
|
1447
3056
|
conn = self._conn
|
|
1448
3057
|
assert conn is not None
|
|
1449
3058
|
|
|
@@ -1453,15 +3062,27 @@ class PersonDatabase:
|
|
|
1453
3062
|
# Build IN clause for candidate IDs
|
|
1454
3063
|
placeholders = ",".join("?" * len(candidate_ids))
|
|
1455
3064
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
3065
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
3066
|
+
if self._has_scalar_table():
|
|
3067
|
+
query = f"""
|
|
3068
|
+
SELECT
|
|
3069
|
+
e.person_id,
|
|
3070
|
+
vec_distance_cosine(e.embedding, vec_int8(?)) as distance
|
|
3071
|
+
FROM person_embeddings_scalar e
|
|
3072
|
+
WHERE e.person_id IN ({placeholders})
|
|
3073
|
+
ORDER BY distance
|
|
3074
|
+
LIMIT ?
|
|
3075
|
+
"""
|
|
3076
|
+
else:
|
|
3077
|
+
query = f"""
|
|
3078
|
+
SELECT
|
|
3079
|
+
e.person_id,
|
|
3080
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
3081
|
+
FROM person_embeddings e
|
|
3082
|
+
WHERE e.person_id IN ({placeholders})
|
|
3083
|
+
ORDER BY distance
|
|
3084
|
+
LIMIT ?
|
|
3085
|
+
"""
|
|
1465
3086
|
|
|
1466
3087
|
cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
|
|
1467
3088
|
|
|
@@ -1484,18 +3105,29 @@ class PersonDatabase:
|
|
|
1484
3105
|
query_blob: bytes,
|
|
1485
3106
|
top_k: int,
|
|
1486
3107
|
) -> list[tuple[PersonRecord, float]]:
|
|
1487
|
-
"""Full vector search without text pre-filtering."""
|
|
3108
|
+
"""Full vector search without text pre-filtering using scalar (int8) embeddings."""
|
|
1488
3109
|
conn = self._conn
|
|
1489
3110
|
assert conn is not None
|
|
1490
3111
|
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
3112
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
3113
|
+
if self._has_scalar_table():
|
|
3114
|
+
query = """
|
|
3115
|
+
SELECT
|
|
3116
|
+
person_id,
|
|
3117
|
+
vec_distance_cosine(embedding, vec_int8(?)) as distance
|
|
3118
|
+
FROM person_embeddings_scalar
|
|
3119
|
+
ORDER BY distance
|
|
3120
|
+
LIMIT ?
|
|
3121
|
+
"""
|
|
3122
|
+
else:
|
|
3123
|
+
query = """
|
|
3124
|
+
SELECT
|
|
3125
|
+
person_id,
|
|
3126
|
+
vec_distance_cosine(embedding, ?) as distance
|
|
3127
|
+
FROM person_embeddings
|
|
3128
|
+
ORDER BY distance
|
|
3129
|
+
LIMIT ?
|
|
3130
|
+
"""
|
|
1499
3131
|
cursor = conn.execute(query, (query_blob, top_k))
|
|
1500
3132
|
|
|
1501
3133
|
results = []
|
|
@@ -1515,21 +3147,36 @@ class PersonDatabase:
|
|
|
1515
3147
|
conn = self._conn
|
|
1516
3148
|
assert conn is not None
|
|
1517
3149
|
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
3150
|
+
if self._is_v2:
|
|
3151
|
+
# v2 schema: join view with base table for record
|
|
3152
|
+
cursor = conn.execute("""
|
|
3153
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3154
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3155
|
+
v.birth_date, v.death_date, p.record
|
|
3156
|
+
FROM people_view v
|
|
3157
|
+
JOIN people p ON v.id = p.id
|
|
3158
|
+
WHERE v.id = ?
|
|
3159
|
+
""", (person_id,))
|
|
3160
|
+
else:
|
|
3161
|
+
cursor = conn.execute("""
|
|
3162
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3163
|
+
FROM people WHERE id = ?
|
|
3164
|
+
""", (person_id,))
|
|
1522
3165
|
|
|
1523
3166
|
row = cursor.fetchone()
|
|
1524
3167
|
if row:
|
|
3168
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
1525
3169
|
return PersonRecord(
|
|
1526
3170
|
name=row["name"],
|
|
1527
3171
|
source=row["source"],
|
|
1528
|
-
source_id=row[
|
|
3172
|
+
source_id=row[source_id_field],
|
|
1529
3173
|
country=row["country"] or "",
|
|
1530
3174
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1531
3175
|
known_for_role=row["known_for_role"] or "",
|
|
1532
3176
|
known_for_org=row["known_for_org"] or "",
|
|
3177
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
3178
|
+
birth_date=row["birth_date"] or "",
|
|
3179
|
+
death_date=row["death_date"] or "",
|
|
1533
3180
|
record=json.loads(row["record"]),
|
|
1534
3181
|
)
|
|
1535
3182
|
return None
|
|
@@ -1538,22 +3185,37 @@ class PersonDatabase:
|
|
|
1538
3185
|
"""Get a person record by source and source_id."""
|
|
1539
3186
|
conn = self._connect()
|
|
1540
3187
|
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
3188
|
+
if self._is_v2:
|
|
3189
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
3190
|
+
cursor = conn.execute("""
|
|
3191
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3192
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3193
|
+
v.birth_date, v.death_date, p.record
|
|
3194
|
+
FROM people_view v
|
|
3195
|
+
JOIN people p ON v.id = p.id
|
|
3196
|
+
WHERE p.source_id = ? AND p.source_identifier = ?
|
|
3197
|
+
""", (source_type_id, source_id))
|
|
3198
|
+
else:
|
|
3199
|
+
cursor = conn.execute("""
|
|
3200
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3201
|
+
FROM people
|
|
3202
|
+
WHERE source = ? AND source_id = ?
|
|
3203
|
+
""", (source, source_id))
|
|
1546
3204
|
|
|
1547
3205
|
row = cursor.fetchone()
|
|
1548
3206
|
if row:
|
|
3207
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
1549
3208
|
return PersonRecord(
|
|
1550
3209
|
name=row["name"],
|
|
1551
3210
|
source=row["source"],
|
|
1552
|
-
source_id=row[
|
|
3211
|
+
source_id=row[source_id_field],
|
|
1553
3212
|
country=row["country"] or "",
|
|
1554
3213
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1555
3214
|
known_for_role=row["known_for_role"] or "",
|
|
1556
3215
|
known_for_org=row["known_for_org"] or "",
|
|
3216
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
3217
|
+
birth_date=row["birth_date"] or "",
|
|
3218
|
+
death_date=row["death_date"] or "",
|
|
1557
3219
|
record=json.loads(row["record"]),
|
|
1558
3220
|
)
|
|
1559
3221
|
return None
|
|
@@ -1566,12 +3228,32 @@ class PersonDatabase:
|
|
|
1566
3228
|
cursor = conn.execute("SELECT COUNT(*) FROM people")
|
|
1567
3229
|
total = cursor.fetchone()[0]
|
|
1568
3230
|
|
|
1569
|
-
# Count by person_type
|
|
1570
|
-
|
|
3231
|
+
# Count by person_type - handle both v1 and v2 schema
|
|
3232
|
+
if self._is_v2:
|
|
3233
|
+
# v2 schema - join with people_types
|
|
3234
|
+
cursor = conn.execute("""
|
|
3235
|
+
SELECT pt.name as person_type, COUNT(*) as cnt
|
|
3236
|
+
FROM people p
|
|
3237
|
+
JOIN people_types pt ON p.person_type_id = pt.id
|
|
3238
|
+
GROUP BY p.person_type_id
|
|
3239
|
+
""")
|
|
3240
|
+
else:
|
|
3241
|
+
# v1 schema
|
|
3242
|
+
cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
|
|
1571
3243
|
by_type = {row["person_type"]: row["cnt"] for row in cursor}
|
|
1572
3244
|
|
|
1573
|
-
# Count by source
|
|
1574
|
-
|
|
3245
|
+
# Count by source - handle both v1 and v2 schema
|
|
3246
|
+
if self._is_v2:
|
|
3247
|
+
# v2 schema - join with source_types
|
|
3248
|
+
cursor = conn.execute("""
|
|
3249
|
+
SELECT st.name as source, COUNT(*) as cnt
|
|
3250
|
+
FROM people p
|
|
3251
|
+
JOIN source_types st ON p.source_id = st.id
|
|
3252
|
+
GROUP BY p.source_id
|
|
3253
|
+
""")
|
|
3254
|
+
else:
|
|
3255
|
+
# v1 schema
|
|
3256
|
+
cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
|
|
1575
3257
|
by_source = {row["source"]: row["cnt"] for row in cursor}
|
|
1576
3258
|
|
|
1577
3259
|
return {
|
|
@@ -1580,30 +3262,1564 @@ class PersonDatabase:
|
|
|
1580
3262
|
"by_source": by_source,
|
|
1581
3263
|
}
|
|
1582
3264
|
|
|
1583
|
-
def
|
|
1584
|
-
"""
|
|
3265
|
+
def ensure_scalar_table_exists(self) -> None:
|
|
3266
|
+
"""Create scalar embedding table if it doesn't exist."""
|
|
1585
3267
|
conn = self._connect()
|
|
3268
|
+
conn.execute(f"""
|
|
3269
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
|
|
3270
|
+
person_id INTEGER PRIMARY KEY,
|
|
3271
|
+
embedding int8[{self._embedding_dim}]
|
|
3272
|
+
)
|
|
3273
|
+
""")
|
|
3274
|
+
conn.commit()
|
|
3275
|
+
logger.info("Ensured person_embeddings_scalar table exists")
|
|
1586
3276
|
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
3277
|
+
def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
|
|
3278
|
+
"""
|
|
3279
|
+
Yield batches of person IDs that have float32 but missing scalar embeddings.
|
|
3280
|
+
|
|
3281
|
+
Args:
|
|
3282
|
+
batch_size: Number of IDs per batch
|
|
3283
|
+
|
|
3284
|
+
Yields:
|
|
3285
|
+
Lists of person_ids needing scalar embeddings
|
|
3286
|
+
"""
|
|
3287
|
+
conn = self._connect()
|
|
3288
|
+
|
|
3289
|
+
# Ensure scalar table exists before querying
|
|
3290
|
+
self.ensure_scalar_table_exists()
|
|
3291
|
+
|
|
3292
|
+
last_id = 0
|
|
3293
|
+
while True:
|
|
1594
3294
|
cursor = conn.execute("""
|
|
1595
|
-
SELECT
|
|
1596
|
-
|
|
1597
|
-
|
|
3295
|
+
SELECT e.person_id FROM person_embeddings e
|
|
3296
|
+
LEFT JOIN person_embeddings_scalar s ON e.person_id = s.person_id
|
|
3297
|
+
WHERE s.person_id IS NULL AND e.person_id > ?
|
|
3298
|
+
ORDER BY e.person_id
|
|
3299
|
+
LIMIT ?
|
|
3300
|
+
""", (last_id, batch_size))
|
|
3301
|
+
|
|
3302
|
+
rows = cursor.fetchall()
|
|
3303
|
+
if not rows:
|
|
3304
|
+
break
|
|
3305
|
+
|
|
3306
|
+
ids = [row["person_id"] for row in rows]
|
|
3307
|
+
yield ids
|
|
3308
|
+
last_id = ids[-1]
|
|
3309
|
+
|
|
3310
|
+
def get_embeddings_by_ids(self, person_ids: list[int]) -> dict[int, np.ndarray]:
|
|
3311
|
+
"""
|
|
3312
|
+
Fetch float32 embeddings for given person IDs.
|
|
3313
|
+
|
|
3314
|
+
Args:
|
|
3315
|
+
person_ids: List of person IDs
|
|
1598
3316
|
|
|
3317
|
+
Returns:
|
|
3318
|
+
Dict mapping person_id to float32 embedding array
|
|
3319
|
+
"""
|
|
3320
|
+
conn = self._connect()
|
|
3321
|
+
|
|
3322
|
+
if not person_ids:
|
|
3323
|
+
return {}
|
|
3324
|
+
|
|
3325
|
+
placeholders = ",".join("?" * len(person_ids))
|
|
3326
|
+
cursor = conn.execute(f"""
|
|
3327
|
+
SELECT person_id, embedding FROM person_embeddings
|
|
3328
|
+
WHERE person_id IN ({placeholders})
|
|
3329
|
+
""", person_ids)
|
|
3330
|
+
|
|
3331
|
+
result = {}
|
|
1599
3332
|
for row in cursor:
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
3333
|
+
embedding_blob = row["embedding"]
|
|
3334
|
+
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
3335
|
+
result[row["person_id"]] = embedding
|
|
3336
|
+
return result
|
|
3337
|
+
|
|
3338
|
+
def insert_scalar_embeddings_batch(self, person_ids: list[int], embeddings: np.ndarray) -> int:
|
|
3339
|
+
"""
|
|
3340
|
+
Insert scalar (int8) embeddings for existing people.
|
|
3341
|
+
|
|
3342
|
+
Args:
|
|
3343
|
+
person_ids: List of person IDs
|
|
3344
|
+
embeddings: Matrix of int8 embeddings (N x dim)
|
|
3345
|
+
|
|
3346
|
+
Returns:
|
|
3347
|
+
Number of embeddings inserted
|
|
3348
|
+
"""
|
|
3349
|
+
conn = self._connect()
|
|
3350
|
+
count = 0
|
|
3351
|
+
|
|
3352
|
+
for person_id, embedding in zip(person_ids, embeddings):
|
|
3353
|
+
scalar_blob = embedding.astype(np.int8).tobytes()
|
|
3354
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
|
|
3355
|
+
conn.execute("""
|
|
3356
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
3357
|
+
VALUES (?, vec_int8(?))
|
|
3358
|
+
""", (person_id, scalar_blob))
|
|
3359
|
+
count += 1
|
|
3360
|
+
|
|
3361
|
+
conn.commit()
|
|
3362
|
+
return count
|
|
3363
|
+
|
|
3364
|
+
def get_scalar_embedding_count(self) -> int:
|
|
3365
|
+
"""Get count of scalar embeddings."""
|
|
3366
|
+
conn = self._connect()
|
|
3367
|
+
if not self._has_scalar_table():
|
|
3368
|
+
return 0
|
|
3369
|
+
cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings_scalar")
|
|
3370
|
+
return cursor.fetchone()[0]
|
|
3371
|
+
|
|
3372
|
+
def get_float32_embedding_count(self) -> int:
|
|
3373
|
+
"""Get count of float32 embeddings."""
|
|
3374
|
+
conn = self._connect()
|
|
3375
|
+
cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings")
|
|
3376
|
+
return cursor.fetchone()[0]
|
|
3377
|
+
|
|
3378
|
+
def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
|
|
3379
|
+
"""
|
|
3380
|
+
Yield batches of (person_id, name) tuples for records missing both float32 and scalar embeddings.
|
|
3381
|
+
|
|
3382
|
+
Args:
|
|
3383
|
+
batch_size: Number of IDs per batch
|
|
3384
|
+
|
|
3385
|
+
Yields:
|
|
3386
|
+
Lists of (person_id, name) tuples needing embeddings generated from scratch
|
|
3387
|
+
"""
|
|
3388
|
+
conn = self._connect()
|
|
3389
|
+
|
|
3390
|
+
# Ensure scalar table exists
|
|
3391
|
+
self.ensure_scalar_table_exists()
|
|
3392
|
+
|
|
3393
|
+
last_id = 0
|
|
3394
|
+
while True:
|
|
3395
|
+
cursor = conn.execute("""
|
|
3396
|
+
SELECT p.id, p.name FROM people p
|
|
3397
|
+
LEFT JOIN person_embeddings e ON p.id = e.person_id
|
|
3398
|
+
WHERE e.person_id IS NULL AND p.id > ?
|
|
3399
|
+
ORDER BY p.id
|
|
3400
|
+
LIMIT ?
|
|
3401
|
+
""", (last_id, batch_size))
|
|
3402
|
+
|
|
3403
|
+
rows = cursor.fetchall()
|
|
3404
|
+
if not rows:
|
|
3405
|
+
break
|
|
3406
|
+
|
|
3407
|
+
results = [(row["id"], row["name"]) for row in rows]
|
|
3408
|
+
yield results
|
|
3409
|
+
last_id = results[-1][0]
|
|
3410
|
+
|
|
3411
|
+
def insert_both_embeddings_batch(
|
|
3412
|
+
self,
|
|
3413
|
+
person_ids: list[int],
|
|
3414
|
+
fp32_embeddings: np.ndarray,
|
|
3415
|
+
int8_embeddings: np.ndarray,
|
|
3416
|
+
) -> int:
|
|
3417
|
+
"""
|
|
3418
|
+
Insert both float32 and int8 embeddings for existing people.
|
|
3419
|
+
|
|
3420
|
+
Args:
|
|
3421
|
+
person_ids: List of person IDs
|
|
3422
|
+
fp32_embeddings: Matrix of float32 embeddings (N x dim)
|
|
3423
|
+
int8_embeddings: Matrix of int8 embeddings (N x dim)
|
|
3424
|
+
|
|
3425
|
+
Returns:
|
|
3426
|
+
Number of embeddings inserted
|
|
3427
|
+
"""
|
|
3428
|
+
conn = self._connect()
|
|
3429
|
+
count = 0
|
|
3430
|
+
|
|
3431
|
+
for person_id, fp32, int8 in zip(person_ids, fp32_embeddings, int8_embeddings):
|
|
3432
|
+
# Insert float32
|
|
3433
|
+
fp32_blob = fp32.astype(np.float32).tobytes()
|
|
3434
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
|
|
3435
|
+
conn.execute("""
|
|
3436
|
+
INSERT INTO person_embeddings (person_id, embedding)
|
|
3437
|
+
VALUES (?, ?)
|
|
3438
|
+
""", (person_id, fp32_blob))
|
|
3439
|
+
|
|
3440
|
+
# Insert int8
|
|
3441
|
+
int8_blob = int8.astype(np.int8).tobytes()
|
|
3442
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
|
|
3443
|
+
conn.execute("""
|
|
3444
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
3445
|
+
VALUES (?, vec_int8(?))
|
|
3446
|
+
""", (person_id, int8_blob))
|
|
3447
|
+
|
|
3448
|
+
count += 1
|
|
3449
|
+
|
|
3450
|
+
conn.commit()
|
|
3451
|
+
return count
|
|
3452
|
+
|
|
3453
|
+
def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
|
|
3454
|
+
"""
|
|
3455
|
+
Get all source_ids from the people table.
|
|
3456
|
+
|
|
3457
|
+
Useful for resume operations to skip already-imported records.
|
|
3458
|
+
|
|
3459
|
+
Args:
|
|
3460
|
+
source: Optional source filter (e.g., "wikidata")
|
|
3461
|
+
|
|
3462
|
+
Returns:
|
|
3463
|
+
Set of source_id strings (e.g., Q codes for Wikidata)
|
|
3464
|
+
"""
|
|
3465
|
+
conn = self._connect()
|
|
3466
|
+
|
|
3467
|
+
if self._is_v2:
|
|
3468
|
+
id_col = "source_identifier"
|
|
3469
|
+
if source:
|
|
3470
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
3471
|
+
cursor = conn.execute(
|
|
3472
|
+
f"SELECT DISTINCT {id_col} FROM people WHERE source_id = ?",
|
|
3473
|
+
(source_type_id,)
|
|
3474
|
+
)
|
|
3475
|
+
else:
|
|
3476
|
+
cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM people")
|
|
3477
|
+
else:
|
|
3478
|
+
if source:
|
|
3479
|
+
cursor = conn.execute(
|
|
3480
|
+
"SELECT DISTINCT source_id FROM people WHERE source = ?",
|
|
3481
|
+
(source,)
|
|
3482
|
+
)
|
|
3483
|
+
else:
|
|
3484
|
+
cursor = conn.execute("SELECT DISTINCT source_id FROM people")
|
|
3485
|
+
|
|
3486
|
+
return {row[0] for row in cursor}
|
|
3487
|
+
|
|
3488
|
+
def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
|
|
3489
|
+
"""Iterate over all person records, optionally filtered by source."""
|
|
3490
|
+
conn = self._connect()
|
|
3491
|
+
|
|
3492
|
+
if self._is_v2:
|
|
3493
|
+
if source:
|
|
3494
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
3495
|
+
cursor = conn.execute("""
|
|
3496
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3497
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3498
|
+
v.birth_date, v.death_date, p.record
|
|
3499
|
+
FROM people_view v
|
|
3500
|
+
JOIN people p ON v.id = p.id
|
|
3501
|
+
WHERE p.source_id = ?
|
|
3502
|
+
""", (source_type_id,))
|
|
3503
|
+
else:
|
|
3504
|
+
cursor = conn.execute("""
|
|
3505
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3506
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3507
|
+
v.birth_date, v.death_date, p.record
|
|
3508
|
+
FROM people_view v
|
|
3509
|
+
JOIN people p ON v.id = p.id
|
|
3510
|
+
""")
|
|
3511
|
+
for row in cursor:
|
|
3512
|
+
yield PersonRecord(
|
|
3513
|
+
name=row["name"],
|
|
3514
|
+
source=row["source"],
|
|
3515
|
+
source_id=row["source_identifier"],
|
|
3516
|
+
country=row["country"] or "",
|
|
3517
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
3518
|
+
known_for_role=row["known_for_role"] or "",
|
|
3519
|
+
known_for_org=row["known_for_org"] or "",
|
|
3520
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
3521
|
+
birth_date=row["birth_date"] or "",
|
|
3522
|
+
death_date=row["death_date"] or "",
|
|
3523
|
+
record=json.loads(row["record"]),
|
|
3524
|
+
)
|
|
3525
|
+
else:
|
|
3526
|
+
if source:
|
|
3527
|
+
cursor = conn.execute("""
|
|
3528
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3529
|
+
FROM people
|
|
3530
|
+
WHERE source = ?
|
|
3531
|
+
""", (source,))
|
|
3532
|
+
else:
|
|
3533
|
+
cursor = conn.execute("""
|
|
3534
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3535
|
+
FROM people
|
|
3536
|
+
""")
|
|
3537
|
+
|
|
3538
|
+
for row in cursor:
|
|
3539
|
+
yield PersonRecord(
|
|
3540
|
+
name=row["name"],
|
|
3541
|
+
source=row["source"],
|
|
3542
|
+
source_id=row["source_id"],
|
|
3543
|
+
country=row["country"] or "",
|
|
3544
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
3545
|
+
known_for_role=row["known_for_role"] or "",
|
|
3546
|
+
known_for_org=row["known_for_org"] or "",
|
|
3547
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
3548
|
+
birth_date=row["birth_date"] or "",
|
|
3549
|
+
death_date=row["death_date"] or "",
|
|
3550
|
+
record=json.loads(row["record"]),
|
|
3551
|
+
)
|
|
3552
|
+
|
|
3553
|
+
def resolve_qid_labels(
|
|
3554
|
+
self,
|
|
3555
|
+
label_map: dict[str, str],
|
|
3556
|
+
batch_size: int = 1000,
|
|
3557
|
+
) -> tuple[int, int]:
|
|
3558
|
+
"""
|
|
3559
|
+
Update records that have QIDs instead of labels.
|
|
3560
|
+
|
|
3561
|
+
This is called after dump import to resolve any QIDs that were
|
|
3562
|
+
stored because labels weren't available in the cache at import time.
|
|
3563
|
+
|
|
3564
|
+
If resolving would create a duplicate of an existing record with
|
|
3565
|
+
resolved labels, the QID version is deleted instead.
|
|
3566
|
+
|
|
3567
|
+
Args:
|
|
3568
|
+
label_map: Mapping of QID -> label for resolution
|
|
3569
|
+
batch_size: Commit batch size
|
|
3570
|
+
|
|
3571
|
+
Returns:
|
|
3572
|
+
Tuple of (updates, deletes)
|
|
3573
|
+
"""
|
|
3574
|
+
conn = self._connect()
|
|
3575
|
+
|
|
3576
|
+
# v2 schema stores QIDs as integers, not text - this method doesn't apply
|
|
3577
|
+
if self._is_v2:
|
|
3578
|
+
logger.debug("Skipping resolve_qid_labels for v2 schema (QIDs stored as integers)")
|
|
3579
|
+
return 0, 0
|
|
3580
|
+
|
|
3581
|
+
# Find all records with QIDs in any field (role or org - these are in unique constraint)
|
|
3582
|
+
# Country is not part of unique constraint so can be updated directly
|
|
3583
|
+
cursor = conn.execute("""
|
|
3584
|
+
SELECT id, source, source_id, country, known_for_role, known_for_org
|
|
3585
|
+
FROM people
|
|
3586
|
+
WHERE (country LIKE 'Q%' AND country GLOB 'Q[0-9]*')
|
|
3587
|
+
OR (known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*')
|
|
3588
|
+
OR (known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*')
|
|
3589
|
+
""")
|
|
3590
|
+
rows = cursor.fetchall()
|
|
3591
|
+
|
|
3592
|
+
updates = 0
|
|
3593
|
+
deletes = 0
|
|
3594
|
+
|
|
3595
|
+
for row in rows:
|
|
3596
|
+
person_id = row["id"]
|
|
3597
|
+
source = row["source"]
|
|
3598
|
+
source_id = row["source_id"]
|
|
3599
|
+
country = row["country"]
|
|
3600
|
+
role = row["known_for_role"]
|
|
3601
|
+
org = row["known_for_org"]
|
|
3602
|
+
|
|
3603
|
+
# Resolve QIDs to labels
|
|
3604
|
+
new_country = label_map.get(country, country) if country.startswith("Q") and country[1:].isdigit() else country
|
|
3605
|
+
new_role = label_map.get(role, role) if role.startswith("Q") and role[1:].isdigit() else role
|
|
3606
|
+
new_org = label_map.get(org, org) if org.startswith("Q") and org[1:].isdigit() else org
|
|
3607
|
+
|
|
3608
|
+
# Skip if nothing changed
|
|
3609
|
+
if new_country == country and new_role == role and new_org == org:
|
|
3610
|
+
continue
|
|
3611
|
+
|
|
3612
|
+
# Check if resolved values would duplicate an existing record
|
|
3613
|
+
# (unique constraint is on source, source_id, known_for_role, known_for_org)
|
|
3614
|
+
if new_role != role or new_org != org:
|
|
3615
|
+
cursor2 = conn.execute("""
|
|
3616
|
+
SELECT id FROM people
|
|
3617
|
+
WHERE source = ? AND source_id = ? AND known_for_role = ? AND known_for_org = ?
|
|
3618
|
+
AND id != ?
|
|
3619
|
+
""", (source, source_id, new_role, new_org, person_id))
|
|
3620
|
+
existing = cursor2.fetchone()
|
|
3621
|
+
|
|
3622
|
+
if existing:
|
|
3623
|
+
# Duplicate would exist - delete the QID version
|
|
3624
|
+
conn.execute("DELETE FROM people WHERE id = ?", (person_id,))
|
|
3625
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
|
|
3626
|
+
deletes += 1
|
|
3627
|
+
logger.debug(f"Deleted duplicate QID record {person_id} (source_id={source_id})")
|
|
3628
|
+
continue
|
|
3629
|
+
|
|
3630
|
+
# No duplicate - update in place
|
|
3631
|
+
conn.execute("""
|
|
3632
|
+
UPDATE people SET country = ?, known_for_role = ?, known_for_org = ?
|
|
3633
|
+
WHERE id = ?
|
|
3634
|
+
""", (new_country, new_role, new_org, person_id))
|
|
3635
|
+
updates += 1
|
|
3636
|
+
|
|
3637
|
+
if (updates + deletes) % batch_size == 0:
|
|
3638
|
+
conn.commit()
|
|
3639
|
+
logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes...")
|
|
3640
|
+
|
|
3641
|
+
conn.commit()
|
|
3642
|
+
logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes")
|
|
3643
|
+
return updates, deletes
|
|
3644
|
+
|
|
3645
|
+
def get_unresolved_qids(self) -> set[str]:
|
|
3646
|
+
"""
|
|
3647
|
+
Get all QIDs that still need resolution in the database.
|
|
3648
|
+
|
|
3649
|
+
Returns:
|
|
3650
|
+
Set of QIDs (starting with 'Q') found in country, role, or org fields
|
|
3651
|
+
"""
|
|
3652
|
+
conn = self._connect()
|
|
3653
|
+
|
|
3654
|
+
# v2 schema stores QIDs as integers, not text - this method doesn't apply
|
|
3655
|
+
if self._is_v2:
|
|
3656
|
+
return set()
|
|
3657
|
+
|
|
3658
|
+
qids: set[str] = set()
|
|
3659
|
+
|
|
3660
|
+
# Get QIDs from country field
|
|
3661
|
+
cursor = conn.execute("""
|
|
3662
|
+
SELECT DISTINCT country FROM people
|
|
3663
|
+
WHERE country LIKE 'Q%' AND country GLOB 'Q[0-9]*'
|
|
3664
|
+
""")
|
|
3665
|
+
for row in cursor:
|
|
3666
|
+
qids.add(row["country"])
|
|
3667
|
+
|
|
3668
|
+
# Get QIDs from known_for_role field
|
|
3669
|
+
cursor = conn.execute("""
|
|
3670
|
+
SELECT DISTINCT known_for_role FROM people
|
|
3671
|
+
WHERE known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*'
|
|
3672
|
+
""")
|
|
3673
|
+
for row in cursor:
|
|
3674
|
+
qids.add(row["known_for_role"])
|
|
3675
|
+
|
|
3676
|
+
# Get QIDs from known_for_org field
|
|
3677
|
+
cursor = conn.execute("""
|
|
3678
|
+
SELECT DISTINCT known_for_org FROM people
|
|
3679
|
+
WHERE known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*'
|
|
3680
|
+
""")
|
|
3681
|
+
for row in cursor:
|
|
3682
|
+
qids.add(row["known_for_org"])
|
|
3683
|
+
|
|
3684
|
+
return qids
|
|
3685
|
+
|
|
3686
|
+
def insert_qid_labels(
|
|
3687
|
+
self,
|
|
3688
|
+
label_map: dict[str, str],
|
|
3689
|
+
batch_size: int = 1000,
|
|
3690
|
+
) -> int:
|
|
3691
|
+
"""
|
|
3692
|
+
Insert QID -> label mappings into the lookup table.
|
|
3693
|
+
|
|
3694
|
+
Args:
|
|
3695
|
+
label_map: Mapping of QID -> label
|
|
3696
|
+
batch_size: Commit batch size
|
|
3697
|
+
|
|
3698
|
+
Returns:
|
|
3699
|
+
Number of labels inserted/updated
|
|
3700
|
+
"""
|
|
3701
|
+
conn = self._connect()
|
|
3702
|
+
count = 0
|
|
3703
|
+
skipped = 0
|
|
3704
|
+
|
|
3705
|
+
for qid, label in label_map.items():
|
|
3706
|
+
# Skip non-Q IDs (e.g., property IDs like P19)
|
|
3707
|
+
if not qid.startswith("Q"):
|
|
3708
|
+
skipped += 1
|
|
3709
|
+
continue
|
|
3710
|
+
|
|
3711
|
+
# v2 schema stores QID as integer without Q prefix
|
|
3712
|
+
if self._is_v2:
|
|
3713
|
+
try:
|
|
3714
|
+
qid_val: str | int = int(qid[1:])
|
|
3715
|
+
except ValueError:
|
|
3716
|
+
skipped += 1
|
|
3717
|
+
continue
|
|
3718
|
+
else:
|
|
3719
|
+
qid_val = qid
|
|
3720
|
+
|
|
3721
|
+
conn.execute(
|
|
3722
|
+
"INSERT OR REPLACE INTO qid_labels (qid, label) VALUES (?, ?)",
|
|
3723
|
+
(qid_val, label)
|
|
3724
|
+
)
|
|
3725
|
+
count += 1
|
|
3726
|
+
|
|
3727
|
+
if count % batch_size == 0:
|
|
3728
|
+
conn.commit()
|
|
3729
|
+
logger.debug(f"Inserted {count} QID labels...")
|
|
3730
|
+
|
|
3731
|
+
conn.commit()
|
|
3732
|
+
logger.info(f"Inserted {count} QID labels into lookup table")
|
|
3733
|
+
return count
|
|
3734
|
+
|
|
3735
|
+
def get_qid_label(self, qid: str) -> Optional[str]:
|
|
3736
|
+
"""
|
|
3737
|
+
Get the label for a QID from the lookup table.
|
|
3738
|
+
|
|
3739
|
+
Args:
|
|
3740
|
+
qid: Wikidata QID (e.g., 'Q30')
|
|
3741
|
+
|
|
3742
|
+
Returns:
|
|
3743
|
+
Label string or None if not found
|
|
3744
|
+
"""
|
|
3745
|
+
conn = self._connect()
|
|
3746
|
+
|
|
3747
|
+
# v2 schema stores QID as integer without Q prefix
|
|
3748
|
+
if self._is_v2:
|
|
3749
|
+
qid_val: str | int = int(qid[1:]) if qid.startswith("Q") else int(qid)
|
|
3750
|
+
else:
|
|
3751
|
+
qid_val = qid
|
|
3752
|
+
|
|
3753
|
+
cursor = conn.execute(
|
|
3754
|
+
"SELECT label FROM qid_labels WHERE qid = ?",
|
|
3755
|
+
(qid_val,)
|
|
3756
|
+
)
|
|
3757
|
+
row = cursor.fetchone()
|
|
3758
|
+
return row["label"] if row else None
|
|
3759
|
+
|
|
3760
|
+
def get_all_qid_labels(self) -> dict[str, str]:
|
|
3761
|
+
"""
|
|
3762
|
+
Get all QID -> label mappings from the lookup table.
|
|
3763
|
+
|
|
3764
|
+
Returns:
|
|
3765
|
+
Dict mapping QID -> label
|
|
3766
|
+
"""
|
|
3767
|
+
conn = self._connect()
|
|
3768
|
+
cursor = conn.execute("SELECT qid, label FROM qid_labels")
|
|
3769
|
+
return {row["qid"]: row["label"] for row in cursor}
|
|
3770
|
+
|
|
3771
|
+
def get_qid_labels_count(self) -> int:
|
|
3772
|
+
"""Get the number of QID labels in the lookup table."""
|
|
3773
|
+
conn = self._connect()
|
|
3774
|
+
cursor = conn.execute("SELECT COUNT(*) FROM qid_labels")
|
|
3775
|
+
return cursor.fetchone()[0]
|
|
3776
|
+
|
|
3777
|
+
def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
|
|
3778
|
+
"""
|
|
3779
|
+
Canonicalize person records by linking equivalent entries across sources.
|
|
3780
|
+
|
|
3781
|
+
Uses a multi-phase approach:
|
|
3782
|
+
1. Match by normalized name + same organization (org canonical group)
|
|
3783
|
+
2. Match by normalized name + overlapping date ranges
|
|
3784
|
+
|
|
3785
|
+
Source priority (lower = more authoritative):
|
|
3786
|
+
- wikidata: 1 (curated, has Q codes)
|
|
3787
|
+
- sec_edgar: 2 (US insider filings)
|
|
3788
|
+
- companies_house: 3 (UK officers)
|
|
3789
|
+
|
|
3790
|
+
Args:
|
|
3791
|
+
batch_size: Number of records to process before committing
|
|
3792
|
+
|
|
3793
|
+
Returns:
|
|
3794
|
+
Stats dict with counts for each matching type
|
|
3795
|
+
"""
|
|
3796
|
+
conn = self._connect()
|
|
3797
|
+
stats = {
|
|
3798
|
+
"total_records": 0,
|
|
3799
|
+
"matched_by_org": 0,
|
|
3800
|
+
"matched_by_date": 0,
|
|
3801
|
+
"canonical_groups": 0,
|
|
3802
|
+
"records_in_groups": 0,
|
|
3803
|
+
}
|
|
3804
|
+
|
|
3805
|
+
logger.info("Phase 1: Building person index...")
|
|
3806
|
+
|
|
3807
|
+
# Load all people with their normalized names and org info
|
|
3808
|
+
if self._is_v2:
|
|
3809
|
+
cursor = conn.execute("""
|
|
3810
|
+
SELECT p.id, p.name, p.name_normalized, s.name as source, p.source_identifier as source_id,
|
|
3811
|
+
p.known_for_org, p.known_for_org_id, p.from_date, p.to_date
|
|
3812
|
+
FROM people p
|
|
3813
|
+
JOIN source_types s ON p.source_id = s.id
|
|
3814
|
+
""")
|
|
3815
|
+
else:
|
|
3816
|
+
cursor = conn.execute("""
|
|
3817
|
+
SELECT id, name, name_normalized, source, source_id,
|
|
3818
|
+
known_for_org, known_for_org_id, from_date, to_date
|
|
3819
|
+
FROM people
|
|
3820
|
+
""")
|
|
3821
|
+
|
|
3822
|
+
people: list[dict] = []
|
|
3823
|
+
for row in cursor:
|
|
3824
|
+
people.append({
|
|
3825
|
+
"id": row["id"],
|
|
3826
|
+
"name": row["name"],
|
|
3827
|
+
"name_normalized": row["name_normalized"],
|
|
3828
|
+
"source": row["source"],
|
|
3829
|
+
"source_id": row["source_id"],
|
|
3830
|
+
"known_for_org": row["known_for_org"],
|
|
3831
|
+
"known_for_org_id": row["known_for_org_id"],
|
|
3832
|
+
"from_date": row["from_date"],
|
|
3833
|
+
"to_date": row["to_date"],
|
|
3834
|
+
})
|
|
3835
|
+
|
|
3836
|
+
stats["total_records"] = len(people)
|
|
3837
|
+
logger.info(f"Loaded {len(people)} person records")
|
|
3838
|
+
|
|
3839
|
+
if len(people) == 0:
|
|
3840
|
+
return stats
|
|
3841
|
+
|
|
3842
|
+
# Initialize Union-Find
|
|
3843
|
+
person_ids = [p["id"] for p in people]
|
|
3844
|
+
uf = UnionFind(person_ids)
|
|
3845
|
+
|
|
3846
|
+
# Build indexes for efficient matching
|
|
3847
|
+
# Index by normalized name
|
|
3848
|
+
name_to_people: dict[str, list[dict]] = {}
|
|
3849
|
+
for p in people:
|
|
3850
|
+
name_norm = p["name_normalized"]
|
|
3851
|
+
name_to_people.setdefault(name_norm, []).append(p)
|
|
3852
|
+
|
|
3853
|
+
logger.info("Phase 2: Matching by normalized name + organization...")
|
|
3854
|
+
|
|
3855
|
+
# Match people with same normalized name and same organization
|
|
3856
|
+
for name_norm, same_name in name_to_people.items():
|
|
3857
|
+
if len(same_name) < 2:
|
|
3858
|
+
continue
|
|
3859
|
+
|
|
3860
|
+
# Group by organization (using known_for_org_id if available, else known_for_org)
|
|
3861
|
+
org_groups: dict[str, list[dict]] = {}
|
|
3862
|
+
for p in same_name:
|
|
3863
|
+
org_key = str(p["known_for_org_id"]) if p["known_for_org_id"] else p["known_for_org"]
|
|
3864
|
+
if org_key: # Only group if they have an org
|
|
3865
|
+
org_groups.setdefault(org_key, []).append(p)
|
|
3866
|
+
|
|
3867
|
+
# Union people with same name + same org
|
|
3868
|
+
for org_key, org_people in org_groups.items():
|
|
3869
|
+
if len(org_people) >= 2:
|
|
3870
|
+
first_id = org_people[0]["id"]
|
|
3871
|
+
for p in org_people[1:]:
|
|
3872
|
+
uf.union(first_id, p["id"])
|
|
3873
|
+
stats["matched_by_org"] += 1
|
|
3874
|
+
|
|
3875
|
+
logger.info(f"Phase 2 complete: {stats['matched_by_org']} matches by org")
|
|
3876
|
+
|
|
3877
|
+
logger.info("Phase 3: Matching by normalized name + overlapping dates...")
|
|
3878
|
+
|
|
3879
|
+
# Match people with same normalized name and overlapping date ranges
|
|
3880
|
+
for name_norm, same_name in name_to_people.items():
|
|
3881
|
+
if len(same_name) < 2:
|
|
3882
|
+
continue
|
|
3883
|
+
|
|
3884
|
+
# Skip if already all unified
|
|
3885
|
+
roots = set(uf.find(p["id"]) for p in same_name)
|
|
3886
|
+
if len(roots) == 1:
|
|
3887
|
+
continue
|
|
3888
|
+
|
|
3889
|
+
# Check for overlapping date ranges
|
|
3890
|
+
for i, p1 in enumerate(same_name):
|
|
3891
|
+
for p2 in same_name[i+1:]:
|
|
3892
|
+
# Skip if already in same group
|
|
3893
|
+
if uf.find(p1["id"]) == uf.find(p2["id"]):
|
|
3894
|
+
continue
|
|
3895
|
+
|
|
3896
|
+
# Check date overlap (if both have dates)
|
|
3897
|
+
if p1["from_date"] and p2["from_date"]:
|
|
3898
|
+
# Simple overlap check: if either from_date is before other's to_date
|
|
3899
|
+
p1_from = p1["from_date"]
|
|
3900
|
+
p1_to = p1["to_date"] or "9999-12-31"
|
|
3901
|
+
p2_from = p2["from_date"]
|
|
3902
|
+
p2_to = p2["to_date"] or "9999-12-31"
|
|
3903
|
+
|
|
3904
|
+
# Overlap if: p1_from <= p2_to AND p2_from <= p1_to
|
|
3905
|
+
if p1_from <= p2_to and p2_from <= p1_to:
|
|
3906
|
+
uf.union(p1["id"], p2["id"])
|
|
3907
|
+
stats["matched_by_date"] += 1
|
|
3908
|
+
|
|
3909
|
+
logger.info(f"Phase 3 complete: {stats['matched_by_date']} matches by date")
|
|
3910
|
+
|
|
3911
|
+
logger.info("Phase 4: Applying canonical updates...")
|
|
3912
|
+
|
|
3913
|
+
# Get all groups and select canonical record for each
|
|
3914
|
+
groups = uf.groups()
|
|
3915
|
+
|
|
3916
|
+
# Build id -> source mapping
|
|
3917
|
+
id_to_source = {p["id"]: p["source"] for p in people}
|
|
3918
|
+
|
|
3919
|
+
batch_updates: list[tuple[int, int, int]] = [] # (person_id, canon_id, canon_size)
|
|
3920
|
+
|
|
3921
|
+
for _root, group_ids in groups.items():
|
|
3922
|
+
group_size = len(group_ids)
|
|
3923
|
+
|
|
3924
|
+
if group_size == 1:
|
|
3925
|
+
# Single record is its own canonical
|
|
3926
|
+
person_id = group_ids[0]
|
|
3927
|
+
batch_updates.append((person_id, person_id, 1))
|
|
3928
|
+
else:
|
|
3929
|
+
# Multiple records - pick highest priority source as canonical
|
|
3930
|
+
# Sort by source priority, then by id (for stability)
|
|
3931
|
+
sorted_ids = sorted(
|
|
3932
|
+
group_ids,
|
|
3933
|
+
key=lambda pid: (PERSON_SOURCE_PRIORITY.get(id_to_source[pid], 99), pid)
|
|
3934
|
+
)
|
|
3935
|
+
canon_id = sorted_ids[0]
|
|
3936
|
+
stats["canonical_groups"] += 1
|
|
3937
|
+
stats["records_in_groups"] += group_size
|
|
3938
|
+
|
|
3939
|
+
for person_id in group_ids:
|
|
3940
|
+
batch_updates.append((person_id, canon_id, group_size if person_id == canon_id else 1))
|
|
3941
|
+
|
|
3942
|
+
# Commit in batches
|
|
3943
|
+
if len(batch_updates) >= batch_size:
|
|
3944
|
+
self._apply_person_canon_updates(batch_updates)
|
|
3945
|
+
conn.commit()
|
|
3946
|
+
logger.info(f"Applied {len(batch_updates)} canon updates...")
|
|
3947
|
+
batch_updates = []
|
|
3948
|
+
|
|
3949
|
+
# Final batch
|
|
3950
|
+
if batch_updates:
|
|
3951
|
+
self._apply_person_canon_updates(batch_updates)
|
|
3952
|
+
conn.commit()
|
|
3953
|
+
|
|
3954
|
+
logger.info(f"Canonicalization complete: {stats['canonical_groups']} groups, "
|
|
3955
|
+
f"{stats['records_in_groups']} records in multi-record groups")
|
|
3956
|
+
|
|
3957
|
+
return stats
|
|
3958
|
+
|
|
3959
|
+
def _apply_person_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
|
|
3960
|
+
"""Apply batch of canon updates: (person_id, canon_id, canon_size)."""
|
|
3961
|
+
conn = self._conn
|
|
3962
|
+
assert conn is not None
|
|
3963
|
+
|
|
3964
|
+
for person_id, canon_id, canon_size in updates:
|
|
3965
|
+
conn.execute(
|
|
3966
|
+
"UPDATE people SET canon_id = ?, canon_size = ? WHERE id = ?",
|
|
3967
|
+
(canon_id, canon_size, person_id)
|
|
3968
|
+
)
|
|
3969
|
+
|
|
3970
|
+
|
|
3971
|
+
# =============================================================================
|
|
3972
|
+
# Module-level singletons for new v2 databases
|
|
3973
|
+
# =============================================================================
|
|
3974
|
+
|
|
3975
|
+
_roles_database_instances: dict[str, "RolesDatabase"] = {}
|
|
3976
|
+
_locations_database_instances: dict[str, "LocationsDatabase"] = {}
|
|
3977
|
+
|
|
3978
|
+
|
|
3979
|
+
def get_roles_database(db_path: Optional[str | Path] = None) -> "RolesDatabase":
|
|
3980
|
+
"""
|
|
3981
|
+
Get a singleton RolesDatabase instance for the given path.
|
|
3982
|
+
|
|
3983
|
+
Args:
|
|
3984
|
+
db_path: Path to database file
|
|
3985
|
+
|
|
3986
|
+
Returns:
|
|
3987
|
+
Shared RolesDatabase instance
|
|
3988
|
+
"""
|
|
3989
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
3990
|
+
if path_key not in _roles_database_instances:
|
|
3991
|
+
logger.debug(f"Creating new RolesDatabase instance for {path_key}")
|
|
3992
|
+
_roles_database_instances[path_key] = RolesDatabase(db_path=db_path)
|
|
3993
|
+
return _roles_database_instances[path_key]
|
|
3994
|
+
|
|
3995
|
+
|
|
3996
|
+
def get_locations_database(db_path: Optional[str | Path] = None) -> "LocationsDatabase":
|
|
3997
|
+
"""
|
|
3998
|
+
Get a singleton LocationsDatabase instance for the given path.
|
|
3999
|
+
|
|
4000
|
+
Args:
|
|
4001
|
+
db_path: Path to database file
|
|
4002
|
+
|
|
4003
|
+
Returns:
|
|
4004
|
+
Shared LocationsDatabase instance
|
|
4005
|
+
"""
|
|
4006
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
4007
|
+
if path_key not in _locations_database_instances:
|
|
4008
|
+
logger.debug(f"Creating new LocationsDatabase instance for {path_key}")
|
|
4009
|
+
_locations_database_instances[path_key] = LocationsDatabase(db_path=db_path)
|
|
4010
|
+
return _locations_database_instances[path_key]
|
|
4011
|
+
|
|
4012
|
+
|
|
4013
|
+
# =============================================================================
|
|
4014
|
+
# ROLES DATABASE (v2)
|
|
4015
|
+
# =============================================================================
|
|
4016
|
+
|
|
4017
|
+
|
|
4018
|
+
class RolesDatabase:
|
|
4019
|
+
"""
|
|
4020
|
+
SQLite database for job titles/roles.
|
|
4021
|
+
|
|
4022
|
+
Stores normalized role records with source tracking and supports
|
|
4023
|
+
canonicalization to group equivalent roles (e.g., CEO, Chief Executive).
|
|
4024
|
+
"""
|
|
4025
|
+
|
|
4026
|
+
def __init__(self, db_path: Optional[str | Path] = None):
|
|
4027
|
+
"""
|
|
4028
|
+
Initialize the roles database.
|
|
4029
|
+
|
|
4030
|
+
Args:
|
|
4031
|
+
db_path: Path to database file (creates if not exists)
|
|
4032
|
+
"""
|
|
4033
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
4034
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
4035
|
+
self._role_cache: dict[str, int] = {} # name_normalized -> role_id
|
|
4036
|
+
|
|
4037
|
+
def _connect(self) -> sqlite3.Connection:
|
|
4038
|
+
"""Get or create database connection using shared connection pool."""
|
|
4039
|
+
if self._conn is not None:
|
|
4040
|
+
return self._conn
|
|
4041
|
+
|
|
4042
|
+
self._conn = _get_shared_connection(self._db_path)
|
|
4043
|
+
self._create_tables()
|
|
4044
|
+
return self._conn
|
|
4045
|
+
|
|
4046
|
+
def _create_tables(self) -> None:
|
|
4047
|
+
"""Create roles table and indexes."""
|
|
4048
|
+
conn = self._conn
|
|
4049
|
+
assert conn is not None
|
|
4050
|
+
|
|
4051
|
+
# Check if enum tables exist, create and seed if not
|
|
4052
|
+
cursor = conn.execute(
|
|
4053
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
|
|
4054
|
+
)
|
|
4055
|
+
if not cursor.fetchone():
|
|
4056
|
+
logger.info("Creating enum tables for v2 schema...")
|
|
4057
|
+
from .schema_v2 import (
|
|
4058
|
+
CREATE_SOURCE_TYPES,
|
|
4059
|
+
CREATE_PEOPLE_TYPES,
|
|
4060
|
+
CREATE_ORGANIZATION_TYPES,
|
|
4061
|
+
CREATE_SIMPLIFIED_LOCATION_TYPES,
|
|
4062
|
+
CREATE_LOCATION_TYPES,
|
|
4063
|
+
)
|
|
4064
|
+
conn.execute(CREATE_SOURCE_TYPES)
|
|
4065
|
+
conn.execute(CREATE_PEOPLE_TYPES)
|
|
4066
|
+
conn.execute(CREATE_ORGANIZATION_TYPES)
|
|
4067
|
+
conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
|
|
4068
|
+
conn.execute(CREATE_LOCATION_TYPES)
|
|
4069
|
+
seed_all_enums(conn)
|
|
4070
|
+
|
|
4071
|
+
# Create roles table
|
|
4072
|
+
conn.execute("""
|
|
4073
|
+
CREATE TABLE IF NOT EXISTS roles (
|
|
4074
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
4075
|
+
qid INTEGER,
|
|
4076
|
+
name TEXT NOT NULL,
|
|
4077
|
+
name_normalized TEXT NOT NULL,
|
|
4078
|
+
source_id INTEGER NOT NULL DEFAULT 4,
|
|
4079
|
+
source_identifier TEXT,
|
|
4080
|
+
record TEXT NOT NULL DEFAULT '{}',
|
|
4081
|
+
canon_id INTEGER DEFAULT NULL,
|
|
4082
|
+
canon_size INTEGER DEFAULT 1,
|
|
4083
|
+
UNIQUE(name_normalized, source_id)
|
|
4084
|
+
)
|
|
4085
|
+
""")
|
|
4086
|
+
|
|
4087
|
+
# Create indexes
|
|
4088
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name ON roles(name)")
|
|
4089
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name_normalized ON roles(name_normalized)")
|
|
4090
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_qid ON roles(qid)")
|
|
4091
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_source_id ON roles(source_id)")
|
|
4092
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_canon_id ON roles(canon_id)")
|
|
4093
|
+
|
|
4094
|
+
conn.commit()
|
|
4095
|
+
|
|
4096
|
+
def close(self) -> None:
|
|
4097
|
+
"""Clear connection reference."""
|
|
4098
|
+
self._conn = None
|
|
4099
|
+
|
|
4100
|
+
def get_or_create(
|
|
4101
|
+
self,
|
|
4102
|
+
name: str,
|
|
4103
|
+
source_id: int = 4, # wikidata
|
|
4104
|
+
qid: Optional[int] = None,
|
|
4105
|
+
source_identifier: Optional[str] = None,
|
|
4106
|
+
) -> int:
|
|
4107
|
+
"""
|
|
4108
|
+
Get or create a role record.
|
|
4109
|
+
|
|
4110
|
+
Args:
|
|
4111
|
+
name: Role/title name
|
|
4112
|
+
source_id: FK to source_types table
|
|
4113
|
+
qid: Optional Wikidata QID as integer
|
|
4114
|
+
source_identifier: Optional source-specific identifier
|
|
4115
|
+
|
|
4116
|
+
Returns:
|
|
4117
|
+
Role ID
|
|
4118
|
+
"""
|
|
4119
|
+
if not name:
|
|
4120
|
+
raise ValueError("Role name cannot be empty")
|
|
4121
|
+
|
|
4122
|
+
conn = self._connect()
|
|
4123
|
+
name_normalized = name.lower().strip()
|
|
4124
|
+
|
|
4125
|
+
# Check cache
|
|
4126
|
+
cache_key = f"{name_normalized}:{source_id}"
|
|
4127
|
+
if cache_key in self._role_cache:
|
|
4128
|
+
return self._role_cache[cache_key]
|
|
4129
|
+
|
|
4130
|
+
# Check database
|
|
4131
|
+
cursor = conn.execute(
|
|
4132
|
+
"SELECT id FROM roles WHERE name_normalized = ? AND source_id = ?",
|
|
4133
|
+
(name_normalized, source_id)
|
|
4134
|
+
)
|
|
4135
|
+
row = cursor.fetchone()
|
|
4136
|
+
if row:
|
|
4137
|
+
role_id = row["id"]
|
|
4138
|
+
self._role_cache[cache_key] = role_id
|
|
4139
|
+
return role_id
|
|
4140
|
+
|
|
4141
|
+
# Create new role
|
|
4142
|
+
cursor = conn.execute(
|
|
4143
|
+
"""
|
|
4144
|
+
INSERT INTO roles (name, name_normalized, source_id, qid, source_identifier)
|
|
4145
|
+
VALUES (?, ?, ?, ?, ?)
|
|
4146
|
+
""",
|
|
4147
|
+
(name, name_normalized, source_id, qid, source_identifier)
|
|
4148
|
+
)
|
|
4149
|
+
role_id = cursor.lastrowid
|
|
4150
|
+
assert role_id is not None
|
|
4151
|
+
conn.commit()
|
|
4152
|
+
|
|
4153
|
+
self._role_cache[cache_key] = role_id
|
|
4154
|
+
return role_id
|
|
4155
|
+
|
|
4156
|
+
def get_by_id(self, role_id: int) -> Optional[RoleRecord]:
|
|
4157
|
+
"""Get a role record by ID."""
|
|
4158
|
+
conn = self._connect()
|
|
4159
|
+
|
|
4160
|
+
cursor = conn.execute(
|
|
4161
|
+
"SELECT id, qid, name, source_id, source_identifier, record FROM roles WHERE id = ?",
|
|
4162
|
+
(role_id,)
|
|
4163
|
+
)
|
|
4164
|
+
row = cursor.fetchone()
|
|
4165
|
+
if row:
|
|
4166
|
+
source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
|
|
4167
|
+
return RoleRecord(
|
|
4168
|
+
name=row["name"],
|
|
4169
|
+
source=source_name,
|
|
4170
|
+
source_id=row["source_identifier"],
|
|
4171
|
+
qid=row["qid"],
|
|
4172
|
+
record=json.loads(row["record"]) if row["record"] else {},
|
|
4173
|
+
)
|
|
4174
|
+
return None
|
|
4175
|
+
|
|
4176
|
+
def search(
|
|
4177
|
+
self,
|
|
4178
|
+
query: str,
|
|
4179
|
+
top_k: int = 10,
|
|
4180
|
+
) -> list[tuple[int, str, float]]:
|
|
4181
|
+
"""
|
|
4182
|
+
Search for roles by name.
|
|
4183
|
+
|
|
4184
|
+
Args:
|
|
4185
|
+
query: Search query
|
|
4186
|
+
top_k: Maximum results to return
|
|
4187
|
+
|
|
4188
|
+
Returns:
|
|
4189
|
+
List of (role_id, role_name, score) tuples
|
|
4190
|
+
"""
|
|
4191
|
+
conn = self._connect()
|
|
4192
|
+
query_normalized = query.lower().strip()
|
|
4193
|
+
|
|
4194
|
+
# Exact match first
|
|
4195
|
+
cursor = conn.execute(
|
|
4196
|
+
"SELECT id, name FROM roles WHERE name_normalized = ? LIMIT 1",
|
|
4197
|
+
(query_normalized,)
|
|
4198
|
+
)
|
|
4199
|
+
row = cursor.fetchone()
|
|
4200
|
+
if row:
|
|
4201
|
+
return [(row["id"], row["name"], 1.0)]
|
|
4202
|
+
|
|
4203
|
+
# LIKE match
|
|
4204
|
+
cursor = conn.execute(
|
|
4205
|
+
"""
|
|
4206
|
+
SELECT id, name FROM roles
|
|
4207
|
+
WHERE name_normalized LIKE ?
|
|
4208
|
+
ORDER BY length(name)
|
|
4209
|
+
LIMIT ?
|
|
4210
|
+
""",
|
|
4211
|
+
(f"%{query_normalized}%", top_k)
|
|
4212
|
+
)
|
|
4213
|
+
|
|
4214
|
+
results = []
|
|
4215
|
+
for row in cursor:
|
|
4216
|
+
# Simple score based on match quality
|
|
4217
|
+
name_normalized = row["name"].lower()
|
|
4218
|
+
if query_normalized == name_normalized:
|
|
4219
|
+
score = 1.0
|
|
4220
|
+
elif name_normalized.startswith(query_normalized):
|
|
4221
|
+
score = 0.9
|
|
4222
|
+
else:
|
|
4223
|
+
score = 0.7
|
|
4224
|
+
results.append((row["id"], row["name"], score))
|
|
4225
|
+
|
|
4226
|
+
return results
|
|
4227
|
+
|
|
4228
|
+
def get_stats(self) -> dict[str, int]:
|
|
4229
|
+
"""Get statistics about the roles table."""
|
|
4230
|
+
conn = self._connect()
|
|
4231
|
+
|
|
4232
|
+
cursor = conn.execute("SELECT COUNT(*) FROM roles")
|
|
4233
|
+
total = cursor.fetchone()[0]
|
|
4234
|
+
|
|
4235
|
+
cursor = conn.execute("SELECT COUNT(*) FROM roles WHERE canon_id IS NOT NULL")
|
|
4236
|
+
canonicalized = cursor.fetchone()[0]
|
|
4237
|
+
|
|
4238
|
+
cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM roles WHERE canon_id IS NOT NULL")
|
|
4239
|
+
groups = cursor.fetchone()[0]
|
|
4240
|
+
|
|
4241
|
+
return {
|
|
4242
|
+
"total_roles": total,
|
|
4243
|
+
"canonicalized": canonicalized,
|
|
4244
|
+
"canonical_groups": groups,
|
|
4245
|
+
}
|
|
4246
|
+
|
|
4247
|
+
|
|
4248
|
+
# =============================================================================
|
|
4249
|
+
# LOCATIONS DATABASE (v2)
|
|
4250
|
+
# =============================================================================
|
|
4251
|
+
|
|
4252
|
+
|
|
4253
|
+
class LocationsDatabase:
|
|
4254
|
+
"""
|
|
4255
|
+
SQLite database for geopolitical locations.
|
|
4256
|
+
|
|
4257
|
+
Stores countries, states, cities with hierarchical relationships
|
|
4258
|
+
and type classification. Supports pycountry integration.
|
|
4259
|
+
"""
|
|
4260
|
+
|
|
4261
|
+
def __init__(self, db_path: Optional[str | Path] = None):
|
|
4262
|
+
"""
|
|
4263
|
+
Initialize the locations database.
|
|
4264
|
+
|
|
4265
|
+
Args:
|
|
4266
|
+
db_path: Path to database file (creates if not exists)
|
|
4267
|
+
"""
|
|
4268
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
4269
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
4270
|
+
self._location_cache: dict[str, int] = {} # lookup_key -> location_id
|
|
4271
|
+
self._location_type_cache: dict[str, int] = {} # type_name -> type_id
|
|
4272
|
+
self._location_type_qid_cache: dict[int, int] = {} # qid -> type_id
|
|
4273
|
+
|
|
4274
|
+
def _connect(self) -> sqlite3.Connection:
|
|
4275
|
+
"""Get or create database connection using shared connection pool."""
|
|
4276
|
+
if self._conn is not None:
|
|
4277
|
+
return self._conn
|
|
4278
|
+
|
|
4279
|
+
self._conn = _get_shared_connection(self._db_path)
|
|
4280
|
+
self._create_tables()
|
|
4281
|
+
self._build_caches()
|
|
4282
|
+
return self._conn
|
|
4283
|
+
|
|
4284
|
+
def _create_tables(self) -> None:
|
|
4285
|
+
"""Create locations table and indexes."""
|
|
4286
|
+
conn = self._conn
|
|
4287
|
+
assert conn is not None
|
|
4288
|
+
|
|
4289
|
+
# Check if enum tables exist, create and seed if not
|
|
4290
|
+
cursor = conn.execute(
|
|
4291
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
|
|
4292
|
+
)
|
|
4293
|
+
if not cursor.fetchone():
|
|
4294
|
+
logger.info("Creating enum tables for v2 schema...")
|
|
4295
|
+
from .schema_v2 import (
|
|
4296
|
+
CREATE_SOURCE_TYPES,
|
|
4297
|
+
CREATE_PEOPLE_TYPES,
|
|
4298
|
+
CREATE_ORGANIZATION_TYPES,
|
|
4299
|
+
CREATE_SIMPLIFIED_LOCATION_TYPES,
|
|
4300
|
+
CREATE_LOCATION_TYPES,
|
|
4301
|
+
)
|
|
4302
|
+
conn.execute(CREATE_SOURCE_TYPES)
|
|
4303
|
+
conn.execute(CREATE_PEOPLE_TYPES)
|
|
4304
|
+
conn.execute(CREATE_ORGANIZATION_TYPES)
|
|
4305
|
+
conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
|
|
4306
|
+
conn.execute(CREATE_LOCATION_TYPES)
|
|
4307
|
+
seed_all_enums(conn)
|
|
4308
|
+
|
|
4309
|
+
# Create locations table
|
|
4310
|
+
conn.execute("""
|
|
4311
|
+
CREATE TABLE IF NOT EXISTS locations (
|
|
4312
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
4313
|
+
qid INTEGER,
|
|
4314
|
+
name TEXT NOT NULL,
|
|
4315
|
+
name_normalized TEXT NOT NULL,
|
|
4316
|
+
source_id INTEGER NOT NULL DEFAULT 4,
|
|
4317
|
+
source_identifier TEXT,
|
|
4318
|
+
parent_ids TEXT,
|
|
4319
|
+
location_type_id INTEGER NOT NULL DEFAULT 2,
|
|
4320
|
+
record TEXT NOT NULL DEFAULT '{}',
|
|
4321
|
+
from_date TEXT DEFAULT NULL,
|
|
4322
|
+
to_date TEXT DEFAULT NULL,
|
|
4323
|
+
canon_id INTEGER DEFAULT NULL,
|
|
4324
|
+
canon_size INTEGER DEFAULT 1,
|
|
4325
|
+
UNIQUE(source_identifier, source_id)
|
|
4326
|
+
)
|
|
4327
|
+
""")
|
|
4328
|
+
|
|
4329
|
+
# Create indexes
|
|
4330
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name ON locations(name)")
|
|
4331
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name_normalized ON locations(name_normalized)")
|
|
4332
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_qid ON locations(qid)")
|
|
4333
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_source_id ON locations(source_id)")
|
|
4334
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_location_type_id ON locations(location_type_id)")
|
|
4335
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_canon_id ON locations(canon_id)")
|
|
4336
|
+
|
|
4337
|
+
conn.commit()
|
|
4338
|
+
|
|
4339
|
+
def _build_caches(self) -> None:
|
|
4340
|
+
"""Build lookup caches from database and seed data."""
|
|
4341
|
+
# Load location type caches from seed data
|
|
4342
|
+
self._location_type_cache = dict(LOCATION_TYPE_NAME_TO_ID)
|
|
4343
|
+
self._location_type_qid_cache = dict(LOCATION_TYPE_QID_TO_ID)
|
|
4344
|
+
|
|
4345
|
+
# Load existing locations into cache
|
|
4346
|
+
conn = self._conn
|
|
4347
|
+
if conn:
|
|
4348
|
+
cursor = conn.execute(
|
|
4349
|
+
"SELECT id, name_normalized, source_identifier FROM locations"
|
|
4350
|
+
)
|
|
4351
|
+
for row in cursor:
|
|
4352
|
+
# Cache by normalized name
|
|
4353
|
+
self._location_cache[row["name_normalized"]] = row["id"]
|
|
4354
|
+
# Also cache by source_identifier
|
|
4355
|
+
if row["source_identifier"]:
|
|
4356
|
+
self._location_cache[row["source_identifier"].lower()] = row["id"]
|
|
4357
|
+
|
|
4358
|
+
def close(self) -> None:
|
|
4359
|
+
"""Clear connection reference."""
|
|
4360
|
+
self._conn = None
|
|
4361
|
+
|
|
4362
|
+
def get_or_create(
|
|
4363
|
+
self,
|
|
4364
|
+
name: str,
|
|
4365
|
+
location_type_id: int,
|
|
4366
|
+
source_id: int = 4, # wikidata
|
|
4367
|
+
qid: Optional[int] = None,
|
|
4368
|
+
source_identifier: Optional[str] = None,
|
|
4369
|
+
parent_ids: Optional[list[int]] = None,
|
|
4370
|
+
) -> int:
|
|
4371
|
+
"""
|
|
4372
|
+
Get or create a location record.
|
|
4373
|
+
|
|
4374
|
+
Args:
|
|
4375
|
+
name: Location name
|
|
4376
|
+
location_type_id: FK to location_types table
|
|
4377
|
+
source_id: FK to source_types table
|
|
4378
|
+
qid: Optional Wikidata QID as integer
|
|
4379
|
+
source_identifier: Optional source-specific identifier (e.g., "US", "CA")
|
|
4380
|
+
parent_ids: Optional list of parent location IDs
|
|
4381
|
+
|
|
4382
|
+
Returns:
|
|
4383
|
+
Location ID
|
|
4384
|
+
"""
|
|
4385
|
+
if not name:
|
|
4386
|
+
raise ValueError("Location name cannot be empty")
|
|
4387
|
+
|
|
4388
|
+
conn = self._connect()
|
|
4389
|
+
name_normalized = name.lower().strip()
|
|
4390
|
+
|
|
4391
|
+
# Check cache by source_identifier first (more specific)
|
|
4392
|
+
if source_identifier:
|
|
4393
|
+
cache_key = source_identifier.lower()
|
|
4394
|
+
if cache_key in self._location_cache:
|
|
4395
|
+
return self._location_cache[cache_key]
|
|
4396
|
+
|
|
4397
|
+
# Check cache by normalized name
|
|
4398
|
+
if name_normalized in self._location_cache:
|
|
4399
|
+
return self._location_cache[name_normalized]
|
|
4400
|
+
|
|
4401
|
+
# Check database
|
|
4402
|
+
if source_identifier:
|
|
4403
|
+
cursor = conn.execute(
|
|
4404
|
+
"SELECT id FROM locations WHERE source_identifier = ? AND source_id = ?",
|
|
4405
|
+
(source_identifier, source_id)
|
|
4406
|
+
)
|
|
4407
|
+
else:
|
|
4408
|
+
cursor = conn.execute(
|
|
4409
|
+
"SELECT id FROM locations WHERE name_normalized = ? AND source_id = ?",
|
|
4410
|
+
(name_normalized, source_id)
|
|
4411
|
+
)
|
|
4412
|
+
|
|
4413
|
+
row = cursor.fetchone()
|
|
4414
|
+
if row:
|
|
4415
|
+
location_id = row["id"]
|
|
4416
|
+
self._location_cache[name_normalized] = location_id
|
|
4417
|
+
if source_identifier:
|
|
4418
|
+
self._location_cache[source_identifier.lower()] = location_id
|
|
4419
|
+
return location_id
|
|
4420
|
+
|
|
4421
|
+
# Create new location
|
|
4422
|
+
parent_ids_json = json.dumps(parent_ids) if parent_ids else None
|
|
4423
|
+
cursor = conn.execute(
|
|
4424
|
+
"""
|
|
4425
|
+
INSERT INTO locations
|
|
4426
|
+
(name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids)
|
|
4427
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
4428
|
+
""",
|
|
4429
|
+
(name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids_json)
|
|
4430
|
+
)
|
|
4431
|
+
location_id = cursor.lastrowid
|
|
4432
|
+
assert location_id is not None
|
|
4433
|
+
conn.commit()
|
|
4434
|
+
|
|
4435
|
+
self._location_cache[name_normalized] = location_id
|
|
4436
|
+
if source_identifier:
|
|
4437
|
+
self._location_cache[source_identifier.lower()] = location_id
|
|
4438
|
+
return location_id
|
|
4439
|
+
|
|
4440
|
+
def get_or_create_by_qid(
|
|
4441
|
+
self,
|
|
4442
|
+
name: str,
|
|
4443
|
+
wikidata_type_qid: int,
|
|
4444
|
+
source_id: int = 4,
|
|
4445
|
+
entity_qid: Optional[int] = None,
|
|
4446
|
+
source_identifier: Optional[str] = None,
|
|
4447
|
+
parent_ids: Optional[list[int]] = None,
|
|
4448
|
+
) -> int:
|
|
4449
|
+
"""
|
|
4450
|
+
Get or create a location using Wikidata P31 type QID.
|
|
4451
|
+
|
|
4452
|
+
Args:
|
|
4453
|
+
name: Location name
|
|
4454
|
+
wikidata_type_qid: Wikidata instance-of QID (e.g., 515 for city)
|
|
4455
|
+
source_id: FK to source_types table
|
|
4456
|
+
entity_qid: Wikidata QID of the entity itself
|
|
4457
|
+
source_identifier: Optional source-specific identifier
|
|
4458
|
+
parent_ids: Optional list of parent location IDs
|
|
4459
|
+
|
|
4460
|
+
Returns:
|
|
4461
|
+
Location ID
|
|
4462
|
+
"""
|
|
4463
|
+
location_type_id = self.get_location_type_id_from_qid(wikidata_type_qid)
|
|
4464
|
+
return self.get_or_create(
|
|
4465
|
+
name=name,
|
|
4466
|
+
location_type_id=location_type_id,
|
|
4467
|
+
source_id=source_id,
|
|
4468
|
+
qid=entity_qid,
|
|
4469
|
+
source_identifier=source_identifier,
|
|
4470
|
+
parent_ids=parent_ids,
|
|
4471
|
+
)
|
|
4472
|
+
|
|
4473
|
+
def get_by_id(self, location_id: int) -> Optional[LocationRecord]:
|
|
4474
|
+
"""Get a location record by ID."""
|
|
4475
|
+
conn = self._connect()
|
|
4476
|
+
|
|
4477
|
+
cursor = conn.execute(
|
|
4478
|
+
"""
|
|
4479
|
+
SELECT id, qid, name, source_id, source_identifier, location_type_id,
|
|
4480
|
+
parent_ids, from_date, to_date, record
|
|
4481
|
+
FROM locations WHERE id = ?
|
|
4482
|
+
""",
|
|
4483
|
+
(location_id,)
|
|
4484
|
+
)
|
|
4485
|
+
row = cursor.fetchone()
|
|
4486
|
+
if row:
|
|
4487
|
+
source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
|
|
4488
|
+
location_type_id = row["location_type_id"]
|
|
4489
|
+
location_type_name = self._get_location_type_name(location_type_id)
|
|
4490
|
+
simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
|
|
4491
|
+
simplified_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
|
|
4492
|
+
|
|
4493
|
+
parent_ids = json.loads(row["parent_ids"]) if row["parent_ids"] else []
|
|
4494
|
+
|
|
4495
|
+
return LocationRecord(
|
|
4496
|
+
name=row["name"],
|
|
4497
|
+
source=source_name,
|
|
4498
|
+
source_id=row["source_identifier"],
|
|
4499
|
+
qid=row["qid"],
|
|
4500
|
+
location_type=location_type_name,
|
|
4501
|
+
simplified_type=SimplifiedLocationType(simplified_name),
|
|
4502
|
+
parent_ids=parent_ids,
|
|
4503
|
+
from_date=row["from_date"],
|
|
4504
|
+
to_date=row["to_date"],
|
|
4505
|
+
record=json.loads(row["record"]) if row["record"] else {},
|
|
4506
|
+
)
|
|
4507
|
+
return None
|
|
4508
|
+
|
|
4509
|
+
def _get_location_type_name(self, type_id: int) -> str:
|
|
4510
|
+
"""Get location type name from ID."""
|
|
4511
|
+
# Reverse lookup in cache
|
|
4512
|
+
for name, id_ in self._location_type_cache.items():
|
|
4513
|
+
if id_ == type_id:
|
|
4514
|
+
return name
|
|
4515
|
+
return "other"
|
|
4516
|
+
|
|
4517
|
+
def get_location_type_id(self, type_name: str) -> int:
|
|
4518
|
+
"""Get location_type_id for a type name."""
|
|
4519
|
+
return self._location_type_cache.get(type_name, 36) # default to "other"
|
|
4520
|
+
|
|
4521
|
+
def get_location_type_id_from_qid(self, wikidata_qid: int) -> int:
|
|
4522
|
+
"""Get location_type_id from Wikidata P31 QID."""
|
|
4523
|
+
return self._location_type_qid_cache.get(wikidata_qid, 36) # default to "other"
|
|
4524
|
+
|
|
4525
|
+
def get_simplified_type(self, location_type_id: int) -> str:
|
|
4526
|
+
"""Get simplified type name for a location_type_id."""
|
|
4527
|
+
simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
|
|
4528
|
+
return SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
|
|
4529
|
+
|
|
4530
|
+
def resolve_region_text(self, text: str) -> Optional[int]:
|
|
4531
|
+
"""
|
|
4532
|
+
Resolve a region/country text to a location ID.
|
|
4533
|
+
|
|
4534
|
+
Uses pycountry for country resolution, then falls back to search.
|
|
4535
|
+
|
|
4536
|
+
Args:
|
|
4537
|
+
text: Region text (country code, name, or QID)
|
|
4538
|
+
|
|
4539
|
+
Returns:
|
|
4540
|
+
Location ID or None if not resolved
|
|
4541
|
+
"""
|
|
4542
|
+
if not text:
|
|
4543
|
+
return None
|
|
4544
|
+
|
|
4545
|
+
text_lower = text.lower().strip()
|
|
4546
|
+
|
|
4547
|
+
# Check cache first
|
|
4548
|
+
if text_lower in self._location_cache:
|
|
4549
|
+
return self._location_cache[text_lower]
|
|
4550
|
+
|
|
4551
|
+
# Try pycountry resolution
|
|
4552
|
+
alpha_2 = self._resolve_via_pycountry(text)
|
|
4553
|
+
if alpha_2:
|
|
4554
|
+
alpha_2_lower = alpha_2.lower()
|
|
4555
|
+
if alpha_2_lower in self._location_cache:
|
|
4556
|
+
location_id = self._location_cache[alpha_2_lower]
|
|
4557
|
+
self._location_cache[text_lower] = location_id # Cache the input too
|
|
4558
|
+
return location_id
|
|
4559
|
+
|
|
4560
|
+
# Country not in database yet, import it
|
|
4561
|
+
try:
|
|
4562
|
+
country = pycountry.countries.get(alpha_2=alpha_2)
|
|
4563
|
+
if country:
|
|
4564
|
+
country_type_id = self._location_type_cache.get("country", 2)
|
|
4565
|
+
location_id = self.get_or_create(
|
|
4566
|
+
name=country.name,
|
|
4567
|
+
location_type_id=country_type_id,
|
|
4568
|
+
source_id=4, # wikidata
|
|
4569
|
+
source_identifier=alpha_2,
|
|
4570
|
+
)
|
|
4571
|
+
self._location_cache[text_lower] = location_id
|
|
4572
|
+
return location_id
|
|
4573
|
+
except Exception:
|
|
4574
|
+
pass
|
|
4575
|
+
|
|
4576
|
+
return None
|
|
4577
|
+
|
|
4578
|
+
def _resolve_via_pycountry(self, region: str) -> Optional[str]:
|
|
4579
|
+
"""Try to resolve region via pycountry."""
|
|
4580
|
+
region_clean = region.strip()
|
|
4581
|
+
if not region_clean:
|
|
4582
|
+
return None
|
|
4583
|
+
|
|
4584
|
+
# Try as 2-letter code
|
|
4585
|
+
if len(region_clean) == 2:
|
|
4586
|
+
country = pycountry.countries.get(alpha_2=region_clean.upper())
|
|
4587
|
+
if country:
|
|
4588
|
+
return country.alpha_2
|
|
4589
|
+
|
|
4590
|
+
# Try as 3-letter code
|
|
4591
|
+
if len(region_clean) == 3:
|
|
4592
|
+
country = pycountry.countries.get(alpha_3=region_clean.upper())
|
|
4593
|
+
if country:
|
|
4594
|
+
return country.alpha_2
|
|
4595
|
+
|
|
4596
|
+
# Try fuzzy search
|
|
4597
|
+
try:
|
|
4598
|
+
matches = pycountry.countries.search_fuzzy(region_clean)
|
|
4599
|
+
if matches:
|
|
4600
|
+
return matches[0].alpha_2
|
|
4601
|
+
except LookupError:
|
|
4602
|
+
pass
|
|
4603
|
+
|
|
4604
|
+
return None
|
|
4605
|
+
|
|
4606
|
+
def import_from_pycountry(self) -> int:
|
|
4607
|
+
"""
|
|
4608
|
+
Import all countries from pycountry.
|
|
4609
|
+
|
|
4610
|
+
Returns:
|
|
4611
|
+
Number of locations imported
|
|
4612
|
+
"""
|
|
4613
|
+
conn = self._connect()
|
|
4614
|
+
country_type_id = self._location_type_cache.get("country", 2)
|
|
4615
|
+
count = 0
|
|
4616
|
+
|
|
4617
|
+
for country in pycountry.countries:
|
|
4618
|
+
name = country.name
|
|
4619
|
+
alpha_2 = country.alpha_2
|
|
4620
|
+
name_normalized = name.lower()
|
|
4621
|
+
|
|
4622
|
+
# Check if already exists
|
|
4623
|
+
if alpha_2.lower() in self._location_cache:
|
|
4624
|
+
continue
|
|
4625
|
+
|
|
4626
|
+
cursor = conn.execute(
|
|
4627
|
+
"""
|
|
4628
|
+
INSERT OR IGNORE INTO locations
|
|
4629
|
+
(name, name_normalized, source_id, source_identifier, location_type_id)
|
|
4630
|
+
VALUES (?, ?, 4, ?, ?)
|
|
4631
|
+
""",
|
|
4632
|
+
(name, name_normalized, alpha_2, country_type_id)
|
|
4633
|
+
)
|
|
4634
|
+
|
|
4635
|
+
if cursor.lastrowid:
|
|
4636
|
+
self._location_cache[name_normalized] = cursor.lastrowid
|
|
4637
|
+
self._location_cache[alpha_2.lower()] = cursor.lastrowid
|
|
4638
|
+
count += 1
|
|
4639
|
+
|
|
4640
|
+
conn.commit()
|
|
4641
|
+
logger.info(f"Imported {count} countries from pycountry")
|
|
4642
|
+
return count
|
|
4643
|
+
|
|
4644
|
+
def search(
|
|
4645
|
+
self,
|
|
4646
|
+
query: str,
|
|
4647
|
+
top_k: int = 10,
|
|
4648
|
+
simplified_type: Optional[str] = None,
|
|
4649
|
+
) -> list[tuple[int, str, float]]:
|
|
4650
|
+
"""
|
|
4651
|
+
Search for locations by name.
|
|
4652
|
+
|
|
4653
|
+
Args:
|
|
4654
|
+
query: Search query
|
|
4655
|
+
top_k: Maximum results to return
|
|
4656
|
+
simplified_type: Optional filter by simplified type (e.g., "country", "city")
|
|
4657
|
+
|
|
4658
|
+
Returns:
|
|
4659
|
+
List of (location_id, location_name, score) tuples
|
|
4660
|
+
"""
|
|
4661
|
+
conn = self._connect()
|
|
4662
|
+
query_normalized = query.lower().strip()
|
|
4663
|
+
|
|
4664
|
+
# Build query with optional type filter
|
|
4665
|
+
if simplified_type:
|
|
4666
|
+
# Get all location_type_ids for this simplified type
|
|
4667
|
+
simplified_id = {
|
|
4668
|
+
name: id_ for id_, name in SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.items()
|
|
4669
|
+
}.get(simplified_type)
|
|
4670
|
+
if simplified_id:
|
|
4671
|
+
type_ids = [
|
|
4672
|
+
type_id for type_id, simp_id in LOCATION_TYPE_TO_SIMPLIFIED.items()
|
|
4673
|
+
if simp_id == simplified_id
|
|
4674
|
+
]
|
|
4675
|
+
if type_ids:
|
|
4676
|
+
placeholders = ",".join("?" * len(type_ids))
|
|
4677
|
+
cursor = conn.execute(
|
|
4678
|
+
f"""
|
|
4679
|
+
SELECT id, name FROM locations
|
|
4680
|
+
WHERE name_normalized LIKE ? AND location_type_id IN ({placeholders})
|
|
4681
|
+
ORDER BY length(name)
|
|
4682
|
+
LIMIT ?
|
|
4683
|
+
""",
|
|
4684
|
+
[f"%{query_normalized}%"] + type_ids + [top_k]
|
|
4685
|
+
)
|
|
4686
|
+
else:
|
|
4687
|
+
return []
|
|
4688
|
+
else:
|
|
4689
|
+
return []
|
|
4690
|
+
else:
|
|
4691
|
+
cursor = conn.execute(
|
|
4692
|
+
"""
|
|
4693
|
+
SELECT id, name FROM locations
|
|
4694
|
+
WHERE name_normalized LIKE ?
|
|
4695
|
+
ORDER BY length(name)
|
|
4696
|
+
LIMIT ?
|
|
4697
|
+
""",
|
|
4698
|
+
(f"%{query_normalized}%", top_k)
|
|
4699
|
+
)
|
|
4700
|
+
|
|
4701
|
+
results = []
|
|
4702
|
+
for row in cursor:
|
|
4703
|
+
name_normalized = row["name"].lower()
|
|
4704
|
+
if query_normalized == name_normalized:
|
|
4705
|
+
score = 1.0
|
|
4706
|
+
elif name_normalized.startswith(query_normalized):
|
|
4707
|
+
score = 0.9
|
|
4708
|
+
else:
|
|
4709
|
+
score = 0.7
|
|
4710
|
+
results.append((row["id"], row["name"], score))
|
|
4711
|
+
|
|
4712
|
+
return results
|
|
4713
|
+
|
|
4714
|
+
def get_stats(self) -> dict[str, Any]:
|
|
4715
|
+
"""Get statistics about the locations table."""
|
|
4716
|
+
conn = self._connect()
|
|
4717
|
+
|
|
4718
|
+
cursor = conn.execute("SELECT COUNT(*) FROM locations")
|
|
4719
|
+
total = cursor.fetchone()[0]
|
|
4720
|
+
|
|
4721
|
+
cursor = conn.execute("SELECT COUNT(*) FROM locations WHERE canon_id IS NOT NULL")
|
|
4722
|
+
canonicalized = cursor.fetchone()[0]
|
|
4723
|
+
|
|
4724
|
+
cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM locations WHERE canon_id IS NOT NULL")
|
|
4725
|
+
groups = cursor.fetchone()[0]
|
|
4726
|
+
|
|
4727
|
+
# Count by simplified type
|
|
4728
|
+
by_type: dict[str, int] = {}
|
|
4729
|
+
cursor = conn.execute("""
|
|
4730
|
+
SELECT lt.simplified_id, COUNT(*) as cnt
|
|
4731
|
+
FROM locations l
|
|
4732
|
+
JOIN location_types lt ON l.location_type_id = lt.id
|
|
4733
|
+
GROUP BY lt.simplified_id
|
|
4734
|
+
""")
|
|
4735
|
+
for row in cursor:
|
|
4736
|
+
type_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(row["simplified_id"], "other")
|
|
4737
|
+
by_type[type_name] = row["cnt"]
|
|
4738
|
+
|
|
4739
|
+
return {
|
|
4740
|
+
"total_locations": total,
|
|
4741
|
+
"canonicalized": canonicalized,
|
|
4742
|
+
"canonical_groups": groups,
|
|
4743
|
+
"by_type": by_type,
|
|
4744
|
+
}
|
|
4745
|
+
|
|
4746
|
+
def insert_batch(self, records: list[LocationRecord]) -> int:
|
|
4747
|
+
"""
|
|
4748
|
+
Insert a batch of location records.
|
|
4749
|
+
|
|
4750
|
+
Args:
|
|
4751
|
+
records: List of LocationRecord objects to insert
|
|
4752
|
+
|
|
4753
|
+
Returns:
|
|
4754
|
+
Number of records inserted
|
|
4755
|
+
"""
|
|
4756
|
+
if not records:
|
|
4757
|
+
return 0
|
|
4758
|
+
|
|
4759
|
+
conn = self._connect()
|
|
4760
|
+
inserted = 0
|
|
4761
|
+
|
|
4762
|
+
for record in records:
|
|
4763
|
+
name_normalized = record.name.lower().strip()
|
|
4764
|
+
source_identifier = record.source_id # Q code in source_id field
|
|
4765
|
+
|
|
4766
|
+
# Check cache first
|
|
4767
|
+
cache_key = source_identifier.lower() if source_identifier else name_normalized
|
|
4768
|
+
if cache_key in self._location_cache:
|
|
4769
|
+
continue
|
|
4770
|
+
|
|
4771
|
+
# Get location_type_id from type name
|
|
4772
|
+
location_type_id = self._location_type_cache.get(record.location_type, 36) # default "other"
|
|
4773
|
+
source_id = SOURCE_NAME_TO_ID.get(record.source, 4) # default wikidata
|
|
4774
|
+
|
|
4775
|
+
parent_ids_json = json.dumps(record.parent_ids) if record.parent_ids else None
|
|
4776
|
+
record_json = json.dumps(record.record) if record.record else "{}"
|
|
4777
|
+
|
|
4778
|
+
try:
|
|
4779
|
+
cursor = conn.execute(
|
|
4780
|
+
"""
|
|
4781
|
+
INSERT OR IGNORE INTO locations
|
|
4782
|
+
(name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids, record, from_date, to_date)
|
|
4783
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
4784
|
+
""",
|
|
4785
|
+
(
|
|
4786
|
+
record.name,
|
|
4787
|
+
name_normalized,
|
|
4788
|
+
source_id,
|
|
4789
|
+
source_identifier,
|
|
4790
|
+
record.qid,
|
|
4791
|
+
location_type_id,
|
|
4792
|
+
parent_ids_json,
|
|
4793
|
+
record_json,
|
|
4794
|
+
record.from_date,
|
|
4795
|
+
record.to_date,
|
|
4796
|
+
)
|
|
4797
|
+
)
|
|
4798
|
+
if cursor.lastrowid:
|
|
4799
|
+
self._location_cache[name_normalized] = cursor.lastrowid
|
|
4800
|
+
if source_identifier:
|
|
4801
|
+
self._location_cache[source_identifier.lower()] = cursor.lastrowid
|
|
4802
|
+
inserted += 1
|
|
4803
|
+
except Exception as e:
|
|
4804
|
+
logger.warning(f"Failed to insert location {record.name}: {e}")
|
|
4805
|
+
|
|
4806
|
+
conn.commit()
|
|
4807
|
+
return inserted
|
|
4808
|
+
|
|
4809
|
+
def get_all_source_ids(self, source: str = "wikidata") -> set[str]:
|
|
4810
|
+
"""
|
|
4811
|
+
Get all source_identifiers for a given source.
|
|
4812
|
+
|
|
4813
|
+
Args:
|
|
4814
|
+
source: Source name (e.g., "wikidata")
|
|
4815
|
+
|
|
4816
|
+
Returns:
|
|
4817
|
+
Set of source_identifiers
|
|
4818
|
+
"""
|
|
4819
|
+
conn = self._connect()
|
|
4820
|
+
source_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
4821
|
+
cursor = conn.execute(
|
|
4822
|
+
"SELECT source_identifier FROM locations WHERE source_id = ? AND source_identifier IS NOT NULL",
|
|
4823
|
+
(source_id,)
|
|
4824
|
+
)
|
|
4825
|
+
return {row["source_identifier"] for row in cursor}
|