corp-extractor 0.9.0__py3-none-any.whl → 0.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/METADATA +72 -11
  2. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/RECORD +34 -27
  3. statement_extractor/cli.py +1317 -101
  4. statement_extractor/database/embeddings.py +45 -0
  5. statement_extractor/database/hub.py +86 -136
  6. statement_extractor/database/importers/__init__.py +10 -2
  7. statement_extractor/database/importers/companies_house.py +16 -2
  8. statement_extractor/database/importers/companies_house_officers.py +431 -0
  9. statement_extractor/database/importers/gleif.py +23 -0
  10. statement_extractor/database/importers/import_utils.py +264 -0
  11. statement_extractor/database/importers/sec_edgar.py +17 -0
  12. statement_extractor/database/importers/sec_form4.py +512 -0
  13. statement_extractor/database/importers/wikidata.py +151 -43
  14. statement_extractor/database/importers/wikidata_dump.py +2282 -0
  15. statement_extractor/database/importers/wikidata_people.py +867 -325
  16. statement_extractor/database/migrate_v2.py +852 -0
  17. statement_extractor/database/models.py +155 -7
  18. statement_extractor/database/schema_v2.py +409 -0
  19. statement_extractor/database/seed_data.py +359 -0
  20. statement_extractor/database/store.py +3449 -233
  21. statement_extractor/document/deduplicator.py +10 -12
  22. statement_extractor/extractor.py +1 -1
  23. statement_extractor/models/__init__.py +3 -2
  24. statement_extractor/models/statement.py +15 -17
  25. statement_extractor/models.py +1 -1
  26. statement_extractor/pipeline/context.py +5 -5
  27. statement_extractor/pipeline/orchestrator.py +12 -12
  28. statement_extractor/plugins/base.py +17 -17
  29. statement_extractor/plugins/extractors/gliner2.py +28 -28
  30. statement_extractor/plugins/qualifiers/embedding_company.py +7 -5
  31. statement_extractor/plugins/qualifiers/person.py +120 -53
  32. statement_extractor/plugins/splitters/t5_gemma.py +35 -39
  33. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/WHEEL +0 -0
  34. {corp_extractor-0.9.0.dist-info → corp_extractor-0.9.4.dist-info}/entry_points.txt +0 -0
@@ -12,17 +12,44 @@ import re
12
12
  import sqlite3
13
13
  import time
14
14
  from pathlib import Path
15
- from typing import Iterator, Optional
15
+ from typing import Any, Iterator, Optional
16
16
 
17
17
  import numpy as np
18
+ import pycountry
18
19
  import sqlite_vec
19
20
 
20
- from .models import CompanyRecord, DatabaseStats, EntityType, PersonRecord, PersonType
21
+ from .models import (
22
+ CompanyRecord,
23
+ DatabaseStats,
24
+ EntityType,
25
+ LocationRecord,
26
+ PersonRecord,
27
+ PersonType,
28
+ RoleRecord,
29
+ SimplifiedLocationType,
30
+ )
31
+ from .seed_data import (
32
+ LOCATION_TYPE_NAME_TO_ID,
33
+ LOCATION_TYPE_QID_TO_ID,
34
+ LOCATION_TYPE_TO_SIMPLIFIED,
35
+ ORG_TYPE_ID_TO_NAME,
36
+ ORG_TYPE_NAME_TO_ID,
37
+ PEOPLE_TYPE_ID_TO_NAME,
38
+ PEOPLE_TYPE_NAME_TO_ID,
39
+ SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME,
40
+ SOURCE_ID_TO_NAME,
41
+ SOURCE_NAME_TO_ID,
42
+ seed_all_enums,
43
+ seed_pycountry_locations,
44
+ )
21
45
 
22
46
  logger = logging.getLogger(__name__)
23
47
 
24
48
  # Default database location
25
- DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
49
+ DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities-v2.db"
50
+
51
+ # Module-level shared connections by path (both databases share the same connection)
52
+ _shared_connections: dict[str, sqlite3.Connection] = {}
26
53
 
27
54
  # Module-level singleton for OrganizationDatabase to prevent multiple loads
28
55
  _database_instances: dict[str, "OrganizationDatabase"] = {}
@@ -30,6 +57,36 @@ _database_instances: dict[str, "OrganizationDatabase"] = {}
30
57
  # Module-level singleton for PersonDatabase
31
58
  _person_database_instances: dict[str, "PersonDatabase"] = {}
32
59
 
60
+
61
+ def _get_shared_connection(db_path: Path, embedding_dim: int = 768) -> sqlite3.Connection:
62
+ """Get or create a shared database connection for the given path."""
63
+ path_key = str(db_path)
64
+ if path_key not in _shared_connections:
65
+ # Ensure directory exists
66
+ db_path.parent.mkdir(parents=True, exist_ok=True)
67
+
68
+ conn = sqlite3.connect(str(db_path))
69
+ conn.row_factory = sqlite3.Row
70
+
71
+ # Load sqlite-vec extension
72
+ conn.enable_load_extension(True)
73
+ sqlite_vec.load(conn)
74
+ conn.enable_load_extension(False)
75
+
76
+ _shared_connections[path_key] = conn
77
+ logger.debug(f"Created shared database connection for {path_key}")
78
+
79
+ return _shared_connections[path_key]
80
+
81
+
82
+ def close_shared_connection(db_path: Optional[Path] = None) -> None:
83
+ """Close a shared database connection."""
84
+ path_key = str(db_path or DEFAULT_DB_PATH)
85
+ if path_key in _shared_connections:
86
+ _shared_connections[path_key].close()
87
+ del _shared_connections[path_key]
88
+ logger.debug(f"Closed shared database connection for {path_key}")
89
+
33
90
  # Comprehensive set of corporate legal suffixes (international)
34
91
  COMPANY_SUFFIXES: set[str] = {
35
92
  'A/S', 'AB', 'AG', 'AO', 'AG & Co', 'AG &', 'AG & CO.', 'AG & CO. KG', 'AG & CO. KGaA',
@@ -43,6 +100,222 @@ COMPANY_SUFFIXES: set[str] = {
43
100
  'Group', 'Holdings', 'Holding', 'Partners', 'Trust', 'Fund', 'Bank', 'N.A.', 'The',
44
101
  }
45
102
 
103
+ # Source priority for organization canonicalization (lower = higher priority)
104
+ SOURCE_PRIORITY: dict[str, int] = {
105
+ "gleif": 1, # Gold standard LEI - globally unique legal entity identifier
106
+ "sec_edgar": 2, # Vetted US filers with CIK + ticker
107
+ "companies_house": 3, # Official UK registry
108
+ "wikipedia": 4, # Crowdsourced, less authoritative
109
+ }
110
+
111
+ # Source priority for people canonicalization (lower = higher priority)
112
+ PERSON_SOURCE_PRIORITY: dict[str, int] = {
113
+ "wikidata": 1, # Curated, has rich biographical data and Q codes
114
+ "sec_edgar": 2, # Vetted US filers (Form 4 officers/directors)
115
+ "companies_house": 3, # UK company officers
116
+ }
117
+
118
+ # Suffix expansions for canonical name matching
119
+ SUFFIX_EXPANSIONS: dict[str, str] = {
120
+ " ltd": " limited",
121
+ " corp": " corporation",
122
+ " inc": " incorporated",
123
+ " co": " company",
124
+ " intl": " international",
125
+ " natl": " national",
126
+ }
127
+
128
+
129
+ class UnionFind:
130
+ """Simple Union-Find (Disjoint Set Union) data structure for canonicalization."""
131
+
132
+ def __init__(self, elements: list[int]):
133
+ """Initialize with list of element IDs."""
134
+ self.parent: dict[int, int] = {e: e for e in elements}
135
+ self.rank: dict[int, int] = {e: 0 for e in elements}
136
+
137
+ def find(self, x: int) -> int:
138
+ """Find with path compression."""
139
+ if self.parent[x] != x:
140
+ self.parent[x] = self.find(self.parent[x])
141
+ return self.parent[x]
142
+
143
+ def union(self, x: int, y: int) -> None:
144
+ """Union by rank."""
145
+ px, py = self.find(x), self.find(y)
146
+ if px == py:
147
+ return
148
+ if self.rank[px] < self.rank[py]:
149
+ px, py = py, px
150
+ self.parent[py] = px
151
+ if self.rank[px] == self.rank[py]:
152
+ self.rank[px] += 1
153
+
154
+ def groups(self) -> dict[int, list[int]]:
155
+ """Return dict of root -> list of members."""
156
+ result: dict[int, list[int]] = {}
157
+ for e in self.parent:
158
+ root = self.find(e)
159
+ result.setdefault(root, []).append(e)
160
+ return result
161
+
162
+
163
+ # Common region aliases not handled well by pycountry fuzzy search
164
+ REGION_ALIASES: dict[str, str] = {
165
+ "uk": "GB",
166
+ "u.k.": "GB",
167
+ "england": "GB",
168
+ "scotland": "GB",
169
+ "wales": "GB",
170
+ "northern ireland": "GB",
171
+ "usa": "US",
172
+ "u.s.a.": "US",
173
+ "u.s.": "US",
174
+ "united states of america": "US",
175
+ "america": "US",
176
+ }
177
+
178
+ # Cache for region normalization lookups
179
+ _region_cache: dict[str, str] = {}
180
+
181
+
182
+ def _normalize_region(region: str) -> str:
183
+ """
184
+ Normalize a region string to ISO 3166-1 alpha-2 country code.
185
+
186
+ Handles:
187
+ - Country codes (2-letter, 3-letter)
188
+ - Country names (with fuzzy matching)
189
+ - US state codes (CA, NY) -> US
190
+ - US state names (California, New York) -> US
191
+ - Common aliases (UK, USA, England) -> proper codes
192
+
193
+ Returns empty string if region cannot be normalized.
194
+ """
195
+ if not region:
196
+ return ""
197
+
198
+ # Check cache first
199
+ cache_key = region.lower().strip()
200
+ if cache_key in _region_cache:
201
+ return _region_cache[cache_key]
202
+
203
+ result = _normalize_region_uncached(region)
204
+ _region_cache[cache_key] = result
205
+ return result
206
+
207
+
208
+ def _normalize_region_uncached(region: str) -> str:
209
+ """Uncached region normalization logic."""
210
+ region_clean = region.strip()
211
+
212
+ # Empty after stripping = empty result
213
+ if not region_clean:
214
+ return ""
215
+
216
+ region_lower = region_clean.lower()
217
+ region_upper = region_clean.upper()
218
+
219
+ # Check common aliases first
220
+ if region_lower in REGION_ALIASES:
221
+ return REGION_ALIASES[region_lower]
222
+
223
+ # For 2-letter codes, check country first, then US state
224
+ # This means ambiguous codes like "CA" (Canada vs California) prefer country
225
+ # But unambiguous codes like "NY" (not a country) will match as US state
226
+ if len(region_clean) == 2:
227
+ # Try as country alpha-2 first
228
+ country = pycountry.countries.get(alpha_2=region_upper)
229
+ if country:
230
+ return country.alpha_2
231
+
232
+ # If not a country, try as US state code
233
+ subdivision = pycountry.subdivisions.get(code=f"US-{region_upper}")
234
+ if subdivision:
235
+ return "US"
236
+
237
+ # Try alpha-3 lookup
238
+ if len(region_clean) == 3:
239
+ country = pycountry.countries.get(alpha_3=region_upper)
240
+ if country:
241
+ return country.alpha_2
242
+
243
+ # Try as US state name (e.g., "California", "New York")
244
+ try:
245
+ subdivisions = list(pycountry.subdivisions.search_fuzzy(region_clean))
246
+ if subdivisions:
247
+ # Check if it's a US state
248
+ if subdivisions[0].code.startswith("US-"):
249
+ return "US"
250
+ # Return the parent country code
251
+ return subdivisions[0].country_code
252
+ except LookupError:
253
+ pass
254
+
255
+ # Try country fuzzy search
256
+ try:
257
+ countries = pycountry.countries.search_fuzzy(region_clean)
258
+ if countries:
259
+ return countries[0].alpha_2
260
+ except LookupError:
261
+ pass
262
+
263
+ # Return empty if we can't normalize
264
+ return ""
265
+
266
+
267
+ def _regions_match(region1: str, region2: str) -> bool:
268
+ """
269
+ Check if two regions match after normalization.
270
+
271
+ Empty regions match anything (lenient matching for incomplete data).
272
+ """
273
+ norm1 = _normalize_region(region1)
274
+ norm2 = _normalize_region(region2)
275
+
276
+ # Empty regions match anything
277
+ if not norm1 or not norm2:
278
+ return True
279
+
280
+ return norm1 == norm2
281
+
282
+
283
+ def _normalize_for_canon(name: str) -> str:
284
+ """Normalize name for canonical matching (simpler than search normalization)."""
285
+ # Lowercase
286
+ result = name.lower()
287
+ # Remove trailing dots
288
+ result = result.rstrip(".")
289
+ # Remove extra whitespace
290
+ result = " ".join(result.split())
291
+ return result
292
+
293
+
294
+ def _expand_suffix(name: str) -> str:
295
+ """Expand known suffix abbreviations."""
296
+ result = name.lower().rstrip(".")
297
+ for abbrev, full in SUFFIX_EXPANSIONS.items():
298
+ if result.endswith(abbrev):
299
+ result = result[:-len(abbrev)] + full
300
+ break # Only expand one suffix
301
+ return result
302
+
303
+
304
+ def _names_match_for_canon(name1: str, name2: str) -> bool:
305
+ """Check if two names match for canonicalization."""
306
+ n1 = _normalize_for_canon(name1)
307
+ n2 = _normalize_for_canon(name2)
308
+
309
+ # Exact match after normalization
310
+ if n1 == n2:
311
+ return True
312
+
313
+ # Try with suffix expansion
314
+ if _expand_suffix(n1) == _expand_suffix(n2):
315
+ return True
316
+
317
+ return False
318
+
46
319
  # Pre-compile the suffix pattern for performance
47
320
  _SUFFIX_PATTERN = re.compile(
48
321
  r'\s+(' + '|'.join(re.escape(suffix) for suffix in COMPANY_SUFFIXES) + r')\.?$',
@@ -250,30 +523,40 @@ class OrganizationDatabase:
250
523
  self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
251
524
  self._embedding_dim = embedding_dim
252
525
  self._conn: Optional[sqlite3.Connection] = None
526
+ self._is_v2: Optional[bool] = None # Detected on first connect
253
527
 
254
528
  def _ensure_dir(self) -> None:
255
529
  """Ensure database directory exists."""
256
530
  self._db_path.parent.mkdir(parents=True, exist_ok=True)
257
531
 
258
532
  def _connect(self) -> sqlite3.Connection:
259
- """Get or create database connection with sqlite-vec loaded."""
533
+ """Get or create database connection using shared connection pool."""
260
534
  if self._conn is not None:
261
535
  return self._conn
262
536
 
263
- self._ensure_dir()
264
- self._conn = sqlite3.connect(str(self._db_path))
265
- self._conn.row_factory = sqlite3.Row
537
+ self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
266
538
 
267
- # Load sqlite-vec extension
268
- self._conn.enable_load_extension(True)
269
- sqlite_vec.load(self._conn)
270
- self._conn.enable_load_extension(False)
539
+ # Detect schema version BEFORE creating tables
540
+ # v2 has entity_type_id (FK) instead of entity_type (TEXT)
541
+ if self._is_v2 is None:
542
+ cursor = self._conn.execute("PRAGMA table_info(organizations)")
543
+ columns = {row["name"] for row in cursor}
544
+ self._is_v2 = "entity_type_id" in columns
545
+ if self._is_v2:
546
+ logger.debug("Detected v2 schema for organizations")
271
547
 
272
- # Create tables
273
- self._create_tables()
548
+ # Create tables (idempotent) - only for v1 schema or fresh databases
549
+ # v2 databases already have their schema from migration
550
+ if not self._is_v2:
551
+ self._create_tables()
274
552
 
275
553
  return self._conn
276
554
 
555
+ @property
556
+ def _org_table(self) -> str:
557
+ """Return table/view name for organization queries needing text fields."""
558
+ return "organizations_view" if self._is_v2 else "organizations"
559
+
277
560
  def _create_tables(self) -> None:
278
561
  """Create database tables including sqlite-vec virtual table."""
279
562
  conn = self._conn
@@ -289,6 +572,8 @@ class OrganizationDatabase:
289
572
  source_id TEXT NOT NULL,
290
573
  region TEXT NOT NULL DEFAULT '',
291
574
  entity_type TEXT NOT NULL DEFAULT 'unknown',
575
+ from_date TEXT NOT NULL DEFAULT '',
576
+ to_date TEXT NOT NULL DEFAULT '',
292
577
  record TEXT NOT NULL,
293
578
  UNIQUE(source, source_id)
294
579
  )
@@ -308,6 +593,34 @@ class OrganizationDatabase:
308
593
  except sqlite3.OperationalError:
309
594
  pass # Column already exists
310
595
 
596
+ # Add from_date column if it doesn't exist (migration for existing DBs)
597
+ try:
598
+ conn.execute("ALTER TABLE organizations ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
599
+ logger.info("Added from_date column to organizations table")
600
+ except sqlite3.OperationalError:
601
+ pass # Column already exists
602
+
603
+ # Add to_date column if it doesn't exist (migration for existing DBs)
604
+ try:
605
+ conn.execute("ALTER TABLE organizations ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
606
+ logger.info("Added to_date column to organizations table")
607
+ except sqlite3.OperationalError:
608
+ pass # Column already exists
609
+
610
+ # Add canon_id column if it doesn't exist (migration for canonicalization)
611
+ try:
612
+ conn.execute("ALTER TABLE organizations ADD COLUMN canon_id INTEGER DEFAULT NULL")
613
+ logger.info("Added canon_id column to organizations table")
614
+ except sqlite3.OperationalError:
615
+ pass # Column already exists
616
+
617
+ # Add canon_size column if it doesn't exist (migration for canonicalization)
618
+ try:
619
+ conn.execute("ALTER TABLE organizations ADD COLUMN canon_size INTEGER DEFAULT 1")
620
+ logger.info("Added canon_size column to organizations table")
621
+ except sqlite3.OperationalError:
622
+ pass # Column already exists
623
+
311
624
  # Create indexes on main table
312
625
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
313
626
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
@@ -316,8 +629,9 @@ class OrganizationDatabase:
316
629
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
317
630
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
318
631
  conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
632
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_canon_id ON organizations(canon_id)")
319
633
 
320
- # Create sqlite-vec virtual table for embeddings
634
+ # Create sqlite-vec virtual table for embeddings (float32)
321
635
  # vec0 is the recommended virtual table type
322
636
  conn.execute(f"""
323
637
  CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings USING vec0(
@@ -326,21 +640,34 @@ class OrganizationDatabase:
326
640
  )
327
641
  """)
328
642
 
643
+ # Create sqlite-vec virtual table for scalar embeddings (int8)
644
+ # Provides 75% storage reduction with ~92% recall at top-100
645
+ conn.execute(f"""
646
+ CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
647
+ org_id INTEGER PRIMARY KEY,
648
+ embedding int8[{self._embedding_dim}]
649
+ )
650
+ """)
651
+
329
652
  conn.commit()
330
653
 
331
654
  def close(self) -> None:
332
- """Close database connection."""
333
- if self._conn:
334
- self._conn.close()
335
- self._conn = None
655
+ """Clear connection reference (shared connection remains open)."""
656
+ self._conn = None
336
657
 
337
- def insert(self, record: CompanyRecord, embedding: np.ndarray) -> int:
658
+ def insert(
659
+ self,
660
+ record: CompanyRecord,
661
+ embedding: np.ndarray,
662
+ scalar_embedding: Optional[np.ndarray] = None,
663
+ ) -> int:
338
664
  """
339
665
  Insert an organization record with its embedding.
340
666
 
341
667
  Args:
342
668
  record: Organization record to insert
343
- embedding: Embedding vector for the organization name
669
+ embedding: Embedding vector for the organization name (float32)
670
+ scalar_embedding: Optional int8 scalar embedding for compact storage
344
671
 
345
672
  Returns:
346
673
  Row ID of inserted record
@@ -353,8 +680,8 @@ class OrganizationDatabase:
353
680
 
354
681
  cursor = conn.execute("""
355
682
  INSERT OR REPLACE INTO organizations
356
- (name, name_normalized, source, source_id, region, entity_type, record)
357
- VALUES (?, ?, ?, ?, ?, ?, ?)
683
+ (name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
684
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
358
685
  """, (
359
686
  record.name,
360
687
  name_normalized,
@@ -362,20 +689,32 @@ class OrganizationDatabase:
362
689
  record.source_id,
363
690
  record.region,
364
691
  record.entity_type.value,
692
+ record.from_date or "",
693
+ record.to_date or "",
365
694
  record_json,
366
695
  ))
367
696
 
368
697
  row_id = cursor.lastrowid
369
698
  assert row_id is not None
370
699
 
371
- # Insert embedding into vec table
372
- # sqlite-vec expects the embedding as a blob
700
+ # Insert embedding into vec table (float32)
701
+ # sqlite-vec virtual tables don't support INSERT OR REPLACE, so delete first
373
702
  embedding_blob = embedding.astype(np.float32).tobytes()
703
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
374
704
  conn.execute("""
375
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
705
+ INSERT INTO organization_embeddings (org_id, embedding)
376
706
  VALUES (?, ?)
377
707
  """, (row_id, embedding_blob))
378
708
 
709
+ # Insert scalar embedding if provided (int8)
710
+ if scalar_embedding is not None:
711
+ scalar_blob = scalar_embedding.astype(np.int8).tobytes()
712
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
713
+ conn.execute("""
714
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
715
+ VALUES (?, vec_int8(?))
716
+ """, (row_id, scalar_blob))
717
+
379
718
  conn.commit()
380
719
  return row_id
381
720
 
@@ -384,14 +723,16 @@ class OrganizationDatabase:
384
723
  records: list[CompanyRecord],
385
724
  embeddings: np.ndarray,
386
725
  batch_size: int = 1000,
726
+ scalar_embeddings: Optional[np.ndarray] = None,
387
727
  ) -> int:
388
728
  """
389
729
  Insert multiple organization records with embeddings.
390
730
 
391
731
  Args:
392
732
  records: List of organization records
393
- embeddings: Matrix of embeddings (N x dim)
733
+ embeddings: Matrix of embeddings (N x dim) - float32
394
734
  batch_size: Commit batch size
735
+ scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
395
736
 
396
737
  Returns:
397
738
  Number of records inserted
@@ -399,34 +740,75 @@ class OrganizationDatabase:
399
740
  conn = self._connect()
400
741
  count = 0
401
742
 
402
- for record, embedding in zip(records, embeddings):
743
+ for i, (record, embedding) in enumerate(zip(records, embeddings)):
403
744
  record_json = json.dumps(record.record)
404
745
  name_normalized = _normalize_name(record.name)
405
746
 
406
- cursor = conn.execute("""
407
- INSERT OR REPLACE INTO organizations
408
- (name, name_normalized, source, source_id, region, entity_type, record)
409
- VALUES (?, ?, ?, ?, ?, ?, ?)
410
- """, (
411
- record.name,
412
- name_normalized,
413
- record.source,
414
- record.source_id,
415
- record.region,
416
- record.entity_type.value,
417
- record_json,
418
- ))
747
+ if self._is_v2:
748
+ # v2 schema: use FK IDs instead of TEXT columns
749
+ source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
750
+ entity_type_id = ORG_TYPE_NAME_TO_ID.get(record.entity_type.value, 17) # 17 = unknown
751
+
752
+ # Resolve region to location_id if provided
753
+ region_id = None
754
+ if record.region:
755
+ # Use locations database to resolve region
756
+ locations_db = get_locations_database(db_path=self._db_path)
757
+ region_id = locations_db.resolve_region_text(record.region)
758
+
759
+ cursor = conn.execute("""
760
+ INSERT OR REPLACE INTO organizations
761
+ (name, name_normalized, source_id, source_identifier, region_id, entity_type_id, from_date, to_date, record)
762
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
763
+ """, (
764
+ record.name,
765
+ name_normalized,
766
+ source_type_id,
767
+ record.source_id,
768
+ region_id,
769
+ entity_type_id,
770
+ record.from_date or "",
771
+ record.to_date or "",
772
+ record_json,
773
+ ))
774
+ else:
775
+ # v1 schema: use TEXT columns
776
+ cursor = conn.execute("""
777
+ INSERT OR REPLACE INTO organizations
778
+ (name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
779
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
780
+ """, (
781
+ record.name,
782
+ name_normalized,
783
+ record.source,
784
+ record.source_id,
785
+ record.region,
786
+ record.entity_type.value,
787
+ record.from_date or "",
788
+ record.to_date or "",
789
+ record_json,
790
+ ))
419
791
 
420
792
  row_id = cursor.lastrowid
421
793
  assert row_id is not None
422
794
 
423
- # Insert embedding
795
+ # Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
424
796
  embedding_blob = embedding.astype(np.float32).tobytes()
797
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
425
798
  conn.execute("""
426
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
799
+ INSERT INTO organization_embeddings (org_id, embedding)
427
800
  VALUES (?, ?)
428
801
  """, (row_id, embedding_blob))
429
802
 
803
+ # Insert scalar embedding if provided (int8)
804
+ if scalar_embeddings is not None:
805
+ scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
806
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
807
+ conn.execute("""
808
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
809
+ VALUES (?, vec_int8(?))
810
+ """, (row_id, scalar_blob))
811
+
430
812
  count += 1
431
813
 
432
814
  if count % batch_size == 0:
@@ -443,13 +825,15 @@ class OrganizationDatabase:
443
825
  source_filter: Optional[str] = None,
444
826
  query_text: Optional[str] = None,
445
827
  max_text_candidates: int = 5000,
828
+ rerank_min_candidates: int = 500,
446
829
  ) -> list[tuple[CompanyRecord, float]]:
447
830
  """
448
831
  Search for similar organizations using hybrid text + vector search.
449
832
 
450
- Two-stage approach:
833
+ Three-stage approach:
451
834
  1. If query_text provided, use SQL LIKE to find candidates containing search terms
452
835
  2. Use sqlite-vec for vector similarity ranking on filtered candidates
836
+ 3. Apply prominence-based re-ranking to boost major companies (SEC filers, tickers)
453
837
 
454
838
  Args:
455
839
  query_embedding: Query embedding vector
@@ -457,9 +841,10 @@ class OrganizationDatabase:
457
841
  source_filter: Optional filter by source (gleif, sec_edgar, etc.)
458
842
  query_text: Optional query text for text-based pre-filtering
459
843
  max_text_candidates: Max candidates to keep after text filtering
844
+ rerank_min_candidates: Minimum candidates to fetch for re-ranking (default 500)
460
845
 
461
846
  Returns:
462
- List of (CompanyRecord, similarity_score) tuples
847
+ List of (CompanyRecord, adjusted_score) tuples sorted by prominence-adjusted score
463
848
  """
464
849
  start = time.time()
465
850
  self._connect()
@@ -469,10 +854,17 @@ class OrganizationDatabase:
469
854
  if query_norm == 0:
470
855
  return []
471
856
  query_normalized = query_embedding / query_norm
472
- query_blob = query_normalized.astype(np.float32).tobytes()
857
+
858
+ # Use int8 quantized query if scalar table is available (75% storage savings)
859
+ if self._has_scalar_table():
860
+ query_int8 = self._quantize_query(query_normalized)
861
+ query_blob = query_int8.tobytes()
862
+ else:
863
+ query_blob = query_normalized.astype(np.float32).tobytes()
473
864
 
474
865
  # Stage 1: Text-based pre-filtering (if query_text provided)
475
866
  candidate_ids: Optional[set[int]] = None
867
+ query_normalized_text = ""
476
868
  if query_text:
477
869
  query_normalized_text = _normalize_name(query_text)
478
870
  if query_normalized_text:
@@ -483,24 +875,168 @@ class OrganizationDatabase:
483
875
  )
484
876
  logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
485
877
 
486
- # Stage 2: Vector search
878
+ # Stage 2: Vector search - fetch more candidates for re-ranking
487
879
  if candidate_ids is not None and len(candidate_ids) == 0:
488
880
  # No text matches, return empty
489
881
  return []
490
882
 
883
+ # Fetch enough candidates for prominence re-ranking to be effective
884
+ # Use at least rerank_min_candidates, or all text-filtered candidates if fewer
885
+ if candidate_ids is not None:
886
+ fetch_k = min(len(candidate_ids), max(rerank_min_candidates, top_k * 5))
887
+ else:
888
+ fetch_k = max(rerank_min_candidates, top_k * 5)
889
+
491
890
  if candidate_ids is not None:
492
891
  # Search within text-filtered candidates
493
892
  results = self._vector_search_filtered(
494
- query_blob, candidate_ids, top_k, source_filter
893
+ query_blob, candidate_ids, fetch_k, source_filter
495
894
  )
496
895
  else:
497
896
  # Full vector search
498
- results = self._vector_search_full(query_blob, top_k, source_filter)
897
+ results = self._vector_search_full(query_blob, fetch_k, source_filter)
898
+
899
+ # Stage 3: Prominence-based re-ranking
900
+ if results and query_normalized_text:
901
+ results = self._apply_prominence_reranking(results, query_normalized_text, top_k)
902
+ else:
903
+ # No re-ranking, just trim to top_k
904
+ results = results[:top_k]
499
905
 
500
906
  elapsed = time.time() - start
501
907
  logger.debug(f"Hybrid search took {elapsed:.3f}s (results={len(results)})")
502
908
  return results
503
909
 
910
+ def _calculate_prominence_boost(
911
+ self,
912
+ record: CompanyRecord,
913
+ query_normalized: str,
914
+ canon_sources: Optional[set[str]] = None,
915
+ ) -> float:
916
+ """
917
+ Calculate prominence boost for re-ranking search results.
918
+
919
+ Boosts scores based on signals that indicate a major/prominent company:
920
+ - Has ticker symbol (publicly traded)
921
+ - GLEIF source (has LEI)
922
+ - SEC source (vetted US filers)
923
+ - Wikidata source (Wikipedia-notable)
924
+ - Exact normalized name match
925
+
926
+ When canon_sources is provided (from a canonical group), boosts are
927
+ applied for ALL sources in the canon group, not just this record's source.
928
+
929
+ Args:
930
+ record: The company record to evaluate
931
+ query_normalized: Normalized query text for exact match check
932
+ canon_sources: Optional set of sources in this record's canonical group
933
+
934
+ Returns:
935
+ Boost value to add to embedding similarity (0.0 to ~0.21)
936
+ """
937
+ boost = 0.0
938
+
939
+ # Get all sources to consider (canon group or just this record)
940
+ sources_to_check = canon_sources or {record.source}
941
+
942
+ # Has ticker symbol = publicly traded major company
943
+ # Check if ANY record in canon group has ticker
944
+ if record.record.get("ticker") or (canon_sources and "sec_edgar" in canon_sources):
945
+ boost += 0.08
946
+
947
+ # Source-based boosts - accumulate for all sources in canon group
948
+ if "gleif" in sources_to_check:
949
+ boost += 0.05 # Has LEI = verified legal entity
950
+ if "sec_edgar" in sources_to_check:
951
+ boost += 0.03 # SEC filer
952
+ if "wikipedia" in sources_to_check:
953
+ boost += 0.02 # Wikipedia notable
954
+
955
+ # Exact normalized name match bonus
956
+ record_normalized = _normalize_name(record.name)
957
+ if query_normalized == record_normalized:
958
+ boost += 0.05
959
+
960
+ return boost
961
+
962
+ def _apply_prominence_reranking(
963
+ self,
964
+ results: list[tuple[CompanyRecord, float]],
965
+ query_normalized: str,
966
+ top_k: int,
967
+ similarity_weight: float = 0.3,
968
+ ) -> list[tuple[CompanyRecord, float]]:
969
+ """
970
+ Apply prominence-based re-ranking to search results with canon group awareness.
971
+
972
+ When records have been canonicalized, boosts are applied based on ALL sources
973
+ in the canonical group, not just the matched record's source.
974
+
975
+ Args:
976
+ results: List of (record, similarity) from vector search
977
+ query_normalized: Normalized query text
978
+ top_k: Number of results to return after re-ranking
979
+ similarity_weight: Weight for similarity score (0-1), lower = prominence matters more
980
+
981
+ Returns:
982
+ Re-ranked list of (record, adjusted_score) tuples
983
+ """
984
+ conn = self._conn
985
+ assert conn is not None
986
+
987
+ # Build canon_id -> sources mapping for all results that have canon_id
988
+ canon_sources_map: dict[int, set[str]] = {}
989
+ canon_ids = [
990
+ r.record.get("canon_id")
991
+ for r, _ in results
992
+ if r.record.get("canon_id") is not None
993
+ ]
994
+
995
+ if canon_ids:
996
+ # Fetch all sources for each canon_id in one query
997
+ unique_canon_ids = list(set(canon_ids))
998
+ placeholders = ",".join("?" * len(unique_canon_ids))
999
+ rows = conn.execute(f"""
1000
+ SELECT canon_id, source
1001
+ FROM organizations
1002
+ WHERE canon_id IN ({placeholders})
1003
+ """, unique_canon_ids).fetchall()
1004
+
1005
+ for row in rows:
1006
+ canon_id = row["canon_id"]
1007
+ canon_sources_map.setdefault(canon_id, set()).add(row["source"])
1008
+
1009
+ # Calculate boosted scores with canon group awareness
1010
+ # Formula: adjusted = (similarity * weight) + boost
1011
+ # With weight=0.3, a sim=0.65 SEC+ticker (boost=0.11) beats sim=0.75 no-boost
1012
+ boosted_results: list[tuple[CompanyRecord, float, float, float]] = []
1013
+ for record, similarity in results:
1014
+ canon_id = record.record.get("canon_id")
1015
+ # Get all sources in this record's canon group (if any)
1016
+ canon_sources = canon_sources_map.get(canon_id) if canon_id else None
1017
+
1018
+ boost = self._calculate_prominence_boost(record, query_normalized, canon_sources)
1019
+ adjusted_score = (similarity * similarity_weight) + boost
1020
+ boosted_results.append((record, similarity, boost, adjusted_score))
1021
+
1022
+ # Sort by adjusted score (descending)
1023
+ boosted_results.sort(key=lambda x: x[3], reverse=True)
1024
+
1025
+ # Log re-ranking details for top results
1026
+ logger.debug(f"Prominence re-ranking for '{query_normalized}':")
1027
+ for record, sim, boost, adj in boosted_results[:10]:
1028
+ ticker = record.record.get("ticker", "")
1029
+ ticker_str = f" ticker={ticker}" if ticker else ""
1030
+ canon_id = record.record.get("canon_id")
1031
+ canon_str = f" canon={canon_id}" if canon_id else ""
1032
+ logger.debug(
1033
+ f" {record.name}: sim={sim:.3f} + boost={boost:.3f} = {adj:.3f} "
1034
+ f"[{record.source}{ticker_str}{canon_str}]"
1035
+ )
1036
+
1037
+ # Return top_k with adjusted scores
1038
+ return [(r, adj) for r, _, _, adj in boosted_results[:top_k]]
1039
+
504
1040
  def _text_filter_candidates(
505
1041
  self,
506
1042
  query_normalized: str,
@@ -554,6 +1090,19 @@ class OrganizationDatabase:
554
1090
  cursor = conn.execute(query, params)
555
1091
  return set(row["id"] for row in cursor)
556
1092
 
1093
+ def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
1094
+ """Quantize query embedding to int8 for scalar search."""
1095
+ return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
1096
+
1097
+ def _has_scalar_table(self) -> bool:
1098
+ """Check if scalar embedding table exists."""
1099
+ conn = self._conn
1100
+ assert conn is not None
1101
+ cursor = conn.execute(
1102
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='organization_embeddings_scalar'"
1103
+ )
1104
+ return cursor.fetchone() is not None
1105
+
557
1106
  def _vector_search_filtered(
558
1107
  self,
559
1108
  query_blob: bytes,
@@ -561,7 +1110,7 @@ class OrganizationDatabase:
561
1110
  top_k: int,
562
1111
  source_filter: Optional[str],
563
1112
  ) -> list[tuple[CompanyRecord, float]]:
564
- """Vector search within a filtered set of candidates."""
1113
+ """Vector search within a filtered set of candidates using scalar (int8) embeddings."""
565
1114
  conn = self._conn
566
1115
  assert conn is not None
567
1116
 
@@ -571,18 +1120,29 @@ class OrganizationDatabase:
571
1120
  # Build IN clause for candidate IDs
572
1121
  placeholders = ",".join("?" * len(candidate_ids))
573
1122
 
574
- # Query sqlite-vec with KNN search, filtered by candidate IDs
575
- # Using distance function - lower is more similar for L2
576
- # We'll use cosine distance
577
- query = f"""
578
- SELECT
579
- e.org_id,
580
- vec_distance_cosine(e.embedding, ?) as distance
581
- FROM organization_embeddings e
582
- WHERE e.org_id IN ({placeholders})
583
- ORDER BY distance
584
- LIMIT ?
585
- """
1123
+ # Use scalar embedding table if available (75% storage reduction)
1124
+ if self._has_scalar_table():
1125
+ # Query uses int8 embeddings with vec_int8() wrapper
1126
+ query = f"""
1127
+ SELECT
1128
+ e.org_id,
1129
+ vec_distance_cosine(e.embedding, vec_int8(?)) as distance
1130
+ FROM organization_embeddings_scalar e
1131
+ WHERE e.org_id IN ({placeholders})
1132
+ ORDER BY distance
1133
+ LIMIT ?
1134
+ """
1135
+ else:
1136
+ # Fall back to float32 embeddings
1137
+ query = f"""
1138
+ SELECT
1139
+ e.org_id,
1140
+ vec_distance_cosine(e.embedding, ?) as distance
1141
+ FROM organization_embeddings e
1142
+ WHERE e.org_id IN ({placeholders})
1143
+ ORDER BY distance
1144
+ LIMIT ?
1145
+ """
586
1146
 
587
1147
  cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
588
1148
 
@@ -609,33 +1169,58 @@ class OrganizationDatabase:
609
1169
  top_k: int,
610
1170
  source_filter: Optional[str],
611
1171
  ) -> list[tuple[CompanyRecord, float]]:
612
- """Full vector search without text pre-filtering."""
1172
+ """Full vector search without text pre-filtering using scalar (int8) embeddings."""
613
1173
  conn = self._conn
614
1174
  assert conn is not None
615
1175
 
1176
+ # Use scalar embedding table if available (75% storage reduction)
1177
+ use_scalar = self._has_scalar_table()
1178
+
616
1179
  # KNN search with sqlite-vec
617
1180
  if source_filter:
618
1181
  # Need to join with organizations table for source filter
619
- query = """
620
- SELECT
621
- e.org_id,
622
- vec_distance_cosine(e.embedding, ?) as distance
623
- FROM organization_embeddings e
624
- JOIN organizations c ON e.org_id = c.id
625
- WHERE c.source = ?
626
- ORDER BY distance
627
- LIMIT ?
628
- """
1182
+ if use_scalar:
1183
+ query = """
1184
+ SELECT
1185
+ e.org_id,
1186
+ vec_distance_cosine(e.embedding, vec_int8(?)) as distance
1187
+ FROM organization_embeddings_scalar e
1188
+ JOIN organizations c ON e.org_id = c.id
1189
+ WHERE c.source = ?
1190
+ ORDER BY distance
1191
+ LIMIT ?
1192
+ """
1193
+ else:
1194
+ query = """
1195
+ SELECT
1196
+ e.org_id,
1197
+ vec_distance_cosine(e.embedding, ?) as distance
1198
+ FROM organization_embeddings e
1199
+ JOIN organizations c ON e.org_id = c.id
1200
+ WHERE c.source = ?
1201
+ ORDER BY distance
1202
+ LIMIT ?
1203
+ """
629
1204
  cursor = conn.execute(query, (query_blob, source_filter, top_k))
630
1205
  else:
631
- query = """
632
- SELECT
633
- org_id,
634
- vec_distance_cosine(embedding, ?) as distance
635
- FROM organization_embeddings
636
- ORDER BY distance
637
- LIMIT ?
638
- """
1206
+ if use_scalar:
1207
+ query = """
1208
+ SELECT
1209
+ org_id,
1210
+ vec_distance_cosine(embedding, vec_int8(?)) as distance
1211
+ FROM organization_embeddings_scalar
1212
+ ORDER BY distance
1213
+ LIMIT ?
1214
+ """
1215
+ else:
1216
+ query = """
1217
+ SELECT
1218
+ org_id,
1219
+ vec_distance_cosine(embedding, ?) as distance
1220
+ FROM organization_embeddings
1221
+ ORDER BY distance
1222
+ LIMIT ?
1223
+ """
639
1224
  cursor = conn.execute(query, (query_blob, top_k))
640
1225
 
641
1226
  results = []
@@ -651,24 +1236,38 @@ class OrganizationDatabase:
651
1236
  return results
652
1237
 
653
1238
  def _get_record_by_id(self, org_id: int) -> Optional[CompanyRecord]:
654
- """Get an organization record by ID."""
1239
+ """Get an organization record by ID, including db_id and canon_id in record dict."""
655
1240
  conn = self._conn
656
1241
  assert conn is not None
657
1242
 
658
- cursor = conn.execute("""
659
- SELECT name, source, source_id, region, entity_type, record
660
- FROM organizations WHERE id = ?
661
- """, (org_id,))
1243
+ if self._is_v2:
1244
+ # v2 schema: use view for text fields, but need record from base table
1245
+ cursor = conn.execute("""
1246
+ SELECT v.id, v.name, v.source, v.source_identifier, v.region, v.entity_type, v.canon_id, o.record
1247
+ FROM organizations_view v
1248
+ JOIN organizations o ON v.id = o.id
1249
+ WHERE v.id = ?
1250
+ """, (org_id,))
1251
+ else:
1252
+ cursor = conn.execute("""
1253
+ SELECT id, name, source, source_id, region, entity_type, record, canon_id
1254
+ FROM organizations WHERE id = ?
1255
+ """, (org_id,))
662
1256
 
663
1257
  row = cursor.fetchone()
664
1258
  if row:
1259
+ record_data = json.loads(row["record"])
1260
+ # Add db_id and canon_id to record dict for canon-aware search
1261
+ record_data["db_id"] = row["id"]
1262
+ record_data["canon_id"] = row["canon_id"]
1263
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
665
1264
  return CompanyRecord(
666
1265
  name=row["name"],
667
1266
  source=row["source"],
668
- source_id=row["source_id"],
1267
+ source_id=row[source_id_field],
669
1268
  region=row["region"] or "",
670
1269
  entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
671
- record=json.loads(row["record"]),
1270
+ record=record_data,
672
1271
  )
673
1272
  return None
674
1273
 
@@ -676,24 +1275,56 @@ class OrganizationDatabase:
676
1275
  """Get an organization record by source and source_id."""
677
1276
  conn = self._connect()
678
1277
 
679
- cursor = conn.execute("""
680
- SELECT name, source, source_id, region, entity_type, record
681
- FROM organizations
682
- WHERE source = ? AND source_id = ?
683
- """, (source, source_id))
1278
+ if self._is_v2:
1279
+ # v2 schema: join view with base table for record
1280
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1281
+ cursor = conn.execute("""
1282
+ SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
1283
+ FROM organizations_view v
1284
+ JOIN organizations o ON v.id = o.id
1285
+ WHERE o.source_id = ? AND o.source_identifier = ?
1286
+ """, (source_type_id, source_id))
1287
+ else:
1288
+ cursor = conn.execute("""
1289
+ SELECT name, source, source_id, region, entity_type, record
1290
+ FROM organizations
1291
+ WHERE source = ? AND source_id = ?
1292
+ """, (source, source_id))
684
1293
 
685
1294
  row = cursor.fetchone()
686
1295
  if row:
1296
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
687
1297
  return CompanyRecord(
688
1298
  name=row["name"],
689
1299
  source=row["source"],
690
- source_id=row["source_id"],
1300
+ source_id=row[source_id_field],
691
1301
  region=row["region"] or "",
692
1302
  entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
693
1303
  record=json.loads(row["record"]),
694
1304
  )
695
1305
  return None
696
1306
 
1307
+ def get_id_by_source_id(self, source: str, source_id: str) -> Optional[int]:
1308
+ """Get the internal database ID for an organization by source and source_id."""
1309
+ conn = self._connect()
1310
+
1311
+ if self._is_v2:
1312
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1313
+ cursor = conn.execute("""
1314
+ SELECT id FROM organizations
1315
+ WHERE source_id = ? AND source_identifier = ?
1316
+ """, (source_type_id, source_id))
1317
+ else:
1318
+ cursor = conn.execute("""
1319
+ SELECT id FROM organizations
1320
+ WHERE source = ? AND source_id = ?
1321
+ """, (source, source_id))
1322
+
1323
+ row = cursor.fetchone()
1324
+ if row:
1325
+ return row["id"]
1326
+ return None
1327
+
697
1328
  def get_stats(self) -> DatabaseStats:
698
1329
  """Get database statistics."""
699
1330
  conn = self._connect()
@@ -702,8 +1333,18 @@ class OrganizationDatabase:
702
1333
  cursor = conn.execute("SELECT COUNT(*) FROM organizations")
703
1334
  total = cursor.fetchone()[0]
704
1335
 
705
- # Count by source
706
- cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
1336
+ # Count by source - handle both v1 and v2 schema
1337
+ if self._is_v2:
1338
+ # v2 schema - join with source_types
1339
+ cursor = conn.execute("""
1340
+ SELECT st.name as source, COUNT(*) as cnt
1341
+ FROM organizations o
1342
+ JOIN source_types st ON o.source_id = st.id
1343
+ GROUP BY o.source_id
1344
+ """)
1345
+ else:
1346
+ # v1 schema
1347
+ cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
707
1348
  by_source = {row["source"]: row["cnt"] for row in cursor}
708
1349
 
709
1350
  # Database file size
@@ -716,44 +1357,350 @@ class OrganizationDatabase:
716
1357
  database_size_bytes=db_size,
717
1358
  )
718
1359
 
1360
+ def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
1361
+ """
1362
+ Get all source_ids from the organizations table.
1363
+
1364
+ Useful for resume operations to skip already-imported records.
1365
+
1366
+ Args:
1367
+ source: Optional source filter (e.g., "wikidata" for Wikidata orgs)
1368
+
1369
+ Returns:
1370
+ Set of source_id strings (e.g., Q codes for Wikidata)
1371
+ """
1372
+ conn = self._connect()
1373
+
1374
+ if self._is_v2:
1375
+ id_col = "source_identifier"
1376
+ if source:
1377
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1378
+ cursor = conn.execute(
1379
+ f"SELECT DISTINCT {id_col} FROM organizations WHERE source_id = ?",
1380
+ (source_type_id,)
1381
+ )
1382
+ else:
1383
+ cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM organizations")
1384
+ else:
1385
+ if source:
1386
+ cursor = conn.execute(
1387
+ "SELECT DISTINCT source_id FROM organizations WHERE source = ?",
1388
+ (source,)
1389
+ )
1390
+ else:
1391
+ cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
1392
+
1393
+ return {row[0] for row in cursor}
1394
+
719
1395
  def iter_records(self, source: Optional[str] = None) -> Iterator[CompanyRecord]:
720
1396
  """Iterate over all records, optionally filtered by source."""
721
1397
  conn = self._connect()
722
1398
 
723
- if source:
724
- cursor = conn.execute("""
725
- SELECT name, source, source_id, region, entity_type, record
726
- FROM organizations
727
- WHERE source = ?
728
- """, (source,))
1399
+ if self._is_v2:
1400
+ if source:
1401
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1402
+ cursor = conn.execute("""
1403
+ SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
1404
+ FROM organizations_view v
1405
+ JOIN organizations o ON v.id = o.id
1406
+ WHERE o.source_id = ?
1407
+ """, (source_type_id,))
1408
+ else:
1409
+ cursor = conn.execute("""
1410
+ SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
1411
+ FROM organizations_view v
1412
+ JOIN organizations o ON v.id = o.id
1413
+ """)
1414
+ for row in cursor:
1415
+ yield CompanyRecord(
1416
+ name=row["name"],
1417
+ source=row["source"],
1418
+ source_id=row["source_identifier"],
1419
+ region=row["region"] or "",
1420
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1421
+ record=json.loads(row["record"]),
1422
+ )
729
1423
  else:
730
- cursor = conn.execute("""
731
- SELECT name, source, source_id, region, entity_type, record
732
- FROM organizations
733
- """)
734
-
735
- for row in cursor:
736
- yield CompanyRecord(
737
- name=row["name"],
738
- source=row["source"],
739
- source_id=row["source_id"],
740
- region=row["region"] or "",
741
- entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
742
- record=json.loads(row["record"]),
743
- )
1424
+ if source:
1425
+ cursor = conn.execute("""
1426
+ SELECT name, source, source_id, region, entity_type, record
1427
+ FROM organizations
1428
+ WHERE source = ?
1429
+ """, (source,))
1430
+ else:
1431
+ cursor = conn.execute("""
1432
+ SELECT name, source, source_id, region, entity_type, record
1433
+ FROM organizations
1434
+ """)
1435
+ for row in cursor:
1436
+ yield CompanyRecord(
1437
+ name=row["name"],
1438
+ source=row["source"],
1439
+ source_id=row["source_id"],
1440
+ region=row["region"] or "",
1441
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1442
+ record=json.loads(row["record"]),
1443
+ )
744
1444
 
745
- def migrate_name_normalized(self, batch_size: int = 50000) -> int:
1445
+ def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
746
1446
  """
747
- Populate the name_normalized column for all records.
1447
+ Canonicalize all organizations by linking equivalent records.
748
1448
 
749
- This is a one-time migration for databases that don't have
750
- normalized names populated.
1449
+ Records are considered equivalent if they match by:
1450
+ 1. Same LEI (GLEIF source_id or Wikidata P1278) - globally unique, no region check
1451
+ 2. Same ticker symbol - globally unique, no region check
1452
+ 3. Same CIK - globally unique, no region check
1453
+ 4. Same normalized name AND same normalized region
1454
+ 5. Name match with suffix expansion AND same region
1455
+
1456
+ Region normalization uses pycountry to handle:
1457
+ - Country codes/names (GB, United Kingdom, Great Britain -> GB)
1458
+ - US state codes/names (CA, California -> US)
1459
+ - Common aliases (UK -> GB, USA -> US)
1460
+
1461
+ For each group of equivalent records, the highest-priority source
1462
+ (gleif > sec_edgar > companies_house > wikipedia) becomes canonical.
751
1463
 
752
1464
  Args:
753
- batch_size: Number of records to process per batch
1465
+ batch_size: Commit batch size for updates
754
1466
 
755
1467
  Returns:
756
- Number of records updated
1468
+ Dict with stats: total_records, groups_found, records_updated
1469
+ """
1470
+ conn = self._connect()
1471
+ logger.info("Starting canonicalization...")
1472
+
1473
+ # Phase 1: Load all organization data and build indexes
1474
+ logger.info("Phase 1: Building indexes...")
1475
+
1476
+ lei_index: dict[str, list[int]] = {}
1477
+ ticker_index: dict[str, list[int]] = {}
1478
+ cik_index: dict[str, list[int]] = {}
1479
+ # Name indexes now keyed by (normalized_name, normalized_region)
1480
+ # Region-less matching only applies for identifier-based matching
1481
+ name_region_index: dict[tuple[str, str], list[int]] = {}
1482
+ expanded_name_region_index: dict[tuple[str, str], list[int]] = {}
1483
+
1484
+ sources: dict[int, str] = {} # org_id -> source
1485
+ all_org_ids: list[int] = []
1486
+
1487
+ if self._is_v2:
1488
+ cursor = conn.execute("""
1489
+ SELECT o.id, s.name as source, o.source_identifier as source_id, o.name, l.name as region, o.record
1490
+ FROM organizations o
1491
+ JOIN source_types s ON o.source_id = s.id
1492
+ LEFT JOIN locations l ON o.region_id = l.id
1493
+ """)
1494
+ else:
1495
+ cursor = conn.execute("""
1496
+ SELECT id, source, source_id, name, region, record
1497
+ FROM organizations
1498
+ """)
1499
+
1500
+ count = 0
1501
+ for row in cursor:
1502
+ org_id = row["id"]
1503
+ source = row["source"]
1504
+ name = row["name"]
1505
+ region = row["region"] or ""
1506
+ record = json.loads(row["record"])
1507
+
1508
+ all_org_ids.append(org_id)
1509
+ sources[org_id] = source
1510
+
1511
+ # Index by LEI (GLEIF source_id or Wikidata's P1278)
1512
+ # LEI is globally unique - no region check needed
1513
+ if source == "gleif":
1514
+ lei = row["source_id"]
1515
+ else:
1516
+ lei = record.get("lei")
1517
+ if lei:
1518
+ lei_index.setdefault(lei.upper(), []).append(org_id)
1519
+
1520
+ # Index by ticker - globally unique, no region check
1521
+ ticker = record.get("ticker")
1522
+ if ticker:
1523
+ ticker_index.setdefault(ticker.upper(), []).append(org_id)
1524
+
1525
+ # Index by CIK - globally unique, no region check
1526
+ if source == "sec_edgar":
1527
+ cik = row["source_id"]
1528
+ else:
1529
+ cik = record.get("cik")
1530
+ if cik:
1531
+ cik_index.setdefault(str(cik), []).append(org_id)
1532
+
1533
+ # Index by (normalized_name, normalized_region)
1534
+ # Same name in different regions = different legal entities
1535
+ norm_name = _normalize_for_canon(name)
1536
+ norm_region = _normalize_region(region)
1537
+ if norm_name:
1538
+ key = (norm_name, norm_region)
1539
+ name_region_index.setdefault(key, []).append(org_id)
1540
+
1541
+ # Index by (expanded_name, normalized_region)
1542
+ expanded_name = _expand_suffix(name)
1543
+ if expanded_name and expanded_name != norm_name:
1544
+ key = (expanded_name, norm_region)
1545
+ expanded_name_region_index.setdefault(key, []).append(org_id)
1546
+
1547
+ count += 1
1548
+ if count % 100000 == 0:
1549
+ logger.info(f" Indexed {count} organizations...")
1550
+
1551
+ logger.info(f" Indexed {count} organizations total")
1552
+ logger.info(f" LEI index: {len(lei_index)} unique LEIs")
1553
+ logger.info(f" Ticker index: {len(ticker_index)} unique tickers")
1554
+ logger.info(f" CIK index: {len(cik_index)} unique CIKs")
1555
+ logger.info(f" Name+region index: {len(name_region_index)} unique (name, region) pairs")
1556
+ logger.info(f" Expanded name+region index: {len(expanded_name_region_index)} unique pairs")
1557
+
1558
+ # Phase 2: Build equivalence groups using Union-Find
1559
+ logger.info("Phase 2: Building equivalence groups...")
1560
+
1561
+ uf = UnionFind(all_org_ids)
1562
+
1563
+ # Merge by LEI (globally unique identifier)
1564
+ for _lei, ids in lei_index.items():
1565
+ for i in range(1, len(ids)):
1566
+ uf.union(ids[0], ids[i])
1567
+
1568
+ # Merge by ticker (globally unique identifier)
1569
+ for _ticker, ids in ticker_index.items():
1570
+ for i in range(1, len(ids)):
1571
+ uf.union(ids[0], ids[i])
1572
+
1573
+ # Merge by CIK (globally unique identifier)
1574
+ for _cik, ids in cik_index.items():
1575
+ for i in range(1, len(ids)):
1576
+ uf.union(ids[0], ids[i])
1577
+
1578
+ # Merge by (normalized_name, normalized_region)
1579
+ for _name_region, ids in name_region_index.items():
1580
+ for i in range(1, len(ids)):
1581
+ uf.union(ids[0], ids[i])
1582
+
1583
+ # Merge by (expanded_name, normalized_region)
1584
+ # This connects "Amazon Ltd" with "Amazon Limited" in same region
1585
+ for key, expanded_ids in expanded_name_region_index.items():
1586
+ # Find org_ids with the expanded form as their normalized name in same region
1587
+ if key in name_region_index:
1588
+ # Link first expanded_id to first name_id
1589
+ uf.union(expanded_ids[0], name_region_index[key][0])
1590
+
1591
+ groups = uf.groups()
1592
+ logger.info(f" Found {len(groups)} equivalence groups")
1593
+
1594
+ # Count groups with multiple records
1595
+ multi_record_groups = sum(1 for ids in groups.values() if len(ids) > 1)
1596
+ logger.info(f" Groups with multiple records: {multi_record_groups}")
1597
+
1598
+ # Phase 3: Select canonical record for each group and update database
1599
+ logger.info("Phase 3: Updating database...")
1600
+
1601
+ updated_count = 0
1602
+ batch_updates: list[tuple[int, int, int]] = [] # (org_id, canon_id, canon_size)
1603
+
1604
+ for _root, group_ids in groups.items():
1605
+ if len(group_ids) == 1:
1606
+ # Single record - canonical to itself
1607
+ batch_updates.append((group_ids[0], group_ids[0], 1))
1608
+ else:
1609
+ # Multiple records - find highest priority source
1610
+ best_id = min(
1611
+ group_ids,
1612
+ key=lambda oid: (SOURCE_PRIORITY.get(sources[oid], 99), oid)
1613
+ )
1614
+ group_size = len(group_ids)
1615
+
1616
+ # All records in group point to the best one
1617
+ for oid in group_ids:
1618
+ # canon_size is only set on the canonical record
1619
+ size = group_size if oid == best_id else 1
1620
+ batch_updates.append((oid, best_id, size))
1621
+
1622
+ # Commit batch
1623
+ if len(batch_updates) >= batch_size:
1624
+ self._apply_canon_updates(batch_updates)
1625
+ updated_count += len(batch_updates)
1626
+ logger.info(f" Updated {updated_count} records...")
1627
+ batch_updates = []
1628
+
1629
+ # Final batch
1630
+ if batch_updates:
1631
+ self._apply_canon_updates(batch_updates)
1632
+ updated_count += len(batch_updates)
1633
+
1634
+ conn.commit()
1635
+ logger.info(f"Canonicalization complete: {updated_count} records updated, {multi_record_groups} multi-record groups")
1636
+
1637
+ return {
1638
+ "total_records": count,
1639
+ "groups_found": len(groups),
1640
+ "multi_record_groups": multi_record_groups,
1641
+ "records_updated": updated_count,
1642
+ }
1643
+
1644
+ def _apply_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
1645
+ """Apply batch of canon updates: (org_id, canon_id, canon_size)."""
1646
+ conn = self._conn
1647
+ assert conn is not None
1648
+
1649
+ for org_id, canon_id, canon_size in updates:
1650
+ conn.execute(
1651
+ "UPDATE organizations SET canon_id = ?, canon_size = ? WHERE id = ?",
1652
+ (canon_id, canon_size, org_id)
1653
+ )
1654
+
1655
+ conn.commit()
1656
+
1657
+ def get_canon_stats(self) -> dict[str, int]:
1658
+ """Get statistics about canonicalization status."""
1659
+ conn = self._connect()
1660
+
1661
+ # Total records
1662
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations")
1663
+ total = cursor.fetchone()[0]
1664
+
1665
+ # Records with canon_id set
1666
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_id IS NOT NULL")
1667
+ canonicalized = cursor.fetchone()[0]
1668
+
1669
+ # Number of canonical groups (unique canon_ids)
1670
+ cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM organizations WHERE canon_id IS NOT NULL")
1671
+ groups = cursor.fetchone()[0]
1672
+
1673
+ # Multi-record groups (canon_size > 1)
1674
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE canon_size > 1")
1675
+ multi_record_groups = cursor.fetchone()[0]
1676
+
1677
+ # Records in multi-record groups
1678
+ cursor = conn.execute("""
1679
+ SELECT COUNT(*) FROM organizations o1
1680
+ WHERE EXISTS (SELECT 1 FROM organizations o2 WHERE o2.id = o1.canon_id AND o2.canon_size > 1)
1681
+ """)
1682
+ records_in_multi = cursor.fetchone()[0]
1683
+
1684
+ return {
1685
+ "total_records": total,
1686
+ "canonicalized_records": canonicalized,
1687
+ "canonical_groups": groups,
1688
+ "multi_record_groups": multi_record_groups,
1689
+ "records_in_multi_groups": records_in_multi,
1690
+ }
1691
+
1692
+ def migrate_name_normalized(self, batch_size: int = 50000) -> int:
1693
+ """
1694
+ Populate the name_normalized column for all records.
1695
+
1696
+ This is a one-time migration for databases that don't have
1697
+ normalized names populated.
1698
+
1699
+ Args:
1700
+ batch_size: Number of records to process per batch
1701
+
1702
+ Returns:
1703
+ Number of records updated
757
1704
  """
758
1705
  conn = self._connect()
759
1706
 
@@ -867,8 +1814,9 @@ class OrganizationDatabase:
867
1814
  assert conn is not None
868
1815
 
869
1816
  for org_id, embedding_blob in batch:
1817
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
870
1818
  conn.execute("""
871
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
1819
+ INSERT INTO organization_embeddings (org_id, embedding)
872
1820
  VALUES (?, ?)
873
1821
  """, (org_id, embedding_blob))
874
1822
 
@@ -878,17 +1826,32 @@ class OrganizationDatabase:
878
1826
  """Delete all records from a specific source."""
879
1827
  conn = self._connect()
880
1828
 
881
- # First get IDs to delete from vec table
882
- cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
883
- ids_to_delete = [row["id"] for row in cursor]
1829
+ if self._is_v2:
1830
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1831
+ # First get IDs to delete from vec table
1832
+ cursor = conn.execute("SELECT id FROM organizations WHERE source_id = ?", (source_type_id,))
1833
+ ids_to_delete = [row["id"] for row in cursor]
1834
+
1835
+ # Delete from vec table
1836
+ if ids_to_delete:
1837
+ placeholders = ",".join("?" * len(ids_to_delete))
1838
+ conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
1839
+
1840
+ # Delete from main table
1841
+ cursor = conn.execute("DELETE FROM organizations WHERE source_id = ?", (source_type_id,))
1842
+ else:
1843
+ # First get IDs to delete from vec table
1844
+ cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
1845
+ ids_to_delete = [row["id"] for row in cursor]
1846
+
1847
+ # Delete from vec table
1848
+ if ids_to_delete:
1849
+ placeholders = ",".join("?" * len(ids_to_delete))
1850
+ conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
884
1851
 
885
- # Delete from vec table
886
- if ids_to_delete:
887
- placeholders = ",".join("?" * len(ids_to_delete))
888
- conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
1852
+ # Delete from main table
1853
+ cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
889
1854
 
890
- # Delete from main table
891
- cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
892
1855
  deleted = cursor.rowcount
893
1856
 
894
1857
  conn.commit()
@@ -1107,8 +2070,9 @@ class OrganizationDatabase:
1107
2070
 
1108
2071
  for org_id, embedding in zip(org_ids, embeddings):
1109
2072
  embedding_blob = embedding.astype(np.float32).tobytes()
2073
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
1110
2074
  conn.execute("""
1111
- INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
2075
+ INSERT INTO organization_embeddings (org_id, embedding)
1112
2076
  VALUES (?, ?)
1113
2077
  """, (org_id, embedding_blob))
1114
2078
  count += 1
@@ -1116,6 +2080,287 @@ class OrganizationDatabase:
1116
2080
  conn.commit()
1117
2081
  return count
1118
2082
 
2083
+ def ensure_scalar_table_exists(self) -> None:
2084
+ """Create scalar embedding table if it doesn't exist."""
2085
+ conn = self._connect()
2086
+ conn.execute(f"""
2087
+ CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
2088
+ org_id INTEGER PRIMARY KEY,
2089
+ embedding int8[{self._embedding_dim}]
2090
+ )
2091
+ """)
2092
+ conn.commit()
2093
+ logger.info("Ensured organization_embeddings_scalar table exists")
2094
+
2095
+ def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
2096
+ """
2097
+ Yield batches of org IDs that have float32 but missing scalar embeddings.
2098
+
2099
+ Args:
2100
+ batch_size: Number of IDs per batch
2101
+
2102
+ Yields:
2103
+ Lists of org_ids needing scalar embeddings
2104
+ """
2105
+ conn = self._connect()
2106
+
2107
+ # Ensure scalar table exists before querying
2108
+ self.ensure_scalar_table_exists()
2109
+
2110
+ last_id = 0
2111
+ while True:
2112
+ cursor = conn.execute("""
2113
+ SELECT e.org_id FROM organization_embeddings e
2114
+ LEFT JOIN organization_embeddings_scalar s ON e.org_id = s.org_id
2115
+ WHERE s.org_id IS NULL AND e.org_id > ?
2116
+ ORDER BY e.org_id
2117
+ LIMIT ?
2118
+ """, (last_id, batch_size))
2119
+
2120
+ rows = cursor.fetchall()
2121
+ if not rows:
2122
+ break
2123
+
2124
+ ids = [row["org_id"] for row in rows]
2125
+ yield ids
2126
+ last_id = ids[-1]
2127
+
2128
+ def get_embeddings_by_ids(self, org_ids: list[int]) -> dict[int, np.ndarray]:
2129
+ """
2130
+ Fetch float32 embeddings for given org IDs.
2131
+
2132
+ Args:
2133
+ org_ids: List of organization IDs
2134
+
2135
+ Returns:
2136
+ Dict mapping org_id to float32 embedding array
2137
+ """
2138
+ conn = self._connect()
2139
+
2140
+ if not org_ids:
2141
+ return {}
2142
+
2143
+ placeholders = ",".join("?" * len(org_ids))
2144
+ cursor = conn.execute(f"""
2145
+ SELECT org_id, embedding FROM organization_embeddings
2146
+ WHERE org_id IN ({placeholders})
2147
+ """, org_ids)
2148
+
2149
+ result = {}
2150
+ for row in cursor:
2151
+ embedding_blob = row["embedding"]
2152
+ embedding = np.frombuffer(embedding_blob, dtype=np.float32)
2153
+ result[row["org_id"]] = embedding
2154
+ return result
2155
+
2156
+ def insert_scalar_embeddings_batch(self, org_ids: list[int], embeddings: np.ndarray) -> int:
2157
+ """
2158
+ Insert scalar (int8) embeddings for existing orgs.
2159
+
2160
+ Args:
2161
+ org_ids: List of organization IDs
2162
+ embeddings: Matrix of int8 embeddings (N x dim)
2163
+
2164
+ Returns:
2165
+ Number of embeddings inserted
2166
+ """
2167
+ conn = self._connect()
2168
+ count = 0
2169
+
2170
+ for org_id, embedding in zip(org_ids, embeddings):
2171
+ scalar_blob = embedding.astype(np.int8).tobytes()
2172
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
2173
+ conn.execute("""
2174
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
2175
+ VALUES (?, vec_int8(?))
2176
+ """, (org_id, scalar_blob))
2177
+ count += 1
2178
+
2179
+ conn.commit()
2180
+ return count
2181
+
2182
+ def get_scalar_embedding_count(self) -> int:
2183
+ """Get count of scalar embeddings."""
2184
+ conn = self._connect()
2185
+ if not self._has_scalar_table():
2186
+ return 0
2187
+ cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings_scalar")
2188
+ return cursor.fetchone()[0]
2189
+
2190
+ def get_float32_embedding_count(self) -> int:
2191
+ """Get count of float32 embeddings."""
2192
+ conn = self._connect()
2193
+ cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings")
2194
+ return cursor.fetchone()[0]
2195
+
2196
+ def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
2197
+ """
2198
+ Yield batches of (org_id, name) tuples for records missing both float32 and scalar embeddings.
2199
+
2200
+ Args:
2201
+ batch_size: Number of IDs per batch
2202
+
2203
+ Yields:
2204
+ Lists of (org_id, name) tuples needing embeddings generated from scratch
2205
+ """
2206
+ conn = self._connect()
2207
+
2208
+ # Ensure scalar table exists
2209
+ self.ensure_scalar_table_exists()
2210
+
2211
+ last_id = 0
2212
+ while True:
2213
+ cursor = conn.execute("""
2214
+ SELECT o.id, o.name FROM organizations o
2215
+ LEFT JOIN organization_embeddings e ON o.id = e.org_id
2216
+ WHERE e.org_id IS NULL AND o.id > ?
2217
+ ORDER BY o.id
2218
+ LIMIT ?
2219
+ """, (last_id, batch_size))
2220
+
2221
+ rows = cursor.fetchall()
2222
+ if not rows:
2223
+ break
2224
+
2225
+ results = [(row["id"], row["name"]) for row in rows]
2226
+ yield results
2227
+ last_id = results[-1][0]
2228
+
2229
+ def insert_both_embeddings_batch(
2230
+ self,
2231
+ org_ids: list[int],
2232
+ fp32_embeddings: np.ndarray,
2233
+ int8_embeddings: np.ndarray,
2234
+ ) -> int:
2235
+ """
2236
+ Insert both float32 and int8 embeddings for existing orgs.
2237
+
2238
+ Args:
2239
+ org_ids: List of organization IDs
2240
+ fp32_embeddings: Matrix of float32 embeddings (N x dim)
2241
+ int8_embeddings: Matrix of int8 embeddings (N x dim)
2242
+
2243
+ Returns:
2244
+ Number of embeddings inserted
2245
+ """
2246
+ conn = self._connect()
2247
+ count = 0
2248
+
2249
+ for org_id, fp32, int8 in zip(org_ids, fp32_embeddings, int8_embeddings):
2250
+ # Insert float32
2251
+ fp32_blob = fp32.astype(np.float32).tobytes()
2252
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
2253
+ conn.execute("""
2254
+ INSERT INTO organization_embeddings (org_id, embedding)
2255
+ VALUES (?, ?)
2256
+ """, (org_id, fp32_blob))
2257
+
2258
+ # Insert int8
2259
+ int8_blob = int8.astype(np.int8).tobytes()
2260
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
2261
+ conn.execute("""
2262
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
2263
+ VALUES (?, vec_int8(?))
2264
+ """, (org_id, int8_blob))
2265
+
2266
+ count += 1
2267
+
2268
+ conn.commit()
2269
+ return count
2270
+
2271
+ def resolve_qid_labels(
2272
+ self,
2273
+ label_map: dict[str, str],
2274
+ batch_size: int = 1000,
2275
+ ) -> tuple[int, int]:
2276
+ """
2277
+ Update organization records that have QIDs instead of labels in region field.
2278
+
2279
+ If resolving would create a duplicate of an existing record with
2280
+ resolved labels, the QID version is deleted instead.
2281
+
2282
+ Args:
2283
+ label_map: Mapping of QID -> label for resolution
2284
+ batch_size: Commit batch size
2285
+
2286
+ Returns:
2287
+ Tuple of (records updated, duplicates deleted)
2288
+ """
2289
+ conn = self._connect()
2290
+
2291
+ # Find records with QIDs in region field (starts with 'Q' followed by digits)
2292
+ region_updates = 0
2293
+ cursor = conn.execute("""
2294
+ SELECT id, region FROM organizations
2295
+ WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
2296
+ """)
2297
+ rows = cursor.fetchall()
2298
+
2299
+ duplicates_deleted = 0
2300
+ for row in rows:
2301
+ org_id = row["id"]
2302
+ qid = row["region"]
2303
+ if qid in label_map:
2304
+ resolved_region = label_map[qid]
2305
+ # Check if this update would create a duplicate
2306
+ # Get the name and source of the current record
2307
+ org_cursor = conn.execute(
2308
+ "SELECT name, source FROM organizations WHERE id = ?",
2309
+ (org_id,)
2310
+ )
2311
+ org_row = org_cursor.fetchone()
2312
+ if org_row is None:
2313
+ continue
2314
+
2315
+ org_name = org_row["name"]
2316
+ org_source = org_row["source"]
2317
+
2318
+ # Check if a record with the resolved region already exists
2319
+ existing_cursor = conn.execute(
2320
+ "SELECT id FROM organizations WHERE name = ? AND region = ? AND source = ? AND id != ?",
2321
+ (org_name, resolved_region, org_source, org_id)
2322
+ )
2323
+ existing = existing_cursor.fetchone()
2324
+
2325
+ if existing is not None:
2326
+ # Duplicate would be created - delete the QID-based record
2327
+ conn.execute("DELETE FROM organizations WHERE id = ?", (org_id,))
2328
+ duplicates_deleted += 1
2329
+ else:
2330
+ # Safe to update
2331
+ conn.execute(
2332
+ "UPDATE organizations SET region = ? WHERE id = ?",
2333
+ (resolved_region, org_id)
2334
+ )
2335
+ region_updates += 1
2336
+
2337
+ if (region_updates + duplicates_deleted) % batch_size == 0:
2338
+ conn.commit()
2339
+ logger.info(f"Resolved QID labels: {region_updates} updates, {duplicates_deleted} deletes...")
2340
+
2341
+ conn.commit()
2342
+ logger.info(f"Resolved QID labels: {region_updates} organization regions, {duplicates_deleted} duplicates deleted")
2343
+ return region_updates, duplicates_deleted
2344
+
2345
+ def get_unresolved_qids(self) -> set[str]:
2346
+ """
2347
+ Get all QIDs that still need resolution in the organizations table.
2348
+
2349
+ Returns:
2350
+ Set of QIDs (starting with 'Q') found in region field
2351
+ """
2352
+ conn = self._connect()
2353
+ qids: set[str] = set()
2354
+
2355
+ cursor = conn.execute("""
2356
+ SELECT DISTINCT region FROM organizations
2357
+ WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
2358
+ """)
2359
+ for row in cursor:
2360
+ qids.add(row["region"])
2361
+
2362
+ return qids
2363
+
1119
2364
 
1120
2365
  def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
1121
2366
  """
@@ -1161,36 +2406,51 @@ class PersonDatabase:
1161
2406
  self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
1162
2407
  self._embedding_dim = embedding_dim
1163
2408
  self._conn: Optional[sqlite3.Connection] = None
2409
+ self._is_v2: Optional[bool] = None # Detected on first connect
1164
2410
 
1165
2411
  def _ensure_dir(self) -> None:
1166
2412
  """Ensure database directory exists."""
1167
2413
  self._db_path.parent.mkdir(parents=True, exist_ok=True)
1168
2414
 
1169
2415
  def _connect(self) -> sqlite3.Connection:
1170
- """Get or create database connection with sqlite-vec loaded."""
2416
+ """Get or create database connection using shared connection pool."""
1171
2417
  if self._conn is not None:
1172
2418
  return self._conn
1173
2419
 
1174
- self._ensure_dir()
1175
- self._conn = sqlite3.connect(str(self._db_path))
1176
- self._conn.row_factory = sqlite3.Row
2420
+ self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
1177
2421
 
1178
- # Load sqlite-vec extension
1179
- self._conn.enable_load_extension(True)
1180
- sqlite_vec.load(self._conn)
1181
- self._conn.enable_load_extension(False)
2422
+ # Detect schema version BEFORE creating tables
2423
+ # v2 has person_type_id (FK) instead of person_type (TEXT)
2424
+ if self._is_v2 is None:
2425
+ cursor = self._conn.execute("PRAGMA table_info(people)")
2426
+ columns = {row["name"] for row in cursor}
2427
+ self._is_v2 = "person_type_id" in columns
2428
+ if self._is_v2:
2429
+ logger.debug("Detected v2 schema for people")
1182
2430
 
1183
- # Create tables
1184
- self._create_tables()
2431
+ # Create tables (idempotent) - only for v1 schema or fresh databases
2432
+ # v2 databases already have their schema from migration
2433
+ if not self._is_v2:
2434
+ self._create_tables()
1185
2435
 
1186
2436
  return self._conn
1187
2437
 
2438
+ @property
2439
+ def _people_table(self) -> str:
2440
+ """Return table/view name for people queries needing text fields."""
2441
+ return "people_view" if self._is_v2 else "people"
2442
+
1188
2443
  def _create_tables(self) -> None:
1189
2444
  """Create database tables including sqlite-vec virtual table."""
1190
2445
  conn = self._conn
1191
2446
  assert conn is not None
1192
2447
 
2448
+ # Check if we need to migrate from old schema (unique on source+source_id only)
2449
+ self._migrate_people_schema_if_needed(conn)
2450
+
1193
2451
  # Main people records table
2452
+ # Unique constraint on source+source_id+role+org allows multiple records
2453
+ # for the same person with different role/org combinations
1194
2454
  conn.execute("""
1195
2455
  CREATE TABLE IF NOT EXISTS people (
1196
2456
  id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -1202,8 +2462,14 @@ class PersonDatabase:
1202
2462
  person_type TEXT NOT NULL DEFAULT 'unknown',
1203
2463
  known_for_role TEXT NOT NULL DEFAULT '',
1204
2464
  known_for_org TEXT NOT NULL DEFAULT '',
2465
+ known_for_org_id INTEGER DEFAULT NULL,
2466
+ from_date TEXT NOT NULL DEFAULT '',
2467
+ to_date TEXT NOT NULL DEFAULT '',
2468
+ birth_date TEXT NOT NULL DEFAULT '',
2469
+ death_date TEXT NOT NULL DEFAULT '',
1205
2470
  record TEXT NOT NULL,
1206
- UNIQUE(source, source_id)
2471
+ UNIQUE(source, source_id, known_for_role, known_for_org),
2472
+ FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
1207
2473
  )
1208
2474
  """)
1209
2475
 
@@ -1211,10 +2477,72 @@ class PersonDatabase:
1211
2477
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name ON people(name)")
1212
2478
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name_normalized ON people(name_normalized)")
1213
2479
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source ON people(source)")
1214
- conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id)")
2480
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id, known_for_role, known_for_org)")
1215
2481
  conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org ON people(known_for_org)")
1216
2482
 
1217
- # Create sqlite-vec virtual table for embeddings
2483
+ # Add from_date column if it doesn't exist (migration for existing DBs)
2484
+ try:
2485
+ conn.execute("ALTER TABLE people ADD COLUMN from_date TEXT NOT NULL DEFAULT ''")
2486
+ logger.info("Added from_date column to people table")
2487
+ except sqlite3.OperationalError:
2488
+ pass # Column already exists
2489
+
2490
+ # Add to_date column if it doesn't exist (migration for existing DBs)
2491
+ try:
2492
+ conn.execute("ALTER TABLE people ADD COLUMN to_date TEXT NOT NULL DEFAULT ''")
2493
+ logger.info("Added to_date column to people table")
2494
+ except sqlite3.OperationalError:
2495
+ pass # Column already exists
2496
+
2497
+ # Add known_for_org_id column if it doesn't exist (migration for existing DBs)
2498
+ # This is a foreign key to the organizations table (nullable)
2499
+ try:
2500
+ conn.execute("ALTER TABLE people ADD COLUMN known_for_org_id INTEGER DEFAULT NULL")
2501
+ logger.info("Added known_for_org_id column to people table")
2502
+ except sqlite3.OperationalError:
2503
+ pass # Column already exists
2504
+
2505
+ # Create index on known_for_org_id for joins (only if column exists)
2506
+ try:
2507
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org_id ON people(known_for_org_id)")
2508
+ except sqlite3.OperationalError:
2509
+ pass # Column doesn't exist yet (will be added on next connection)
2510
+
2511
+ # Add birth_date column if it doesn't exist (migration for existing DBs)
2512
+ try:
2513
+ conn.execute("ALTER TABLE people ADD COLUMN birth_date TEXT NOT NULL DEFAULT ''")
2514
+ logger.info("Added birth_date column to people table")
2515
+ except sqlite3.OperationalError:
2516
+ pass # Column already exists
2517
+
2518
+ # Add death_date column if it doesn't exist (migration for existing DBs)
2519
+ try:
2520
+ conn.execute("ALTER TABLE people ADD COLUMN death_date TEXT NOT NULL DEFAULT ''")
2521
+ logger.info("Added death_date column to people table")
2522
+ except sqlite3.OperationalError:
2523
+ pass # Column already exists
2524
+
2525
+ # Add canon_id column if it doesn't exist (migration for canonicalization)
2526
+ try:
2527
+ conn.execute("ALTER TABLE people ADD COLUMN canon_id INTEGER DEFAULT NULL")
2528
+ logger.info("Added canon_id column to people table")
2529
+ except sqlite3.OperationalError:
2530
+ pass # Column already exists
2531
+
2532
+ # Add canon_size column if it doesn't exist (migration for canonicalization)
2533
+ try:
2534
+ conn.execute("ALTER TABLE people ADD COLUMN canon_size INTEGER DEFAULT 1")
2535
+ logger.info("Added canon_size column to people table")
2536
+ except sqlite3.OperationalError:
2537
+ pass # Column already exists
2538
+
2539
+ # Create index on canon_id for joins
2540
+ try:
2541
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_canon_id ON people(canon_id)")
2542
+ except sqlite3.OperationalError:
2543
+ pass # Column doesn't exist yet
2544
+
2545
+ # Create sqlite-vec virtual table for embeddings (float32)
1218
2546
  conn.execute(f"""
1219
2547
  CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
1220
2548
  person_id INTEGER PRIMARY KEY,
@@ -1222,21 +2550,117 @@ class PersonDatabase:
1222
2550
  )
1223
2551
  """)
1224
2552
 
1225
- conn.commit()
1226
-
2553
+ # Create sqlite-vec virtual table for scalar embeddings (int8)
2554
+ # Provides 75% storage reduction with ~92% recall at top-100
2555
+ conn.execute(f"""
2556
+ CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
2557
+ person_id INTEGER PRIMARY KEY,
2558
+ embedding int8[{self._embedding_dim}]
2559
+ )
2560
+ """)
2561
+
2562
+ # Create QID labels lookup table for Wikidata QID -> label mappings
2563
+ conn.execute("""
2564
+ CREATE TABLE IF NOT EXISTS qid_labels (
2565
+ qid TEXT PRIMARY KEY,
2566
+ label TEXT NOT NULL
2567
+ )
2568
+ """)
2569
+
2570
+ conn.commit()
2571
+
2572
+ def _migrate_people_schema_if_needed(self, conn: sqlite3.Connection) -> None:
2573
+ """Migrate people table from old schema if needed."""
2574
+ # Check if people table exists
2575
+ cursor = conn.execute(
2576
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='people'"
2577
+ )
2578
+ if not cursor.fetchone():
2579
+ return # Table doesn't exist, no migration needed
2580
+
2581
+ # Check the unique constraint - look at index info
2582
+ # Old schema: UNIQUE(source, source_id)
2583
+ # New schema: UNIQUE(source, source_id, known_for_role, known_for_org)
2584
+ cursor = conn.execute("PRAGMA index_list(people)")
2585
+ indexes = cursor.fetchall()
2586
+
2587
+ needs_migration = False
2588
+ for idx in indexes:
2589
+ idx_name = idx[1]
2590
+ if "sqlite_autoindex_people" in idx_name:
2591
+ # Check columns in this unique index
2592
+ cursor = conn.execute(f"PRAGMA index_info('{idx_name}')")
2593
+ cols = [row[2] for row in cursor.fetchall()]
2594
+ # Old schema has only 2 columns in unique constraint
2595
+ if cols == ["source", "source_id"]:
2596
+ needs_migration = True
2597
+ logger.info("Detected old people schema, migrating to new unique constraint...")
2598
+ break
2599
+
2600
+ if not needs_migration:
2601
+ return
2602
+
2603
+ # Migrate: create new table, copy data, drop old, rename new
2604
+ logger.info("Migrating people table to new schema with (source, source_id, role, org) unique constraint...")
2605
+
2606
+ conn.execute("""
2607
+ CREATE TABLE people_new (
2608
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
2609
+ name TEXT NOT NULL,
2610
+ name_normalized TEXT NOT NULL,
2611
+ source TEXT NOT NULL DEFAULT 'wikidata',
2612
+ source_id TEXT NOT NULL,
2613
+ country TEXT NOT NULL DEFAULT '',
2614
+ person_type TEXT NOT NULL DEFAULT 'unknown',
2615
+ known_for_role TEXT NOT NULL DEFAULT '',
2616
+ known_for_org TEXT NOT NULL DEFAULT '',
2617
+ known_for_org_id INTEGER DEFAULT NULL,
2618
+ from_date TEXT NOT NULL DEFAULT '',
2619
+ to_date TEXT NOT NULL DEFAULT '',
2620
+ record TEXT NOT NULL,
2621
+ UNIQUE(source, source_id, known_for_role, known_for_org),
2622
+ FOREIGN KEY (known_for_org_id) REFERENCES organizations(id)
2623
+ )
2624
+ """)
2625
+
2626
+ # Copy data (old IDs will change, but embeddings table references them)
2627
+ # Note: old table may not have from_date/to_date columns, so use defaults
2628
+ conn.execute("""
2629
+ INSERT INTO people_new (name, name_normalized, source, source_id, country,
2630
+ person_type, known_for_role, known_for_org, record)
2631
+ SELECT name, name_normalized, source, source_id, country,
2632
+ person_type, known_for_role, known_for_org, record
2633
+ FROM people
2634
+ """)
2635
+
2636
+ # Drop old table and embeddings (IDs changed, embeddings are invalid)
2637
+ conn.execute("DROP TABLE IF EXISTS person_embeddings")
2638
+ conn.execute("DROP TABLE people")
2639
+ conn.execute("ALTER TABLE people_new RENAME TO people")
2640
+
2641
+ # Drop old index if it exists
2642
+ conn.execute("DROP INDEX IF EXISTS idx_people_source_id")
2643
+
2644
+ conn.commit()
2645
+ logger.info("Migration complete. Note: person embeddings were cleared and need to be regenerated.")
2646
+
1227
2647
  def close(self) -> None:
1228
- """Close database connection."""
1229
- if self._conn:
1230
- self._conn.close()
1231
- self._conn = None
2648
+ """Clear connection reference (shared connection remains open)."""
2649
+ self._conn = None
1232
2650
 
1233
- def insert(self, record: PersonRecord, embedding: np.ndarray) -> int:
2651
+ def insert(
2652
+ self,
2653
+ record: PersonRecord,
2654
+ embedding: np.ndarray,
2655
+ scalar_embedding: Optional[np.ndarray] = None,
2656
+ ) -> int:
1234
2657
  """
1235
2658
  Insert a person record with its embedding.
1236
2659
 
1237
2660
  Args:
1238
2661
  record: Person record to insert
1239
- embedding: Embedding vector for the person name
2662
+ embedding: Embedding vector for the person name (float32)
2663
+ scalar_embedding: Optional int8 scalar embedding for compact storage
1240
2664
 
1241
2665
  Returns:
1242
2666
  Row ID of inserted record
@@ -1249,8 +2673,10 @@ class PersonDatabase:
1249
2673
 
1250
2674
  cursor = conn.execute("""
1251
2675
  INSERT OR REPLACE INTO people
1252
- (name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
1253
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
2676
+ (name, name_normalized, source, source_id, country, person_type,
2677
+ known_for_role, known_for_org, known_for_org_id, from_date, to_date,
2678
+ birth_date, death_date, record)
2679
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
1254
2680
  """, (
1255
2681
  record.name,
1256
2682
  name_normalized,
@@ -1260,19 +2686,34 @@ class PersonDatabase:
1260
2686
  record.person_type.value,
1261
2687
  record.known_for_role,
1262
2688
  record.known_for_org,
2689
+ record.known_for_org_id, # Can be None
2690
+ record.from_date or "",
2691
+ record.to_date or "",
2692
+ record.birth_date or "",
2693
+ record.death_date or "",
1263
2694
  record_json,
1264
2695
  ))
1265
2696
 
1266
2697
  row_id = cursor.lastrowid
1267
2698
  assert row_id is not None
1268
2699
 
1269
- # Insert embedding into vec table
2700
+ # Insert embedding into vec table (float32)
1270
2701
  embedding_blob = embedding.astype(np.float32).tobytes()
2702
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
1271
2703
  conn.execute("""
1272
- INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
2704
+ INSERT INTO person_embeddings (person_id, embedding)
1273
2705
  VALUES (?, ?)
1274
2706
  """, (row_id, embedding_blob))
1275
2707
 
2708
+ # Insert scalar embedding if provided (int8)
2709
+ if scalar_embedding is not None:
2710
+ scalar_blob = scalar_embedding.astype(np.int8).tobytes()
2711
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
2712
+ conn.execute("""
2713
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
2714
+ VALUES (?, vec_int8(?))
2715
+ """, (row_id, scalar_blob))
2716
+
1276
2717
  conn.commit()
1277
2718
  return row_id
1278
2719
 
@@ -1281,14 +2722,16 @@ class PersonDatabase:
1281
2722
  records: list[PersonRecord],
1282
2723
  embeddings: np.ndarray,
1283
2724
  batch_size: int = 1000,
2725
+ scalar_embeddings: Optional[np.ndarray] = None,
1284
2726
  ) -> int:
1285
2727
  """
1286
2728
  Insert multiple person records with embeddings.
1287
2729
 
1288
2730
  Args:
1289
2731
  records: List of person records
1290
- embeddings: Matrix of embeddings (N x dim)
2732
+ embeddings: Matrix of embeddings (N x dim) - float32
1291
2733
  batch_size: Commit batch size
2734
+ scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
1292
2735
 
1293
2736
  Returns:
1294
2737
  Number of records inserted
@@ -1296,36 +2739,87 @@ class PersonDatabase:
1296
2739
  conn = self._connect()
1297
2740
  count = 0
1298
2741
 
1299
- for record, embedding in zip(records, embeddings):
2742
+ for i, (record, embedding) in enumerate(zip(records, embeddings)):
1300
2743
  record_json = json.dumps(record.record)
1301
2744
  name_normalized = _normalize_person_name(record.name)
1302
2745
 
1303
- cursor = conn.execute("""
1304
- INSERT OR REPLACE INTO people
1305
- (name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
1306
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
1307
- """, (
1308
- record.name,
1309
- name_normalized,
1310
- record.source,
1311
- record.source_id,
1312
- record.country,
1313
- record.person_type.value,
1314
- record.known_for_role,
1315
- record.known_for_org,
1316
- record_json,
1317
- ))
2746
+ if self._is_v2:
2747
+ # v2 schema: use FK IDs instead of TEXT columns
2748
+ source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
2749
+ person_type_id = PEOPLE_TYPE_NAME_TO_ID.get(record.person_type.value, 15) # 15 = unknown
2750
+
2751
+ # Resolve country to location_id if provided
2752
+ country_id = None
2753
+ if record.country:
2754
+ locations_db = get_locations_database(db_path=self._db_path)
2755
+ country_id = locations_db.resolve_region_text(record.country)
2756
+
2757
+ cursor = conn.execute("""
2758
+ INSERT OR REPLACE INTO people
2759
+ (name, name_normalized, source_id, source_identifier, country_id, person_type_id,
2760
+ known_for_org, known_for_org_id, from_date, to_date,
2761
+ birth_date, death_date, record)
2762
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
2763
+ """, (
2764
+ record.name,
2765
+ name_normalized,
2766
+ source_type_id,
2767
+ record.source_id,
2768
+ country_id,
2769
+ person_type_id,
2770
+ record.known_for_org,
2771
+ record.known_for_org_id, # Can be None
2772
+ record.from_date or "",
2773
+ record.to_date or "",
2774
+ record.birth_date or "",
2775
+ record.death_date or "",
2776
+ record_json,
2777
+ ))
2778
+ else:
2779
+ # v1 schema: use TEXT columns
2780
+ cursor = conn.execute("""
2781
+ INSERT OR REPLACE INTO people
2782
+ (name, name_normalized, source, source_id, country, person_type,
2783
+ known_for_role, known_for_org, known_for_org_id, from_date, to_date,
2784
+ birth_date, death_date, record)
2785
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
2786
+ """, (
2787
+ record.name,
2788
+ name_normalized,
2789
+ record.source,
2790
+ record.source_id,
2791
+ record.country,
2792
+ record.person_type.value,
2793
+ record.known_for_role,
2794
+ record.known_for_org,
2795
+ record.known_for_org_id, # Can be None
2796
+ record.from_date or "",
2797
+ record.to_date or "",
2798
+ record.birth_date or "",
2799
+ record.death_date or "",
2800
+ record_json,
2801
+ ))
1318
2802
 
1319
2803
  row_id = cursor.lastrowid
1320
2804
  assert row_id is not None
1321
2805
 
1322
- # Insert embedding
2806
+ # Insert embedding (delete first since sqlite-vec doesn't support REPLACE)
1323
2807
  embedding_blob = embedding.astype(np.float32).tobytes()
2808
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
1324
2809
  conn.execute("""
1325
- INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
2810
+ INSERT INTO person_embeddings (person_id, embedding)
1326
2811
  VALUES (?, ?)
1327
2812
  """, (row_id, embedding_blob))
1328
2813
 
2814
+ # Insert scalar embedding if provided (int8)
2815
+ if scalar_embeddings is not None:
2816
+ scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
2817
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
2818
+ conn.execute("""
2819
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
2820
+ VALUES (?, vec_int8(?))
2821
+ """, (row_id, scalar_blob))
2822
+
1329
2823
  count += 1
1330
2824
 
1331
2825
  if count % batch_size == 0:
@@ -1335,6 +2829,102 @@ class PersonDatabase:
1335
2829
  conn.commit()
1336
2830
  return count
1337
2831
 
2832
+ def update_dates(self, source: str, source_id: str, from_date: Optional[str], to_date: Optional[str]) -> bool:
2833
+ """
2834
+ Update the from_date and to_date for a person record.
2835
+
2836
+ Args:
2837
+ source: Data source (e.g., 'wikidata')
2838
+ source_id: Source identifier (e.g., QID)
2839
+ from_date: Start date in ISO format or None
2840
+ to_date: End date in ISO format or None
2841
+
2842
+ Returns:
2843
+ True if record was updated, False if not found
2844
+ """
2845
+ conn = self._connect()
2846
+
2847
+ if self._is_v2:
2848
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
2849
+ cursor = conn.execute("""
2850
+ UPDATE people SET from_date = ?, to_date = ?
2851
+ WHERE source_id = ? AND source_identifier = ?
2852
+ """, (from_date or "", to_date or "", source_type_id, source_id))
2853
+ else:
2854
+ cursor = conn.execute("""
2855
+ UPDATE people SET from_date = ?, to_date = ?
2856
+ WHERE source = ? AND source_id = ?
2857
+ """, (from_date or "", to_date or "", source, source_id))
2858
+
2859
+ conn.commit()
2860
+ return cursor.rowcount > 0
2861
+
2862
+ def update_role_org(
2863
+ self,
2864
+ source: str,
2865
+ source_id: str,
2866
+ known_for_role: str,
2867
+ known_for_org: str,
2868
+ known_for_org_id: Optional[int],
2869
+ new_embedding: np.ndarray,
2870
+ from_date: Optional[str] = None,
2871
+ to_date: Optional[str] = None,
2872
+ ) -> bool:
2873
+ """
2874
+ Update the role/org/dates data for a person record and re-embed.
2875
+
2876
+ Args:
2877
+ source: Data source (e.g., 'wikidata')
2878
+ source_id: Source identifier (e.g., QID)
2879
+ known_for_role: Role/position title
2880
+ known_for_org: Organization name
2881
+ known_for_org_id: Organization internal ID (FK) or None
2882
+ new_embedding: New embedding vector based on updated data
2883
+ from_date: Start date in ISO format or None
2884
+ to_date: End date in ISO format or None
2885
+
2886
+ Returns:
2887
+ True if record was updated, False if not found
2888
+ """
2889
+ conn = self._connect()
2890
+
2891
+ # First get the person's internal ID
2892
+ if self._is_v2:
2893
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
2894
+ row = conn.execute(
2895
+ "SELECT id FROM people WHERE source_id = ? AND source_identifier = ?",
2896
+ (source_type_id, source_id)
2897
+ ).fetchone()
2898
+ else:
2899
+ row = conn.execute(
2900
+ "SELECT id FROM people WHERE source = ? AND source_id = ?",
2901
+ (source, source_id)
2902
+ ).fetchone()
2903
+
2904
+ if not row:
2905
+ return False
2906
+
2907
+ person_id = row[0]
2908
+
2909
+ # Update the person record (including dates)
2910
+ conn.execute("""
2911
+ UPDATE people SET
2912
+ known_for_role = ?, known_for_org = ?, known_for_org_id = ?,
2913
+ from_date = COALESCE(?, from_date, ''),
2914
+ to_date = COALESCE(?, to_date, '')
2915
+ WHERE id = ?
2916
+ """, (known_for_role, known_for_org, known_for_org_id, from_date, to_date, person_id))
2917
+
2918
+ # Update the embedding
2919
+ embedding_bytes = new_embedding.astype(np.float32).tobytes()
2920
+ conn.execute("""
2921
+ UPDATE people_vec SET embedding = ?
2922
+ WHERE rowid = ?
2923
+ """, (embedding_bytes, person_id))
2924
+
2925
+ conn.commit()
2926
+ return True
2927
+
1338
2928
  def search(
1339
2929
  self,
1340
2930
  query_embedding: np.ndarray,
@@ -1366,7 +2956,13 @@ class PersonDatabase:
1366
2956
  if query_norm == 0:
1367
2957
  return []
1368
2958
  query_normalized = query_embedding / query_norm
1369
- query_blob = query_normalized.astype(np.float32).tobytes()
2959
+
2960
+ # Use int8 quantized query if scalar table is available (75% storage savings)
2961
+ if self._has_scalar_table():
2962
+ query_int8 = self._quantize_query(query_normalized)
2963
+ query_blob = query_int8.tobytes()
2964
+ else:
2965
+ query_blob = query_normalized.astype(np.float32).tobytes()
1370
2966
 
1371
2967
  # Stage 1: Text-based pre-filtering (if query_text provided)
1372
2968
  candidate_ids: Optional[set[int]] = None
@@ -1437,13 +3033,26 @@ class PersonDatabase:
1437
3033
  cursor = conn.execute(query, params)
1438
3034
  return set(row["id"] for row in cursor)
1439
3035
 
3036
+ def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
3037
+ """Quantize query embedding to int8 for scalar search."""
3038
+ return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
3039
+
3040
+ def _has_scalar_table(self) -> bool:
3041
+ """Check if scalar embedding table exists."""
3042
+ conn = self._conn
3043
+ assert conn is not None
3044
+ cursor = conn.execute(
3045
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='person_embeddings_scalar'"
3046
+ )
3047
+ return cursor.fetchone() is not None
3048
+
1440
3049
  def _vector_search_filtered(
1441
3050
  self,
1442
3051
  query_blob: bytes,
1443
3052
  candidate_ids: set[int],
1444
3053
  top_k: int,
1445
3054
  ) -> list[tuple[PersonRecord, float]]:
1446
- """Vector search within a filtered set of candidates."""
3055
+ """Vector search within a filtered set of candidates using scalar (int8) embeddings."""
1447
3056
  conn = self._conn
1448
3057
  assert conn is not None
1449
3058
 
@@ -1453,15 +3062,27 @@ class PersonDatabase:
1453
3062
  # Build IN clause for candidate IDs
1454
3063
  placeholders = ",".join("?" * len(candidate_ids))
1455
3064
 
1456
- query = f"""
1457
- SELECT
1458
- e.person_id,
1459
- vec_distance_cosine(e.embedding, ?) as distance
1460
- FROM person_embeddings e
1461
- WHERE e.person_id IN ({placeholders})
1462
- ORDER BY distance
1463
- LIMIT ?
1464
- """
3065
+ # Use scalar embedding table if available (75% storage reduction)
3066
+ if self._has_scalar_table():
3067
+ query = f"""
3068
+ SELECT
3069
+ e.person_id,
3070
+ vec_distance_cosine(e.embedding, vec_int8(?)) as distance
3071
+ FROM person_embeddings_scalar e
3072
+ WHERE e.person_id IN ({placeholders})
3073
+ ORDER BY distance
3074
+ LIMIT ?
3075
+ """
3076
+ else:
3077
+ query = f"""
3078
+ SELECT
3079
+ e.person_id,
3080
+ vec_distance_cosine(e.embedding, ?) as distance
3081
+ FROM person_embeddings e
3082
+ WHERE e.person_id IN ({placeholders})
3083
+ ORDER BY distance
3084
+ LIMIT ?
3085
+ """
1465
3086
 
1466
3087
  cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
1467
3088
 
@@ -1484,18 +3105,29 @@ class PersonDatabase:
1484
3105
  query_blob: bytes,
1485
3106
  top_k: int,
1486
3107
  ) -> list[tuple[PersonRecord, float]]:
1487
- """Full vector search without text pre-filtering."""
3108
+ """Full vector search without text pre-filtering using scalar (int8) embeddings."""
1488
3109
  conn = self._conn
1489
3110
  assert conn is not None
1490
3111
 
1491
- query = """
1492
- SELECT
1493
- person_id,
1494
- vec_distance_cosine(embedding, ?) as distance
1495
- FROM person_embeddings
1496
- ORDER BY distance
1497
- LIMIT ?
1498
- """
3112
+ # Use scalar embedding table if available (75% storage reduction)
3113
+ if self._has_scalar_table():
3114
+ query = """
3115
+ SELECT
3116
+ person_id,
3117
+ vec_distance_cosine(embedding, vec_int8(?)) as distance
3118
+ FROM person_embeddings_scalar
3119
+ ORDER BY distance
3120
+ LIMIT ?
3121
+ """
3122
+ else:
3123
+ query = """
3124
+ SELECT
3125
+ person_id,
3126
+ vec_distance_cosine(embedding, ?) as distance
3127
+ FROM person_embeddings
3128
+ ORDER BY distance
3129
+ LIMIT ?
3130
+ """
1499
3131
  cursor = conn.execute(query, (query_blob, top_k))
1500
3132
 
1501
3133
  results = []
@@ -1515,21 +3147,36 @@ class PersonDatabase:
1515
3147
  conn = self._conn
1516
3148
  assert conn is not None
1517
3149
 
1518
- cursor = conn.execute("""
1519
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1520
- FROM people WHERE id = ?
1521
- """, (person_id,))
3150
+ if self._is_v2:
3151
+ # v2 schema: join view with base table for record
3152
+ cursor = conn.execute("""
3153
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3154
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3155
+ v.birth_date, v.death_date, p.record
3156
+ FROM people_view v
3157
+ JOIN people p ON v.id = p.id
3158
+ WHERE v.id = ?
3159
+ """, (person_id,))
3160
+ else:
3161
+ cursor = conn.execute("""
3162
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3163
+ FROM people WHERE id = ?
3164
+ """, (person_id,))
1522
3165
 
1523
3166
  row = cursor.fetchone()
1524
3167
  if row:
3168
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
1525
3169
  return PersonRecord(
1526
3170
  name=row["name"],
1527
3171
  source=row["source"],
1528
- source_id=row["source_id"],
3172
+ source_id=row[source_id_field],
1529
3173
  country=row["country"] or "",
1530
3174
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1531
3175
  known_for_role=row["known_for_role"] or "",
1532
3176
  known_for_org=row["known_for_org"] or "",
3177
+ known_for_org_id=row["known_for_org_id"], # Can be None
3178
+ birth_date=row["birth_date"] or "",
3179
+ death_date=row["death_date"] or "",
1533
3180
  record=json.loads(row["record"]),
1534
3181
  )
1535
3182
  return None
@@ -1538,22 +3185,37 @@ class PersonDatabase:
1538
3185
  """Get a person record by source and source_id."""
1539
3186
  conn = self._connect()
1540
3187
 
1541
- cursor = conn.execute("""
1542
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1543
- FROM people
1544
- WHERE source = ? AND source_id = ?
1545
- """, (source, source_id))
3188
+ if self._is_v2:
3189
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
3190
+ cursor = conn.execute("""
3191
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3192
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3193
+ v.birth_date, v.death_date, p.record
3194
+ FROM people_view v
3195
+ JOIN people p ON v.id = p.id
3196
+ WHERE p.source_id = ? AND p.source_identifier = ?
3197
+ """, (source_type_id, source_id))
3198
+ else:
3199
+ cursor = conn.execute("""
3200
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3201
+ FROM people
3202
+ WHERE source = ? AND source_id = ?
3203
+ """, (source, source_id))
1546
3204
 
1547
3205
  row = cursor.fetchone()
1548
3206
  if row:
3207
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
1549
3208
  return PersonRecord(
1550
3209
  name=row["name"],
1551
3210
  source=row["source"],
1552
- source_id=row["source_id"],
3211
+ source_id=row[source_id_field],
1553
3212
  country=row["country"] or "",
1554
3213
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1555
3214
  known_for_role=row["known_for_role"] or "",
1556
3215
  known_for_org=row["known_for_org"] or "",
3216
+ known_for_org_id=row["known_for_org_id"], # Can be None
3217
+ birth_date=row["birth_date"] or "",
3218
+ death_date=row["death_date"] or "",
1557
3219
  record=json.loads(row["record"]),
1558
3220
  )
1559
3221
  return None
@@ -1566,12 +3228,32 @@ class PersonDatabase:
1566
3228
  cursor = conn.execute("SELECT COUNT(*) FROM people")
1567
3229
  total = cursor.fetchone()[0]
1568
3230
 
1569
- # Count by person_type
1570
- cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
3231
+ # Count by person_type - handle both v1 and v2 schema
3232
+ if self._is_v2:
3233
+ # v2 schema - join with people_types
3234
+ cursor = conn.execute("""
3235
+ SELECT pt.name as person_type, COUNT(*) as cnt
3236
+ FROM people p
3237
+ JOIN people_types pt ON p.person_type_id = pt.id
3238
+ GROUP BY p.person_type_id
3239
+ """)
3240
+ else:
3241
+ # v1 schema
3242
+ cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
1571
3243
  by_type = {row["person_type"]: row["cnt"] for row in cursor}
1572
3244
 
1573
- # Count by source
1574
- cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
3245
+ # Count by source - handle both v1 and v2 schema
3246
+ if self._is_v2:
3247
+ # v2 schema - join with source_types
3248
+ cursor = conn.execute("""
3249
+ SELECT st.name as source, COUNT(*) as cnt
3250
+ FROM people p
3251
+ JOIN source_types st ON p.source_id = st.id
3252
+ GROUP BY p.source_id
3253
+ """)
3254
+ else:
3255
+ # v1 schema
3256
+ cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
1575
3257
  by_source = {row["source"]: row["cnt"] for row in cursor}
1576
3258
 
1577
3259
  return {
@@ -1580,30 +3262,1564 @@ class PersonDatabase:
1580
3262
  "by_source": by_source,
1581
3263
  }
1582
3264
 
1583
- def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
1584
- """Iterate over all person records, optionally filtered by source."""
3265
+ def ensure_scalar_table_exists(self) -> None:
3266
+ """Create scalar embedding table if it doesn't exist."""
1585
3267
  conn = self._connect()
3268
+ conn.execute(f"""
3269
+ CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
3270
+ person_id INTEGER PRIMARY KEY,
3271
+ embedding int8[{self._embedding_dim}]
3272
+ )
3273
+ """)
3274
+ conn.commit()
3275
+ logger.info("Ensured person_embeddings_scalar table exists")
1586
3276
 
1587
- if source:
1588
- cursor = conn.execute("""
1589
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1590
- FROM people
1591
- WHERE source = ?
1592
- """, (source,))
1593
- else:
3277
+ def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
3278
+ """
3279
+ Yield batches of person IDs that have float32 but missing scalar embeddings.
3280
+
3281
+ Args:
3282
+ batch_size: Number of IDs per batch
3283
+
3284
+ Yields:
3285
+ Lists of person_ids needing scalar embeddings
3286
+ """
3287
+ conn = self._connect()
3288
+
3289
+ # Ensure scalar table exists before querying
3290
+ self.ensure_scalar_table_exists()
3291
+
3292
+ last_id = 0
3293
+ while True:
1594
3294
  cursor = conn.execute("""
1595
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1596
- FROM people
1597
- """)
3295
+ SELECT e.person_id FROM person_embeddings e
3296
+ LEFT JOIN person_embeddings_scalar s ON e.person_id = s.person_id
3297
+ WHERE s.person_id IS NULL AND e.person_id > ?
3298
+ ORDER BY e.person_id
3299
+ LIMIT ?
3300
+ """, (last_id, batch_size))
3301
+
3302
+ rows = cursor.fetchall()
3303
+ if not rows:
3304
+ break
3305
+
3306
+ ids = [row["person_id"] for row in rows]
3307
+ yield ids
3308
+ last_id = ids[-1]
3309
+
3310
+ def get_embeddings_by_ids(self, person_ids: list[int]) -> dict[int, np.ndarray]:
3311
+ """
3312
+ Fetch float32 embeddings for given person IDs.
3313
+
3314
+ Args:
3315
+ person_ids: List of person IDs
1598
3316
 
3317
+ Returns:
3318
+ Dict mapping person_id to float32 embedding array
3319
+ """
3320
+ conn = self._connect()
3321
+
3322
+ if not person_ids:
3323
+ return {}
3324
+
3325
+ placeholders = ",".join("?" * len(person_ids))
3326
+ cursor = conn.execute(f"""
3327
+ SELECT person_id, embedding FROM person_embeddings
3328
+ WHERE person_id IN ({placeholders})
3329
+ """, person_ids)
3330
+
3331
+ result = {}
1599
3332
  for row in cursor:
1600
- yield PersonRecord(
1601
- name=row["name"],
1602
- source=row["source"],
1603
- source_id=row["source_id"],
1604
- country=row["country"] or "",
1605
- person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1606
- known_for_role=row["known_for_role"] or "",
1607
- known_for_org=row["known_for_org"] or "",
1608
- record=json.loads(row["record"]),
1609
- )
3333
+ embedding_blob = row["embedding"]
3334
+ embedding = np.frombuffer(embedding_blob, dtype=np.float32)
3335
+ result[row["person_id"]] = embedding
3336
+ return result
3337
+
3338
+ def insert_scalar_embeddings_batch(self, person_ids: list[int], embeddings: np.ndarray) -> int:
3339
+ """
3340
+ Insert scalar (int8) embeddings for existing people.
3341
+
3342
+ Args:
3343
+ person_ids: List of person IDs
3344
+ embeddings: Matrix of int8 embeddings (N x dim)
3345
+
3346
+ Returns:
3347
+ Number of embeddings inserted
3348
+ """
3349
+ conn = self._connect()
3350
+ count = 0
3351
+
3352
+ for person_id, embedding in zip(person_ids, embeddings):
3353
+ scalar_blob = embedding.astype(np.int8).tobytes()
3354
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
3355
+ conn.execute("""
3356
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
3357
+ VALUES (?, vec_int8(?))
3358
+ """, (person_id, scalar_blob))
3359
+ count += 1
3360
+
3361
+ conn.commit()
3362
+ return count
3363
+
3364
+ def get_scalar_embedding_count(self) -> int:
3365
+ """Get count of scalar embeddings."""
3366
+ conn = self._connect()
3367
+ if not self._has_scalar_table():
3368
+ return 0
3369
+ cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings_scalar")
3370
+ return cursor.fetchone()[0]
3371
+
3372
+ def get_float32_embedding_count(self) -> int:
3373
+ """Get count of float32 embeddings."""
3374
+ conn = self._connect()
3375
+ cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings")
3376
+ return cursor.fetchone()[0]
3377
+
3378
+ def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
3379
+ """
3380
+ Yield batches of (person_id, name) tuples for records missing both float32 and scalar embeddings.
3381
+
3382
+ Args:
3383
+ batch_size: Number of IDs per batch
3384
+
3385
+ Yields:
3386
+ Lists of (person_id, name) tuples needing embeddings generated from scratch
3387
+ """
3388
+ conn = self._connect()
3389
+
3390
+ # Ensure scalar table exists
3391
+ self.ensure_scalar_table_exists()
3392
+
3393
+ last_id = 0
3394
+ while True:
3395
+ cursor = conn.execute("""
3396
+ SELECT p.id, p.name FROM people p
3397
+ LEFT JOIN person_embeddings e ON p.id = e.person_id
3398
+ WHERE e.person_id IS NULL AND p.id > ?
3399
+ ORDER BY p.id
3400
+ LIMIT ?
3401
+ """, (last_id, batch_size))
3402
+
3403
+ rows = cursor.fetchall()
3404
+ if not rows:
3405
+ break
3406
+
3407
+ results = [(row["id"], row["name"]) for row in rows]
3408
+ yield results
3409
+ last_id = results[-1][0]
3410
+
3411
+ def insert_both_embeddings_batch(
3412
+ self,
3413
+ person_ids: list[int],
3414
+ fp32_embeddings: np.ndarray,
3415
+ int8_embeddings: np.ndarray,
3416
+ ) -> int:
3417
+ """
3418
+ Insert both float32 and int8 embeddings for existing people.
3419
+
3420
+ Args:
3421
+ person_ids: List of person IDs
3422
+ fp32_embeddings: Matrix of float32 embeddings (N x dim)
3423
+ int8_embeddings: Matrix of int8 embeddings (N x dim)
3424
+
3425
+ Returns:
3426
+ Number of embeddings inserted
3427
+ """
3428
+ conn = self._connect()
3429
+ count = 0
3430
+
3431
+ for person_id, fp32, int8 in zip(person_ids, fp32_embeddings, int8_embeddings):
3432
+ # Insert float32
3433
+ fp32_blob = fp32.astype(np.float32).tobytes()
3434
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
3435
+ conn.execute("""
3436
+ INSERT INTO person_embeddings (person_id, embedding)
3437
+ VALUES (?, ?)
3438
+ """, (person_id, fp32_blob))
3439
+
3440
+ # Insert int8
3441
+ int8_blob = int8.astype(np.int8).tobytes()
3442
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
3443
+ conn.execute("""
3444
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
3445
+ VALUES (?, vec_int8(?))
3446
+ """, (person_id, int8_blob))
3447
+
3448
+ count += 1
3449
+
3450
+ conn.commit()
3451
+ return count
3452
+
3453
+ def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
3454
+ """
3455
+ Get all source_ids from the people table.
3456
+
3457
+ Useful for resume operations to skip already-imported records.
3458
+
3459
+ Args:
3460
+ source: Optional source filter (e.g., "wikidata")
3461
+
3462
+ Returns:
3463
+ Set of source_id strings (e.g., Q codes for Wikidata)
3464
+ """
3465
+ conn = self._connect()
3466
+
3467
+ if self._is_v2:
3468
+ id_col = "source_identifier"
3469
+ if source:
3470
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
3471
+ cursor = conn.execute(
3472
+ f"SELECT DISTINCT {id_col} FROM people WHERE source_id = ?",
3473
+ (source_type_id,)
3474
+ )
3475
+ else:
3476
+ cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM people")
3477
+ else:
3478
+ if source:
3479
+ cursor = conn.execute(
3480
+ "SELECT DISTINCT source_id FROM people WHERE source = ?",
3481
+ (source,)
3482
+ )
3483
+ else:
3484
+ cursor = conn.execute("SELECT DISTINCT source_id FROM people")
3485
+
3486
+ return {row[0] for row in cursor}
3487
+
3488
+ def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
3489
+ """Iterate over all person records, optionally filtered by source."""
3490
+ conn = self._connect()
3491
+
3492
+ if self._is_v2:
3493
+ if source:
3494
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
3495
+ cursor = conn.execute("""
3496
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3497
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3498
+ v.birth_date, v.death_date, p.record
3499
+ FROM people_view v
3500
+ JOIN people p ON v.id = p.id
3501
+ WHERE p.source_id = ?
3502
+ """, (source_type_id,))
3503
+ else:
3504
+ cursor = conn.execute("""
3505
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3506
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3507
+ v.birth_date, v.death_date, p.record
3508
+ FROM people_view v
3509
+ JOIN people p ON v.id = p.id
3510
+ """)
3511
+ for row in cursor:
3512
+ yield PersonRecord(
3513
+ name=row["name"],
3514
+ source=row["source"],
3515
+ source_id=row["source_identifier"],
3516
+ country=row["country"] or "",
3517
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
3518
+ known_for_role=row["known_for_role"] or "",
3519
+ known_for_org=row["known_for_org"] or "",
3520
+ known_for_org_id=row["known_for_org_id"], # Can be None
3521
+ birth_date=row["birth_date"] or "",
3522
+ death_date=row["death_date"] or "",
3523
+ record=json.loads(row["record"]),
3524
+ )
3525
+ else:
3526
+ if source:
3527
+ cursor = conn.execute("""
3528
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3529
+ FROM people
3530
+ WHERE source = ?
3531
+ """, (source,))
3532
+ else:
3533
+ cursor = conn.execute("""
3534
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3535
+ FROM people
3536
+ """)
3537
+
3538
+ for row in cursor:
3539
+ yield PersonRecord(
3540
+ name=row["name"],
3541
+ source=row["source"],
3542
+ source_id=row["source_id"],
3543
+ country=row["country"] or "",
3544
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
3545
+ known_for_role=row["known_for_role"] or "",
3546
+ known_for_org=row["known_for_org"] or "",
3547
+ known_for_org_id=row["known_for_org_id"], # Can be None
3548
+ birth_date=row["birth_date"] or "",
3549
+ death_date=row["death_date"] or "",
3550
+ record=json.loads(row["record"]),
3551
+ )
3552
+
3553
+ def resolve_qid_labels(
3554
+ self,
3555
+ label_map: dict[str, str],
3556
+ batch_size: int = 1000,
3557
+ ) -> tuple[int, int]:
3558
+ """
3559
+ Update records that have QIDs instead of labels.
3560
+
3561
+ This is called after dump import to resolve any QIDs that were
3562
+ stored because labels weren't available in the cache at import time.
3563
+
3564
+ If resolving would create a duplicate of an existing record with
3565
+ resolved labels, the QID version is deleted instead.
3566
+
3567
+ Args:
3568
+ label_map: Mapping of QID -> label for resolution
3569
+ batch_size: Commit batch size
3570
+
3571
+ Returns:
3572
+ Tuple of (updates, deletes)
3573
+ """
3574
+ conn = self._connect()
3575
+
3576
+ # v2 schema stores QIDs as integers, not text - this method doesn't apply
3577
+ if self._is_v2:
3578
+ logger.debug("Skipping resolve_qid_labels for v2 schema (QIDs stored as integers)")
3579
+ return 0, 0
3580
+
3581
+ # Find all records with QIDs in any field (role or org - these are in unique constraint)
3582
+ # Country is not part of unique constraint so can be updated directly
3583
+ cursor = conn.execute("""
3584
+ SELECT id, source, source_id, country, known_for_role, known_for_org
3585
+ FROM people
3586
+ WHERE (country LIKE 'Q%' AND country GLOB 'Q[0-9]*')
3587
+ OR (known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*')
3588
+ OR (known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*')
3589
+ """)
3590
+ rows = cursor.fetchall()
3591
+
3592
+ updates = 0
3593
+ deletes = 0
3594
+
3595
+ for row in rows:
3596
+ person_id = row["id"]
3597
+ source = row["source"]
3598
+ source_id = row["source_id"]
3599
+ country = row["country"]
3600
+ role = row["known_for_role"]
3601
+ org = row["known_for_org"]
3602
+
3603
+ # Resolve QIDs to labels
3604
+ new_country = label_map.get(country, country) if country.startswith("Q") and country[1:].isdigit() else country
3605
+ new_role = label_map.get(role, role) if role.startswith("Q") and role[1:].isdigit() else role
3606
+ new_org = label_map.get(org, org) if org.startswith("Q") and org[1:].isdigit() else org
3607
+
3608
+ # Skip if nothing changed
3609
+ if new_country == country and new_role == role and new_org == org:
3610
+ continue
3611
+
3612
+ # Check if resolved values would duplicate an existing record
3613
+ # (unique constraint is on source, source_id, known_for_role, known_for_org)
3614
+ if new_role != role or new_org != org:
3615
+ cursor2 = conn.execute("""
3616
+ SELECT id FROM people
3617
+ WHERE source = ? AND source_id = ? AND known_for_role = ? AND known_for_org = ?
3618
+ AND id != ?
3619
+ """, (source, source_id, new_role, new_org, person_id))
3620
+ existing = cursor2.fetchone()
3621
+
3622
+ if existing:
3623
+ # Duplicate would exist - delete the QID version
3624
+ conn.execute("DELETE FROM people WHERE id = ?", (person_id,))
3625
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
3626
+ deletes += 1
3627
+ logger.debug(f"Deleted duplicate QID record {person_id} (source_id={source_id})")
3628
+ continue
3629
+
3630
+ # No duplicate - update in place
3631
+ conn.execute("""
3632
+ UPDATE people SET country = ?, known_for_role = ?, known_for_org = ?
3633
+ WHERE id = ?
3634
+ """, (new_country, new_role, new_org, person_id))
3635
+ updates += 1
3636
+
3637
+ if (updates + deletes) % batch_size == 0:
3638
+ conn.commit()
3639
+ logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes...")
3640
+
3641
+ conn.commit()
3642
+ logger.info(f"Resolved QID labels: {updates} updates, {deletes} deletes")
3643
+ return updates, deletes
3644
+
3645
+ def get_unresolved_qids(self) -> set[str]:
3646
+ """
3647
+ Get all QIDs that still need resolution in the database.
3648
+
3649
+ Returns:
3650
+ Set of QIDs (starting with 'Q') found in country, role, or org fields
3651
+ """
3652
+ conn = self._connect()
3653
+
3654
+ # v2 schema stores QIDs as integers, not text - this method doesn't apply
3655
+ if self._is_v2:
3656
+ return set()
3657
+
3658
+ qids: set[str] = set()
3659
+
3660
+ # Get QIDs from country field
3661
+ cursor = conn.execute("""
3662
+ SELECT DISTINCT country FROM people
3663
+ WHERE country LIKE 'Q%' AND country GLOB 'Q[0-9]*'
3664
+ """)
3665
+ for row in cursor:
3666
+ qids.add(row["country"])
3667
+
3668
+ # Get QIDs from known_for_role field
3669
+ cursor = conn.execute("""
3670
+ SELECT DISTINCT known_for_role FROM people
3671
+ WHERE known_for_role LIKE 'Q%' AND known_for_role GLOB 'Q[0-9]*'
3672
+ """)
3673
+ for row in cursor:
3674
+ qids.add(row["known_for_role"])
3675
+
3676
+ # Get QIDs from known_for_org field
3677
+ cursor = conn.execute("""
3678
+ SELECT DISTINCT known_for_org FROM people
3679
+ WHERE known_for_org LIKE 'Q%' AND known_for_org GLOB 'Q[0-9]*'
3680
+ """)
3681
+ for row in cursor:
3682
+ qids.add(row["known_for_org"])
3683
+
3684
+ return qids
3685
+
3686
+ def insert_qid_labels(
3687
+ self,
3688
+ label_map: dict[str, str],
3689
+ batch_size: int = 1000,
3690
+ ) -> int:
3691
+ """
3692
+ Insert QID -> label mappings into the lookup table.
3693
+
3694
+ Args:
3695
+ label_map: Mapping of QID -> label
3696
+ batch_size: Commit batch size
3697
+
3698
+ Returns:
3699
+ Number of labels inserted/updated
3700
+ """
3701
+ conn = self._connect()
3702
+ count = 0
3703
+ skipped = 0
3704
+
3705
+ for qid, label in label_map.items():
3706
+ # Skip non-Q IDs (e.g., property IDs like P19)
3707
+ if not qid.startswith("Q"):
3708
+ skipped += 1
3709
+ continue
3710
+
3711
+ # v2 schema stores QID as integer without Q prefix
3712
+ if self._is_v2:
3713
+ try:
3714
+ qid_val: str | int = int(qid[1:])
3715
+ except ValueError:
3716
+ skipped += 1
3717
+ continue
3718
+ else:
3719
+ qid_val = qid
3720
+
3721
+ conn.execute(
3722
+ "INSERT OR REPLACE INTO qid_labels (qid, label) VALUES (?, ?)",
3723
+ (qid_val, label)
3724
+ )
3725
+ count += 1
3726
+
3727
+ if count % batch_size == 0:
3728
+ conn.commit()
3729
+ logger.debug(f"Inserted {count} QID labels...")
3730
+
3731
+ conn.commit()
3732
+ logger.info(f"Inserted {count} QID labels into lookup table")
3733
+ return count
3734
+
3735
+ def get_qid_label(self, qid: str) -> Optional[str]:
3736
+ """
3737
+ Get the label for a QID from the lookup table.
3738
+
3739
+ Args:
3740
+ qid: Wikidata QID (e.g., 'Q30')
3741
+
3742
+ Returns:
3743
+ Label string or None if not found
3744
+ """
3745
+ conn = self._connect()
3746
+
3747
+ # v2 schema stores QID as integer without Q prefix
3748
+ if self._is_v2:
3749
+ qid_val: str | int = int(qid[1:]) if qid.startswith("Q") else int(qid)
3750
+ else:
3751
+ qid_val = qid
3752
+
3753
+ cursor = conn.execute(
3754
+ "SELECT label FROM qid_labels WHERE qid = ?",
3755
+ (qid_val,)
3756
+ )
3757
+ row = cursor.fetchone()
3758
+ return row["label"] if row else None
3759
+
3760
+ def get_all_qid_labels(self) -> dict[str, str]:
3761
+ """
3762
+ Get all QID -> label mappings from the lookup table.
3763
+
3764
+ Returns:
3765
+ Dict mapping QID -> label
3766
+ """
3767
+ conn = self._connect()
3768
+ cursor = conn.execute("SELECT qid, label FROM qid_labels")
3769
+ return {row["qid"]: row["label"] for row in cursor}
3770
+
3771
+ def get_qid_labels_count(self) -> int:
3772
+ """Get the number of QID labels in the lookup table."""
3773
+ conn = self._connect()
3774
+ cursor = conn.execute("SELECT COUNT(*) FROM qid_labels")
3775
+ return cursor.fetchone()[0]
3776
+
3777
+ def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
3778
+ """
3779
+ Canonicalize person records by linking equivalent entries across sources.
3780
+
3781
+ Uses a multi-phase approach:
3782
+ 1. Match by normalized name + same organization (org canonical group)
3783
+ 2. Match by normalized name + overlapping date ranges
3784
+
3785
+ Source priority (lower = more authoritative):
3786
+ - wikidata: 1 (curated, has Q codes)
3787
+ - sec_edgar: 2 (US insider filings)
3788
+ - companies_house: 3 (UK officers)
3789
+
3790
+ Args:
3791
+ batch_size: Number of records to process before committing
3792
+
3793
+ Returns:
3794
+ Stats dict with counts for each matching type
3795
+ """
3796
+ conn = self._connect()
3797
+ stats = {
3798
+ "total_records": 0,
3799
+ "matched_by_org": 0,
3800
+ "matched_by_date": 0,
3801
+ "canonical_groups": 0,
3802
+ "records_in_groups": 0,
3803
+ }
3804
+
3805
+ logger.info("Phase 1: Building person index...")
3806
+
3807
+ # Load all people with their normalized names and org info
3808
+ if self._is_v2:
3809
+ cursor = conn.execute("""
3810
+ SELECT p.id, p.name, p.name_normalized, s.name as source, p.source_identifier as source_id,
3811
+ p.known_for_org, p.known_for_org_id, p.from_date, p.to_date
3812
+ FROM people p
3813
+ JOIN source_types s ON p.source_id = s.id
3814
+ """)
3815
+ else:
3816
+ cursor = conn.execute("""
3817
+ SELECT id, name, name_normalized, source, source_id,
3818
+ known_for_org, known_for_org_id, from_date, to_date
3819
+ FROM people
3820
+ """)
3821
+
3822
+ people: list[dict] = []
3823
+ for row in cursor:
3824
+ people.append({
3825
+ "id": row["id"],
3826
+ "name": row["name"],
3827
+ "name_normalized": row["name_normalized"],
3828
+ "source": row["source"],
3829
+ "source_id": row["source_id"],
3830
+ "known_for_org": row["known_for_org"],
3831
+ "known_for_org_id": row["known_for_org_id"],
3832
+ "from_date": row["from_date"],
3833
+ "to_date": row["to_date"],
3834
+ })
3835
+
3836
+ stats["total_records"] = len(people)
3837
+ logger.info(f"Loaded {len(people)} person records")
3838
+
3839
+ if len(people) == 0:
3840
+ return stats
3841
+
3842
+ # Initialize Union-Find
3843
+ person_ids = [p["id"] for p in people]
3844
+ uf = UnionFind(person_ids)
3845
+
3846
+ # Build indexes for efficient matching
3847
+ # Index by normalized name
3848
+ name_to_people: dict[str, list[dict]] = {}
3849
+ for p in people:
3850
+ name_norm = p["name_normalized"]
3851
+ name_to_people.setdefault(name_norm, []).append(p)
3852
+
3853
+ logger.info("Phase 2: Matching by normalized name + organization...")
3854
+
3855
+ # Match people with same normalized name and same organization
3856
+ for name_norm, same_name in name_to_people.items():
3857
+ if len(same_name) < 2:
3858
+ continue
3859
+
3860
+ # Group by organization (using known_for_org_id if available, else known_for_org)
3861
+ org_groups: dict[str, list[dict]] = {}
3862
+ for p in same_name:
3863
+ org_key = str(p["known_for_org_id"]) if p["known_for_org_id"] else p["known_for_org"]
3864
+ if org_key: # Only group if they have an org
3865
+ org_groups.setdefault(org_key, []).append(p)
3866
+
3867
+ # Union people with same name + same org
3868
+ for org_key, org_people in org_groups.items():
3869
+ if len(org_people) >= 2:
3870
+ first_id = org_people[0]["id"]
3871
+ for p in org_people[1:]:
3872
+ uf.union(first_id, p["id"])
3873
+ stats["matched_by_org"] += 1
3874
+
3875
+ logger.info(f"Phase 2 complete: {stats['matched_by_org']} matches by org")
3876
+
3877
+ logger.info("Phase 3: Matching by normalized name + overlapping dates...")
3878
+
3879
+ # Match people with same normalized name and overlapping date ranges
3880
+ for name_norm, same_name in name_to_people.items():
3881
+ if len(same_name) < 2:
3882
+ continue
3883
+
3884
+ # Skip if already all unified
3885
+ roots = set(uf.find(p["id"]) for p in same_name)
3886
+ if len(roots) == 1:
3887
+ continue
3888
+
3889
+ # Check for overlapping date ranges
3890
+ for i, p1 in enumerate(same_name):
3891
+ for p2 in same_name[i+1:]:
3892
+ # Skip if already in same group
3893
+ if uf.find(p1["id"]) == uf.find(p2["id"]):
3894
+ continue
3895
+
3896
+ # Check date overlap (if both have dates)
3897
+ if p1["from_date"] and p2["from_date"]:
3898
+ # Simple overlap check: if either from_date is before other's to_date
3899
+ p1_from = p1["from_date"]
3900
+ p1_to = p1["to_date"] or "9999-12-31"
3901
+ p2_from = p2["from_date"]
3902
+ p2_to = p2["to_date"] or "9999-12-31"
3903
+
3904
+ # Overlap if: p1_from <= p2_to AND p2_from <= p1_to
3905
+ if p1_from <= p2_to and p2_from <= p1_to:
3906
+ uf.union(p1["id"], p2["id"])
3907
+ stats["matched_by_date"] += 1
3908
+
3909
+ logger.info(f"Phase 3 complete: {stats['matched_by_date']} matches by date")
3910
+
3911
+ logger.info("Phase 4: Applying canonical updates...")
3912
+
3913
+ # Get all groups and select canonical record for each
3914
+ groups = uf.groups()
3915
+
3916
+ # Build id -> source mapping
3917
+ id_to_source = {p["id"]: p["source"] for p in people}
3918
+
3919
+ batch_updates: list[tuple[int, int, int]] = [] # (person_id, canon_id, canon_size)
3920
+
3921
+ for _root, group_ids in groups.items():
3922
+ group_size = len(group_ids)
3923
+
3924
+ if group_size == 1:
3925
+ # Single record is its own canonical
3926
+ person_id = group_ids[0]
3927
+ batch_updates.append((person_id, person_id, 1))
3928
+ else:
3929
+ # Multiple records - pick highest priority source as canonical
3930
+ # Sort by source priority, then by id (for stability)
3931
+ sorted_ids = sorted(
3932
+ group_ids,
3933
+ key=lambda pid: (PERSON_SOURCE_PRIORITY.get(id_to_source[pid], 99), pid)
3934
+ )
3935
+ canon_id = sorted_ids[0]
3936
+ stats["canonical_groups"] += 1
3937
+ stats["records_in_groups"] += group_size
3938
+
3939
+ for person_id in group_ids:
3940
+ batch_updates.append((person_id, canon_id, group_size if person_id == canon_id else 1))
3941
+
3942
+ # Commit in batches
3943
+ if len(batch_updates) >= batch_size:
3944
+ self._apply_person_canon_updates(batch_updates)
3945
+ conn.commit()
3946
+ logger.info(f"Applied {len(batch_updates)} canon updates...")
3947
+ batch_updates = []
3948
+
3949
+ # Final batch
3950
+ if batch_updates:
3951
+ self._apply_person_canon_updates(batch_updates)
3952
+ conn.commit()
3953
+
3954
+ logger.info(f"Canonicalization complete: {stats['canonical_groups']} groups, "
3955
+ f"{stats['records_in_groups']} records in multi-record groups")
3956
+
3957
+ return stats
3958
+
3959
+ def _apply_person_canon_updates(self, updates: list[tuple[int, int, int]]) -> None:
3960
+ """Apply batch of canon updates: (person_id, canon_id, canon_size)."""
3961
+ conn = self._conn
3962
+ assert conn is not None
3963
+
3964
+ for person_id, canon_id, canon_size in updates:
3965
+ conn.execute(
3966
+ "UPDATE people SET canon_id = ?, canon_size = ? WHERE id = ?",
3967
+ (canon_id, canon_size, person_id)
3968
+ )
3969
+
3970
+
3971
+ # =============================================================================
3972
+ # Module-level singletons for new v2 databases
3973
+ # =============================================================================
3974
+
3975
+ _roles_database_instances: dict[str, "RolesDatabase"] = {}
3976
+ _locations_database_instances: dict[str, "LocationsDatabase"] = {}
3977
+
3978
+
3979
+ def get_roles_database(db_path: Optional[str | Path] = None) -> "RolesDatabase":
3980
+ """
3981
+ Get a singleton RolesDatabase instance for the given path.
3982
+
3983
+ Args:
3984
+ db_path: Path to database file
3985
+
3986
+ Returns:
3987
+ Shared RolesDatabase instance
3988
+ """
3989
+ path_key = str(db_path or DEFAULT_DB_PATH)
3990
+ if path_key not in _roles_database_instances:
3991
+ logger.debug(f"Creating new RolesDatabase instance for {path_key}")
3992
+ _roles_database_instances[path_key] = RolesDatabase(db_path=db_path)
3993
+ return _roles_database_instances[path_key]
3994
+
3995
+
3996
+ def get_locations_database(db_path: Optional[str | Path] = None) -> "LocationsDatabase":
3997
+ """
3998
+ Get a singleton LocationsDatabase instance for the given path.
3999
+
4000
+ Args:
4001
+ db_path: Path to database file
4002
+
4003
+ Returns:
4004
+ Shared LocationsDatabase instance
4005
+ """
4006
+ path_key = str(db_path or DEFAULT_DB_PATH)
4007
+ if path_key not in _locations_database_instances:
4008
+ logger.debug(f"Creating new LocationsDatabase instance for {path_key}")
4009
+ _locations_database_instances[path_key] = LocationsDatabase(db_path=db_path)
4010
+ return _locations_database_instances[path_key]
4011
+
4012
+
4013
+ # =============================================================================
4014
+ # ROLES DATABASE (v2)
4015
+ # =============================================================================
4016
+
4017
+
4018
+ class RolesDatabase:
4019
+ """
4020
+ SQLite database for job titles/roles.
4021
+
4022
+ Stores normalized role records with source tracking and supports
4023
+ canonicalization to group equivalent roles (e.g., CEO, Chief Executive).
4024
+ """
4025
+
4026
+ def __init__(self, db_path: Optional[str | Path] = None):
4027
+ """
4028
+ Initialize the roles database.
4029
+
4030
+ Args:
4031
+ db_path: Path to database file (creates if not exists)
4032
+ """
4033
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
4034
+ self._conn: Optional[sqlite3.Connection] = None
4035
+ self._role_cache: dict[str, int] = {} # name_normalized -> role_id
4036
+
4037
+ def _connect(self) -> sqlite3.Connection:
4038
+ """Get or create database connection using shared connection pool."""
4039
+ if self._conn is not None:
4040
+ return self._conn
4041
+
4042
+ self._conn = _get_shared_connection(self._db_path)
4043
+ self._create_tables()
4044
+ return self._conn
4045
+
4046
+ def _create_tables(self) -> None:
4047
+ """Create roles table and indexes."""
4048
+ conn = self._conn
4049
+ assert conn is not None
4050
+
4051
+ # Check if enum tables exist, create and seed if not
4052
+ cursor = conn.execute(
4053
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
4054
+ )
4055
+ if not cursor.fetchone():
4056
+ logger.info("Creating enum tables for v2 schema...")
4057
+ from .schema_v2 import (
4058
+ CREATE_SOURCE_TYPES,
4059
+ CREATE_PEOPLE_TYPES,
4060
+ CREATE_ORGANIZATION_TYPES,
4061
+ CREATE_SIMPLIFIED_LOCATION_TYPES,
4062
+ CREATE_LOCATION_TYPES,
4063
+ )
4064
+ conn.execute(CREATE_SOURCE_TYPES)
4065
+ conn.execute(CREATE_PEOPLE_TYPES)
4066
+ conn.execute(CREATE_ORGANIZATION_TYPES)
4067
+ conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
4068
+ conn.execute(CREATE_LOCATION_TYPES)
4069
+ seed_all_enums(conn)
4070
+
4071
+ # Create roles table
4072
+ conn.execute("""
4073
+ CREATE TABLE IF NOT EXISTS roles (
4074
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
4075
+ qid INTEGER,
4076
+ name TEXT NOT NULL,
4077
+ name_normalized TEXT NOT NULL,
4078
+ source_id INTEGER NOT NULL DEFAULT 4,
4079
+ source_identifier TEXT,
4080
+ record TEXT NOT NULL DEFAULT '{}',
4081
+ canon_id INTEGER DEFAULT NULL,
4082
+ canon_size INTEGER DEFAULT 1,
4083
+ UNIQUE(name_normalized, source_id)
4084
+ )
4085
+ """)
4086
+
4087
+ # Create indexes
4088
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name ON roles(name)")
4089
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name_normalized ON roles(name_normalized)")
4090
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_qid ON roles(qid)")
4091
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_source_id ON roles(source_id)")
4092
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_canon_id ON roles(canon_id)")
4093
+
4094
+ conn.commit()
4095
+
4096
+ def close(self) -> None:
4097
+ """Clear connection reference."""
4098
+ self._conn = None
4099
+
4100
+ def get_or_create(
4101
+ self,
4102
+ name: str,
4103
+ source_id: int = 4, # wikidata
4104
+ qid: Optional[int] = None,
4105
+ source_identifier: Optional[str] = None,
4106
+ ) -> int:
4107
+ """
4108
+ Get or create a role record.
4109
+
4110
+ Args:
4111
+ name: Role/title name
4112
+ source_id: FK to source_types table
4113
+ qid: Optional Wikidata QID as integer
4114
+ source_identifier: Optional source-specific identifier
4115
+
4116
+ Returns:
4117
+ Role ID
4118
+ """
4119
+ if not name:
4120
+ raise ValueError("Role name cannot be empty")
4121
+
4122
+ conn = self._connect()
4123
+ name_normalized = name.lower().strip()
4124
+
4125
+ # Check cache
4126
+ cache_key = f"{name_normalized}:{source_id}"
4127
+ if cache_key in self._role_cache:
4128
+ return self._role_cache[cache_key]
4129
+
4130
+ # Check database
4131
+ cursor = conn.execute(
4132
+ "SELECT id FROM roles WHERE name_normalized = ? AND source_id = ?",
4133
+ (name_normalized, source_id)
4134
+ )
4135
+ row = cursor.fetchone()
4136
+ if row:
4137
+ role_id = row["id"]
4138
+ self._role_cache[cache_key] = role_id
4139
+ return role_id
4140
+
4141
+ # Create new role
4142
+ cursor = conn.execute(
4143
+ """
4144
+ INSERT INTO roles (name, name_normalized, source_id, qid, source_identifier)
4145
+ VALUES (?, ?, ?, ?, ?)
4146
+ """,
4147
+ (name, name_normalized, source_id, qid, source_identifier)
4148
+ )
4149
+ role_id = cursor.lastrowid
4150
+ assert role_id is not None
4151
+ conn.commit()
4152
+
4153
+ self._role_cache[cache_key] = role_id
4154
+ return role_id
4155
+
4156
+ def get_by_id(self, role_id: int) -> Optional[RoleRecord]:
4157
+ """Get a role record by ID."""
4158
+ conn = self._connect()
4159
+
4160
+ cursor = conn.execute(
4161
+ "SELECT id, qid, name, source_id, source_identifier, record FROM roles WHERE id = ?",
4162
+ (role_id,)
4163
+ )
4164
+ row = cursor.fetchone()
4165
+ if row:
4166
+ source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
4167
+ return RoleRecord(
4168
+ name=row["name"],
4169
+ source=source_name,
4170
+ source_id=row["source_identifier"],
4171
+ qid=row["qid"],
4172
+ record=json.loads(row["record"]) if row["record"] else {},
4173
+ )
4174
+ return None
4175
+
4176
+ def search(
4177
+ self,
4178
+ query: str,
4179
+ top_k: int = 10,
4180
+ ) -> list[tuple[int, str, float]]:
4181
+ """
4182
+ Search for roles by name.
4183
+
4184
+ Args:
4185
+ query: Search query
4186
+ top_k: Maximum results to return
4187
+
4188
+ Returns:
4189
+ List of (role_id, role_name, score) tuples
4190
+ """
4191
+ conn = self._connect()
4192
+ query_normalized = query.lower().strip()
4193
+
4194
+ # Exact match first
4195
+ cursor = conn.execute(
4196
+ "SELECT id, name FROM roles WHERE name_normalized = ? LIMIT 1",
4197
+ (query_normalized,)
4198
+ )
4199
+ row = cursor.fetchone()
4200
+ if row:
4201
+ return [(row["id"], row["name"], 1.0)]
4202
+
4203
+ # LIKE match
4204
+ cursor = conn.execute(
4205
+ """
4206
+ SELECT id, name FROM roles
4207
+ WHERE name_normalized LIKE ?
4208
+ ORDER BY length(name)
4209
+ LIMIT ?
4210
+ """,
4211
+ (f"%{query_normalized}%", top_k)
4212
+ )
4213
+
4214
+ results = []
4215
+ for row in cursor:
4216
+ # Simple score based on match quality
4217
+ name_normalized = row["name"].lower()
4218
+ if query_normalized == name_normalized:
4219
+ score = 1.0
4220
+ elif name_normalized.startswith(query_normalized):
4221
+ score = 0.9
4222
+ else:
4223
+ score = 0.7
4224
+ results.append((row["id"], row["name"], score))
4225
+
4226
+ return results
4227
+
4228
+ def get_stats(self) -> dict[str, int]:
4229
+ """Get statistics about the roles table."""
4230
+ conn = self._connect()
4231
+
4232
+ cursor = conn.execute("SELECT COUNT(*) FROM roles")
4233
+ total = cursor.fetchone()[0]
4234
+
4235
+ cursor = conn.execute("SELECT COUNT(*) FROM roles WHERE canon_id IS NOT NULL")
4236
+ canonicalized = cursor.fetchone()[0]
4237
+
4238
+ cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM roles WHERE canon_id IS NOT NULL")
4239
+ groups = cursor.fetchone()[0]
4240
+
4241
+ return {
4242
+ "total_roles": total,
4243
+ "canonicalized": canonicalized,
4244
+ "canonical_groups": groups,
4245
+ }
4246
+
4247
+
4248
+ # =============================================================================
4249
+ # LOCATIONS DATABASE (v2)
4250
+ # =============================================================================
4251
+
4252
+
4253
+ class LocationsDatabase:
4254
+ """
4255
+ SQLite database for geopolitical locations.
4256
+
4257
+ Stores countries, states, cities with hierarchical relationships
4258
+ and type classification. Supports pycountry integration.
4259
+ """
4260
+
4261
+ def __init__(self, db_path: Optional[str | Path] = None):
4262
+ """
4263
+ Initialize the locations database.
4264
+
4265
+ Args:
4266
+ db_path: Path to database file (creates if not exists)
4267
+ """
4268
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
4269
+ self._conn: Optional[sqlite3.Connection] = None
4270
+ self._location_cache: dict[str, int] = {} # lookup_key -> location_id
4271
+ self._location_type_cache: dict[str, int] = {} # type_name -> type_id
4272
+ self._location_type_qid_cache: dict[int, int] = {} # qid -> type_id
4273
+
4274
+ def _connect(self) -> sqlite3.Connection:
4275
+ """Get or create database connection using shared connection pool."""
4276
+ if self._conn is not None:
4277
+ return self._conn
4278
+
4279
+ self._conn = _get_shared_connection(self._db_path)
4280
+ self._create_tables()
4281
+ self._build_caches()
4282
+ return self._conn
4283
+
4284
+ def _create_tables(self) -> None:
4285
+ """Create locations table and indexes."""
4286
+ conn = self._conn
4287
+ assert conn is not None
4288
+
4289
+ # Check if enum tables exist, create and seed if not
4290
+ cursor = conn.execute(
4291
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
4292
+ )
4293
+ if not cursor.fetchone():
4294
+ logger.info("Creating enum tables for v2 schema...")
4295
+ from .schema_v2 import (
4296
+ CREATE_SOURCE_TYPES,
4297
+ CREATE_PEOPLE_TYPES,
4298
+ CREATE_ORGANIZATION_TYPES,
4299
+ CREATE_SIMPLIFIED_LOCATION_TYPES,
4300
+ CREATE_LOCATION_TYPES,
4301
+ )
4302
+ conn.execute(CREATE_SOURCE_TYPES)
4303
+ conn.execute(CREATE_PEOPLE_TYPES)
4304
+ conn.execute(CREATE_ORGANIZATION_TYPES)
4305
+ conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
4306
+ conn.execute(CREATE_LOCATION_TYPES)
4307
+ seed_all_enums(conn)
4308
+
4309
+ # Create locations table
4310
+ conn.execute("""
4311
+ CREATE TABLE IF NOT EXISTS locations (
4312
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
4313
+ qid INTEGER,
4314
+ name TEXT NOT NULL,
4315
+ name_normalized TEXT NOT NULL,
4316
+ source_id INTEGER NOT NULL DEFAULT 4,
4317
+ source_identifier TEXT,
4318
+ parent_ids TEXT,
4319
+ location_type_id INTEGER NOT NULL DEFAULT 2,
4320
+ record TEXT NOT NULL DEFAULT '{}',
4321
+ from_date TEXT DEFAULT NULL,
4322
+ to_date TEXT DEFAULT NULL,
4323
+ canon_id INTEGER DEFAULT NULL,
4324
+ canon_size INTEGER DEFAULT 1,
4325
+ UNIQUE(source_identifier, source_id)
4326
+ )
4327
+ """)
4328
+
4329
+ # Create indexes
4330
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name ON locations(name)")
4331
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name_normalized ON locations(name_normalized)")
4332
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_qid ON locations(qid)")
4333
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_source_id ON locations(source_id)")
4334
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_location_type_id ON locations(location_type_id)")
4335
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_canon_id ON locations(canon_id)")
4336
+
4337
+ conn.commit()
4338
+
4339
+ def _build_caches(self) -> None:
4340
+ """Build lookup caches from database and seed data."""
4341
+ # Load location type caches from seed data
4342
+ self._location_type_cache = dict(LOCATION_TYPE_NAME_TO_ID)
4343
+ self._location_type_qid_cache = dict(LOCATION_TYPE_QID_TO_ID)
4344
+
4345
+ # Load existing locations into cache
4346
+ conn = self._conn
4347
+ if conn:
4348
+ cursor = conn.execute(
4349
+ "SELECT id, name_normalized, source_identifier FROM locations"
4350
+ )
4351
+ for row in cursor:
4352
+ # Cache by normalized name
4353
+ self._location_cache[row["name_normalized"]] = row["id"]
4354
+ # Also cache by source_identifier
4355
+ if row["source_identifier"]:
4356
+ self._location_cache[row["source_identifier"].lower()] = row["id"]
4357
+
4358
+ def close(self) -> None:
4359
+ """Clear connection reference."""
4360
+ self._conn = None
4361
+
4362
+ def get_or_create(
4363
+ self,
4364
+ name: str,
4365
+ location_type_id: int,
4366
+ source_id: int = 4, # wikidata
4367
+ qid: Optional[int] = None,
4368
+ source_identifier: Optional[str] = None,
4369
+ parent_ids: Optional[list[int]] = None,
4370
+ ) -> int:
4371
+ """
4372
+ Get or create a location record.
4373
+
4374
+ Args:
4375
+ name: Location name
4376
+ location_type_id: FK to location_types table
4377
+ source_id: FK to source_types table
4378
+ qid: Optional Wikidata QID as integer
4379
+ source_identifier: Optional source-specific identifier (e.g., "US", "CA")
4380
+ parent_ids: Optional list of parent location IDs
4381
+
4382
+ Returns:
4383
+ Location ID
4384
+ """
4385
+ if not name:
4386
+ raise ValueError("Location name cannot be empty")
4387
+
4388
+ conn = self._connect()
4389
+ name_normalized = name.lower().strip()
4390
+
4391
+ # Check cache by source_identifier first (more specific)
4392
+ if source_identifier:
4393
+ cache_key = source_identifier.lower()
4394
+ if cache_key in self._location_cache:
4395
+ return self._location_cache[cache_key]
4396
+
4397
+ # Check cache by normalized name
4398
+ if name_normalized in self._location_cache:
4399
+ return self._location_cache[name_normalized]
4400
+
4401
+ # Check database
4402
+ if source_identifier:
4403
+ cursor = conn.execute(
4404
+ "SELECT id FROM locations WHERE source_identifier = ? AND source_id = ?",
4405
+ (source_identifier, source_id)
4406
+ )
4407
+ else:
4408
+ cursor = conn.execute(
4409
+ "SELECT id FROM locations WHERE name_normalized = ? AND source_id = ?",
4410
+ (name_normalized, source_id)
4411
+ )
4412
+
4413
+ row = cursor.fetchone()
4414
+ if row:
4415
+ location_id = row["id"]
4416
+ self._location_cache[name_normalized] = location_id
4417
+ if source_identifier:
4418
+ self._location_cache[source_identifier.lower()] = location_id
4419
+ return location_id
4420
+
4421
+ # Create new location
4422
+ parent_ids_json = json.dumps(parent_ids) if parent_ids else None
4423
+ cursor = conn.execute(
4424
+ """
4425
+ INSERT INTO locations
4426
+ (name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids)
4427
+ VALUES (?, ?, ?, ?, ?, ?, ?)
4428
+ """,
4429
+ (name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids_json)
4430
+ )
4431
+ location_id = cursor.lastrowid
4432
+ assert location_id is not None
4433
+ conn.commit()
4434
+
4435
+ self._location_cache[name_normalized] = location_id
4436
+ if source_identifier:
4437
+ self._location_cache[source_identifier.lower()] = location_id
4438
+ return location_id
4439
+
4440
+ def get_or_create_by_qid(
4441
+ self,
4442
+ name: str,
4443
+ wikidata_type_qid: int,
4444
+ source_id: int = 4,
4445
+ entity_qid: Optional[int] = None,
4446
+ source_identifier: Optional[str] = None,
4447
+ parent_ids: Optional[list[int]] = None,
4448
+ ) -> int:
4449
+ """
4450
+ Get or create a location using Wikidata P31 type QID.
4451
+
4452
+ Args:
4453
+ name: Location name
4454
+ wikidata_type_qid: Wikidata instance-of QID (e.g., 515 for city)
4455
+ source_id: FK to source_types table
4456
+ entity_qid: Wikidata QID of the entity itself
4457
+ source_identifier: Optional source-specific identifier
4458
+ parent_ids: Optional list of parent location IDs
4459
+
4460
+ Returns:
4461
+ Location ID
4462
+ """
4463
+ location_type_id = self.get_location_type_id_from_qid(wikidata_type_qid)
4464
+ return self.get_or_create(
4465
+ name=name,
4466
+ location_type_id=location_type_id,
4467
+ source_id=source_id,
4468
+ qid=entity_qid,
4469
+ source_identifier=source_identifier,
4470
+ parent_ids=parent_ids,
4471
+ )
4472
+
4473
+ def get_by_id(self, location_id: int) -> Optional[LocationRecord]:
4474
+ """Get a location record by ID."""
4475
+ conn = self._connect()
4476
+
4477
+ cursor = conn.execute(
4478
+ """
4479
+ SELECT id, qid, name, source_id, source_identifier, location_type_id,
4480
+ parent_ids, from_date, to_date, record
4481
+ FROM locations WHERE id = ?
4482
+ """,
4483
+ (location_id,)
4484
+ )
4485
+ row = cursor.fetchone()
4486
+ if row:
4487
+ source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
4488
+ location_type_id = row["location_type_id"]
4489
+ location_type_name = self._get_location_type_name(location_type_id)
4490
+ simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
4491
+ simplified_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
4492
+
4493
+ parent_ids = json.loads(row["parent_ids"]) if row["parent_ids"] else []
4494
+
4495
+ return LocationRecord(
4496
+ name=row["name"],
4497
+ source=source_name,
4498
+ source_id=row["source_identifier"],
4499
+ qid=row["qid"],
4500
+ location_type=location_type_name,
4501
+ simplified_type=SimplifiedLocationType(simplified_name),
4502
+ parent_ids=parent_ids,
4503
+ from_date=row["from_date"],
4504
+ to_date=row["to_date"],
4505
+ record=json.loads(row["record"]) if row["record"] else {},
4506
+ )
4507
+ return None
4508
+
4509
+ def _get_location_type_name(self, type_id: int) -> str:
4510
+ """Get location type name from ID."""
4511
+ # Reverse lookup in cache
4512
+ for name, id_ in self._location_type_cache.items():
4513
+ if id_ == type_id:
4514
+ return name
4515
+ return "other"
4516
+
4517
+ def get_location_type_id(self, type_name: str) -> int:
4518
+ """Get location_type_id for a type name."""
4519
+ return self._location_type_cache.get(type_name, 36) # default to "other"
4520
+
4521
+ def get_location_type_id_from_qid(self, wikidata_qid: int) -> int:
4522
+ """Get location_type_id from Wikidata P31 QID."""
4523
+ return self._location_type_qid_cache.get(wikidata_qid, 36) # default to "other"
4524
+
4525
+ def get_simplified_type(self, location_type_id: int) -> str:
4526
+ """Get simplified type name for a location_type_id."""
4527
+ simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
4528
+ return SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
4529
+
4530
+ def resolve_region_text(self, text: str) -> Optional[int]:
4531
+ """
4532
+ Resolve a region/country text to a location ID.
4533
+
4534
+ Uses pycountry for country resolution, then falls back to search.
4535
+
4536
+ Args:
4537
+ text: Region text (country code, name, or QID)
4538
+
4539
+ Returns:
4540
+ Location ID or None if not resolved
4541
+ """
4542
+ if not text:
4543
+ return None
4544
+
4545
+ text_lower = text.lower().strip()
4546
+
4547
+ # Check cache first
4548
+ if text_lower in self._location_cache:
4549
+ return self._location_cache[text_lower]
4550
+
4551
+ # Try pycountry resolution
4552
+ alpha_2 = self._resolve_via_pycountry(text)
4553
+ if alpha_2:
4554
+ alpha_2_lower = alpha_2.lower()
4555
+ if alpha_2_lower in self._location_cache:
4556
+ location_id = self._location_cache[alpha_2_lower]
4557
+ self._location_cache[text_lower] = location_id # Cache the input too
4558
+ return location_id
4559
+
4560
+ # Country not in database yet, import it
4561
+ try:
4562
+ country = pycountry.countries.get(alpha_2=alpha_2)
4563
+ if country:
4564
+ country_type_id = self._location_type_cache.get("country", 2)
4565
+ location_id = self.get_or_create(
4566
+ name=country.name,
4567
+ location_type_id=country_type_id,
4568
+ source_id=4, # wikidata
4569
+ source_identifier=alpha_2,
4570
+ )
4571
+ self._location_cache[text_lower] = location_id
4572
+ return location_id
4573
+ except Exception:
4574
+ pass
4575
+
4576
+ return None
4577
+
4578
+ def _resolve_via_pycountry(self, region: str) -> Optional[str]:
4579
+ """Try to resolve region via pycountry."""
4580
+ region_clean = region.strip()
4581
+ if not region_clean:
4582
+ return None
4583
+
4584
+ # Try as 2-letter code
4585
+ if len(region_clean) == 2:
4586
+ country = pycountry.countries.get(alpha_2=region_clean.upper())
4587
+ if country:
4588
+ return country.alpha_2
4589
+
4590
+ # Try as 3-letter code
4591
+ if len(region_clean) == 3:
4592
+ country = pycountry.countries.get(alpha_3=region_clean.upper())
4593
+ if country:
4594
+ return country.alpha_2
4595
+
4596
+ # Try fuzzy search
4597
+ try:
4598
+ matches = pycountry.countries.search_fuzzy(region_clean)
4599
+ if matches:
4600
+ return matches[0].alpha_2
4601
+ except LookupError:
4602
+ pass
4603
+
4604
+ return None
4605
+
4606
+ def import_from_pycountry(self) -> int:
4607
+ """
4608
+ Import all countries from pycountry.
4609
+
4610
+ Returns:
4611
+ Number of locations imported
4612
+ """
4613
+ conn = self._connect()
4614
+ country_type_id = self._location_type_cache.get("country", 2)
4615
+ count = 0
4616
+
4617
+ for country in pycountry.countries:
4618
+ name = country.name
4619
+ alpha_2 = country.alpha_2
4620
+ name_normalized = name.lower()
4621
+
4622
+ # Check if already exists
4623
+ if alpha_2.lower() in self._location_cache:
4624
+ continue
4625
+
4626
+ cursor = conn.execute(
4627
+ """
4628
+ INSERT OR IGNORE INTO locations
4629
+ (name, name_normalized, source_id, source_identifier, location_type_id)
4630
+ VALUES (?, ?, 4, ?, ?)
4631
+ """,
4632
+ (name, name_normalized, alpha_2, country_type_id)
4633
+ )
4634
+
4635
+ if cursor.lastrowid:
4636
+ self._location_cache[name_normalized] = cursor.lastrowid
4637
+ self._location_cache[alpha_2.lower()] = cursor.lastrowid
4638
+ count += 1
4639
+
4640
+ conn.commit()
4641
+ logger.info(f"Imported {count} countries from pycountry")
4642
+ return count
4643
+
4644
+ def search(
4645
+ self,
4646
+ query: str,
4647
+ top_k: int = 10,
4648
+ simplified_type: Optional[str] = None,
4649
+ ) -> list[tuple[int, str, float]]:
4650
+ """
4651
+ Search for locations by name.
4652
+
4653
+ Args:
4654
+ query: Search query
4655
+ top_k: Maximum results to return
4656
+ simplified_type: Optional filter by simplified type (e.g., "country", "city")
4657
+
4658
+ Returns:
4659
+ List of (location_id, location_name, score) tuples
4660
+ """
4661
+ conn = self._connect()
4662
+ query_normalized = query.lower().strip()
4663
+
4664
+ # Build query with optional type filter
4665
+ if simplified_type:
4666
+ # Get all location_type_ids for this simplified type
4667
+ simplified_id = {
4668
+ name: id_ for id_, name in SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.items()
4669
+ }.get(simplified_type)
4670
+ if simplified_id:
4671
+ type_ids = [
4672
+ type_id for type_id, simp_id in LOCATION_TYPE_TO_SIMPLIFIED.items()
4673
+ if simp_id == simplified_id
4674
+ ]
4675
+ if type_ids:
4676
+ placeholders = ",".join("?" * len(type_ids))
4677
+ cursor = conn.execute(
4678
+ f"""
4679
+ SELECT id, name FROM locations
4680
+ WHERE name_normalized LIKE ? AND location_type_id IN ({placeholders})
4681
+ ORDER BY length(name)
4682
+ LIMIT ?
4683
+ """,
4684
+ [f"%{query_normalized}%"] + type_ids + [top_k]
4685
+ )
4686
+ else:
4687
+ return []
4688
+ else:
4689
+ return []
4690
+ else:
4691
+ cursor = conn.execute(
4692
+ """
4693
+ SELECT id, name FROM locations
4694
+ WHERE name_normalized LIKE ?
4695
+ ORDER BY length(name)
4696
+ LIMIT ?
4697
+ """,
4698
+ (f"%{query_normalized}%", top_k)
4699
+ )
4700
+
4701
+ results = []
4702
+ for row in cursor:
4703
+ name_normalized = row["name"].lower()
4704
+ if query_normalized == name_normalized:
4705
+ score = 1.0
4706
+ elif name_normalized.startswith(query_normalized):
4707
+ score = 0.9
4708
+ else:
4709
+ score = 0.7
4710
+ results.append((row["id"], row["name"], score))
4711
+
4712
+ return results
4713
+
4714
+ def get_stats(self) -> dict[str, Any]:
4715
+ """Get statistics about the locations table."""
4716
+ conn = self._connect()
4717
+
4718
+ cursor = conn.execute("SELECT COUNT(*) FROM locations")
4719
+ total = cursor.fetchone()[0]
4720
+
4721
+ cursor = conn.execute("SELECT COUNT(*) FROM locations WHERE canon_id IS NOT NULL")
4722
+ canonicalized = cursor.fetchone()[0]
4723
+
4724
+ cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM locations WHERE canon_id IS NOT NULL")
4725
+ groups = cursor.fetchone()[0]
4726
+
4727
+ # Count by simplified type
4728
+ by_type: dict[str, int] = {}
4729
+ cursor = conn.execute("""
4730
+ SELECT lt.simplified_id, COUNT(*) as cnt
4731
+ FROM locations l
4732
+ JOIN location_types lt ON l.location_type_id = lt.id
4733
+ GROUP BY lt.simplified_id
4734
+ """)
4735
+ for row in cursor:
4736
+ type_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(row["simplified_id"], "other")
4737
+ by_type[type_name] = row["cnt"]
4738
+
4739
+ return {
4740
+ "total_locations": total,
4741
+ "canonicalized": canonicalized,
4742
+ "canonical_groups": groups,
4743
+ "by_type": by_type,
4744
+ }
4745
+
4746
+ def insert_batch(self, records: list[LocationRecord]) -> int:
4747
+ """
4748
+ Insert a batch of location records.
4749
+
4750
+ Args:
4751
+ records: List of LocationRecord objects to insert
4752
+
4753
+ Returns:
4754
+ Number of records inserted
4755
+ """
4756
+ if not records:
4757
+ return 0
4758
+
4759
+ conn = self._connect()
4760
+ inserted = 0
4761
+
4762
+ for record in records:
4763
+ name_normalized = record.name.lower().strip()
4764
+ source_identifier = record.source_id # Q code in source_id field
4765
+
4766
+ # Check cache first
4767
+ cache_key = source_identifier.lower() if source_identifier else name_normalized
4768
+ if cache_key in self._location_cache:
4769
+ continue
4770
+
4771
+ # Get location_type_id from type name
4772
+ location_type_id = self._location_type_cache.get(record.location_type, 36) # default "other"
4773
+ source_id = SOURCE_NAME_TO_ID.get(record.source, 4) # default wikidata
4774
+
4775
+ parent_ids_json = json.dumps(record.parent_ids) if record.parent_ids else None
4776
+ record_json = json.dumps(record.record) if record.record else "{}"
4777
+
4778
+ try:
4779
+ cursor = conn.execute(
4780
+ """
4781
+ INSERT OR IGNORE INTO locations
4782
+ (name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids, record, from_date, to_date)
4783
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
4784
+ """,
4785
+ (
4786
+ record.name,
4787
+ name_normalized,
4788
+ source_id,
4789
+ source_identifier,
4790
+ record.qid,
4791
+ location_type_id,
4792
+ parent_ids_json,
4793
+ record_json,
4794
+ record.from_date,
4795
+ record.to_date,
4796
+ )
4797
+ )
4798
+ if cursor.lastrowid:
4799
+ self._location_cache[name_normalized] = cursor.lastrowid
4800
+ if source_identifier:
4801
+ self._location_cache[source_identifier.lower()] = cursor.lastrowid
4802
+ inserted += 1
4803
+ except Exception as e:
4804
+ logger.warning(f"Failed to insert location {record.name}: {e}")
4805
+
4806
+ conn.commit()
4807
+ return inserted
4808
+
4809
+ def get_all_source_ids(self, source: str = "wikidata") -> set[str]:
4810
+ """
4811
+ Get all source_identifiers for a given source.
4812
+
4813
+ Args:
4814
+ source: Source name (e.g., "wikidata")
4815
+
4816
+ Returns:
4817
+ Set of source_identifiers
4818
+ """
4819
+ conn = self._connect()
4820
+ source_id = SOURCE_NAME_TO_ID.get(source, 4)
4821
+ cursor = conn.execute(
4822
+ "SELECT source_identifier FROM locations WHERE source_id = ? AND source_identifier IS NOT NULL",
4823
+ (source_id,)
4824
+ )
4825
+ return {row["source_identifier"] for row in cursor}