corp-extractor 0.9.0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/METADATA +40 -9
  2. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/RECORD +29 -26
  3. statement_extractor/cli.py +866 -77
  4. statement_extractor/database/hub.py +35 -127
  5. statement_extractor/database/importers/__init__.py +10 -2
  6. statement_extractor/database/importers/companies_house.py +16 -2
  7. statement_extractor/database/importers/companies_house_officers.py +431 -0
  8. statement_extractor/database/importers/gleif.py +23 -0
  9. statement_extractor/database/importers/sec_edgar.py +17 -0
  10. statement_extractor/database/importers/sec_form4.py +512 -0
  11. statement_extractor/database/importers/wikidata.py +151 -43
  12. statement_extractor/database/importers/wikidata_dump.py +1951 -0
  13. statement_extractor/database/importers/wikidata_people.py +823 -325
  14. statement_extractor/database/models.py +30 -6
  15. statement_extractor/database/store.py +1485 -60
  16. statement_extractor/document/deduplicator.py +10 -12
  17. statement_extractor/extractor.py +1 -1
  18. statement_extractor/models/__init__.py +3 -2
  19. statement_extractor/models/statement.py +15 -17
  20. statement_extractor/models.py +1 -1
  21. statement_extractor/pipeline/context.py +5 -5
  22. statement_extractor/pipeline/orchestrator.py +12 -12
  23. statement_extractor/plugins/base.py +17 -17
  24. statement_extractor/plugins/extractors/gliner2.py +28 -28
  25. statement_extractor/plugins/qualifiers/embedding_company.py +7 -5
  26. statement_extractor/plugins/qualifiers/person.py +11 -1
  27. statement_extractor/plugins/splitters/t5_gemma.py +35 -39
  28. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/WHEEL +0 -0
  29. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.3.dist-info}/entry_points.txt +0 -0
@@ -15,6 +15,7 @@ from pathlib import Path
15
15
  from typing import Iterator, Optional
16
16
 
17
17
  import numpy as np
18
+ import pycountry
18
19
  import sqlite_vec
19
20
 
20
21
  from .models import CompanyRecord, DatabaseStats, EntityType, PersonRecord, PersonType
@@ -24,12 +25,45 @@ logger = logging.getLogger(__name__)
24
25
  # Default database location
25
26
  DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
26
27
 
28
+ # Module-level shared connections by path (both databases share the same connection)
29
+ _shared_connections: dict[str, sqlite3.Connection] = {}
30
+
27
31
  # Module-level singleton for OrganizationDatabase to prevent multiple loads
28
32
  _database_instances: dict[str, "OrganizationDatabase"] = {}
29
33
 
30
34
  # Module-level singleton for PersonDatabase
31
35
  _person_database_instances: dict[str, "PersonDatabase"] = {}
32
36
 
37
+
38
+ def _get_shared_connection(db_path: Path, embedding_dim: int = 768) -> sqlite3.Connection:
39
+ """Get or create a shared database connection for the given path."""
40
+ path_key = str(db_path)
41
+ if path_key not in _shared_connections:
42
+ # Ensure directory exists
43
+ db_path.parent.mkdir(parents=True, exist_ok=True)
44
+
45
+ conn = sqlite3.connect(str(db_path))
46
+ conn.row_factory = sqlite3.Row
47
+
48
+ # Load sqlite-vec extension
49
+ conn.enable_load_extension(True)
50
+ sqlite_vec.load(conn)
51
+ conn.enable_load_extension(False)
52
+
53
+ _shared_connections[path_key] = conn
54
+ logger.debug(f"Created shared database connection for {path_key}")
55
+
56
+ return _shared_connections[path_key]
57
+
58
+
59
+ def close_shared_connection(db_path: Optional[Path] = None) -> None:
60
+ """Close a shared database connection."""
61
+ path_key = str(db_path or DEFAULT_DB_PATH)
62
+ if path_key in _shared_connections:
63
+ _shared_connections[path_key].close()
64
+ del _shared_connections[path_key]
65
+ logger.debug(f"Closed shared database connection for {path_key}")
66
+
33
67
  # Comprehensive set of corporate legal suffixes (international)
34
68
  COMPANY_SUFFIXES: set[str] = {
35
69
  'A/S', 'AB', 'AG', 'AO', 'AG & Co', 'AG &', 'AG & CO.', 'AG & CO. KG', 'AG & CO. KGaA',
@@ -43,6 +77,222 @@ COMPANY_SUFFIXES: set[str] = {
43
77
  'Group', 'Holdings', 'Holding', 'Partners', 'Trust', 'Fund', 'Bank', 'N.A.', 'The',
44
78
  }
45
79
 
80
+ # Source priority for organization canonicalization (lower = higher priority)
81
+ SOURCE_PRIORITY: dict[str, int] = {
82
+ "gleif": 1, # Gold standard LEI - globally unique legal entity identifier
83
+ "sec_edgar": 2, # Vetted US filers with CIK + ticker
84
+ "companies_house": 3, # Official UK registry
85
+ "wikipedia": 4, # Crowdsourced, less authoritative
86
+ }
87
+
88
+ # Source priority for people canonicalization (lower = higher priority)
89
+ PERSON_SOURCE_PRIORITY: dict[str, int] = {
90
+ "wikidata": 1, # Curated, has rich biographical data and Q codes
91
+ "sec_edgar": 2, # Vetted US filers (Form 4 officers/directors)
92
+ "companies_house": 3, # UK company officers
93
+ }
94
+
95
+ # Suffix expansions for canonical name matching
96
+ SUFFIX_EXPANSIONS: dict[str, str] = {
97
+ " ltd": " limited",
98
+ " corp": " corporation",
99
+ " inc": " incorporated",
100
+ " co": " company",
101
+ " intl": " international",
102
+ " natl": " national",
103
+ }
104
+
105
+
106
+ class UnionFind:
107
+ """Simple Union-Find (Disjoint Set Union) data structure for canonicalization."""
108
+
109
+ def __init__(self, elements: list[int]):
110
+ """Initialize with list of element IDs."""
111
+ self.parent: dict[int, int] = {e: e for e in elements}
112
+ self.rank: dict[int, int] = {e: 0 for e in elements}
113
+
114
+ def find(self, x: int) -> int:
115
+ """Find with path compression."""
116
+ if self.parent[x] != x:
117
+ self.parent[x] = self.find(self.parent[x])
118
+ return self.parent[x]
119
+
120
+ def union(self, x: int, y: int) -> None:
121
+ """Union by rank."""
122
+ px, py = self.find(x), self.find(y)
123
+ if px == py:
124
+ return
125
+ if self.rank[px] < self.rank[py]:
126
+ px, py = py, px
127
+ self.parent[py] = px
128
+ if self.rank[px] == self.rank[py]:
129
+ self.rank[px] += 1
130
+
131
+ def groups(self) -> dict[int, list[int]]:
132
+ """Return dict of root -> list of members."""
133
+ result: dict[int, list[int]] = {}
134
+ for e in self.parent:
135
+ root = self.find(e)
136
+ result.setdefault(root, []).append(e)
137
+ return result
138
+
139
+
140
+ # Common region aliases not handled well by pycountry fuzzy search
141
+ REGION_ALIASES: dict[str, str] = {
142
+ "uk": "GB",
143
+ "u.k.": "GB",
144
+ "england": "GB",
145
+ "scotland": "GB",
146
+ "wales": "GB",
147
+ "northern ireland": "GB",
148
+ "usa": "US",
149
+ "u.s.a.": "US",
150
+ "u.s.": "US",
151
+ "united states of america": "US",
152
+ "america": "US",
153
+ }
154
+
155
+ # Cache for region normalization lookups
156
+ _region_cache: dict[str, str] = {}
157
+
158
+
159
+ def _normalize_region(region: str) -> str:
160
+ """
161
+ Normalize a region string to ISO 3166-1 alpha-2 country code.
162
+
163
+ Handles:
164
+ - Country codes (2-letter, 3-letter)
165
+ - Country names (with fuzzy matching)
166
+ - US state codes (CA, NY) -> US
167
+ - US state names (California, New York) -> US
168
+ - Common aliases (UK, USA, England) -> proper codes
169
+
170
+ Returns empty string if region cannot be normalized.
171
+ """
172
+ if not region:
173
+ return ""
174
+
175
+ # Check cache first
176
+ cache_key = region.lower().strip()
177
+ if cache_key in _region_cache:
178
+ return _region_cache[cache_key]
179
+
180
+ result = _normalize_region_uncached(region)
181
+ _region_cache[cache_key] = result
182
+ return result
183
+
184
+
185
+ def _normalize_region_uncached(region: str) -> str:
186
+ """Uncached region normalization logic."""
187
+ region_clean = region.strip()
188
+
189
+ # Empty after stripping = empty result
190
+ if not region_clean:
191
+ return ""
192
+
193
+ region_lower = region_clean.lower()
194
+ region_upper = region_clean.upper()
195
+
196
+ # Check common aliases first
197
+ if region_lower in REGION_ALIASES:
198
+ return REGION_ALIASES[region_lower]
199
+
200
+ # For 2-letter codes, check country first, then US state
201
+ # This means ambiguous codes like "CA" (Canada vs California) prefer country
202
+ # But unambiguous codes like "NY" (not a country) will match as US state
203
+ if len(region_clean) == 2:
204
+ # Try as country alpha-2 first
205
+ country = pycountry.countries.get(alpha_2=region_upper)
206
+ if country:
207
+ return country.alpha_2
208
+
209
+ # If not a country, try as US state code
210
+ subdivision = pycountry.subdivisions.get(code=f"US-{region_upper}")
211
+ if subdivision:
212
+ return "US"
213
+
214
+ # Try alpha-3 lookup
215
+ if len(region_clean) == 3:
216
+ country = pycountry.countries.get(alpha_3=region_upper)
217
+ if country:
218
+ return country.alpha_2
219
+
220
+ # Try as US state name (e.g., "California", "New York")
221
+ try:
222
+ subdivisions = list(pycountry.subdivisions.search_fuzzy(region_clean))
223
+ if subdivisions:
224
+ # Check if it's a US state
225
+ if subdivisions[0].code.startswith("US-"):
226
+ return "US"
227
+ # Return the parent country code
228
+ return subdivisions[0].country_code
229
+ except LookupError:
230
+ pass
231
+
232
+ # Try country fuzzy search
233
+ try:
234
+ countries = pycountry.countries.search_fuzzy(region_clean)
235
+ if countries:
236
+ return countries[0].alpha_2
237
+ except LookupError:
238
+ pass
239
+
240
+ # Return empty if we can't normalize
241
+ return ""
242
+
243
+
244
+ def _regions_match(region1: str, region2: str) -> bool:
245
+ """
246
+ Check if two regions match after normalization.
247
+
248
+ Empty regions match anything (lenient matching for incomplete data).
249
+ """
250
+ norm1 = _normalize_region(region1)
251
+ norm2 = _normalize_region(region2)
252
+
253
+ # Empty regions match anything
254
+ if not norm1 or not norm2:
255
+ return True
256
+
257
+ return norm1 == norm2
258
+
259
+
260
+ def _normalize_for_canon(name: str) -> str:
261
+ """Normalize name for canonical matching (simpler than search normalization)."""
262
+ # Lowercase
263
+ result = name.lower()
264
+ # Remove trailing dots
265
+ result = result.rstrip(".")
266
+ # Remove extra whitespace
267
+ result = " ".join(result.split())
268
+ return result
269
+
270
+
271
+ def _expand_suffix(name: str) -> str:
272
+ """Expand known suffix abbreviations."""
273
+ result = name.lower().rstrip(".")
274
+ for abbrev, full in SUFFIX_EXPANSIONS.items():
275
+ if result.endswith(abbrev):
276
+ result = result[:-len(abbrev)] + full
277
+ break # Only expand one suffix
278
+ return result
279
+
280
+
281
+ def _names_match_for_canon(name1: str, name2: str) -> bool:
282
+ """Check if two names match for canonicalization."""
283
+ n1 = _normalize_for_canon(name1)
284
+ n2 = _normalize_for_canon(name2)
285
+
286
+ # Exact match after normalization
287
+ if n1 == n2:
288
+ return True
289
+
290
+ # Try with suffix expansion
291
+ if _expand_suffix(n1) == _expand_suffix(n2):
292
+ return True
293
+
294
+ return False
295
+
46
296
  # Pre-compile the suffix pattern for performance
47
297
  _SUFFIX_PATTERN = re.compile(
48
298
  r'\s+(' + '|'.join(re.escape(suffix) for suffix in COMPANY_SUFFIXES) + r')\.?$',
@@ -256,20 +506,13 @@ class OrganizationDatabase:
256
506
  self._db_path.parent.mkdir(parents=True, exist_ok=True)
257
507
 
258
508
  def _connect(self) -> sqlite3.Connection:
259
- """Get or create database connection with sqlite-vec loaded."""
509
+ """Get or create database connection using shared connection pool."""
260
510
  if self._conn is not None:
261
511
  return self._conn
262
512
 
263
- self._ensure_dir()
264
- self._conn = sqlite3.connect(str(self._db_path))
265
- self._conn.row_factory = sqlite3.Row
266
-
267
- # Load sqlite-vec extension
268
- self._conn.enable_load_extension(True)
269
- sqlite_vec.load(self._conn)
270
- self._conn.enable_load_extension(False)
513
+ self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
271
514
 
272
- # Create tables
515
+ # Create tables (idempotent)
273
516
  self._create_tables()
274
517
 
275
518
  return self._conn
@@ -289,6 +532,8 @@ class OrganizationDatabase:
289
532
  source_id TEXT NOT NULL,
290
533
  region TEXT NOT NULL DEFAULT '',
291
534
  entity_type TEXT NOT NULL DEFAULT 'unknown',
535
+ from_date TEXT NOT NULL DEFAULT '',
536
+ to_date TEXT NOT NULL DEFAULT '',
292
537
  record TEXT NOT NULL,
293
538
  UNIQUE(source, source_id)
294
539
  )
@@ -308,6 +553,34 @@ class OrganizationDatabase:
308
553
  except sqlite3.OperationalError:
309
554
  pass # Column already exists
310
555
 
556
+ # Add from_date column if it doesn't exist (migration for existing DBs)
557
+ try:
558
+ conn.execute("ALTER TABLE organizations ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
559
+ logger.info("Added from_date column to organizations table")
560
+ except sqlite3.OperationalError:
561
+ pass # Column already exists
562
+
563
+ # Add to_date column if it doesn't exist (migration for existing DBs)
564
+ try:
565
+ conn.execute("ALTER TABLE organizations ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
566
+ logger.info("Added to_date column to organizations table")
567
+ except sqlite3.OperationalError:
568
+ pass # Column already exists
569
+
570
+ # Add canon_id column if it doesn't exist (migration for canonicalization)
571
+ try:
572
+ conn.execute("ALTER TABLE organizations ADD COLUMN canon_id INTEGER DEFAULT NULL")
573
+ logger.info("Added canon_id column to organizations table")
574
+ except sqlite3.OperationalError:
575
+ pass # Column already exists
576
+
577
+ # Add canon_size column if it doesn't exist (migration for canonicalization)
578
+ try:
579
+ conn.execute("ALTER TABLE organizations ADD COLUMN canon_size INTEGER DEFAULT 1")
580
+ logger.info("Added canon_size column to organizations table")
581
+ except sqlite3.OperationalError:
582
+ pass # Column already exists
583
+
311
584
  # Create indexes on main table
312
585
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
313
586
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
@@ -316,6 +589,7 @@ class OrganizationDatabase:
316
589
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
317
590
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
318
591
  conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
592
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_canon_id ON organizations(canon_id)")
319
593
 
320
594
  # Create sqlite-vec virtual table for embeddings
321
595
  # vec0 is the recommended virtual table type
@@ -329,10 +603,8 @@ class OrganizationDatabase:
329
603
  conn.commit()
330
604
 
331
605
  def close(self) -> None:
332
- """Close database connection."""
333
- if self._conn:
334
- self._conn.close()
335
- self._conn = None
606
+ """Clear connection reference (shared connection remains open)."""
607
+ self._conn = None
336
608
 
337
609
  def insert(self, record: CompanyRecord, embedding: np.ndarray) -> int:
338
610
  """
@@ -353,8 +625,8 @@ class OrganizationDatabase:
353
625
 
354
626
  cursor = conn.execute("""
355
627
  INSERT OR REPLACE INTO organizations
356
- (name, name_normalized, source, source_id, region, entity_type, record)
357
- VALUES (?, ?, ?, ?, ?, ?, ?)
628
+ (name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
629
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
358
630
  """, (
359
631
  record.name,
360
632
  name_normalized,
@@ -362,6 +634,8 @@ class OrganizationDatabase:
362
634
  record.source_id,
363
635
  record.region,
364
636
  record.entity_type.value,
637
+ record.from_date or "",
638
+ record.to_date or "",
365
639
  record_json,
366
640
  ))
367
641
 
@@ -369,10 +643,11 @@ class OrganizationDatabase:
369
643
  assert row_id is not None
370
644
 
371
645
  # Insert embedding into vec table
372
- # sqlite-vec expects the embedding as a blob
646
+ # sqlite-vec virtual tables don't support INSERT OR REPLACE, so delete first
373
647
  embedding_blob = embedding.astype(np.float32).tobytes()
648
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
374
649
  conn.execute("""
375
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
650
+ INSERT INTO organization_embeddings (org_id, embedding)
376
651
  VALUES (?, ?)
377
652
  """, (row_id, embedding_blob))
378
653
 
@@ -405,8 +680,8 @@ class OrganizationDatabase:
405
680
 
406
681
  cursor = conn.execute("""
407
682
  INSERT OR REPLACE INTO organizations
408
- (name, name_normalized, source, source_id, region, entity_type, record)
409
- VALUES (?, ?, ?, ?, ?, ?, ?)
683
+ (name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
684
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
410
685
  """, (
411
686
  record.name,
412
687
  name_normalized,
@@ -414,16 +689,19 @@ class OrganizationDatabase:
414
689
  record.source_id,
415
690
  record.region,
416
691
  record.entity_type.value,
692
+ record.from_date or "",
693
+ record.to_date or "",
417
694
  record_json,
418
695
  ))
419
696
 
420
697
  row_id = cursor.lastrowid
421
698
  assert row_id is not None
422
699
 
423
- # Insert embedding
700
+ # Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
424
701
  embedding_blob = embedding.astype(np.float32).tobytes()
702
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
425
703
  conn.execute("""
426
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
704
+ INSERT INTO organization_embeddings (org_id, embedding)
427
705
  VALUES (?, ?)
428
706
  """, (row_id, embedding_blob))
429
707
 
@@ -443,13 +721,15 @@ class OrganizationDatabase:
443
721
  source_filter: Optional[str] = None,
444
722
  query_text: Optional[str] = None,
445
723
  max_text_candidates: int = 5000,
724
+ rerank_min_candidates: int = 500,
446
725
  ) -> list[tuple[CompanyRecord, float]]:
447
726
  """
448
727
  Search for similar organizations using hybrid text + vector search.
449
728
 
450
- Two-stage approach:
729
+ Three-stage approach:
451
730
  1. If query_text provided, use SQL LIKE to find candidates containing search terms
452
731
  2. Use sqlite-vec for vector similarity ranking on filtered candidates
732
+ 3. Apply prominence-based re-ranking to boost major companies (SEC filers, tickers)
453
733
 
454
734
  Args:
455
735
  query_embedding: Query embedding vector
@@ -457,9 +737,10 @@ class OrganizationDatabase:
457
737
  source_filter: Optional filter by source (gleif, sec_edgar, etc.)
458
738
  query_text: Optional query text for text-based pre-filtering
459
739
  max_text_candidates: Max candidates to keep after text filtering
740
+ rerank_min_candidates: Minimum candidates to fetch for re-ranking (default 500)
460
741
 
461
742
  Returns:
462
- List of (CompanyRecord, similarity_score) tuples
743
+ List of (CompanyRecord, adjusted_score) tuples sorted by prominence-adjusted score
463
744
  """
464
745
  start = time.time()
465
746
  self._connect()
@@ -473,6 +754,7 @@ class OrganizationDatabase:
473
754
 
474
755
  # Stage 1: Text-based pre-filtering (if query_text provided)
475
756
  candidate_ids: Optional[set[int]] = None
757
+ query_normalized_text = ""
476
758
  if query_text:
477
759
  query_normalized_text = _normalize_name(query_text)
478
760
  if query_normalized_text:
@@ -483,24 +765,168 @@ class OrganizationDatabase:
483
765
  )
484
766
  logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
485
767
 
486
- # Stage 2: Vector search
768
+ # Stage 2: Vector search - fetch more candidates for re-ranking
487
769
  if candidate_ids is not None and len(candidate_ids) == 0:
488
770
  # No text matches, return empty
489
771
  return []
490
772
 
773
+ # Fetch enough candidates for prominence re-ranking to be effective
774
+ # Use at least rerank_min_candidates, or all text-filtered candidates if fewer
775
+ if candidate_ids is not None:
776
+ fetch_k = min(len(candidate_ids), max(rerank_min_candidates, top_k * 5))
777
+ else:
778
+ fetch_k = max(rerank_min_candidates, top_k * 5)
779
+
491
780
  if candidate_ids is not None:
492
781
  # Search within text-filtered candidates
493
782
  results = self._vector_search_filtered(
494
- query_blob, candidate_ids, top_k, source_filter
783
+ query_blob, candidate_ids, fetch_k, source_filter
495
784
  )
496
785
  else:
497
786
  # Full vector search
498
- results = self._vector_search_full(query_blob, top_k, source_filter)
787
+ results = self._vector_search_full(query_blob, fetch_k, source_filter)
788
+
789
+ # Stage 3: Prominence-based re-ranking
790
+ if results and query_normalized_text:
791
+ results = self._apply_prominence_reranking(results, query_normalized_text, top_k)
792
+ else:
793
+ # No re-ranking, just trim to top_k
794
+ results = results[:top_k]
499
795
 
500
796
  elapsed = time.time() - start
501
797
  logger.debug(f"Hybrid search took {elapsed:.3f}s (results={len(results)})")
502
798
  return results
503
799
 
800
+ def _calculate_prominence_boost(
801
+ self,
802
+ record: CompanyRecord,
803
+ query_normalized: str,
804
+ canon_sources: Optional[set[str]] = None,
805
+ ) -> float:
806
+ """
807
+ Calculate prominence boost for re-ranking search results.
808
+
809
+ Boosts scores based on signals that indicate a major/prominent company:
810
+ - Has ticker symbol (publicly traded)
811
+ - GLEIF source (has LEI)
812
+ - SEC source (vetted US filers)
813
+ - Wikidata source (Wikipedia-notable)
814
+ - Exact normalized name match
815
+
816
+ When canon_sources is provided (from a canonical group), boosts are
817
+ applied for ALL sources in the canon group, not just this record's source.
818
+
819
+ Args:
820
+ record: The company record to evaluate
821
+ query_normalized: Normalized query text for exact match check
822
+ canon_sources: Optional set of sources in this record's canonical group
823
+
824
+ Returns:
825
+ Boost value to add to embedding similarity (0.0 to ~0.21)
826
+ """
827
+ boost = 0.0
828
+
829
+ # Get all sources to consider (canon group or just this record)
830
+ sources_to_check = canon_sources or {record.source}
831
+
832
+ # Has ticker symbol = publicly traded major company
833
+ # Check if ANY record in canon group has ticker
834
+ if record.record.get("ticker") or (canon_sources and "sec_edgar" in canon_sources):
835
+ boost += 0.08
836
+
837
+ # Source-based boosts - accumulate for all sources in canon group
838
+ if "gleif" in sources_to_check:
839
+ boost += 0.05 # Has LEI = verified legal entity
840
+ if "sec_edgar" in sources_to_check:
841
+ boost += 0.03 # SEC filer
842
+ if "wikipedia" in sources_to_check:
843
+ boost += 0.02 # Wikipedia notable
844
+
845
+ # Exact normalized name match bonus
846
+ record_normalized = _normalize_name(record.name)
847
+ if query_normalized == record_normalized:
848
+ boost += 0.05
849
+
850
+ return boost
851
+
852
+ def _apply_prominence_reranking(
853
+ self,
854
+ results: list[tuple[CompanyRecord, float]],
855
+ query_normalized: str,
856
+ top_k: int,
857
+ similarity_weight: float = 0.3,
858
+ ) -> list[tuple[CompanyRecord, float]]:
859
+ """
860
+ Apply prominence-based re-ranking to search results with canon group awareness.
861
+
862
+ When records have been canonicalized, boosts are applied based on ALL sources
863
+ in the canonical group, not just the matched record's source.
864
+
865
+ Args:
866
+ results: List of (record, similarity) from vector search
867
+ query_normalized: Normalized query text
868
+ top_k: Number of results to return after re-ranking
869
+ similarity_weight: Weight for similarity score (0-1), lower = prominence matters more
870
+
871
+ Returns:
872
+ Re-ranked list of (record, adjusted_score) tuples
873
+ """
874
+ conn = self._conn
875
+ assert conn is not None
876
+
877
+ # Build canon_id -> sources mapping for all results that have canon_id
878
+ canon_sources_map: dict[int, set[str]] = {}
879
+ canon_ids = [
880
+ r.record.get("canon_id")
881
+ for r, _ in results
882
+ if r.record.get("canon_id") is not None
883
+ ]
884
+
885
+ if canon_ids:
886
+ # Fetch all sources for each canon_id in one query
887
+ unique_canon_ids = list(set(canon_ids))
888
+ placeholders = ",".join("?" * len(unique_canon_ids))
889
+ rows = conn.execute(f"""
890
+ SELECT canon_id, source
891
+ FROM organizations
892
+ WHERE canon_id IN ({placeholders})
893
+ """, unique_canon_ids).fetchall()
894
+
895
+ for row in rows:
896
+ canon_id = row["canon_id"]
897
+ canon_sources_map.setdefault(canon_id, set()).add(row["source"])
898
+
899
+ # Calculate boosted scores with canon group awareness
900
+ # Formula: adjusted = (similarity * weight) + boost
901
+ # With weight=0.3, a sim=0.65 SEC+ticker (boost=0.11) beats sim=0.75 no-boost
902
+ boosted_results: list[tuple[CompanyRecord, float, float, float]] = []
903
+ for record, similarity in results:
904
+ canon_id = record.record.get("canon_id")
905
+ # Get all sources in this record's canon group (if any)
906
+ canon_sources = canon_sources_map.get(canon_id) if canon_id else None
907
+
908
+ boost = self._calculate_prominence_boost(record, query_normalized, canon_sources)
909
+ adjusted_score = (similarity * similarity_weight) + boost
910
+ boosted_results.append((record, similarity, boost, adjusted_score))
911
+
912
+ # Sort by adjusted score (descending)
913
+ boosted_results.sort(key=lambda x: x[3], reverse=True)
914
+
915
+ # Log re-ranking details for top results
916
+ logger.debug(f"Prominence re-ranking for '{query_normalized}':")
917
+ for record, sim, boost, adj in boosted_results[:10]:
918
+ ticker = record.record.get("ticker", "")
919
+ ticker_str = f" ticker={ticker}" if ticker else ""
920
+ canon_id = record.record.get("canon_id")
921
+ canon_str = f" canon={canon_id}" if canon_id else ""
922
+ logger.debug(
923
+ f" {record.name}: sim={sim:.3f} + boost={boost:.3f} = {adj:.3f} "
924
+ f"[{record.source}{ticker_str}{canon_str}]"
925
+ )
926
+
927
+ # Return top_k with adjusted scores
928
+ return [(r, adj) for r, _, _, adj in boosted_results[:top_k]]
929
+
504
930
  def _text_filter_candidates(
505
931
  self,
506
932
  query_normalized: str,
@@ -651,24 +1077,28 @@ class OrganizationDatabase:
651
1077
  return results
652
1078
 
653
1079
  def _get_record_by_id(self, org_id: int) -> Optional[CompanyRecord]:
654
- """Get an organization record by ID."""
1080
+ """Get an organization record by ID, including db_id and canon_id in record dict."""
655
1081
  conn = self._conn
656
1082
  assert conn is not None
657
1083
 
658
1084
  cursor = conn.execute("""
659
- SELECT name, source, source_id, region, entity_type, record
1085
+ SELECT id, name, source, source_id, region, entity_type, record, canon_id
660
1086
  FROM organizations WHERE id = ?
661
1087
  """, (org_id,))
662
1088
 
663
1089
  row = cursor.fetchone()
664
1090
  if row:
1091
+ record_data = json.loads(row["record"])
1092
+ # Add db_id and canon_id to record dict for canon-aware search
1093
+ record_data["db_id"] = row["id"]
1094
+ record_data["canon_id"] = row["canon_id"]
665
1095
  return CompanyRecord(
666
1096
  name=row["name"],
667
1097
  source=row["source"],
668
1098
  source_id=row["source_id"],
669
1099
  region=row["region"] or "",
670
1100
  entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
671
- record=json.loads(row["record"]),
1101
+ record=record_data,
672
1102
  )
673
1103
  return None
674
1104
 
@@ -694,6 +1124,20 @@ class OrganizationDatabase:
694
1124
  )
695
1125
  return None
696
1126
 
1127
+ def get_id_by_source_id(self, source: str, source_id: str) -> Optional[int]:
1128
+ """Get the internal database ID for an organization by source and source_id."""
1129
+ conn = self._connect()
1130
+
1131
+ cursor = conn.execute("""
1132
+ SELECT id FROM organizations
1133
+ WHERE source = ? AND source_id = ?
1134
+ """, (source, source_id))
1135
+
1136
+ row = cursor.fetchone()
1137
+ if row:
1138
+ return row["id"]
1139
+ return None
1140
+
697
1141
  def get_stats(self) -> DatabaseStats:
698
1142
  """Get database statistics."""
699
1143
  conn = self._connect()
@@ -716,6 +1160,30 @@ class OrganizationDatabase:
716
1160
  database_size_bytes=db_size,
717
1161
  )
718
1162
 
1163
+ def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
1164
+ """
1165
+ Get all source_ids from the organizations table.
1166
+
1167
+ Useful for resume operations to skip already-imported records.
1168
+
1169
+ Args:
1170
+ source: Optional source filter (e.g., "wikipedia" for Wikidata orgs)
1171
+
1172
+ Returns:
1173
+ Set of source_id strings (e.g., Q codes for Wikidata)
1174
+ """
1175
+ conn = self._connect()
1176
+
1177
+ if source:
1178
+ cursor = conn.execute(
1179
+ "SELECT DISTINCT source_id FROM organizations WHERE source = ?",
1180
+ (source,)
1181
+ )
1182
+ else:
1183
+ cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
1184
+
1185
+ return {row[0] for row in cursor}
1186
+
719
1187
  def iter_records(self, source: Optional[str] = None) -> Iterator[CompanyRecord]:
720
1188
  """Iterate over all records, optionally filtered by source."""
721
1189
  conn = self._connect()
@@ -742,6 +1210,245 @@ class OrganizationDatabase:
742
1210
  record=json.loads(row["record"]),
743
1211
  )
744
1212
 
1213
+ def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
1214
+ """
1215
+ Canonicalize all organizations by linking equivalent records.
1216
+
1217
+ Records are considered equivalent if they match by:
1218
+ 1. Same LEI (GLEIF source_id or Wikidata P1278) - globally unique, no region check
1219
+ 2. Same ticker symbol - globally unique, no region check
1220
+ 3. Same CIK - globally unique, no region check
1221
+ 4. Same normalized name AND same normalized region
1222
+ 5. Name match with suffix expansion AND same region
1223
+
1224
+ Region normalization uses pycountry to handle:
1225
+ - Country codes/names (GB, United Kingdom, Great Britain -> GB)
1226
+ - US state codes/names (CA, California -> US)
1227
+ - Common aliases (UK -> GB, USA -> US)
1228
+
1229
+ For each group of equivalent records, the highest-priority source
1230
+ (gleif > sec_edgar > companies_house > wikipedia) becomes canonical.
1231
+
1232
+ Args:
1233
+ batch_size: Commit batch size for updates
1234
+
1235
+ Returns:
1236
+ Dict with stats: total_records, groups_found, records_updated
1237
+ """
1238
+ conn = self._connect()
1239
+ logger.info("Starting canonicalization...")
1240
+
1241
+ # Phase 1: Load all organization data and build indexes
1242
+ logger.info("Phase 1: Building indexes...")
1243
+
1244
+ lei_index: dict[str, list[int]] = {}
1245
+ ticker_index: dict[str, list[int]] = {}
1246
+ cik_index: dict[str, list[int]] = {}
1247
+ # Name indexes now keyed by (normalized_name, normalized_region)
1248
+ # Region-less matching only applies for identifier-based matching
1249
+ name_region_index: dict[tuple[str, str], list[int]] = {}
1250
+ expanded_name_region_index: dict[tuple[str, str], list[int]] = {}
1251
+
1252
+ sources: dict[int, str] = {} # org_id -> source
1253
+ all_org_ids: list[int] = []
1254
+
1255
+ cursor = conn.execute("""
1256
+ SELECT id, source, source_id, name, region, record
1257
+ FROM organizations
1258
+ """)
1259
+
1260
+ count = 0
1261
+ for row in cursor:
1262
+ org_id = row["id"]
1263
+ source = row["source"]
1264
+ name = row["name"]
1265
+ region = row["region"] or ""
1266
+ record = json.loads(row["record"])
1267
+
1268
+ all_org_ids.append(org_id)
1269
+ sources[org_id] = source
1270
+
1271
+ # Index by LEI (GLEIF source_id or Wikidata's P1278)
1272
+ # LEI is globally unique - no region check needed
1273
+ if source == "gleif":
1274
+ lei = row["source_id"]
1275
+ else:
1276
+ lei = record.get("lei")
1277
+ if lei:
1278
+ lei_index.setdefault(lei.upper(), []).append(org_id)
1279
+
1280
+ # Index by ticker - globally unique, no region check
1281
+ ticker = record.get("ticker")
1282
+ if ticker:
1283
+ ticker_index.setdefault(ticker.upper(), []).append(org_id)
1284
+
1285
+ # Index by CIK - globally unique, no region check
1286
+ if source == "sec_edgar":
1287
+ cik = row["source_id"]
1288
+ else:
1289
+ cik = record.get("cik")
1290
+ if cik:
1291
+ cik_index.setdefault(str(cik), []).append(org_id)
1292
+
1293
+ # Index by (normalized_name, normalized_region)
1294
+ # Same name in different regions = different legal entities
1295
+ norm_name = _normalize_for_canon(name)
1296
+ norm_region = _normalize_region(region)
1297
+ if norm_name:
1298
+ key = (norm_name, norm_region)
1299
+ name_region_index.setdefault(key, []).append(org_id)
1300
+
1301
+ # Index by (expanded_name, normalized_region)
1302
+ expanded_name = _expand_suffix(name)
1303
+ if expanded_name and expanded_name != norm_name:
1304
+ key = (expanded_name, norm_region)
1305
+ expanded_name_region_index.setdefault(key, []).append(org_id)
1306
+
1307
+ count += 1
1308
+ if count % 100000 == 0:
1309
+ logger.info(f" Indexed {count} organizations...")
1310
+
1311
+ logger.info(f" Indexed {count} organizations total")
1312
+ logger.info(f" LEI index: {len(lei_index)} unique LEIs")
1313
+ logger.info(f" Ticker index: {len(ticker_index)} unique tickers")
1314
+ logger.info(f" CIK index: {len(cik_index)} unique CIKs")
1315
+ logger.info(f" Name+region index: {len(name_region_index)} unique (name, region) pairs")
1316
+ logger.info(f" Expanded name+region index: {len(expanded_name_region_index)} unique pairs")
1317
+
1318
+ # Phase 2: Build equivalence groups using Union-Find
1319
+ logger.info("Phase 2: Building equivalence groups...")
1320
+
1321
+ uf = UnionFind(all_org_ids)
1322
+
1323
+ # Merge by LEI (globally unique identifier)
1324
+ for _lei, ids in lei_index.items():
1325
+ for i in range(1, len(ids)):
1326
+ uf.union(ids[0], ids[i])
1327
+
1328
+ # Merge by ticker (globally unique identifier)
1329
+ for _ticker, ids in ticker_index.items():
1330
+ for i in range(1, len(ids)):
1331
+ uf.union(ids[0], ids[i])
1332
+
1333
+ # Merge by CIK (globally unique identifier)
1334
+ for _cik, ids in cik_index.items():
1335
+ for i in range(1, len(ids)):
1336
+ uf.union(ids[0], ids[i])
1337
+
1338
+ # Merge by (normalized_name, normalized_region)
1339
+ for _name_region, ids in name_region_index.items():
1340
+ for i in range(1, len(ids)):
1341
+ uf.union(ids[0], ids[i])
1342
+
1343
+ # Merge by (expanded_name, normalized_region)
1344
+ # This connects "Amazon Ltd" with "Amazon Limited" in same region
1345
+ for key, expanded_ids in expanded_name_region_index.items():
1346
+ # Find org_ids with the expanded form as their normalized name in same region
1347
+ if key in name_region_index:
1348
+ # Link first expanded_id to first name_id
1349
+ uf.union(expanded_ids[0], name_region_index[key][0])
1350
+
1351
+ groups = uf.groups()
1352
+ logger.info(f" Found {len(groups)} equivalence groups")
1353
+
1354
+ # Count groups with multiple records
1355
+ multi_record_groups = sum(1 for ids in groups.values() if len(ids) > 1)
1356
+ logger.info(f" Groups with multiple records: {multi_record_groups}")
1357
+
1358
+ # Phase 3: Select canonical record for each group and update database
1359
+ logger.info("Phase 3: Updating database...")
1360
+
1361
+ updated_count = 0
1362
+ batch_updates: list[tuple[int, int, int]] = [] # (org_id, canon_id, canon_size)
1363
+
1364
+ for _root, group_ids in groups.items():
1365
+ if len(group_ids) == 1:
1366
+ # Single record - canonical to itself
1367
+ batch_updates.append((group_ids[0], group_ids[0], 1))
1368
+ else:
1369
+ # Multiple records - find highest priority source
1370
+ best_id = min(
1371
+ group_ids,
1372
+ key=lambda oid: (SOURCE_PRIORITY.get(sources[oid], 99), oid)
1373
+ )
1374
+ group_size = len(group_ids)
1375
+
1376
+ # All records in group point to the best one
1377
+ for oid in group_ids:
1378
+ # canon_size is only set on the canonical record
1379
+ size = group_size if oid == best_id else 1
1380
+ batch_updates.append((oid, best_id, size))
1381
+
1382
+ # Commit batch
1383
+ if len(batch_updates) >= batch_size:
1384
+ self._apply_canon_updates(batch_updates)
1385
+ updated_count += len(batch_updates)
1386
+ logger.info(f" Updated {updated_count} records...")
1387
+ batch_updates = []
1388
+
1389
+ # Final batch
1390
+ if batch_updates:
1391
+ self._apply_canon_updates(batch_updates)
1392
+ updated_count += len(batch_updates)
1393
+
1394
+ conn.commit()
1395
+ logger.info(f"Canonicalization complete: {updated_count} records updated, {multi_record_groups} multi-record groups")
1396
+
1397
+ return {
1398
+ "total_records": count,
1399
+ "groups_found": len(groups),
1400
+ "multi_record_groups": multi_record_groups,
1401
+ "records_updated": updated_count,
1402
+ }
1403
+
1404
+ def _apply_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
1405
+ """Apply batch of canon updates: (org_id, canon_id, canon_size)."""
1406
+ conn = self._conn
1407
+ assert conn is not None
1408
+
1409
+ for org_id, canon_id, canon_size in updates:
1410
+ conn.execute(
1411
+ "UPDATE organizations SET canon_id = ?, canon_size = ? WHERE id = ?",
1412
+ (canon_id, canon_size, org_id)
1413
+ )
1414
+
1415
+ conn.commit()
1416
+
1417
+ def get_canon_stats(self) -> dict[str, int]:
1418
+ """Get statistics about canonicalization status."""
1419
+ conn = self._connect()
1420
+
1421
+ # Total records
1422
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations")
1423
+ total = cursor.fetchone()[0]
1424
+
1425
+ # Records with canon_id set
1426
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_id IS NOT NULL")
1427
+ canonicalized = cursor.fetchone()[0]
1428
+
1429
+ # Number of canonical groups (unique canon_ids)
1430
+ cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM organizations WHERE canon_id IS NOT NULL")
1431
+ groups = cursor.fetchone()[0]
1432
+
1433
+ # Multi-record groups (canon_size > 1)
1434
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_size > 1")
1435
+ multi_record_groups = cursor.fetchone()[0]
1436
+
1437
+ # Records in multi-record groups
1438
+ cursor = conn.execute("""
1439
+ SELECT COUNT(*) FROM organizations o1
1440
+ WHERE EXISTS (SELECT 1 FROM organizations o2 WHERE o2.id = o1.canon_id AND o2.canon_size > 1)
1441
+ """)
1442
+ records_in_multi = cursor.fetchone()[0]
1443
+
1444
+ return {
1445
+ "total_records": total,
1446
+ "canonicalized_records": canonicalized,
1447
+ "canonical_groups": groups,
1448
+ "multi_record_groups": multi_record_groups,
1449
+ "records_in_multi_groups": records_in_multi,
1450
+ }
1451
+
745
1452
  def migrate_name_normalized(self, batch_size: int = 50000) -> int:
746
1453
  """
747
1454
  Populate the name_normalized column for all records.
@@ -867,8 +1574,9 @@ class OrganizationDatabase:
867
1574
  assert conn is not None
868
1575
 
869
1576
  for org_id, embedding_blob in batch:
1577
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
870
1578
  conn.execute("""
871
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
1579
+ INSERT INTO organization_embeddings (org_id, embedding)
872
1580
  VALUES (?, ?)
873
1581
  """, (org_id, embedding_blob))
874
1582
 
@@ -1107,8 +1815,9 @@ class OrganizationDatabase:
1107
1815
 
1108
1816
  for org_id, embedding in zip(org_ids, embeddings):
1109
1817
  embedding_blob = embedding.astype(np.float32).tobytes()
1818
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
1110
1819
  conn.execute("""
1111
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
1820
+ INSERT INTO organization_embeddings (org_id, embedding)
1112
1821
  VALUES (?, ?)
1113
1822
  """, (org_id, embedding_blob))
1114
1823
  count += 1
@@ -1116,6 +1825,68 @@ class OrganizationDatabase:
1116
1825
  conn.commit()
1117
1826
  return count
1118
1827
 
1828
+ def resolve_qid_labels(
1829
+ self,
1830
+ label_map: dict[str, str],
1831
+ batch_size: int = 1000,
1832
+ ) -> int:
1833
+ """
1834
+ Update organization records that have QIDs instead of labels in region field.
1835
+
1836
+ Args:
1837
+ label_map: Mapping of QID -> label for resolution
1838
+ batch_size: Commit batch size
1839
+
1840
+ Returns:
1841
+ Number of records updated
1842
+ """
1843
+ conn = self._connect()
1844
+
1845
+ # Find records with QIDs in region field (starts with 'Q' followed by digits)
1846
+ region_updates = 0
1847
+ cursor = conn.execute("""
1848
+ SELECT id, region FROM organizations
1849
+ WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
1850
+ """)
1851
+ rows = cursor.fetchall()
1852
+
1853
+ for row in rows:
1854
+ org_id = row["id"]
1855
+ qid = row["region"]
1856
+ if qid in label_map:
1857
+ conn.execute(
1858
+ "UPDATE organizations SET region = ? WHERE id = ?",
1859
+ (label_map[qid], org_id)
1860
+ )
1861
+ region_updates += 1
1862
+
1863
+ if region_updates % batch_size == 0:
1864
+ conn.commit()
1865
+ logger.info(f"Updated {region_updates} organization region labels...")
1866
+
1867
+ conn.commit()
1868
+ logger.info(f"Resolved QID labels: {region_updates} organization regions")
1869
+ return region_updates
1870
+
1871
+ def get_unresolved_qids(self) -> set[str]:
1872
+ """
1873
+ Get all QIDs that still need resolution in the organizations table.
1874
+
1875
+ Returns:
1876
+ Set of QIDs (starting with 'Q') found in region field
1877
+ """
1878
+ conn = self._connect()
1879
+ qids: set[str] = set()
1880
+
1881
+ cursor = conn.execute("""
1882
+ SELECT DISTINCT region FROM organizations
1883
+ WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
1884
+ """)
1885
+ for row in cursor:
1886
+ qids.add(row["region"])
1887
+
1888
+ return qids
1889
+
1119
1890
 
1120
1891
  def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
1121
1892
  """
@@ -1167,20 +1938,13 @@ class PersonDatabase:
1167
1938
  self._db_path.parent.mkdir(parents=True, exist_ok=True)
1168
1939
 
1169
1940
  def _connect(self) -> sqlite3.Connection:
1170
- """Get or create database connection with sqlite-vec loaded."""
1941
+ """Get or create database connection using shared connection pool."""
1171
1942
  if self._conn is not None:
1172
1943
  return self._conn
1173
1944
 
1174
- self._ensure_dir()
1175
- self._conn = sqlite3.connect(str(self._db_path))
1176
- self._conn.row_factory = sqlite3.Row
1177
-
1178
- # Load sqlite-vec extension
1179
- self._conn.enable_load_extension(True)
1180
- sqlite_vec.load(self._conn)
1181
- self._conn.enable_load_extension(False)
1945
+ self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
1182
1946
 
1183
- # Create tables
1947
+ # Create tables (idempotent)
1184
1948
  self._create_tables()
1185
1949
 
1186
1950
  return self._conn
@@ -1190,7 +1954,12 @@ class PersonDatabase:
1190
1954
  conn = self._conn
1191
1955
  assert conn is not None
1192
1956
 
1957
+ # Check if we need to migrate from old schema (unique on source+source_id only)
1958
+ self._migrate_people_schema_if_needed(conn)
1959
+
1193
1960
  # Main people records table
1961
+ # Unique constraint on source+source_id+role+org allows multiple records
1962
+ # for the same person with different role/org combinations
1194
1963
  conn.execute("""
1195
1964
  CREATE TABLE IF NOT EXISTS people (
1196
1965
  id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -1202,8 +1971,14 @@ class PersonDatabase:
1202
1971
  person_type TEXT NOT NULL DEFAULT 'unknown',
1203
1972
  known_for_role TEXT NOT NULL DEFAULT '',
1204
1973
  known_for_org TEXT NOT NULL DEFAULT '',
1974
+ known_for_org_id INTEGER DEFAULT NULL,
1975
+ from_date TEXT NOT NULL DEFAULT '',
1976
+ to_date TEXT NOT NULL DEFAULT '',
1977
+ birth_date TEXT NOT NULL DEFAULT '',
1978
+ death_date TEXT NOT NULL DEFAULT '',
1205
1979
  record TEXT NOT NULL,
1206
- UNIQUE(source, source_id)
1980
+ UNIQUE(source, source_id, known_for_role, known_for_org),
1981
+ FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
1207
1982
  )
1208
1983
  """)
1209
1984
 
@@ -1211,9 +1986,71 @@ class PersonDatabase:
1211
1986
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name ON people(name)")
1212
1987
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name_normalized ON people(name_normalized)")
1213
1988
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source ON people(source)")
1214
- conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id)")
1989
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id, known_for_role, known_for_org)")
1215
1990
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org ON people(known_for_org)")
1216
1991
 
1992
+ # Add from_date column if it doesn't exist (migration for existing DBs)
1993
+ try:
1994
+ conn.execute("ALTER TABLE people ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
1995
+ logger.info("Added from_date column to people table")
1996
+ except sqlite3.OperationalError:
1997
+ pass # Column already exists
1998
+
1999
+ # Add to_date column if it doesn't exist (migration for existing DBs)
2000
+ try:
2001
+ conn.execute("ALTER TABLE people ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
2002
+ logger.info("Added to_date column to people table")
2003
+ except sqlite3.OperationalError:
2004
+ pass # Column already exists
2005
+
2006
+ # Add known_for_org_id column if it doesn't exist (migration for existing DBs)
2007
+ # This is a foreign key to the organizations table (nullable)
2008
+ try:
2009
+ conn.execute("ALTER TABLE people ADD COLUMN known_for_org_id INTEGER DEFAULT NULL")
2010
+ logger.info("Added known_for_org_id column to people table")
2011
+ except sqlite3.OperationalError:
2012
+ pass # Column already exists
2013
+
2014
+ # Create index on known_for_org_id for joins (only if column exists)
2015
+ try:
2016
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org_id ON people(known_for_org_id)")
2017
+ except sqlite3.OperationalError:
2018
+ pass # Column doesn't exist yet (will be added on next connection)
2019
+
2020
+ # Add birth_date column if it doesn't exist (migration for existing DBs)
2021
+ try:
2022
+ conn.execute("ALTER TABLE people ADD COLUMN birth_date TEXT NOT NULL DEFAULT ''")
2023
+ logger.info("Added birth_date column to people table")
2024
+ except sqlite3.OperationalError:
2025
+ pass # Column already exists
2026
+
2027
+ # Add death_date column if it doesn't exist (migration for existing DBs)
2028
+ try:
2029
+ conn.execute("ALTER TABLE people ADD COLUMN death_date TEXT NOT NULL DEFAULT ''")
2030
+ logger.info("Added death_date column to people table")
2031
+ except sqlite3.OperationalError:
2032
+ pass # Column already exists
2033
+
2034
+ # Add canon_id column if it doesn't exist (migration for canonicalization)
2035
+ try:
2036
+ conn.execute("ALTER TABLE people ADD COLUMN canon_id INTEGER DEFAULT NULL")
2037
+ logger.info("Added canon_id column to people table")
2038
+ except sqlite3.OperationalError:
2039
+ pass # Column already exists
2040
+
2041
+ # Add canon_size column if it doesn't exist (migration for canonicalization)
2042
+ try:
2043
+ conn.execute("ALTER TABLE people ADD COLUMN canon_size INTEGER DEFAULT 1")
2044
+ logger.info("Added canon_size column to people table")
2045
+ except sqlite3.OperationalError:
2046
+ pass # Column already exists
2047
+
2048
+ # Create index on canon_id for joins
2049
+ try:
2050
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_canon_id ON people(canon_id)")
2051
+ except sqlite3.OperationalError:
2052
+ pass # Column doesn't exist yet
2053
+
1217
2054
  # Create sqlite-vec virtual table for embeddings
1218
2055
  conn.execute(f"""
1219
2056
  CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
@@ -1222,13 +2059,94 @@ class PersonDatabase:
1222
2059
  )
1223
2060
  """)
1224
2061
 
2062
+ # Create QID labels lookup table for Wikidata QID -> label mappings
2063
+ conn.execute("""
2064
+ CREATE TABLE IF NOT EXISTS qid_labels (
2065
+ qid TEXT PRIMARY KEY,
2066
+ label TEXT NOT NULL
2067
+ )
2068
+ """)
2069
+
1225
2070
  conn.commit()
1226
2071
 
2072
+ def _migrate_people_schema_if_needed(self, conn: sqlite3.Connection) -> None:
2073
+ """Migrate people table from old schema if needed."""
2074
+ # Check if people table exists
2075
+ cursor = conn.execute(
2076
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='people'"
2077
+ )
2078
+ if not cursor.fetchone():
2079
+ return # Table doesn't exist, no migration needed
2080
+
2081
+ # Check the unique constraint - look at index info
2082
+ # Old schema: UNIQUE(source, source_id)
2083
+ # New schema: UNIQUE(source, source_id, known_for_role, known_for_org)
2084
+ cursor = conn.execute("PRAGMA index_list(people)")
2085
+ indexes = cursor.fetchall()
2086
+
2087
+ needs_migration = False
2088
+ for idx in indexes:
2089
+ idx_name = idx[1]
2090
+ if "sqlite_autoindex_people" in idx_name:
2091
+ # Check columns in this unique index
2092
+ cursor = conn.execute(f"PRAGMA index_info('{idx_name}')")
2093
+ cols = [row[2] for row in cursor.fetchall()]
2094
+ # Old schema has only 2 columns in unique constraint
2095
+ if cols == ["source", "source_id"]:
2096
+ needs_migration = True
2097
+ logger.info("Detected old people schema, migrating to new unique constraint...")
2098
+ break
2099
+
2100
+ if not needs_migration:
2101
+ return
2102
+
2103
+ # Migrate: create new table, copy data, drop old, rename new
2104
+ logger.info("Migrating people table to new schema with (source, source_id, role, org) unique constraint...")
2105
+
2106
+ conn.execute("""
2107
+ CREATE TABLE people_new (
2108
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
2109
+ name TEXT NOT NULL,
2110
+ name_normalized TEXT NOT NULL,
2111
+ source TEXT NOT NULL DEFAULT 'wikidata',
2112
+ source_id TEXT NOT NULL,
2113
+ country TEXT NOT NULL DEFAULT '',
2114
+ person_type TEXT NOT NULL DEFAULT 'unknown',
2115
+ known_for_role TEXT NOT NULL DEFAULT '',
2116
+ known_for_org TEXT NOT NULL DEFAULT '',
2117
+ known_for_org_id INTEGER DEFAULT NULL,
2118
+ from_date TEXT NOT NULL DEFAULT '',
2119
+ to_date TEXT NOT NULL DEFAULT '',
2120
+ record TEXT NOT NULL,
2121
+ UNIQUE(source, source_id, known_for_role, known_for_org),
2122
+ FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
2123
+ )
2124
+ """)
2125
+
2126
+ # Copy data (old IDs will change, but embeddings table references them)
2127
+ # Note: old table may not have from_date/to_date columns, so use defaults
2128
+ conn.execute("""
2129
+ INSERT INTO people_new (name, name_normalized, source, source_id, country,
2130
+ person_type, known_for_role, known_for_org, record)
2131
+ SELECT name, name_normalized, source, source_id, country,
2132
+ person_type, known_for_role, known_for_org, record
2133
+ FROM people
2134
+ """)
2135
+
2136
+ # Drop old table and embeddings (IDs changed, embeddings are invalid)
2137
+ conn.execute("DROP TABLE IF EXISTS person_embeddings")
2138
+ conn.execute("DROP TABLE people")
2139
+ conn.execute("ALTER TABLE people_new RENAME TO people")
2140
+
2141
+ # Drop old index if it exists
2142
+ conn.execute("DROP INDEX IF EXISTS idx_people_source_id")
2143
+
2144
+ conn.commit()
2145
+ logger.info("Migration complete. Note: person embeddings were cleared and need to be regenerated.")
2146
+
1227
2147
  def close(self) -> None:
1228
- """Close database connection."""
1229
- if self._conn:
1230
- self._conn.close()
1231
- self._conn = None
2148
+ """Clear connection reference (shared connection remains open)."""
2149
+ self._conn = None
1232
2150
 
1233
2151
  def insert(self, record: PersonRecord, embedding: np.ndarray) -> int:
1234
2152
  """
@@ -1249,8 +2167,10 @@ class PersonDatabase:
1249
2167
 
1250
2168
  cursor = conn.execute("""
1251
2169
  INSERT OR REPLACE INTO people
1252
- (name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
1253
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
2170
+ (name, name_normalized, source, source_id, country, person_type,
2171
+ known_for_role, known_for_org, known_for_org_id, from_date, to_date,
2172
+ birth_date, death_date, record)
2173
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
1254
2174
  """, (
1255
2175
  record.name,
1256
2176
  name_normalized,
@@ -1260,16 +2180,22 @@ class PersonDatabase:
1260
2180
  record.person_type.value,
1261
2181
  record.known_for_role,
1262
2182
  record.known_for_org,
2183
+ record.known_for_org_id, # Can be None
2184
+ record.from_date or "",
2185
+ record.to_date or "",
2186
+ record.birth_date or "",
2187
+ record.death_date or "",
1263
2188
  record_json,
1264
2189
  ))
1265
2190
 
1266
2191
  row_id = cursor.lastrowid
1267
2192
  assert row_id is not None
1268
2193
 
1269
- # Insert embedding into vec table
2194
+ # Insert embedding into vec table (delete first since sqlite-vec doesn't support REPLACE)
1270
2195
  embedding_blob = embedding.astype(np.float32).tobytes()
2196
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
1271
2197
  conn.execute("""
1272
- INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
2198
+ INSERT INTO person_embeddings (person_id, embedding)
1273
2199
  VALUES (?, ?)
1274
2200
  """, (row_id, embedding_blob))
1275
2201
 
@@ -1302,8 +2228,10 @@ class PersonDatabase:
1302
2228
 
1303
2229
  cursor = conn.execute("""
1304
2230
  INSERT OR REPLACE INTO people
1305
- (name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
1306
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
2231
+ (name, name_normalized, source, source_id, country, person_type,
2232
+ known_for_role, known_for_org, known_for_org_id, from_date, to_date,
2233
+ birth_date, death_date, record)
2234
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
1307
2235
  """, (
1308
2236
  record.name,
1309
2237
  name_normalized,
@@ -1313,16 +2241,22 @@ class PersonDatabase:
1313
2241
  record.person_type.value,
1314
2242
  record.known_for_role,
1315
2243
  record.known_for_org,
2244
+ record.known_for_org_id, # Can be None
2245
+ record.from_date or "",
2246
+ record.to_date or "",
2247
+ record.birth_date or "",
2248
+ record.death_date or "",
1316
2249
  record_json,
1317
2250
  ))
1318
2251
 
1319
2252
  row_id = cursor.lastrowid
1320
2253
  assert row_id is not None
1321
2254
 
1322
- # Insert embedding
2255
+ # Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
1323
2256
  embedding_blob = embedding.astype(np.float32).tobytes()
2257
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
1324
2258
  conn.execute("""
1325
- INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
2259
+ INSERT INTO person_embeddings (person_id, embedding)
1326
2260
  VALUES (?, ?)
1327
2261
  """, (row_id, embedding_blob))
1328
2262
 
@@ -1335,6 +2269,88 @@ class PersonDatabase:
1335
2269
  conn.commit()
1336
2270
  return count
1337
2271
 
2272
+ def update_dates(self, source: str, source_id: str, from_date: Optional[str], to_date: Optional[str]) -> bool:
2273
+ """
2274
+ Update the from_date and to_date for a person record.
2275
+
2276
+ Args:
2277
+ source: Data source (e.g., 'wikidata')
2278
+ source_id: Source identifier (e.g., QID)
2279
+ from_date: Start date in ISO format or None
2280
+ to_date: End date in ISO format or None
2281
+
2282
+ Returns:
2283
+ True if record was updated, False if not found
2284
+ """
2285
+ conn = self._connect()
2286
+
2287
+ cursor = conn.execute("""
2288
+ UPDATE people SET from_date = ?, to_date = ?
2289
+ WHERE source = ? AND source_id = ?
2290
+ """, (from_date or "", to_date or "", source, source_id))
2291
+
2292
+ conn.commit()
2293
+ return cursor.rowcount > 0
2294
+
2295
+ def update_role_org(
2296
+ self,
2297
+ source: str,
2298
+ source_id: str,
2299
+ known_for_role: str,
2300
+ known_for_org: str,
2301
+ known_for_org_id: Optional[int],
2302
+ new_embedding: np.ndarray,
2303
+ from_date: Optional[str] = None,
2304
+ to_date: Optional[str] = None,
2305
+ ) -> bool:
2306
+ """
2307
+ Update the role/org/dates data for a person record and re-embed.
2308
+
2309
+ Args:
2310
+ source: Data source (e.g., 'wikidata')
2311
+ source_id: Source identifier (e.g., QID)
2312
+ known_for_role: Role/position title
2313
+ known_for_org: Organization name
2314
+ known_for_org_id: Organization internal ID (FK) or None
2315
+ new_embedding: New embedding vector based on updated data
2316
+ from_date: Start date in ISO format or None
2317
+ to_date: End date in ISO format or None
2318
+
2319
+ Returns:
2320
+ True if record was updated, False if not found
2321
+ """
2322
+ conn = self._connect()
2323
+
2324
+ # First get the person's internal ID
2325
+ row = conn.execute(
2326
+ "SELECT id FROM people WHERE source = ? AND source_id = ?",
2327
+ (source, source_id)
2328
+ ).fetchone()
2329
+
2330
+ if not row:
2331
+ return False
2332
+
2333
+ person_id = row[0]
2334
+
2335
+ # Update the person record (including dates)
2336
+ conn.execute("""
2337
+ UPDATE people SET
2338
+ known_for_role = ?, known_for_org = ?, known_for_org_id = ?,
2339
+ from_date = COALESCE(?, from_date, ''),
2340
+ to_date = COALESCE(?, to_date, '')
2341
+ WHERE id = ?
2342
+ """, (known_for_role, known_for_org, known_for_org_id, from_date, to_date, person_id))
2343
+
2344
+ # Update the embedding
2345
+ embedding_bytes = new_embedding.astype(np.float32).tobytes()
2346
+ conn.execute("""
2347
+ UPDATE people_vec SET embedding = ?
2348
+ WHERE rowid = ?
2349
+ """, (embedding_bytes, person_id))
2350
+
2351
+ conn.commit()
2352
+ return True
2353
+
1338
2354
  def search(
1339
2355
  self,
1340
2356
  query_embedding: np.ndarray,
@@ -1516,7 +2532,7 @@ class PersonDatabase:
1516
2532
  assert conn is not None
1517
2533
 
1518
2534
  cursor = conn.execute("""
1519
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
2535
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
1520
2536
  FROM people WHERE id = ?
1521
2537
  """, (person_id,))
1522
2538
 
@@ -1530,6 +2546,9 @@ class PersonDatabase:
1530
2546
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1531
2547
  known_for_role=row["known_for_role"] or "",
1532
2548
  known_for_org=row["known_for_org"] or "",
2549
+ known_for_org_id=row["known_for_org_id"], # Can be None
2550
+ birth_date=row["birth_date"] or "",
2551
+ death_date=row["death_date"] or "",
1533
2552
  record=json.loads(row["record"]),
1534
2553
  )
1535
2554
  return None
@@ -1539,7 +2558,7 @@ class PersonDatabase:
1539
2558
  conn = self._connect()
1540
2559
 
1541
2560
  cursor = conn.execute("""
1542
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
2561
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
1543
2562
  FROM people
1544
2563
  WHERE source = ? AND source_id = ?
1545
2564
  """, (source, source_id))
@@ -1554,6 +2573,9 @@ class PersonDatabase:
1554
2573
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1555
2574
  known_for_role=row["known_for_role"] or "",
1556
2575
  known_for_org=row["known_for_org"] or "",
2576
+ known_for_org_id=row["known_for_org_id"], # Can be None
2577
+ birth_date=row["birth_date"] or "",
2578
+ death_date=row["death_date"] or "",
1557
2579
  record=json.loads(row["record"]),
1558
2580
  )
1559
2581
  return None
@@ -1580,19 +2602,43 @@ class PersonDatabase:
1580
2602
  "by_source": by_source,
1581
2603
  }
1582
2604
 
2605
+ def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
2606
+ """
2607
+ Get all source_ids from the people table.
2608
+
2609
+ Useful for resume operations to skip already-imported records.
2610
+
2611
+ Args:
2612
+ source: Optional source filter (e.g., "wikidata")
2613
+
2614
+ Returns:
2615
+ Set of source_id strings (e.g., Q codes for Wikidata)
2616
+ """
2617
+ conn = self._connect()
2618
+
2619
+ if source:
2620
+ cursor = conn.execute(
2621
+ "SELECT DISTINCT source_id FROM people WHERE source = ?",
2622
+ (source,)
2623
+ )
2624
+ else:
2625
+ cursor = conn.execute("SELECT DISTINCT source_id FROM people")
2626
+
2627
+ return {row[0] for row in cursor}
2628
+
1583
2629
  def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
1584
2630
  """Iterate over all person records, optionally filtered by source."""
1585
2631
  conn = self._connect()
1586
2632
 
1587
2633
  if source:
1588
2634
  cursor = conn.execute("""
1589
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
2635
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
1590
2636
  FROM people
1591
2637
  WHERE source = ?
1592
2638
  """, (source,))
1593
2639
  else:
1594
2640
  cursor = conn.execute("""
1595
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
2641
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
1596
2642
  FROM people
1597
2643
  """)
1598
2644
 
@@ -1605,5 +2651,384 @@ class PersonDatabase:
1605
2651
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1606
2652
  known_for_role=row["known_for_role"] or "",
1607
2653
  known_for_org=row["known_for_org"] or "",
2654
+ known_for_org_id=row["known_for_org_id"], # Can be None
2655
+ birth_date=row["birth_date"] or "",
2656
+ death_date=row["death_date"] or "",
1608
2657
  record=json.loads(row["record"]),
1609
2658
  )
2659
+
2660
+ def resolve_qid_labels(
2661
+ self,
2662
+ label_map: dict[str, str],
2663
+ batch_size: int = 1000,
2664
+ ) -> tuple[int, int]:
2665
+ """
2666
+ Update records that have QIDs instead of labels.
2667
+
2668
+ This is called after dump import to resolve any QIDs that were
2669
+ stored because labels weren't available in the cache at import time.
2670
+
2671
+ If resolving would create a duplicate of an existing record with
2672
+ resolved labels, the QID version is deleted instead.
2673
+
2674
+ Args:
2675
+ label_map: Mapping of QID -> label for resolution
2676
+ batch_size: Commit batch size
2677
+
2678
+ Returns:
2679
+ Tuple of (updates, deletes)
2680
+ """
2681
+ conn = self._connect()
2682
+
2683
+ # Find all records with QIDs in any field (role or org - these are in unique constraint)
2684
+ # Country is not part of unique constraint so can be updated directly
2685
+ cursor = conn.execute("""
2686
+ SELECT id, source, source_id, country, known_for_role, known_for_org
2687
+ FROM people
2688
+ WHERE (country LIKE 'Q%' AND country GLOB 'Q[0-9]*')
2689
+ OR (known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*')
2690
+ OR (known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*')
2691
+ """)
2692
+ rows = cursor.fetchall()
2693
+
2694
+ updates = 0
2695
+ deletes = 0
2696
+
2697
+ for row in rows:
2698
+ person_id = row["id"]
2699
+ source = row["source"]
2700
+ source_id = row["source_id"]
2701
+ country = row["country"]
2702
+ role = row["known_for_role"]
2703
+ org = row["known_for_org"]
2704
+
2705
+ # Resolve QIDs to labels
2706
+ new_country = label_map.get(country, country) if country.startswith("Q") and country[1:].isdigit() else country
2707
+ new_role = label_map.get(role, role) if role.startswith("Q") and role[1:].isdigit() else role
2708
+ new_org = label_map.get(org, org) if org.startswith("Q") and org[1:].isdigit() else org
2709
+
2710
+ # Skip if nothing changed
2711
+ if new_country == country and new_role == role and new_org == org:
2712
+ continue
2713
+
2714
+ # Check if resolved values would duplicate an existing record
2715
+ # (unique constraint is on source, source_id, known_for_role, known_for_org)
2716
+ if new_role != role or new_org != org:
2717
+ cursor2 = conn.execute("""
2718
+ SELECT id FROM people
2719
+ WHERE source = ? AND source_id = ? AND known_for_role = ? AND known_for_org = ?
2720
+ AND id != ?
2721
+ """, (source, source_id, new_role, new_org, person_id))
2722
+ existing = cursor2.fetchone()
2723
+
2724
+ if existing:
2725
+ # Duplicate would exist - delete the QID version
2726
+ conn.execute("DELETE FROM people WHERE id = ?", (person_id,))
2727
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
2728
+ deletes += 1
2729
+ logger.debug(f"Deleted duplicate QID record {person_id} (source_id={source_id})")
2730
+ continue
2731
+
2732
+ # No duplicate - update in place
2733
+ conn.execute("""
2734
+ UPDATE people SET country = ?, known_for_role = ?, known_for_org = ?
2735
+ WHERE id = ?
2736
+ """, (new_country, new_role, new_org, person_id))
2737
+ updates += 1
2738
+
2739
+ if (updates + deletes) % batch_size == 0:
2740
+ conn.commit()
2741
+ logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes...")
2742
+
2743
+ conn.commit()
2744
+ logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes")
2745
+ return updates, deletes
2746
+
2747
+ def get_unresolved_qids(self) -> set[str]:
2748
+ """
2749
+ Get all QIDs that still need resolution in the database.
2750
+
2751
+ Returns:
2752
+ Set of QIDs (starting with 'Q') found in country, role, or org fields
2753
+ """
2754
+ conn = self._connect()
2755
+ qids: set[str] = set()
2756
+
2757
+ # Get QIDs from country field
2758
+ cursor = conn.execute("""
2759
+ SELECT DISTINCT country FROM people
2760
+ WHERE country LIKE 'Q%' AND country GLOB 'Q[0-9]*'
2761
+ """)
2762
+ for row in cursor:
2763
+ qids.add(row["country"])
2764
+
2765
+ # Get QIDs from known_for_role field
2766
+ cursor = conn.execute("""
2767
+ SELECT DISTINCT known_for_role FROM people
2768
+ WHERE known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*'
2769
+ """)
2770
+ for row in cursor:
2771
+ qids.add(row["known_for_role"])
2772
+
2773
+ # Get QIDs from known_for_org field
2774
+ cursor = conn.execute("""
2775
+ SELECT DISTINCT known_for_org FROM people
2776
+ WHERE known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*'
2777
+ """)
2778
+ for row in cursor:
2779
+ qids.add(row["known_for_org"])
2780
+
2781
+ return qids
2782
+
2783
+ def insert_qid_labels(
2784
+ self,
2785
+ label_map: dict[str, str],
2786
+ batch_size: int = 1000,
2787
+ ) -> int:
2788
+ """
2789
+ Insert QID -> label mappings into the lookup table.
2790
+
2791
+ Args:
2792
+ label_map: Mapping of QID -> label
2793
+ batch_size: Commit batch size
2794
+
2795
+ Returns:
2796
+ Number of labels inserted/updated
2797
+ """
2798
+ conn = self._connect()
2799
+ count = 0
2800
+
2801
+ for qid, label in label_map.items():
2802
+ conn.execute(
2803
+ "INSERT OR REPLACE INTO qid_labels (qid, label) VALUES (?, ?)",
2804
+ (qid, label)
2805
+ )
2806
+ count += 1
2807
+
2808
+ if count % batch_size == 0:
2809
+ conn.commit()
2810
+ logger.debug(f"Inserted {count} QID labels...")
2811
+
2812
+ conn.commit()
2813
+ logger.info(f"Inserted {count} QID labels into lookup table")
2814
+ return count
2815
+
2816
+ def get_qid_label(self, qid: str) -> Optional[str]:
2817
+ """
2818
+ Get the label for a QID from the lookup table.
2819
+
2820
+ Args:
2821
+ qid: Wikidata QID (e.g., 'Q30')
2822
+
2823
+ Returns:
2824
+ Label string or None if not found
2825
+ """
2826
+ conn = self._connect()
2827
+ cursor = conn.execute(
2828
+ "SELECT label FROM qid_labels WHERE qid = ?",
2829
+ (qid,)
2830
+ )
2831
+ row = cursor.fetchone()
2832
+ return row["label"] if row else None
2833
+
2834
+ def get_all_qid_labels(self) -> dict[str, str]:
2835
+ """
2836
+ Get all QID -> label mappings from the lookup table.
2837
+
2838
+ Returns:
2839
+ Dict mapping QID -> label
2840
+ """
2841
+ conn = self._connect()
2842
+ cursor = conn.execute("SELECT qid, label FROM qid_labels")
2843
+ return {row["qid"]: row["label"] for row in cursor}
2844
+
2845
+ def get_qid_labels_count(self) -> int:
2846
+ """Get the number of QID labels in the lookup table."""
2847
+ conn = self._connect()
2848
+ cursor = conn.execute("SELECT COUNT(*) FROM qid_labels")
2849
+ return cursor.fetchone()[0]
2850
+
2851
+ def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
2852
+ """
2853
+ Canonicalize person records by linking equivalent entries across sources.
2854
+
2855
+ Uses a multi-phase approach:
2856
+ 1. Match by normalized name + same organization (org canonical group)
2857
+ 2. Match by normalized name + overlapping date ranges
2858
+
2859
+ Source priority (lower = more authoritative):
2860
+ - wikidata: 1 (curated, has Q codes)
2861
+ - sec_edgar: 2 (US insider filings)
2862
+ - companies_house: 3 (UK officers)
2863
+
2864
+ Args:
2865
+ batch_size: Number of records to process before committing
2866
+
2867
+ Returns:
2868
+ Stats dict with counts for each matching type
2869
+ """
2870
+ conn = self._connect()
2871
+ stats = {
2872
+ "total_records": 0,
2873
+ "matched_by_org": 0,
2874
+ "matched_by_date": 0,
2875
+ "canonical_groups": 0,
2876
+ "records_in_groups": 0,
2877
+ }
2878
+
2879
+ logger.info("Phase 1: Building person index...")
2880
+
2881
+ # Load all people with their normalized names and org info
2882
+ cursor = conn.execute("""
2883
+ SELECT id, name, name_normalized, source, source_id,
2884
+ known_for_org, known_for_org_id, from_date, to_date
2885
+ FROM people
2886
+ """)
2887
+
2888
+ people: list[dict] = []
2889
+ for row in cursor:
2890
+ people.append({
2891
+ "id": row["id"],
2892
+ "name": row["name"],
2893
+ "name_normalized": row["name_normalized"],
2894
+ "source": row["source"],
2895
+ "source_id": row["source_id"],
2896
+ "known_for_org": row["known_for_org"],
2897
+ "known_for_org_id": row["known_for_org_id"],
2898
+ "from_date": row["from_date"],
2899
+ "to_date": row["to_date"],
2900
+ })
2901
+
2902
+ stats["total_records"] = len(people)
2903
+ logger.info(f"Loaded {len(people)} person records")
2904
+
2905
+ if len(people) == 0:
2906
+ return stats
2907
+
2908
+ # Initialize Union-Find
2909
+ person_ids = [p["id"] for p in people]
2910
+ uf = UnionFind(person_ids)
2911
+
2912
+ # Build indexes for efficient matching
2913
+ # Index by normalized name
2914
+ name_to_people: dict[str, list[dict]] = {}
2915
+ for p in people:
2916
+ name_norm = p["name_normalized"]
2917
+ name_to_people.setdefault(name_norm, []).append(p)
2918
+
2919
+ logger.info("Phase 2: Matching by normalized name + organization...")
2920
+
2921
+ # Match people with same normalized name and same organization
2922
+ for name_norm, same_name in name_to_people.items():
2923
+ if len(same_name) < 2:
2924
+ continue
2925
+
2926
+ # Group by organization (using known_for_org_id if available, else known_for_org)
2927
+ org_groups: dict[str, list[dict]] = {}
2928
+ for p in same_name:
2929
+ org_key = str(p["known_for_org_id"]) if p["known_for_org_id"] else p["known_for_org"]
2930
+ if org_key: # Only group if they have an org
2931
+ org_groups.setdefault(org_key, []).append(p)
2932
+
2933
+ # Union people with same name + same org
2934
+ for org_key, org_people in org_groups.items():
2935
+ if len(org_people) >= 2:
2936
+ first_id = org_people[0]["id"]
2937
+ for p in org_people[1:]:
2938
+ uf.union(first_id, p["id"])
2939
+ stats["matched_by_org"] += 1
2940
+
2941
+ logger.info(f"Phase 2 complete: {stats['matched_by_org']} matches by org")
2942
+
2943
+ logger.info("Phase 3: Matching by normalized name + overlapping dates...")
2944
+
2945
+ # Match people with same normalized name and overlapping date ranges
2946
+ for name_norm, same_name in name_to_people.items():
2947
+ if len(same_name) < 2:
2948
+ continue
2949
+
2950
+ # Skip if already all unified
2951
+ roots = set(uf.find(p["id"]) for p in same_name)
2952
+ if len(roots) == 1:
2953
+ continue
2954
+
2955
+ # Check for overlapping date ranges
2956
+ for i, p1 in enumerate(same_name):
2957
+ for p2 in same_name[i+1:]:
2958
+ # Skip if already in same group
2959
+ if uf.find(p1["id"]) == uf.find(p2["id"]):
2960
+ continue
2961
+
2962
+ # Check date overlap (if both have dates)
2963
+ if p1["from_date"] and p2["from_date"]:
2964
+ # Simple overlap check: if either from_date is before other's to_date
2965
+ p1_from = p1["from_date"]
2966
+ p1_to = p1["to_date"] or "9999-12-31"
2967
+ p2_from = p2["from_date"]
2968
+ p2_to = p2["to_date"] or "9999-12-31"
2969
+
2970
+ # Overlap if: p1_from <= p2_to AND p2_from <= p1_to
2971
+ if p1_from <= p2_to and p2_from <= p1_to:
2972
+ uf.union(p1["id"], p2["id"])
2973
+ stats["matched_by_date"] += 1
2974
+
2975
+ logger.info(f"Phase 3 complete: {stats['matched_by_date']} matches by date")
2976
+
2977
+ logger.info("Phase 4: Applying canonical updates...")
2978
+
2979
+ # Get all groups and select canonical record for each
2980
+ groups = uf.groups()
2981
+
2982
+ # Build id -> source mapping
2983
+ id_to_source = {p["id"]: p["source"] for p in people}
2984
+
2985
+ batch_updates: list[tuple[int, int, int]] = [] # (person_id, canon_id, canon_size)
2986
+
2987
+ for _root, group_ids in groups.items():
2988
+ group_size = len(group_ids)
2989
+
2990
+ if group_size == 1:
2991
+ # Single record is its own canonical
2992
+ person_id = group_ids[0]
2993
+ batch_updates.append((person_id, person_id, 1))
2994
+ else:
2995
+ # Multiple records - pick highest priority source as canonical
2996
+ # Sort by source priority, then by id (for stability)
2997
+ sorted_ids = sorted(
2998
+ group_ids,
2999
+ key=lambda pid: (PERSON_SOURCE_PRIORITY.get(id_to_source[pid], 99), pid)
3000
+ )
3001
+ canon_id = sorted_ids[0]
3002
+ stats["canonical_groups"] += 1
3003
+ stats["records_in_groups"] += group_size
3004
+
3005
+ for person_id in group_ids:
3006
+ batch_updates.append((person_id, canon_id, group_size if person_id == canon_id else 1))
3007
+
3008
+ # Commit in batches
3009
+ if len(batch_updates) >= batch_size:
3010
+ self._apply_person_canon_updates(batch_updates)
3011
+ conn.commit()
3012
+ logger.info(f"Applied {len(batch_updates)} canon updates...")
3013
+ batch_updates = []
3014
+
3015
+ # Final batch
3016
+ if batch_updates:
3017
+ self._apply_person_canon_updates(batch_updates)
3018
+ conn.commit()
3019
+
3020
+ logger.info(f"Canonicalization complete: {stats['canonical_groups']} groups, "
3021
+ f"{stats['records_in_groups']} records in multi-record groups")
3022
+
3023
+ return stats
3024
+
3025
+ def _apply_person_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
3026
+ """Apply batch of canon updates: (person_id, canon_id, canon_size)."""
3027
+ conn = self._conn
3028
+ assert conn is not None
3029
+
3030
+ for person_id, canon_id, canon_size in updates:
3031
+ conn.execute(
3032
+ "UPDATE people SET canon_id = ?, canon_size = ? WHERE id = ?",
3033
+ (canon_id, canon_size, person_id)
3034
+ )