corp-extractor 0.9.3__py3-none-any.whl → 0.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,18 +12,41 @@ import re
12
12
  import sqlite3
13
13
  import time
14
14
  from pathlib import Path
15
- from typing import Iterator, Optional
15
+ from typing import Any, Iterator, Optional
16
16
 
17
17
  import numpy as np
18
18
  import pycountry
19
19
  import sqlite_vec
20
20
 
21
- from .models import CompanyRecord, DatabaseStats, EntityType, PersonRecord, PersonType
21
+ from .models import (
22
+ CompanyRecord,
23
+ DatabaseStats,
24
+ EntityType,
25
+ LocationRecord,
26
+ PersonRecord,
27
+ PersonType,
28
+ RoleRecord,
29
+ SimplifiedLocationType,
30
+ )
31
+ from .seed_data import (
32
+ LOCATION_TYPE_NAME_TO_ID,
33
+ LOCATION_TYPE_QID_TO_ID,
34
+ LOCATION_TYPE_TO_SIMPLIFIED,
35
+ ORG_TYPE_ID_TO_NAME,
36
+ ORG_TYPE_NAME_TO_ID,
37
+ PEOPLE_TYPE_ID_TO_NAME,
38
+ PEOPLE_TYPE_NAME_TO_ID,
39
+ SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME,
40
+ SOURCE_ID_TO_NAME,
41
+ SOURCE_NAME_TO_ID,
42
+ seed_all_enums,
43
+ seed_pycountry_locations,
44
+ )
22
45
 
23
46
  logger = logging.getLogger(__name__)
24
47
 
25
48
  # Default database location
26
- DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
49
+ DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities-v2.db"
27
50
 
28
51
  # Module-level shared connections by path (both databases share the same connection)
29
52
  _shared_connections: dict[str, sqlite3.Connection] = {}
@@ -500,6 +523,7 @@ class OrganizationDatabase:
500
523
  self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
501
524
  self._embedding_dim = embedding_dim
502
525
  self._conn: Optional[sqlite3.Connection] = None
526
+ self._is_v2: Optional[bool] = None # Detected on first connect
503
527
 
504
528
  def _ensure_dir(self) -> None:
505
529
  """Ensure database directory exists."""
@@ -512,11 +536,27 @@ class OrganizationDatabase:
512
536
 
513
537
  self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
514
538
 
515
- # Create tables (idempotent)
516
- self._create_tables()
539
+ # Detect schema version BEFORE creating tables
540
+ # v2 has entity_type_id (FK) instead of entity_type (TEXT)
541
+ if self._is_v2 is None:
542
+ cursor = self._conn.execute("PRAGMA table_info(organizations)")
543
+ columns = {row["name"] for row in cursor}
544
+ self._is_v2 = "entity_type_id" in columns
545
+ if self._is_v2:
546
+ logger.debug("Detected v2 schema for organizations")
547
+
548
+ # Create tables (idempotent) - only for v1 schema or fresh databases
549
+ # v2 databases already have their schema from migration
550
+ if not self._is_v2:
551
+ self._create_tables()
517
552
 
518
553
  return self._conn
519
554
 
555
+ @property
556
+ def _org_table(self) -> str:
557
+ """Return table/view name for organization queries needing text fields."""
558
+ return "organizations_view" if self._is_v2 else "organizations"
559
+
520
560
  def _create_tables(self) -> None:
521
561
  """Create database tables including sqlite-vec virtual table."""
522
562
  conn = self._conn
@@ -591,7 +631,7 @@ class OrganizationDatabase:
591
631
  conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
592
632
  conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_canon_id ON organizations(canon_id)")
593
633
 
594
- # Create sqlite-vec virtual table for embeddings
634
+ # Create sqlite-vec virtual table for embeddings (float32)
595
635
  # vec0 is the recommended virtual table type
596
636
  conn.execute(f"""
597
637
  CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings USING vec0(
@@ -600,19 +640,34 @@ class OrganizationDatabase:
600
640
  )
601
641
  """)
602
642
 
643
+ # Create sqlite-vec virtual table for scalar embeddings (int8)
644
+ # Provides 75% storage reduction with ~92% recall at top-100
645
+ conn.execute(f"""
646
+ CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
647
+ org_id INTEGER PRIMARY KEY,
648
+ embedding int8[{self._embedding_dim}]
649
+ )
650
+ """)
651
+
603
652
  conn.commit()
604
653
 
605
654
  def close(self) -> None:
606
655
  """Clear connection reference (shared connection remains open)."""
607
656
  self._conn = None
608
657
 
609
- def insert(self, record: CompanyRecord, embedding: np.ndarray) -> int:
658
+ def insert(
659
+ self,
660
+ record: CompanyRecord,
661
+ embedding: np.ndarray,
662
+ scalar_embedding: Optional[np.ndarray] = None,
663
+ ) -> int:
610
664
  """
611
665
  Insert an organization record with its embedding.
612
666
 
613
667
  Args:
614
668
  record: Organization record to insert
615
- embedding: Embedding vector for the organization name
669
+ embedding: Embedding vector for the organization name (float32)
670
+ scalar_embedding: Optional int8 scalar embedding for compact storage
616
671
 
617
672
  Returns:
618
673
  Row ID of inserted record
@@ -642,7 +697,7 @@ class OrganizationDatabase:
642
697
  row_id = cursor.lastrowid
643
698
  assert row_id is not None
644
699
 
645
- # Insert embedding into vec table
700
+ # Insert embedding into vec table (float32)
646
701
  # sqlite-vec virtual tables don't support INSERT OR REPLACE, so delete first
647
702
  embedding_blob = embedding.astype(np.float32).tobytes()
648
703
  conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
@@ -651,6 +706,15 @@ class OrganizationDatabase:
651
706
  VALUES (?, ?)
652
707
  """, (row_id, embedding_blob))
653
708
 
709
+ # Insert scalar embedding if provided (int8)
710
+ if scalar_embedding is not None:
711
+ scalar_blob = scalar_embedding.astype(np.int8).tobytes()
712
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
713
+ conn.execute("""
714
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
715
+ VALUES (?, vec_int8(?))
716
+ """, (row_id, scalar_blob))
717
+
654
718
  conn.commit()
655
719
  return row_id
656
720
 
@@ -659,14 +723,16 @@ class OrganizationDatabase:
659
723
  records: list[CompanyRecord],
660
724
  embeddings: np.ndarray,
661
725
  batch_size: int = 1000,
726
+ scalar_embeddings: Optional[np.ndarray] = None,
662
727
  ) -> int:
663
728
  """
664
729
  Insert multiple organization records with embeddings.
665
730
 
666
731
  Args:
667
732
  records: List of organization records
668
- embeddings: Matrix of embeddings (N x dim)
733
+ embeddings: Matrix of embeddings (N x dim) - float32
669
734
  batch_size: Commit batch size
735
+ scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
670
736
 
671
737
  Returns:
672
738
  Number of records inserted
@@ -674,25 +740,54 @@ class OrganizationDatabase:
674
740
  conn = self._connect()
675
741
  count = 0
676
742
 
677
- for record, embedding in zip(records, embeddings):
743
+ for i, (record, embedding) in enumerate(zip(records, embeddings)):
678
744
  record_json = json.dumps(record.record)
679
745
  name_normalized = _normalize_name(record.name)
680
746
 
681
- cursor = conn.execute("""
682
- INSERT OR REPLACE INTO organizations
683
- (name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
684
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
685
- """, (
686
- record.name,
687
- name_normalized,
688
- record.source,
689
- record.source_id,
690
- record.region,
691
- record.entity_type.value,
692
- record.from_date or "",
693
- record.to_date or "",
694
- record_json,
695
- ))
747
+ if self._is_v2:
748
+ # v2 schema: use FK IDs instead of TEXT columns
749
+ source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
750
+ entity_type_id = ORG_TYPE_NAME_TO_ID.get(record.entity_type.value, 17) # 17 = unknown
751
+
752
+ # Resolve region to location_id if provided
753
+ region_id = None
754
+ if record.region:
755
+ # Use locations database to resolve region
756
+ locations_db = get_locations_database(db_path=self._db_path)
757
+ region_id = locations_db.resolve_region_text(record.region)
758
+
759
+ cursor = conn.execute("""
760
+ INSERT OR REPLACE INTO organizations
761
+ (name, name_normalized, source_id, source_identifier, region_id, entity_type_id, from_date, to_date, record)
762
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
763
+ """, (
764
+ record.name,
765
+ name_normalized,
766
+ source_type_id,
767
+ record.source_id,
768
+ region_id,
769
+ entity_type_id,
770
+ record.from_date or "",
771
+ record.to_date or "",
772
+ record_json,
773
+ ))
774
+ else:
775
+ # v1 schema: use TEXT columns
776
+ cursor = conn.execute("""
777
+ INSERT OR REPLACE INTO organizations
778
+ (name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
779
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
780
+ """, (
781
+ record.name,
782
+ name_normalized,
783
+ record.source,
784
+ record.source_id,
785
+ record.region,
786
+ record.entity_type.value,
787
+ record.from_date or "",
788
+ record.to_date or "",
789
+ record_json,
790
+ ))
696
791
 
697
792
  row_id = cursor.lastrowid
698
793
  assert row_id is not None
@@ -705,6 +800,15 @@ class OrganizationDatabase:
705
800
  VALUES (?, ?)
706
801
  """, (row_id, embedding_blob))
707
802
 
803
+ # Insert scalar embedding if provided (int8)
804
+ if scalar_embeddings is not None:
805
+ scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
806
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
807
+ conn.execute("""
808
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
809
+ VALUES (?, vec_int8(?))
810
+ """, (row_id, scalar_blob))
811
+
708
812
  count += 1
709
813
 
710
814
  if count % batch_size == 0:
@@ -750,7 +854,13 @@ class OrganizationDatabase:
750
854
  if query_norm == 0:
751
855
  return []
752
856
  query_normalized = query_embedding / query_norm
753
- query_blob = query_normalized.astype(np.float32).tobytes()
857
+
858
+ # Use int8 quantized query if scalar table is available (75% storage savings)
859
+ if self._has_scalar_table():
860
+ query_int8 = self._quantize_query(query_normalized)
861
+ query_blob = query_int8.tobytes()
862
+ else:
863
+ query_blob = query_normalized.astype(np.float32).tobytes()
754
864
 
755
865
  # Stage 1: Text-based pre-filtering (if query_text provided)
756
866
  candidate_ids: Optional[set[int]] = None
@@ -980,6 +1090,19 @@ class OrganizationDatabase:
980
1090
  cursor = conn.execute(query, params)
981
1091
  return set(row["id"] for row in cursor)
982
1092
 
1093
+ def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
1094
+ """Quantize query embedding to int8 for scalar search."""
1095
+ return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
1096
+
1097
+ def _has_scalar_table(self) -> bool:
1098
+ """Check if scalar embedding table exists."""
1099
+ conn = self._conn
1100
+ assert conn is not None
1101
+ cursor = conn.execute(
1102
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='organization_embeddings_scalar'"
1103
+ )
1104
+ return cursor.fetchone() is not None
1105
+
983
1106
  def _vector_search_filtered(
984
1107
  self,
985
1108
  query_blob: bytes,
@@ -987,7 +1110,7 @@ class OrganizationDatabase:
987
1110
  top_k: int,
988
1111
  source_filter: Optional[str],
989
1112
  ) -> list[tuple[CompanyRecord, float]]:
990
- """Vector search within a filtered set of candidates."""
1113
+ """Vector search within a filtered set of candidates using scalar (int8) embeddings."""
991
1114
  conn = self._conn
992
1115
  assert conn is not None
993
1116
 
@@ -997,18 +1120,29 @@ class OrganizationDatabase:
997
1120
  # Build IN clause for candidate IDs
998
1121
  placeholders = ",".join("?" * len(candidate_ids))
999
1122
 
1000
- # Query sqlite-vec with KNN search, filtered by candidate IDs
1001
- # Using distance function - lower is more similar for L2
1002
- # We'll use cosine distance
1003
- query = f"""
1004
- SELECT
1005
- e.org_id,
1006
- vec_distance_cosine(e.embedding, ?) as distance
1007
- FROM organization_embeddings e
1008
- WHERE e.org_id IN ({placeholders})
1009
- ORDER BY distance
1010
- LIMIT ?
1011
- """
1123
+ # Use scalar embedding table if available (75% storage reduction)
1124
+ if self._has_scalar_table():
1125
+ # Query uses int8 embeddings with vec_int8() wrapper
1126
+ query = f"""
1127
+ SELECT
1128
+ e.org_id,
1129
+ vec_distance_cosine(e.embedding, vec_int8(?)) as distance
1130
+ FROM organization_embeddings_scalar e
1131
+ WHERE e.org_id IN ({placeholders})
1132
+ ORDER BY distance
1133
+ LIMIT ?
1134
+ """
1135
+ else:
1136
+ # Fall back to float32 embeddings
1137
+ query = f"""
1138
+ SELECT
1139
+ e.org_id,
1140
+ vec_distance_cosine(e.embedding, ?) as distance
1141
+ FROM organization_embeddings e
1142
+ WHERE e.org_id IN ({placeholders})
1143
+ ORDER BY distance
1144
+ LIMIT ?
1145
+ """
1012
1146
 
1013
1147
  cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
1014
1148
 
@@ -1035,33 +1169,58 @@ class OrganizationDatabase:
1035
1169
  top_k: int,
1036
1170
  source_filter: Optional[str],
1037
1171
  ) -> list[tuple[CompanyRecord, float]]:
1038
- """Full vector search without text pre-filtering."""
1172
+ """Full vector search without text pre-filtering using scalar (int8) embeddings."""
1039
1173
  conn = self._conn
1040
1174
  assert conn is not None
1041
1175
 
1176
+ # Use scalar embedding table if available (75% storage reduction)
1177
+ use_scalar = self._has_scalar_table()
1178
+
1042
1179
  # KNN search with sqlite-vec
1043
1180
  if source_filter:
1044
1181
  # Need to join with organizations table for source filter
1045
- query = """
1046
- SELECT
1047
- e.org_id,
1048
- vec_distance_cosine(e.embedding, ?) as distance
1049
- FROM organization_embeddings e
1050
- JOIN organizations c ON e.org_id = c.id
1051
- WHERE c.source = ?
1052
- ORDER BY distance
1053
- LIMIT ?
1054
- """
1182
+ if use_scalar:
1183
+ query = """
1184
+ SELECT
1185
+ e.org_id,
1186
+ vec_distance_cosine(e.embedding, vec_int8(?)) as distance
1187
+ FROM organization_embeddings_scalar e
1188
+ JOIN organizations c ON e.org_id = c.id
1189
+ WHERE c.source = ?
1190
+ ORDER BY distance
1191
+ LIMIT ?
1192
+ """
1193
+ else:
1194
+ query = """
1195
+ SELECT
1196
+ e.org_id,
1197
+ vec_distance_cosine(e.embedding, ?) as distance
1198
+ FROM organization_embeddings e
1199
+ JOIN organizations c ON e.org_id = c.id
1200
+ WHERE c.source = ?
1201
+ ORDER BY distance
1202
+ LIMIT ?
1203
+ """
1055
1204
  cursor = conn.execute(query, (query_blob, source_filter, top_k))
1056
1205
  else:
1057
- query = """
1058
- SELECT
1059
- org_id,
1060
- vec_distance_cosine(embedding, ?) as distance
1061
- FROM organization_embeddings
1062
- ORDER BY distance
1063
- LIMIT ?
1064
- """
1206
+ if use_scalar:
1207
+ query = """
1208
+ SELECT
1209
+ org_id,
1210
+ vec_distance_cosine(embedding, vec_int8(?)) as distance
1211
+ FROM organization_embeddings_scalar
1212
+ ORDER BY distance
1213
+ LIMIT ?
1214
+ """
1215
+ else:
1216
+ query = """
1217
+ SELECT
1218
+ org_id,
1219
+ vec_distance_cosine(embedding, ?) as distance
1220
+ FROM organization_embeddings
1221
+ ORDER BY distance
1222
+ LIMIT ?
1223
+ """
1065
1224
  cursor = conn.execute(query, (query_blob, top_k))
1066
1225
 
1067
1226
  results = []
@@ -1081,10 +1240,19 @@ class OrganizationDatabase:
1081
1240
  conn = self._conn
1082
1241
  assert conn is not None
1083
1242
 
1084
- cursor = conn.execute("""
1085
- SELECT id, name, source, source_id, region, entity_type, record, canon_id
1086
- FROM organizations WHERE id = ?
1087
- """, (org_id,))
1243
+ if self._is_v2:
1244
+ # v2 schema: use view for text fields, but need record from base table
1245
+ cursor = conn.execute("""
1246
+ SELECT v.id, v.name, v.source, v.source_identifier, v.region, v.entity_type, v.canon_id, o.record
1247
+ FROM organizations_view v
1248
+ JOIN organizations o ON v.id = o.id
1249
+ WHERE v.id = ?
1250
+ """, (org_id,))
1251
+ else:
1252
+ cursor = conn.execute("""
1253
+ SELECT id, name, source, source_id, region, entity_type, record, canon_id
1254
+ FROM organizations WHERE id = ?
1255
+ """, (org_id,))
1088
1256
 
1089
1257
  row = cursor.fetchone()
1090
1258
  if row:
@@ -1092,10 +1260,11 @@ class OrganizationDatabase:
1092
1260
  # Add db_id and canon_id to record dict for canon-aware search
1093
1261
  record_data["db_id"] = row["id"]
1094
1262
  record_data["canon_id"] = row["canon_id"]
1263
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
1095
1264
  return CompanyRecord(
1096
1265
  name=row["name"],
1097
1266
  source=row["source"],
1098
- source_id=row["source_id"],
1267
+ source_id=row[source_id_field],
1099
1268
  region=row["region"] or "",
1100
1269
  entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1101
1270
  record=record_data,
@@ -1106,18 +1275,29 @@ class OrganizationDatabase:
1106
1275
  """Get an organization record by source and source_id."""
1107
1276
  conn = self._connect()
1108
1277
 
1109
- cursor = conn.execute("""
1110
- SELECT name, source, source_id, region, entity_type, record
1111
- FROM organizations
1112
- WHERE source = ? AND source_id = ?
1113
- """, (source, source_id))
1278
+ if self._is_v2:
1279
+ # v2 schema: join view with base table for record
1280
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1281
+ cursor = conn.execute("""
1282
+ SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
1283
+ FROM organizations_view v
1284
+ JOIN organizations o ON v.id = o.id
1285
+ WHERE o.source_id = ? AND o.source_identifier = ?
1286
+ """, (source_type_id, source_id))
1287
+ else:
1288
+ cursor = conn.execute("""
1289
+ SELECT name, source, source_id, region, entity_type, record
1290
+ FROM organizations
1291
+ WHERE source = ? AND source_id = ?
1292
+ """, (source, source_id))
1114
1293
 
1115
1294
  row = cursor.fetchone()
1116
1295
  if row:
1296
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
1117
1297
  return CompanyRecord(
1118
1298
  name=row["name"],
1119
1299
  source=row["source"],
1120
- source_id=row["source_id"],
1300
+ source_id=row[source_id_field],
1121
1301
  region=row["region"] or "",
1122
1302
  entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1123
1303
  record=json.loads(row["record"]),
@@ -1128,10 +1308,17 @@ class OrganizationDatabase:
1128
1308
  """Get the internal database ID for an organization by source and source_id."""
1129
1309
  conn = self._connect()
1130
1310
 
1131
- cursor = conn.execute("""
1132
- SELECT id FROM organizations
1133
- WHERE source = ? AND source_id = ?
1134
- """, (source, source_id))
1311
+ if self._is_v2:
1312
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1313
+ cursor = conn.execute("""
1314
+ SELECT id FROM organizations
1315
+ WHERE source_id = ? AND source_identifier = ?
1316
+ """, (source_type_id, source_id))
1317
+ else:
1318
+ cursor = conn.execute("""
1319
+ SELECT id FROM organizations
1320
+ WHERE source = ? AND source_id = ?
1321
+ """, (source, source_id))
1135
1322
 
1136
1323
  row = cursor.fetchone()
1137
1324
  if row:
@@ -1146,8 +1333,18 @@ class OrganizationDatabase:
1146
1333
  cursor = conn.execute("SELECT COUNT(*) FROM organizations")
1147
1334
  total = cursor.fetchone()[0]
1148
1335
 
1149
- # Count by source
1150
- cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
1336
+ # Count by source - handle both v1 and v2 schema
1337
+ if self._is_v2:
1338
+ # v2 schema - join with source_types
1339
+ cursor = conn.execute("""
1340
+ SELECT st.name as source, COUNT(*) as cnt
1341
+ FROM organizations o
1342
+ JOIN source_types st ON o.source_id = st.id
1343
+ GROUP BY o.source_id
1344
+ """)
1345
+ else:
1346
+ # v1 schema
1347
+ cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
1151
1348
  by_source = {row["source"]: row["cnt"] for row in cursor}
1152
1349
 
1153
1350
  # Database file size
@@ -1167,20 +1364,31 @@ class OrganizationDatabase:
1167
1364
  Useful for resume operations to skip already-imported records.
1168
1365
 
1169
1366
  Args:
1170
- source: Optional source filter (e.g., "wikipedia" for Wikidata orgs)
1367
+ source: Optional source filter (e.g., "wikidata" for Wikidata orgs)
1171
1368
 
1172
1369
  Returns:
1173
1370
  Set of source_id strings (e.g., Q codes for Wikidata)
1174
1371
  """
1175
1372
  conn = self._connect()
1176
1373
 
1177
- if source:
1178
- cursor = conn.execute(
1179
- "SELECT DISTINCT source_id FROM organizations WHERE source = ?",
1180
- (source,)
1181
- )
1374
+ if self._is_v2:
1375
+ id_col = "source_identifier"
1376
+ if source:
1377
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1378
+ cursor = conn.execute(
1379
+ f"SELECT DISTINCT {id_col} FROM organizations WHERE source_id = ?",
1380
+ (source_type_id,)
1381
+ )
1382
+ else:
1383
+ cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM organizations")
1182
1384
  else:
1183
- cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
1385
+ if source:
1386
+ cursor = conn.execute(
1387
+ "SELECT DISTINCT source_id FROM organizations WHERE source = ?",
1388
+ (source,)
1389
+ )
1390
+ else:
1391
+ cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
1184
1392
 
1185
1393
  return {row[0] for row in cursor}
1186
1394
 
@@ -1188,27 +1396,51 @@ class OrganizationDatabase:
1188
1396
  """Iterate over all records, optionally filtered by source."""
1189
1397
  conn = self._connect()
1190
1398
 
1191
- if source:
1192
- cursor = conn.execute("""
1193
- SELECT name, source, source_id, region, entity_type, record
1194
- FROM organizations
1195
- WHERE source = ?
1196
- """, (source,))
1399
+ if self._is_v2:
1400
+ if source:
1401
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1402
+ cursor = conn.execute("""
1403
+ SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
1404
+ FROM organizations_view v
1405
+ JOIN organizations o ON v.id = o.id
1406
+ WHERE o.source_id = ?
1407
+ """, (source_type_id,))
1408
+ else:
1409
+ cursor = conn.execute("""
1410
+ SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
1411
+ FROM organizations_view v
1412
+ JOIN organizations o ON v.id = o.id
1413
+ """)
1414
+ for row in cursor:
1415
+ yield CompanyRecord(
1416
+ name=row["name"],
1417
+ source=row["source"],
1418
+ source_id=row["source_identifier"],
1419
+ region=row["region"] or "",
1420
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1421
+ record=json.loads(row["record"]),
1422
+ )
1197
1423
  else:
1198
- cursor = conn.execute("""
1199
- SELECT name, source, source_id, region, entity_type, record
1200
- FROM organizations
1201
- """)
1202
-
1203
- for row in cursor:
1204
- yield CompanyRecord(
1205
- name=row["name"],
1206
- source=row["source"],
1207
- source_id=row["source_id"],
1208
- region=row["region"] or "",
1209
- entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1210
- record=json.loads(row["record"]),
1211
- )
1424
+ if source:
1425
+ cursor = conn.execute("""
1426
+ SELECT name, source, source_id, region, entity_type, record
1427
+ FROM organizations
1428
+ WHERE source = ?
1429
+ """, (source,))
1430
+ else:
1431
+ cursor = conn.execute("""
1432
+ SELECT name, source, source_id, region, entity_type, record
1433
+ FROM organizations
1434
+ """)
1435
+ for row in cursor:
1436
+ yield CompanyRecord(
1437
+ name=row["name"],
1438
+ source=row["source"],
1439
+ source_id=row["source_id"],
1440
+ region=row["region"] or "",
1441
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
1442
+ record=json.loads(row["record"]),
1443
+ )
1212
1444
 
1213
1445
  def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
1214
1446
  """
@@ -1252,10 +1484,18 @@ class OrganizationDatabase:
1252
1484
  sources: dict[int, str] = {} # org_id -> source
1253
1485
  all_org_ids: list[int] = []
1254
1486
 
1255
- cursor = conn.execute("""
1256
- SELECT id, source, source_id, name, region, record
1257
- FROM organizations
1258
- """)
1487
+ if self._is_v2:
1488
+ cursor = conn.execute("""
1489
+ SELECT o.id, s.name as source, o.source_identifier as source_id, o.name, l.name as region, o.record
1490
+ FROM organizations o
1491
+ JOIN source_types s ON o.source_id = s.id
1492
+ LEFT JOIN locations l ON o.region_id = l.id
1493
+ """)
1494
+ else:
1495
+ cursor = conn.execute("""
1496
+ SELECT id, source, source_id, name, region, record
1497
+ FROM organizations
1498
+ """)
1259
1499
 
1260
1500
  count = 0
1261
1501
  for row in cursor:
@@ -1586,17 +1826,32 @@ class OrganizationDatabase:
1586
1826
  """Delete all records from a specific source."""
1587
1827
  conn = self._connect()
1588
1828
 
1589
- # First get IDs to delete from vec table
1590
- cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
1591
- ids_to_delete = [row["id"] for row in cursor]
1829
+ if self._is_v2:
1830
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
1831
+ # First get IDs to delete from vec table
1832
+ cursor = conn.execute("SELECT id FROM organizations WHERE source_id = ?", (source_type_id,))
1833
+ ids_to_delete = [row["id"] for row in cursor]
1834
+
1835
+ # Delete from vec table
1836
+ if ids_to_delete:
1837
+ placeholders = ",".join("?" * len(ids_to_delete))
1838
+ conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
1839
+
1840
+ # Delete from main table
1841
+ cursor = conn.execute("DELETE FROM organizations WHERE source_id = ?", (source_type_id,))
1842
+ else:
1843
+ # First get IDs to delete from vec table
1844
+ cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
1845
+ ids_to_delete = [row["id"] for row in cursor]
1846
+
1847
+ # Delete from vec table
1848
+ if ids_to_delete:
1849
+ placeholders = ",".join("?" * len(ids_to_delete))
1850
+ conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
1592
1851
 
1593
- # Delete from vec table
1594
- if ids_to_delete:
1595
- placeholders = ",".join("?" * len(ids_to_delete))
1596
- conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
1852
+ # Delete from main table
1853
+ cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
1597
1854
 
1598
- # Delete from main table
1599
- cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
1600
1855
  deleted = cursor.rowcount
1601
1856
 
1602
1857
  conn.commit()
@@ -1825,130 +2080,366 @@ class OrganizationDatabase:
1825
2080
  conn.commit()
1826
2081
  return count
1827
2082
 
1828
- def resolve_qid_labels(
1829
- self,
1830
- label_map: dict[str, str],
1831
- batch_size: int = 1000,
1832
- ) -> int:
2083
+ def ensure_scalar_table_exists(self) -> None:
2084
+ """Create scalar embedding table if it doesn't exist."""
2085
+ conn = self._connect()
2086
+ conn.execute(f"""
2087
+ CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
2088
+ org_id INTEGER PRIMARY KEY,
2089
+ embedding int8[{self._embedding_dim}]
2090
+ )
2091
+ """)
2092
+ conn.commit()
2093
+ logger.info("Ensured organization_embeddings_scalar table exists")
2094
+
2095
+ def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
1833
2096
  """
1834
- Update organization records that have QIDs instead of labels in region field.
2097
+ Yield batches of org IDs that have float32 but missing scalar embeddings.
1835
2098
 
1836
2099
  Args:
1837
- label_map: Mapping of QID -> label for resolution
1838
- batch_size: Commit batch size
2100
+ batch_size: Number of IDs per batch
1839
2101
 
1840
- Returns:
1841
- Number of records updated
2102
+ Yields:
2103
+ Lists of org_ids needing scalar embeddings
1842
2104
  """
1843
2105
  conn = self._connect()
1844
2106
 
1845
- # Find records with QIDs in region field (starts with 'Q' followed by digits)
1846
- region_updates = 0
1847
- cursor = conn.execute("""
1848
- SELECT id, region FROM organizations
1849
- WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
1850
- """)
1851
- rows = cursor.fetchall()
2107
+ # Ensure scalar table exists before querying
2108
+ self.ensure_scalar_table_exists()
1852
2109
 
1853
- for row in rows:
1854
- org_id = row["id"]
1855
- qid = row["region"]
1856
- if qid in label_map:
1857
- conn.execute(
1858
- "UPDATE organizations SET region = ? WHERE id = ?",
1859
- (label_map[qid], org_id)
1860
- )
1861
- region_updates += 1
2110
+ last_id = 0
2111
+ while True:
2112
+ cursor = conn.execute("""
2113
+ SELECT e.org_id FROM organization_embeddings e
2114
+ LEFT JOIN organization_embeddings_scalar s ON e.org_id = s.org_id
2115
+ WHERE s.org_id IS NULL AND e.org_id > ?
2116
+ ORDER BY e.org_id
2117
+ LIMIT ?
2118
+ """, (last_id, batch_size))
1862
2119
 
1863
- if region_updates % batch_size == 0:
1864
- conn.commit()
1865
- logger.info(f"Updated {region_updates} organization region labels...")
2120
+ rows = cursor.fetchall()
2121
+ if not rows:
2122
+ break
1866
2123
 
1867
- conn.commit()
1868
- logger.info(f"Resolved QID labels: {region_updates} organization regions")
1869
- return region_updates
2124
+ ids = [row["org_id"] for row in rows]
2125
+ yield ids
2126
+ last_id = ids[-1]
1870
2127
 
1871
- def get_unresolved_qids(self) -> set[str]:
2128
+ def get_embeddings_by_ids(self, org_ids: list[int]) -> dict[int, np.ndarray]:
1872
2129
  """
1873
- Get all QIDs that still need resolution in the organizations table.
2130
+ Fetch float32 embeddings for given org IDs.
2131
+
2132
+ Args:
2133
+ org_ids: List of organization IDs
1874
2134
 
1875
2135
  Returns:
1876
- Set of QIDs (starting with 'Q') found in region field
2136
+ Dict mapping org_id to float32 embedding array
1877
2137
  """
1878
2138
  conn = self._connect()
1879
- qids: set[str] = set()
1880
2139
 
1881
- cursor = conn.execute("""
1882
- SELECT DISTINCT region FROM organizations
1883
- WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
1884
- """)
1885
- for row in cursor:
1886
- qids.add(row["region"])
2140
+ if not org_ids:
2141
+ return {}
1887
2142
 
1888
- return qids
2143
+ placeholders = ",".join("?" * len(org_ids))
2144
+ cursor = conn.execute(f"""
2145
+ SELECT org_id, embedding FROM organization_embeddings
2146
+ WHERE org_id IN ({placeholders})
2147
+ """, org_ids)
1889
2148
 
2149
+ result = {}
2150
+ for row in cursor:
2151
+ embedding_blob = row["embedding"]
2152
+ embedding = np.frombuffer(embedding_blob, dtype=np.float32)
2153
+ result[row["org_id"]] = embedding
2154
+ return result
1890
2155
 
1891
- def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
1892
- """
1893
- Get a singleton PersonDatabase instance for the given path.
2156
+ def insert_scalar_embeddings_batch(self, org_ids: list[int], embeddings: np.ndarray) -> int:
2157
+ """
2158
+ Insert scalar (int8) embeddings for existing orgs.
1894
2159
 
1895
- Args:
1896
- db_path: Path to database file
1897
- embedding_dim: Dimension of embeddings
2160
+ Args:
2161
+ org_ids: List of organization IDs
2162
+ embeddings: Matrix of int8 embeddings (N x dim)
1898
2163
 
1899
- Returns:
1900
- Shared PersonDatabase instance
1901
- """
1902
- path_key = str(db_path or DEFAULT_DB_PATH)
1903
- if path_key not in _person_database_instances:
1904
- logger.debug(f"Creating new PersonDatabase instance for {path_key}")
1905
- _person_database_instances[path_key] = PersonDatabase(db_path=db_path, embedding_dim=embedding_dim)
1906
- return _person_database_instances[path_key]
2164
+ Returns:
2165
+ Number of embeddings inserted
2166
+ """
2167
+ conn = self._connect()
2168
+ count = 0
1907
2169
 
2170
+ for org_id, embedding in zip(org_ids, embeddings):
2171
+ scalar_blob = embedding.astype(np.int8).tobytes()
2172
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
2173
+ conn.execute("""
2174
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
2175
+ VALUES (?, vec_int8(?))
2176
+ """, (org_id, scalar_blob))
2177
+ count += 1
1908
2178
 
1909
- class PersonDatabase:
1910
- """
1911
- SQLite database with sqlite-vec for person vector search.
2179
+ conn.commit()
2180
+ return count
1912
2181
 
1913
- Uses hybrid text + vector search:
1914
- 1. Text filtering with LIKE to reduce candidates
1915
- 2. sqlite-vec for semantic similarity ranking
2182
+ def get_scalar_embedding_count(self) -> int:
2183
+ """Get count of scalar embeddings."""
2184
+ conn = self._connect()
2185
+ if not self._has_scalar_table():
2186
+ return 0
2187
+ cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings_scalar")
2188
+ return cursor.fetchone()[0]
1916
2189
 
1917
- Stores people from sources like Wikidata with role/org context.
1918
- """
2190
+ def get_float32_embedding_count(self) -> int:
2191
+ """Get count of float32 embeddings."""
2192
+ conn = self._connect()
2193
+ cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings")
2194
+ return cursor.fetchone()[0]
1919
2195
 
1920
- def __init__(
1921
- self,
1922
- db_path: Optional[str | Path] = None,
1923
- embedding_dim: int = 768, # Default for embeddinggemma-300m
1924
- ):
2196
+ def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
1925
2197
  """
1926
- Initialize the person database.
2198
+ Yield batches of (org_id, name) tuples for records missing both float32 and scalar embeddings.
1927
2199
 
1928
2200
  Args:
1929
- db_path: Path to database file (creates if not exists)
1930
- embedding_dim: Dimension of embeddings to store
1931
- """
1932
- self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
1933
- self._embedding_dim = embedding_dim
1934
- self._conn: Optional[sqlite3.Connection] = None
2201
+ batch_size: Number of IDs per batch
1935
2202
 
1936
- def _ensure_dir(self) -> None:
1937
- """Ensure database directory exists."""
1938
- self._db_path.parent.mkdir(parents=True, exist_ok=True)
2203
+ Yields:
2204
+ Lists of (org_id, name) tuples needing embeddings generated from scratch
2205
+ """
2206
+ conn = self._connect()
1939
2207
 
1940
- def _connect(self) -> sqlite3.Connection:
1941
- """Get or create database connection using shared connection pool."""
1942
- if self._conn is not None:
1943
- return self._conn
2208
+ # Ensure scalar table exists
2209
+ self.ensure_scalar_table_exists()
1944
2210
 
1945
- self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
2211
+ last_id = 0
2212
+ while True:
2213
+ cursor = conn.execute("""
2214
+ SELECT o.id, o.name FROM organizations o
2215
+ LEFT JOIN organization_embeddings e ON o.id = e.org_id
2216
+ WHERE e.org_id IS NULL AND o.id > ?
2217
+ ORDER BY o.id
2218
+ LIMIT ?
2219
+ """, (last_id, batch_size))
1946
2220
 
1947
- # Create tables (idempotent)
1948
- self._create_tables()
2221
+ rows = cursor.fetchall()
2222
+ if not rows:
2223
+ break
2224
+
2225
+ results = [(row["id"], row["name"]) for row in rows]
2226
+ yield results
2227
+ last_id = results[-1][0]
2228
+
2229
+ def insert_both_embeddings_batch(
2230
+ self,
2231
+ org_ids: list[int],
2232
+ fp32_embeddings: np.ndarray,
2233
+ int8_embeddings: np.ndarray,
2234
+ ) -> int:
2235
+ """
2236
+ Insert both float32 and int8 embeddings for existing orgs.
2237
+
2238
+ Args:
2239
+ org_ids: List of organization IDs
2240
+ fp32_embeddings: Matrix of float32 embeddings (N x dim)
2241
+ int8_embeddings: Matrix of int8 embeddings (N x dim)
2242
+
2243
+ Returns:
2244
+ Number of embeddings inserted
2245
+ """
2246
+ conn = self._connect()
2247
+ count = 0
2248
+
2249
+ for org_id, fp32, int8 in zip(org_ids, fp32_embeddings, int8_embeddings):
2250
+ # Insert float32
2251
+ fp32_blob = fp32.astype(np.float32).tobytes()
2252
+ conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
2253
+ conn.execute("""
2254
+ INSERT INTO organization_embeddings (org_id, embedding)
2255
+ VALUES (?, ?)
2256
+ """, (org_id, fp32_blob))
2257
+
2258
+ # Insert int8
2259
+ int8_blob = int8.astype(np.int8).tobytes()
2260
+ conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
2261
+ conn.execute("""
2262
+ INSERT INTO organization_embeddings_scalar (org_id, embedding)
2263
+ VALUES (?, vec_int8(?))
2264
+ """, (org_id, int8_blob))
2265
+
2266
+ count += 1
2267
+
2268
+ conn.commit()
2269
+ return count
2270
+
2271
+ def resolve_qid_labels(
2272
+ self,
2273
+ label_map: dict[str, str],
2274
+ batch_size: int = 1000,
2275
+ ) -> tuple[int, int]:
2276
+ """
2277
+ Update organization records that have QIDs instead of labels in region field.
2278
+
2279
+ If resolving would create a duplicate of an existing record with
2280
+ resolved labels, the QID version is deleted instead.
2281
+
2282
+ Args:
2283
+ label_map: Mapping of QID -> label for resolution
2284
+ batch_size: Commit batch size
2285
+
2286
+ Returns:
2287
+ Tuple of (records updated, duplicates deleted)
2288
+ """
2289
+ conn = self._connect()
2290
+
2291
+ # Find records with QIDs in region field (starts with 'Q' followed by digits)
2292
+ region_updates = 0
2293
+ cursor = conn.execute("""
2294
+ SELECT id, region FROM organizations
2295
+ WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
2296
+ """)
2297
+ rows = cursor.fetchall()
2298
+
2299
+ duplicates_deleted = 0
2300
+ for row in rows:
2301
+ org_id = row["id"]
2302
+ qid = row["region"]
2303
+ if qid in label_map:
2304
+ resolved_region = label_map[qid]
2305
+ # Check if this update would create a duplicate
2306
+ # Get the name and source of the current record
2307
+ org_cursor = conn.execute(
2308
+ "SELECT name, source FROM organizations WHERE id = ?",
2309
+ (org_id,)
2310
+ )
2311
+ org_row = org_cursor.fetchone()
2312
+ if org_row is None:
2313
+ continue
2314
+
2315
+ org_name = org_row["name"]
2316
+ org_source = org_row["source"]
2317
+
2318
+ # Check if a record with the resolved region already exists
2319
+ existing_cursor = conn.execute(
2320
+ "SELECT id FROM organizations WHERE name = ? AND region = ? AND source = ? AND id != ?",
2321
+ (org_name, resolved_region, org_source, org_id)
2322
+ )
2323
+ existing = existing_cursor.fetchone()
2324
+
2325
+ if existing is not None:
2326
+ # Duplicate would be created - delete the QID-based record
2327
+ conn.execute("DELETE FROM organizations WHERE id = ?", (org_id,))
2328
+ duplicates_deleted += 1
2329
+ else:
2330
+ # Safe to update
2331
+ conn.execute(
2332
+ "UPDATE organizations SET region = ? WHERE id = ?",
2333
+ (resolved_region, org_id)
2334
+ )
2335
+ region_updates += 1
2336
+
2337
+ if (region_updates + duplicates_deleted) % batch_size == 0:
2338
+ conn.commit()
2339
+ logger.info(f"Resolved QID labels: {region_updates} updates, {duplicates_deleted} deletes...")
2340
+
2341
+ conn.commit()
2342
+ logger.info(f"Resolved QID labels: {region_updates} organization regions, {duplicates_deleted} duplicates deleted")
2343
+ return region_updates, duplicates_deleted
2344
+
2345
+ def get_unresolved_qids(self) -> set[str]:
2346
+ """
2347
+ Get all QIDs that still need resolution in the organizations table.
2348
+
2349
+ Returns:
2350
+ Set of QIDs (starting with 'Q') found in region field
2351
+ """
2352
+ conn = self._connect()
2353
+ qids: set[str] = set()
2354
+
2355
+ cursor = conn.execute("""
2356
+ SELECT DISTINCT region FROM organizations
2357
+ WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
2358
+ """)
2359
+ for row in cursor:
2360
+ qids.add(row["region"])
2361
+
2362
+ return qids
2363
+
2364
+
2365
+ def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
2366
+ """
2367
+ Get a singleton PersonDatabase instance for the given path.
2368
+
2369
+ Args:
2370
+ db_path: Path to database file
2371
+ embedding_dim: Dimension of embeddings
2372
+
2373
+ Returns:
2374
+ Shared PersonDatabase instance
2375
+ """
2376
+ path_key = str(db_path or DEFAULT_DB_PATH)
2377
+ if path_key not in _person_database_instances:
2378
+ logger.debug(f"Creating new PersonDatabase instance for {path_key}")
2379
+ _person_database_instances[path_key] = PersonDatabase(db_path=db_path, embedding_dim=embedding_dim)
2380
+ return _person_database_instances[path_key]
2381
+
2382
+
2383
+ class PersonDatabase:
2384
+ """
2385
+ SQLite database with sqlite-vec for person vector search.
2386
+
2387
+ Uses hybrid text + vector search:
2388
+ 1. Text filtering with LIKE to reduce candidates
2389
+ 2. sqlite-vec for semantic similarity ranking
2390
+
2391
+ Stores people from sources like Wikidata with role/org context.
2392
+ """
2393
+
2394
+ def __init__(
2395
+ self,
2396
+ db_path: Optional[str | Path] = None,
2397
+ embedding_dim: int = 768, # Default for embeddinggemma-300m
2398
+ ):
2399
+ """
2400
+ Initialize the person database.
2401
+
2402
+ Args:
2403
+ db_path: Path to database file (creates if not exists)
2404
+ embedding_dim: Dimension of embeddings to store
2405
+ """
2406
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
2407
+ self._embedding_dim = embedding_dim
2408
+ self._conn: Optional[sqlite3.Connection] = None
2409
+ self._is_v2: Optional[bool] = None # Detected on first connect
2410
+
2411
+ def _ensure_dir(self) -> None:
2412
+ """Ensure database directory exists."""
2413
+ self._db_path.parent.mkdir(parents=True, exist_ok=True)
2414
+
2415
+ def _connect(self) -> sqlite3.Connection:
2416
+ """Get or create database connection using shared connection pool."""
2417
+ if self._conn is not None:
2418
+ return self._conn
2419
+
2420
+ self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
2421
+
2422
+ # Detect schema version BEFORE creating tables
2423
+ # v2 has person_type_id (FK) instead of person_type (TEXT)
2424
+ if self._is_v2 is None:
2425
+ cursor = self._conn.execute("PRAGMA table_info(people)")
2426
+ columns = {row["name"] for row in cursor}
2427
+ self._is_v2 = "person_type_id" in columns
2428
+ if self._is_v2:
2429
+ logger.debug("Detected v2 schema for people")
2430
+
2431
+ # Create tables (idempotent) - only for v1 schema or fresh databases
2432
+ # v2 databases already have their schema from migration
2433
+ if not self._is_v2:
2434
+ self._create_tables()
1949
2435
 
1950
2436
  return self._conn
1951
2437
 
2438
+ @property
2439
+ def _people_table(self) -> str:
2440
+ """Return table/view name for people queries needing text fields."""
2441
+ return "people_view" if self._is_v2 else "people"
2442
+
1952
2443
  def _create_tables(self) -> None:
1953
2444
  """Create database tables including sqlite-vec virtual table."""
1954
2445
  conn = self._conn
@@ -2051,7 +2542,7 @@ class PersonDatabase:
2051
2542
  except sqlite3.OperationalError:
2052
2543
  pass # Column doesn't exist yet
2053
2544
 
2054
- # Create sqlite-vec virtual table for embeddings
2545
+ # Create sqlite-vec virtual table for embeddings (float32)
2055
2546
  conn.execute(f"""
2056
2547
  CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
2057
2548
  person_id INTEGER PRIMARY KEY,
@@ -2059,6 +2550,15 @@ class PersonDatabase:
2059
2550
  )
2060
2551
  """)
2061
2552
 
2553
+ # Create sqlite-vec virtual table for scalar embeddings (int8)
2554
+ # Provides 75% storage reduction with ~92% recall at top-100
2555
+ conn.execute(f"""
2556
+ CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
2557
+ person_id INTEGER PRIMARY KEY,
2558
+ embedding int8[{self._embedding_dim}]
2559
+ )
2560
+ """)
2561
+
2062
2562
  # Create QID labels lookup table for Wikidata QID -> label mappings
2063
2563
  conn.execute("""
2064
2564
  CREATE TABLE IF NOT EXISTS qid_labels (
@@ -2148,13 +2648,19 @@ class PersonDatabase:
2148
2648
  """Clear connection reference (shared connection remains open)."""
2149
2649
  self._conn = None
2150
2650
 
2151
- def insert(self, record: PersonRecord, embedding: np.ndarray) -> int:
2651
+ def insert(
2652
+ self,
2653
+ record: PersonRecord,
2654
+ embedding: np.ndarray,
2655
+ scalar_embedding: Optional[np.ndarray] = None,
2656
+ ) -> int:
2152
2657
  """
2153
2658
  Insert a person record with its embedding.
2154
2659
 
2155
2660
  Args:
2156
2661
  record: Person record to insert
2157
- embedding: Embedding vector for the person name
2662
+ embedding: Embedding vector for the person name (float32)
2663
+ scalar_embedding: Optional int8 scalar embedding for compact storage
2158
2664
 
2159
2665
  Returns:
2160
2666
  Row ID of inserted record
@@ -2191,7 +2697,7 @@ class PersonDatabase:
2191
2697
  row_id = cursor.lastrowid
2192
2698
  assert row_id is not None
2193
2699
 
2194
- # Insert embedding into vec table (delete first since sqlite-vec doesn't support REPLACE)
2700
+ # Insert embedding into vec table (float32)
2195
2701
  embedding_blob = embedding.astype(np.float32).tobytes()
2196
2702
  conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
2197
2703
  conn.execute("""
@@ -2199,6 +2705,15 @@ class PersonDatabase:
2199
2705
  VALUES (?, ?)
2200
2706
  """, (row_id, embedding_blob))
2201
2707
 
2708
+ # Insert scalar embedding if provided (int8)
2709
+ if scalar_embedding is not None:
2710
+ scalar_blob = scalar_embedding.astype(np.int8).tobytes()
2711
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
2712
+ conn.execute("""
2713
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
2714
+ VALUES (?, vec_int8(?))
2715
+ """, (row_id, scalar_blob))
2716
+
2202
2717
  conn.commit()
2203
2718
  return row_id
2204
2719
 
@@ -2207,14 +2722,16 @@ class PersonDatabase:
2207
2722
  records: list[PersonRecord],
2208
2723
  embeddings: np.ndarray,
2209
2724
  batch_size: int = 1000,
2725
+ scalar_embeddings: Optional[np.ndarray] = None,
2210
2726
  ) -> int:
2211
2727
  """
2212
2728
  Insert multiple person records with embeddings.
2213
2729
 
2214
2730
  Args:
2215
2731
  records: List of person records
2216
- embeddings: Matrix of embeddings (N x dim)
2732
+ embeddings: Matrix of embeddings (N x dim) - float32
2217
2733
  batch_size: Commit batch size
2734
+ scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
2218
2735
 
2219
2736
  Returns:
2220
2737
  Number of records inserted
@@ -2222,32 +2739,66 @@ class PersonDatabase:
2222
2739
  conn = self._connect()
2223
2740
  count = 0
2224
2741
 
2225
- for record, embedding in zip(records, embeddings):
2742
+ for i, (record, embedding) in enumerate(zip(records, embeddings)):
2226
2743
  record_json = json.dumps(record.record)
2227
2744
  name_normalized = _normalize_person_name(record.name)
2228
2745
 
2229
- cursor = conn.execute("""
2230
- INSERT OR REPLACE INTO people
2231
- (name, name_normalized, source, source_id, country, person_type,
2232
- known_for_role, known_for_org, known_for_org_id, from_date, to_date,
2233
- birth_date, death_date, record)
2234
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
2235
- """, (
2236
- record.name,
2237
- name_normalized,
2238
- record.source,
2239
- record.source_id,
2240
- record.country,
2241
- record.person_type.value,
2242
- record.known_for_role,
2243
- record.known_for_org,
2244
- record.known_for_org_id, # Can be None
2245
- record.from_date or "",
2246
- record.to_date or "",
2247
- record.birth_date or "",
2248
- record.death_date or "",
2249
- record_json,
2250
- ))
2746
+ if self._is_v2:
2747
+ # v2 schema: use FK IDs instead of TEXT columns
2748
+ source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
2749
+ person_type_id = PEOPLE_TYPE_NAME_TO_ID.get(record.person_type.value, 15) # 15 = unknown
2750
+
2751
+ # Resolve country to location_id if provided
2752
+ country_id = None
2753
+ if record.country:
2754
+ locations_db = get_locations_database(db_path=self._db_path)
2755
+ country_id = locations_db.resolve_region_text(record.country)
2756
+
2757
+ cursor = conn.execute("""
2758
+ INSERT OR REPLACE INTO people
2759
+ (name, name_normalized, source_id, source_identifier, country_id, person_type_id,
2760
+ known_for_org, known_for_org_id, from_date, to_date,
2761
+ birth_date, death_date, record)
2762
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
2763
+ """, (
2764
+ record.name,
2765
+ name_normalized,
2766
+ source_type_id,
2767
+ record.source_id,
2768
+ country_id,
2769
+ person_type_id,
2770
+ record.known_for_org,
2771
+ record.known_for_org_id, # Can be None
2772
+ record.from_date or "",
2773
+ record.to_date or "",
2774
+ record.birth_date or "",
2775
+ record.death_date or "",
2776
+ record_json,
2777
+ ))
2778
+ else:
2779
+ # v1 schema: use TEXT columns
2780
+ cursor = conn.execute("""
2781
+ INSERT OR REPLACE INTO people
2782
+ (name, name_normalized, source, source_id, country, person_type,
2783
+ known_for_role, known_for_org, known_for_org_id, from_date, to_date,
2784
+ birth_date, death_date, record)
2785
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
2786
+ """, (
2787
+ record.name,
2788
+ name_normalized,
2789
+ record.source,
2790
+ record.source_id,
2791
+ record.country,
2792
+ record.person_type.value,
2793
+ record.known_for_role,
2794
+ record.known_for_org,
2795
+ record.known_for_org_id, # Can be None
2796
+ record.from_date or "",
2797
+ record.to_date or "",
2798
+ record.birth_date or "",
2799
+ record.death_date or "",
2800
+ record_json,
2801
+ ))
2251
2802
 
2252
2803
  row_id = cursor.lastrowid
2253
2804
  assert row_id is not None
@@ -2260,6 +2811,15 @@ class PersonDatabase:
2260
2811
  VALUES (?, ?)
2261
2812
  """, (row_id, embedding_blob))
2262
2813
 
2814
+ # Insert scalar embedding if provided (int8)
2815
+ if scalar_embeddings is not None:
2816
+ scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
2817
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
2818
+ conn.execute("""
2819
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
2820
+ VALUES (?, vec_int8(?))
2821
+ """, (row_id, scalar_blob))
2822
+
2263
2823
  count += 1
2264
2824
 
2265
2825
  if count % batch_size == 0:
@@ -2284,10 +2844,17 @@ class PersonDatabase:
2284
2844
  """
2285
2845
  conn = self._connect()
2286
2846
 
2287
- cursor = conn.execute("""
2288
- UPDATE people SET from_date = ?, to_date = ?
2289
- WHERE source = ? AND source_id = ?
2290
- """, (from_date or "", to_date or "", source, source_id))
2847
+ if self._is_v2:
2848
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
2849
+ cursor = conn.execute("""
2850
+ UPDATE people SET from_date = ?, to_date = ?
2851
+ WHERE source_id = ? AND source_identifier = ?
2852
+ """, (from_date or "", to_date or "", source_type_id, source_id))
2853
+ else:
2854
+ cursor = conn.execute("""
2855
+ UPDATE people SET from_date = ?, to_date = ?
2856
+ WHERE source = ? AND source_id = ?
2857
+ """, (from_date or "", to_date or "", source, source_id))
2291
2858
 
2292
2859
  conn.commit()
2293
2860
  return cursor.rowcount > 0
@@ -2322,10 +2889,17 @@ class PersonDatabase:
2322
2889
  conn = self._connect()
2323
2890
 
2324
2891
  # First get the person's internal ID
2325
- row = conn.execute(
2326
- "SELECT id FROM people WHERE source = ? AND source_id = ?",
2327
- (source, source_id)
2328
- ).fetchone()
2892
+ if self._is_v2:
2893
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
2894
+ row = conn.execute(
2895
+ "SELECT id FROM people WHERE source_id = ? AND source_identifier = ?",
2896
+ (source_type_id, source_id)
2897
+ ).fetchone()
2898
+ else:
2899
+ row = conn.execute(
2900
+ "SELECT id FROM people WHERE source = ? AND source_id = ?",
2901
+ (source, source_id)
2902
+ ).fetchone()
2329
2903
 
2330
2904
  if not row:
2331
2905
  return False
@@ -2382,7 +2956,13 @@ class PersonDatabase:
2382
2956
  if query_norm == 0:
2383
2957
  return []
2384
2958
  query_normalized = query_embedding / query_norm
2385
- query_blob = query_normalized.astype(np.float32).tobytes()
2959
+
2960
+ # Use int8 quantized query if scalar table is available (75% storage savings)
2961
+ if self._has_scalar_table():
2962
+ query_int8 = self._quantize_query(query_normalized)
2963
+ query_blob = query_int8.tobytes()
2964
+ else:
2965
+ query_blob = query_normalized.astype(np.float32).tobytes()
2386
2966
 
2387
2967
  # Stage 1: Text-based pre-filtering (if query_text provided)
2388
2968
  candidate_ids: Optional[set[int]] = None
@@ -2453,13 +3033,26 @@ class PersonDatabase:
2453
3033
  cursor = conn.execute(query, params)
2454
3034
  return set(row["id"] for row in cursor)
2455
3035
 
3036
+ def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
3037
+ """Quantize query embedding to int8 for scalar search."""
3038
+ return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
3039
+
3040
+ def _has_scalar_table(self) -> bool:
3041
+ """Check if scalar embedding table exists."""
3042
+ conn = self._conn
3043
+ assert conn is not None
3044
+ cursor = conn.execute(
3045
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='person_embeddings_scalar'"
3046
+ )
3047
+ return cursor.fetchone() is not None
3048
+
2456
3049
  def _vector_search_filtered(
2457
3050
  self,
2458
3051
  query_blob: bytes,
2459
3052
  candidate_ids: set[int],
2460
3053
  top_k: int,
2461
3054
  ) -> list[tuple[PersonRecord, float]]:
2462
- """Vector search within a filtered set of candidates."""
3055
+ """Vector search within a filtered set of candidates using scalar (int8) embeddings."""
2463
3056
  conn = self._conn
2464
3057
  assert conn is not None
2465
3058
 
@@ -2469,15 +3062,27 @@ class PersonDatabase:
2469
3062
  # Build IN clause for candidate IDs
2470
3063
  placeholders = ",".join("?" * len(candidate_ids))
2471
3064
 
2472
- query = f"""
2473
- SELECT
2474
- e.person_id,
2475
- vec_distance_cosine(e.embedding, ?) as distance
2476
- FROM person_embeddings e
2477
- WHERE e.person_id IN ({placeholders})
2478
- ORDER BY distance
2479
- LIMIT ?
2480
- """
3065
+ # Use scalar embedding table if available (75% storage reduction)
3066
+ if self._has_scalar_table():
3067
+ query = f"""
3068
+ SELECT
3069
+ e.person_id,
3070
+ vec_distance_cosine(e.embedding, vec_int8(?)) as distance
3071
+ FROM person_embeddings_scalar e
3072
+ WHERE e.person_id IN ({placeholders})
3073
+ ORDER BY distance
3074
+ LIMIT ?
3075
+ """
3076
+ else:
3077
+ query = f"""
3078
+ SELECT
3079
+ e.person_id,
3080
+ vec_distance_cosine(e.embedding, ?) as distance
3081
+ FROM person_embeddings e
3082
+ WHERE e.person_id IN ({placeholders})
3083
+ ORDER BY distance
3084
+ LIMIT ?
3085
+ """
2481
3086
 
2482
3087
  cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
2483
3088
 
@@ -2500,18 +3105,29 @@ class PersonDatabase:
2500
3105
  query_blob: bytes,
2501
3106
  top_k: int,
2502
3107
  ) -> list[tuple[PersonRecord, float]]:
2503
- """Full vector search without text pre-filtering."""
3108
+ """Full vector search without text pre-filtering using scalar (int8) embeddings."""
2504
3109
  conn = self._conn
2505
3110
  assert conn is not None
2506
3111
 
2507
- query = """
2508
- SELECT
2509
- person_id,
2510
- vec_distance_cosine(embedding, ?) as distance
2511
- FROM person_embeddings
2512
- ORDER BY distance
2513
- LIMIT ?
2514
- """
3112
+ # Use scalar embedding table if available (75% storage reduction)
3113
+ if self._has_scalar_table():
3114
+ query = """
3115
+ SELECT
3116
+ person_id,
3117
+ vec_distance_cosine(embedding, vec_int8(?)) as distance
3118
+ FROM person_embeddings_scalar
3119
+ ORDER BY distance
3120
+ LIMIT ?
3121
+ """
3122
+ else:
3123
+ query = """
3124
+ SELECT
3125
+ person_id,
3126
+ vec_distance_cosine(embedding, ?) as distance
3127
+ FROM person_embeddings
3128
+ ORDER BY distance
3129
+ LIMIT ?
3130
+ """
2515
3131
  cursor = conn.execute(query, (query_blob, top_k))
2516
3132
 
2517
3133
  results = []
@@ -2531,17 +3147,29 @@ class PersonDatabase:
2531
3147
  conn = self._conn
2532
3148
  assert conn is not None
2533
3149
 
2534
- cursor = conn.execute("""
2535
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
2536
- FROM people WHERE id = ?
2537
- """, (person_id,))
3150
+ if self._is_v2:
3151
+ # v2 schema: join view with base table for record
3152
+ cursor = conn.execute("""
3153
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3154
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3155
+ v.birth_date, v.death_date, p.record
3156
+ FROM people_view v
3157
+ JOIN people p ON v.id = p.id
3158
+ WHERE v.id = ?
3159
+ """, (person_id,))
3160
+ else:
3161
+ cursor = conn.execute("""
3162
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3163
+ FROM people WHERE id = ?
3164
+ """, (person_id,))
2538
3165
 
2539
3166
  row = cursor.fetchone()
2540
3167
  if row:
3168
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
2541
3169
  return PersonRecord(
2542
3170
  name=row["name"],
2543
3171
  source=row["source"],
2544
- source_id=row["source_id"],
3172
+ source_id=row[source_id_field],
2545
3173
  country=row["country"] or "",
2546
3174
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
2547
3175
  known_for_role=row["known_for_role"] or "",
@@ -2557,18 +3185,30 @@ class PersonDatabase:
2557
3185
  """Get a person record by source and source_id."""
2558
3186
  conn = self._connect()
2559
3187
 
2560
- cursor = conn.execute("""
2561
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
2562
- FROM people
2563
- WHERE source = ? AND source_id = ?
2564
- """, (source, source_id))
3188
+ if self._is_v2:
3189
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
3190
+ cursor = conn.execute("""
3191
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3192
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3193
+ v.birth_date, v.death_date, p.record
3194
+ FROM people_view v
3195
+ JOIN people p ON v.id = p.id
3196
+ WHERE p.source_id = ? AND p.source_identifier = ?
3197
+ """, (source_type_id, source_id))
3198
+ else:
3199
+ cursor = conn.execute("""
3200
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3201
+ FROM people
3202
+ WHERE source = ? AND source_id = ?
3203
+ """, (source, source_id))
2565
3204
 
2566
3205
  row = cursor.fetchone()
2567
3206
  if row:
3207
+ source_id_field = "source_identifier" if self._is_v2 else "source_id"
2568
3208
  return PersonRecord(
2569
3209
  name=row["name"],
2570
3210
  source=row["source"],
2571
- source_id=row["source_id"],
3211
+ source_id=row[source_id_field],
2572
3212
  country=row["country"] or "",
2573
3213
  person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
2574
3214
  known_for_role=row["known_for_role"] or "",
@@ -2588,12 +3228,32 @@ class PersonDatabase:
2588
3228
  cursor = conn.execute("SELECT COUNT(*) FROM people")
2589
3229
  total = cursor.fetchone()[0]
2590
3230
 
2591
- # Count by person_type
2592
- cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
3231
+ # Count by person_type - handle both v1 and v2 schema
3232
+ if self._is_v2:
3233
+ # v2 schema - join with people_types
3234
+ cursor = conn.execute("""
3235
+ SELECT pt.name as person_type, COUNT(*) as cnt
3236
+ FROM people p
3237
+ JOIN people_types pt ON p.person_type_id = pt.id
3238
+ GROUP BY p.person_type_id
3239
+ """)
3240
+ else:
3241
+ # v1 schema
3242
+ cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
2593
3243
  by_type = {row["person_type"]: row["cnt"] for row in cursor}
2594
3244
 
2595
- # Count by source
2596
- cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
3245
+ # Count by source - handle both v1 and v2 schema
3246
+ if self._is_v2:
3247
+ # v2 schema - join with source_types
3248
+ cursor = conn.execute("""
3249
+ SELECT st.name as source, COUNT(*) as cnt
3250
+ FROM people p
3251
+ JOIN source_types st ON p.source_id = st.id
3252
+ GROUP BY p.source_id
3253
+ """)
3254
+ else:
3255
+ # v1 schema
3256
+ cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
2597
3257
  by_source = {row["source"]: row["cnt"] for row in cursor}
2598
3258
 
2599
3259
  return {
@@ -2602,60 +3262,293 @@ class PersonDatabase:
2602
3262
  "by_source": by_source,
2603
3263
  }
2604
3264
 
2605
- def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
2606
- """
2607
- Get all source_ids from the people table.
3265
+ def ensure_scalar_table_exists(self) -> None:
3266
+ """Create scalar embedding table if it doesn't exist."""
3267
+ conn = self._connect()
3268
+ conn.execute(f"""
3269
+ CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
3270
+ person_id INTEGER PRIMARY KEY,
3271
+ embedding int8[{self._embedding_dim}]
3272
+ )
3273
+ """)
3274
+ conn.commit()
3275
+ logger.info("Ensured person_embeddings_scalar table exists")
2608
3276
 
2609
- Useful for resume operations to skip already-imported records.
3277
+ def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
3278
+ """
3279
+ Yield batches of person IDs that have float32 but missing scalar embeddings.
2610
3280
 
2611
3281
  Args:
2612
- source: Optional source filter (e.g., "wikidata")
3282
+ batch_size: Number of IDs per batch
2613
3283
 
2614
- Returns:
2615
- Set of source_id strings (e.g., Q codes for Wikidata)
3284
+ Yields:
3285
+ Lists of person_ids needing scalar embeddings
2616
3286
  """
2617
3287
  conn = self._connect()
2618
3288
 
2619
- if source:
2620
- cursor = conn.execute(
2621
- "SELECT DISTINCT source_id FROM people WHERE source = ?",
2622
- (source,)
2623
- )
2624
- else:
2625
- cursor = conn.execute("SELECT DISTINCT source_id FROM people")
2626
-
2627
- return {row[0] for row in cursor}
3289
+ # Ensure scalar table exists before querying
3290
+ self.ensure_scalar_table_exists()
2628
3291
 
2629
- def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
2630
- """Iterate over all person records, optionally filtered by source."""
3292
+ last_id = 0
3293
+ while True:
3294
+ cursor = conn.execute("""
3295
+ SELECT e.person_id FROM person_embeddings e
3296
+ LEFT JOIN person_embeddings_scalar s ON e.person_id = s.person_id
3297
+ WHERE s.person_id IS NULL AND e.person_id > ?
3298
+ ORDER BY e.person_id
3299
+ LIMIT ?
3300
+ """, (last_id, batch_size))
3301
+
3302
+ rows = cursor.fetchall()
3303
+ if not rows:
3304
+ break
3305
+
3306
+ ids = [row["person_id"] for row in rows]
3307
+ yield ids
3308
+ last_id = ids[-1]
3309
+
3310
+ def get_embeddings_by_ids(self, person_ids: list[int]) -> dict[int, np.ndarray]:
3311
+ """
3312
+ Fetch float32 embeddings for given person IDs.
3313
+
3314
+ Args:
3315
+ person_ids: List of person IDs
3316
+
3317
+ Returns:
3318
+ Dict mapping person_id to float32 embedding array
3319
+ """
2631
3320
  conn = self._connect()
2632
3321
 
2633
- if source:
3322
+ if not person_ids:
3323
+ return {}
3324
+
3325
+ placeholders = ",".join("?" * len(person_ids))
3326
+ cursor = conn.execute(f"""
3327
+ SELECT person_id, embedding FROM person_embeddings
3328
+ WHERE person_id IN ({placeholders})
3329
+ """, person_ids)
3330
+
3331
+ result = {}
3332
+ for row in cursor:
3333
+ embedding_blob = row["embedding"]
3334
+ embedding = np.frombuffer(embedding_blob, dtype=np.float32)
3335
+ result[row["person_id"]] = embedding
3336
+ return result
3337
+
3338
+ def insert_scalar_embeddings_batch(self, person_ids: list[int], embeddings: np.ndarray) -> int:
3339
+ """
3340
+ Insert scalar (int8) embeddings for existing people.
3341
+
3342
+ Args:
3343
+ person_ids: List of person IDs
3344
+ embeddings: Matrix of int8 embeddings (N x dim)
3345
+
3346
+ Returns:
3347
+ Number of embeddings inserted
3348
+ """
3349
+ conn = self._connect()
3350
+ count = 0
3351
+
3352
+ for person_id, embedding in zip(person_ids, embeddings):
3353
+ scalar_blob = embedding.astype(np.int8).tobytes()
3354
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
3355
+ conn.execute("""
3356
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
3357
+ VALUES (?, vec_int8(?))
3358
+ """, (person_id, scalar_blob))
3359
+ count += 1
3360
+
3361
+ conn.commit()
3362
+ return count
3363
+
3364
+ def get_scalar_embedding_count(self) -> int:
3365
+ """Get count of scalar embeddings."""
3366
+ conn = self._connect()
3367
+ if not self._has_scalar_table():
3368
+ return 0
3369
+ cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings_scalar")
3370
+ return cursor.fetchone()[0]
3371
+
3372
+ def get_float32_embedding_count(self) -> int:
3373
+ """Get count of float32 embeddings."""
3374
+ conn = self._connect()
3375
+ cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings")
3376
+ return cursor.fetchone()[0]
3377
+
3378
+ def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
3379
+ """
3380
+ Yield batches of (person_id, name) tuples for records missing both float32 and scalar embeddings.
3381
+
3382
+ Args:
3383
+ batch_size: Number of IDs per batch
3384
+
3385
+ Yields:
3386
+ Lists of (person_id, name) tuples needing embeddings generated from scratch
3387
+ """
3388
+ conn = self._connect()
3389
+
3390
+ # Ensure scalar table exists
3391
+ self.ensure_scalar_table_exists()
3392
+
3393
+ last_id = 0
3394
+ while True:
2634
3395
  cursor = conn.execute("""
2635
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
2636
- FROM people
2637
- WHERE source = ?
2638
- """, (source,))
3396
+ SELECT p.id, p.name FROM people p
3397
+ LEFT JOIN person_embeddings e ON p.id = e.person_id
3398
+ WHERE e.person_id IS NULL AND p.id > ?
3399
+ ORDER BY p.id
3400
+ LIMIT ?
3401
+ """, (last_id, batch_size))
3402
+
3403
+ rows = cursor.fetchall()
3404
+ if not rows:
3405
+ break
3406
+
3407
+ results = [(row["id"], row["name"]) for row in rows]
3408
+ yield results
3409
+ last_id = results[-1][0]
3410
+
3411
+ def insert_both_embeddings_batch(
3412
+ self,
3413
+ person_ids: list[int],
3414
+ fp32_embeddings: np.ndarray,
3415
+ int8_embeddings: np.ndarray,
3416
+ ) -> int:
3417
+ """
3418
+ Insert both float32 and int8 embeddings for existing people.
3419
+
3420
+ Args:
3421
+ person_ids: List of person IDs
3422
+ fp32_embeddings: Matrix of float32 embeddings (N x dim)
3423
+ int8_embeddings: Matrix of int8 embeddings (N x dim)
3424
+
3425
+ Returns:
3426
+ Number of embeddings inserted
3427
+ """
3428
+ conn = self._connect()
3429
+ count = 0
3430
+
3431
+ for person_id, fp32, int8 in zip(person_ids, fp32_embeddings, int8_embeddings):
3432
+ # Insert float32
3433
+ fp32_blob = fp32.astype(np.float32).tobytes()
3434
+ conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
3435
+ conn.execute("""
3436
+ INSERT INTO person_embeddings (person_id, embedding)
3437
+ VALUES (?, ?)
3438
+ """, (person_id, fp32_blob))
3439
+
3440
+ # Insert int8
3441
+ int8_blob = int8.astype(np.int8).tobytes()
3442
+ conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
3443
+ conn.execute("""
3444
+ INSERT INTO person_embeddings_scalar (person_id, embedding)
3445
+ VALUES (?, vec_int8(?))
3446
+ """, (person_id, int8_blob))
3447
+
3448
+ count += 1
3449
+
3450
+ conn.commit()
3451
+ return count
3452
+
3453
+ def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
3454
+ """
3455
+ Get all source_ids from the people table.
3456
+
3457
+ Useful for resume operations to skip already-imported records.
3458
+
3459
+ Args:
3460
+ source: Optional source filter (e.g., "wikidata")
3461
+
3462
+ Returns:
3463
+ Set of source_id strings (e.g., Q codes for Wikidata)
3464
+ """
3465
+ conn = self._connect()
3466
+
3467
+ if self._is_v2:
3468
+ id_col = "source_identifier"
3469
+ if source:
3470
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
3471
+ cursor = conn.execute(
3472
+ f"SELECT DISTINCT {id_col} FROM people WHERE source_id = ?",
3473
+ (source_type_id,)
3474
+ )
3475
+ else:
3476
+ cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM people")
2639
3477
  else:
2640
- cursor = conn.execute("""
2641
- SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
2642
- FROM people
2643
- """)
3478
+ if source:
3479
+ cursor = conn.execute(
3480
+ "SELECT DISTINCT source_id FROM people WHERE source = ?",
3481
+ (source,)
3482
+ )
3483
+ else:
3484
+ cursor = conn.execute("SELECT DISTINCT source_id FROM people")
2644
3485
 
2645
- for row in cursor:
2646
- yield PersonRecord(
2647
- name=row["name"],
2648
- source=row["source"],
2649
- source_id=row["source_id"],
2650
- country=row["country"] or "",
2651
- person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
2652
- known_for_role=row["known_for_role"] or "",
2653
- known_for_org=row["known_for_org"] or "",
2654
- known_for_org_id=row["known_for_org_id"], # Can be None
2655
- birth_date=row["birth_date"] or "",
2656
- death_date=row["death_date"] or "",
2657
- record=json.loads(row["record"]),
2658
- )
3486
+ return {row[0] for row in cursor}
3487
+
3488
+ def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
3489
+ """Iterate over all person records, optionally filtered by source."""
3490
+ conn = self._connect()
3491
+
3492
+ if self._is_v2:
3493
+ if source:
3494
+ source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
3495
+ cursor = conn.execute("""
3496
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3497
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3498
+ v.birth_date, v.death_date, p.record
3499
+ FROM people_view v
3500
+ JOIN people p ON v.id = p.id
3501
+ WHERE p.source_id = ?
3502
+ """, (source_type_id,))
3503
+ else:
3504
+ cursor = conn.execute("""
3505
+ SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
3506
+ v.known_for_role, v.known_for_org, v.known_for_org_id,
3507
+ v.birth_date, v.death_date, p.record
3508
+ FROM people_view v
3509
+ JOIN people p ON v.id = p.id
3510
+ """)
3511
+ for row in cursor:
3512
+ yield PersonRecord(
3513
+ name=row["name"],
3514
+ source=row["source"],
3515
+ source_id=row["source_identifier"],
3516
+ country=row["country"] or "",
3517
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
3518
+ known_for_role=row["known_for_role"] or "",
3519
+ known_for_org=row["known_for_org"] or "",
3520
+ known_for_org_id=row["known_for_org_id"], # Can be None
3521
+ birth_date=row["birth_date"] or "",
3522
+ death_date=row["death_date"] or "",
3523
+ record=json.loads(row["record"]),
3524
+ )
3525
+ else:
3526
+ if source:
3527
+ cursor = conn.execute("""
3528
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3529
+ FROM people
3530
+ WHERE source = ?
3531
+ """, (source,))
3532
+ else:
3533
+ cursor = conn.execute("""
3534
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
3535
+ FROM people
3536
+ """)
3537
+
3538
+ for row in cursor:
3539
+ yield PersonRecord(
3540
+ name=row["name"],
3541
+ source=row["source"],
3542
+ source_id=row["source_id"],
3543
+ country=row["country"] or "",
3544
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
3545
+ known_for_role=row["known_for_role"] or "",
3546
+ known_for_org=row["known_for_org"] or "",
3547
+ known_for_org_id=row["known_for_org_id"], # Can be None
3548
+ birth_date=row["birth_date"] or "",
3549
+ death_date=row["death_date"] or "",
3550
+ record=json.loads(row["record"]),
3551
+ )
2659
3552
 
2660
3553
  def resolve_qid_labels(
2661
3554
  self,
@@ -2680,6 +3573,11 @@ class PersonDatabase:
2680
3573
  """
2681
3574
  conn = self._connect()
2682
3575
 
3576
+ # v2 schema stores QIDs as integers, not text - this method doesn't apply
3577
+ if self._is_v2:
3578
+ logger.debug("Skipping resolve_qid_labels for v2 schema (QIDs stored as integers)")
3579
+ return 0, 0
3580
+
2683
3581
  # Find all records with QIDs in any field (role or org - these are in unique constraint)
2684
3582
  # Country is not part of unique constraint so can be updated directly
2685
3583
  cursor = conn.execute("""
@@ -2752,6 +3650,11 @@ class PersonDatabase:
2752
3650
  Set of QIDs (starting with 'Q') found in country, role, or org fields
2753
3651
  """
2754
3652
  conn = self._connect()
3653
+
3654
+ # v2 schema stores QIDs as integers, not text - this method doesn't apply
3655
+ if self._is_v2:
3656
+ return set()
3657
+
2755
3658
  qids: set[str] = set()
2756
3659
 
2757
3660
  # Get QIDs from country field
@@ -2797,11 +3700,27 @@ class PersonDatabase:
2797
3700
  """
2798
3701
  conn = self._connect()
2799
3702
  count = 0
3703
+ skipped = 0
2800
3704
 
2801
3705
  for qid, label in label_map.items():
3706
+ # Skip non-Q IDs (e.g., property IDs like P19)
3707
+ if not qid.startswith("Q"):
3708
+ skipped += 1
3709
+ continue
3710
+
3711
+ # v2 schema stores QID as integer without Q prefix
3712
+ if self._is_v2:
3713
+ try:
3714
+ qid_val: str | int = int(qid[1:])
3715
+ except ValueError:
3716
+ skipped += 1
3717
+ continue
3718
+ else:
3719
+ qid_val = qid
3720
+
2802
3721
  conn.execute(
2803
3722
  "INSERT OR REPLACE INTO qid_labels (qid, label) VALUES (?, ?)",
2804
- (qid, label)
3723
+ (qid_val, label)
2805
3724
  )
2806
3725
  count += 1
2807
3726
 
@@ -2824,9 +3743,16 @@ class PersonDatabase:
2824
3743
  Label string or None if not found
2825
3744
  """
2826
3745
  conn = self._connect()
3746
+
3747
+ # v2 schema stores QID as integer without Q prefix
3748
+ if self._is_v2:
3749
+ qid_val: str | int = int(qid[1:]) if qid.startswith("Q") else int(qid)
3750
+ else:
3751
+ qid_val = qid
3752
+
2827
3753
  cursor = conn.execute(
2828
3754
  "SELECT label FROM qid_labels WHERE qid = ?",
2829
- (qid,)
3755
+ (qid_val,)
2830
3756
  )
2831
3757
  row = cursor.fetchone()
2832
3758
  return row["label"] if row else None
@@ -2879,11 +3805,19 @@ class PersonDatabase:
2879
3805
  logger.info("Phase 1: Building person index...")
2880
3806
 
2881
3807
  # Load all people with their normalized names and org info
2882
- cursor = conn.execute("""
2883
- SELECT id, name, name_normalized, source, source_id,
2884
- known_for_org, known_for_org_id, from_date, to_date
2885
- FROM people
2886
- """)
3808
+ if self._is_v2:
3809
+ cursor = conn.execute("""
3810
+ SELECT p.id, p.name, p.name_normalized, s.name as source, p.source_identifier as source_id,
3811
+ p.known_for_org, p.known_for_org_id, p.from_date, p.to_date
3812
+ FROM people p
3813
+ JOIN source_types s ON p.source_id = s.id
3814
+ """)
3815
+ else:
3816
+ cursor = conn.execute("""
3817
+ SELECT id, name, name_normalized, source, source_id,
3818
+ known_for_org, known_for_org_id, from_date, to_date
3819
+ FROM people
3820
+ """)
2887
3821
 
2888
3822
  people: list[dict] = []
2889
3823
  for row in cursor:
@@ -3032,3 +3966,860 @@ class PersonDatabase:
3032
3966
  "UPDATE people SET canon_id = ?, canon_size = ? WHERE id = ?",
3033
3967
  (canon_id, canon_size, person_id)
3034
3968
  )
3969
+
3970
+
3971
+ # =============================================================================
3972
+ # Module-level singletons for new v2 databases
3973
+ # =============================================================================
3974
+
3975
+ _roles_database_instances: dict[str, "RolesDatabase"] = {}
3976
+ _locations_database_instances: dict[str, "LocationsDatabase"] = {}
3977
+
3978
+
3979
+ def get_roles_database(db_path: Optional[str | Path] = None) -> "RolesDatabase":
3980
+ """
3981
+ Get a singleton RolesDatabase instance for the given path.
3982
+
3983
+ Args:
3984
+ db_path: Path to database file
3985
+
3986
+ Returns:
3987
+ Shared RolesDatabase instance
3988
+ """
3989
+ path_key = str(db_path or DEFAULT_DB_PATH)
3990
+ if path_key not in _roles_database_instances:
3991
+ logger.debug(f"Creating new RolesDatabase instance for {path_key}")
3992
+ _roles_database_instances[path_key] = RolesDatabase(db_path=db_path)
3993
+ return _roles_database_instances[path_key]
3994
+
3995
+
3996
+ def get_locations_database(db_path: Optional[str | Path] = None) -> "LocationsDatabase":
3997
+ """
3998
+ Get a singleton LocationsDatabase instance for the given path.
3999
+
4000
+ Args:
4001
+ db_path: Path to database file
4002
+
4003
+ Returns:
4004
+ Shared LocationsDatabase instance
4005
+ """
4006
+ path_key = str(db_path or DEFAULT_DB_PATH)
4007
+ if path_key not in _locations_database_instances:
4008
+ logger.debug(f"Creating new LocationsDatabase instance for {path_key}")
4009
+ _locations_database_instances[path_key] = LocationsDatabase(db_path=db_path)
4010
+ return _locations_database_instances[path_key]
4011
+
4012
+
4013
+ # =============================================================================
4014
+ # ROLES DATABASE (v2)
4015
+ # =============================================================================
4016
+
4017
+
4018
+ class RolesDatabase:
4019
+ """
4020
+ SQLite database for job titles/roles.
4021
+
4022
+ Stores normalized role records with source tracking and supports
4023
+ canonicalization to group equivalent roles (e.g., CEO, Chief Executive).
4024
+ """
4025
+
4026
+ def __init__(self, db_path: Optional[str | Path] = None):
4027
+ """
4028
+ Initialize the roles database.
4029
+
4030
+ Args:
4031
+ db_path: Path to database file (creates if not exists)
4032
+ """
4033
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
4034
+ self._conn: Optional[sqlite3.Connection] = None
4035
+ self._role_cache: dict[str, int] = {} # name_normalized -> role_id
4036
+
4037
+ def _connect(self) -> sqlite3.Connection:
4038
+ """Get or create database connection using shared connection pool."""
4039
+ if self._conn is not None:
4040
+ return self._conn
4041
+
4042
+ self._conn = _get_shared_connection(self._db_path)
4043
+ self._create_tables()
4044
+ return self._conn
4045
+
4046
+ def _create_tables(self) -> None:
4047
+ """Create roles table and indexes."""
4048
+ conn = self._conn
4049
+ assert conn is not None
4050
+
4051
+ # Check if enum tables exist, create and seed if not
4052
+ cursor = conn.execute(
4053
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
4054
+ )
4055
+ if not cursor.fetchone():
4056
+ logger.info("Creating enum tables for v2 schema...")
4057
+ from .schema_v2 import (
4058
+ CREATE_SOURCE_TYPES,
4059
+ CREATE_PEOPLE_TYPES,
4060
+ CREATE_ORGANIZATION_TYPES,
4061
+ CREATE_SIMPLIFIED_LOCATION_TYPES,
4062
+ CREATE_LOCATION_TYPES,
4063
+ )
4064
+ conn.execute(CREATE_SOURCE_TYPES)
4065
+ conn.execute(CREATE_PEOPLE_TYPES)
4066
+ conn.execute(CREATE_ORGANIZATION_TYPES)
4067
+ conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
4068
+ conn.execute(CREATE_LOCATION_TYPES)
4069
+ seed_all_enums(conn)
4070
+
4071
+ # Create roles table
4072
+ conn.execute("""
4073
+ CREATE TABLE IF NOT EXISTS roles (
4074
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
4075
+ qid INTEGER,
4076
+ name TEXT NOT NULL,
4077
+ name_normalized TEXT NOT NULL,
4078
+ source_id INTEGER NOT NULL DEFAULT 4,
4079
+ source_identifier TEXT,
4080
+ record TEXT NOT NULL DEFAULT '{}',
4081
+ canon_id INTEGER DEFAULT NULL,
4082
+ canon_size INTEGER DEFAULT 1,
4083
+ UNIQUE(name_normalized, source_id)
4084
+ )
4085
+ """)
4086
+
4087
+ # Create indexes
4088
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name ON roles(name)")
4089
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name_normalized ON roles(name_normalized)")
4090
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_qid ON roles(qid)")
4091
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_source_id ON roles(source_id)")
4092
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_canon_id ON roles(canon_id)")
4093
+
4094
+ conn.commit()
4095
+
4096
+ def close(self) -> None:
4097
+ """Clear connection reference."""
4098
+ self._conn = None
4099
+
4100
+ def get_or_create(
4101
+ self,
4102
+ name: str,
4103
+ source_id: int = 4, # wikidata
4104
+ qid: Optional[int] = None,
4105
+ source_identifier: Optional[str] = None,
4106
+ ) -> int:
4107
+ """
4108
+ Get or create a role record.
4109
+
4110
+ Args:
4111
+ name: Role/title name
4112
+ source_id: FK to source_types table
4113
+ qid: Optional Wikidata QID as integer
4114
+ source_identifier: Optional source-specific identifier
4115
+
4116
+ Returns:
4117
+ Role ID
4118
+ """
4119
+ if not name:
4120
+ raise ValueError("Role name cannot be empty")
4121
+
4122
+ conn = self._connect()
4123
+ name_normalized = name.lower().strip()
4124
+
4125
+ # Check cache
4126
+ cache_key = f"{name_normalized}:{source_id}"
4127
+ if cache_key in self._role_cache:
4128
+ return self._role_cache[cache_key]
4129
+
4130
+ # Check database
4131
+ cursor = conn.execute(
4132
+ "SELECT id FROM roles WHERE name_normalized = ? AND source_id = ?",
4133
+ (name_normalized, source_id)
4134
+ )
4135
+ row = cursor.fetchone()
4136
+ if row:
4137
+ role_id = row["id"]
4138
+ self._role_cache[cache_key] = role_id
4139
+ return role_id
4140
+
4141
+ # Create new role
4142
+ cursor = conn.execute(
4143
+ """
4144
+ INSERT INTO roles (name, name_normalized, source_id, qid, source_identifier)
4145
+ VALUES (?, ?, ?, ?, ?)
4146
+ """,
4147
+ (name, name_normalized, source_id, qid, source_identifier)
4148
+ )
4149
+ role_id = cursor.lastrowid
4150
+ assert role_id is not None
4151
+ conn.commit()
4152
+
4153
+ self._role_cache[cache_key] = role_id
4154
+ return role_id
4155
+
4156
+ def get_by_id(self, role_id: int) -> Optional[RoleRecord]:
4157
+ """Get a role record by ID."""
4158
+ conn = self._connect()
4159
+
4160
+ cursor = conn.execute(
4161
+ "SELECT id, qid, name, source_id, source_identifier, record FROM roles WHERE id = ?",
4162
+ (role_id,)
4163
+ )
4164
+ row = cursor.fetchone()
4165
+ if row:
4166
+ source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
4167
+ return RoleRecord(
4168
+ name=row["name"],
4169
+ source=source_name,
4170
+ source_id=row["source_identifier"],
4171
+ qid=row["qid"],
4172
+ record=json.loads(row["record"]) if row["record"] else {},
4173
+ )
4174
+ return None
4175
+
4176
+ def search(
4177
+ self,
4178
+ query: str,
4179
+ top_k: int = 10,
4180
+ ) -> list[tuple[int, str, float]]:
4181
+ """
4182
+ Search for roles by name.
4183
+
4184
+ Args:
4185
+ query: Search query
4186
+ top_k: Maximum results to return
4187
+
4188
+ Returns:
4189
+ List of (role_id, role_name, score) tuples
4190
+ """
4191
+ conn = self._connect()
4192
+ query_normalized = query.lower().strip()
4193
+
4194
+ # Exact match first
4195
+ cursor = conn.execute(
4196
+ "SELECT id, name FROM roles WHERE name_normalized = ? LIMIT 1",
4197
+ (query_normalized,)
4198
+ )
4199
+ row = cursor.fetchone()
4200
+ if row:
4201
+ return [(row["id"], row["name"], 1.0)]
4202
+
4203
+ # LIKE match
4204
+ cursor = conn.execute(
4205
+ """
4206
+ SELECT id, name FROM roles
4207
+ WHERE name_normalized LIKE ?
4208
+ ORDER BY length(name)
4209
+ LIMIT ?
4210
+ """,
4211
+ (f"%{query_normalized}%", top_k)
4212
+ )
4213
+
4214
+ results = []
4215
+ for row in cursor:
4216
+ # Simple score based on match quality
4217
+ name_normalized = row["name"].lower()
4218
+ if query_normalized == name_normalized:
4219
+ score = 1.0
4220
+ elif name_normalized.startswith(query_normalized):
4221
+ score = 0.9
4222
+ else:
4223
+ score = 0.7
4224
+ results.append((row["id"], row["name"], score))
4225
+
4226
+ return results
4227
+
4228
+ def get_stats(self) -> dict[str, int]:
4229
+ """Get statistics about the roles table."""
4230
+ conn = self._connect()
4231
+
4232
+ cursor = conn.execute("SELECT COUNT(*) FROM roles")
4233
+ total = cursor.fetchone()[0]
4234
+
4235
+ cursor = conn.execute("SELECT COUNT(*) FROM roles WHERE canon_id IS NOT NULL")
4236
+ canonicalized = cursor.fetchone()[0]
4237
+
4238
+ cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM roles WHERE canon_id IS NOT NULL")
4239
+ groups = cursor.fetchone()[0]
4240
+
4241
+ return {
4242
+ "total_roles": total,
4243
+ "canonicalized": canonicalized,
4244
+ "canonical_groups": groups,
4245
+ }
4246
+
4247
+
4248
+ # =============================================================================
4249
+ # LOCATIONS DATABASE (v2)
4250
+ # =============================================================================
4251
+
4252
+
4253
+ class LocationsDatabase:
4254
+ """
4255
+ SQLite database for geopolitical locations.
4256
+
4257
+ Stores countries, states, cities with hierarchical relationships
4258
+ and type classification. Supports pycountry integration.
4259
+ """
4260
+
4261
+ def __init__(self, db_path: Optional[str | Path] = None):
4262
+ """
4263
+ Initialize the locations database.
4264
+
4265
+ Args:
4266
+ db_path: Path to database file (creates if not exists)
4267
+ """
4268
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
4269
+ self._conn: Optional[sqlite3.Connection] = None
4270
+ self._location_cache: dict[str, int] = {} # lookup_key -> location_id
4271
+ self._location_type_cache: dict[str, int] = {} # type_name -> type_id
4272
+ self._location_type_qid_cache: dict[int, int] = {} # qid -> type_id
4273
+
4274
+ def _connect(self) -> sqlite3.Connection:
4275
+ """Get or create database connection using shared connection pool."""
4276
+ if self._conn is not None:
4277
+ return self._conn
4278
+
4279
+ self._conn = _get_shared_connection(self._db_path)
4280
+ self._create_tables()
4281
+ self._build_caches()
4282
+ return self._conn
4283
+
4284
+ def _create_tables(self) -> None:
4285
+ """Create locations table and indexes."""
4286
+ conn = self._conn
4287
+ assert conn is not None
4288
+
4289
+ # Check if enum tables exist, create and seed if not
4290
+ cursor = conn.execute(
4291
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
4292
+ )
4293
+ if not cursor.fetchone():
4294
+ logger.info("Creating enum tables for v2 schema...")
4295
+ from .schema_v2 import (
4296
+ CREATE_SOURCE_TYPES,
4297
+ CREATE_PEOPLE_TYPES,
4298
+ CREATE_ORGANIZATION_TYPES,
4299
+ CREATE_SIMPLIFIED_LOCATION_TYPES,
4300
+ CREATE_LOCATION_TYPES,
4301
+ )
4302
+ conn.execute(CREATE_SOURCE_TYPES)
4303
+ conn.execute(CREATE_PEOPLE_TYPES)
4304
+ conn.execute(CREATE_ORGANIZATION_TYPES)
4305
+ conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
4306
+ conn.execute(CREATE_LOCATION_TYPES)
4307
+ seed_all_enums(conn)
4308
+
4309
+ # Create locations table
4310
+ conn.execute("""
4311
+ CREATE TABLE IF NOT EXISTS locations (
4312
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
4313
+ qid INTEGER,
4314
+ name TEXT NOT NULL,
4315
+ name_normalized TEXT NOT NULL,
4316
+ source_id INTEGER NOT NULL DEFAULT 4,
4317
+ source_identifier TEXT,
4318
+ parent_ids TEXT,
4319
+ location_type_id INTEGER NOT NULL DEFAULT 2,
4320
+ record TEXT NOT NULL DEFAULT '{}',
4321
+ from_date TEXT DEFAULT NULL,
4322
+ to_date TEXT DEFAULT NULL,
4323
+ canon_id INTEGER DEFAULT NULL,
4324
+ canon_size INTEGER DEFAULT 1,
4325
+ UNIQUE(source_identifier, source_id)
4326
+ )
4327
+ """)
4328
+
4329
+ # Create indexes
4330
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name ON locations(name)")
4331
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name_normalized ON locations(name_normalized)")
4332
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_qid ON locations(qid)")
4333
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_source_id ON locations(source_id)")
4334
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_location_type_id ON locations(location_type_id)")
4335
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_canon_id ON locations(canon_id)")
4336
+
4337
+ conn.commit()
4338
+
4339
+ def _build_caches(self) -> None:
4340
+ """Build lookup caches from database and seed data."""
4341
+ # Load location type caches from seed data
4342
+ self._location_type_cache = dict(LOCATION_TYPE_NAME_TO_ID)
4343
+ self._location_type_qid_cache = dict(LOCATION_TYPE_QID_TO_ID)
4344
+
4345
+ # Load existing locations into cache
4346
+ conn = self._conn
4347
+ if conn:
4348
+ cursor = conn.execute(
4349
+ "SELECT id, name_normalized, source_identifier FROM locations"
4350
+ )
4351
+ for row in cursor:
4352
+ # Cache by normalized name
4353
+ self._location_cache[row["name_normalized"]] = row["id"]
4354
+ # Also cache by source_identifier
4355
+ if row["source_identifier"]:
4356
+ self._location_cache[row["source_identifier"].lower()] = row["id"]
4357
+
4358
+ def close(self) -> None:
4359
+ """Clear connection reference."""
4360
+ self._conn = None
4361
+
4362
+ def get_or_create(
4363
+ self,
4364
+ name: str,
4365
+ location_type_id: int,
4366
+ source_id: int = 4, # wikidata
4367
+ qid: Optional[int] = None,
4368
+ source_identifier: Optional[str] = None,
4369
+ parent_ids: Optional[list[int]] = None,
4370
+ ) -> int:
4371
+ """
4372
+ Get or create a location record.
4373
+
4374
+ Args:
4375
+ name: Location name
4376
+ location_type_id: FK to location_types table
4377
+ source_id: FK to source_types table
4378
+ qid: Optional Wikidata QID as integer
4379
+ source_identifier: Optional source-specific identifier (e.g., "US", "CA")
4380
+ parent_ids: Optional list of parent location IDs
4381
+
4382
+ Returns:
4383
+ Location ID
4384
+ """
4385
+ if not name:
4386
+ raise ValueError("Location name cannot be empty")
4387
+
4388
+ conn = self._connect()
4389
+ name_normalized = name.lower().strip()
4390
+
4391
+ # Check cache by source_identifier first (more specific)
4392
+ if source_identifier:
4393
+ cache_key = source_identifier.lower()
4394
+ if cache_key in self._location_cache:
4395
+ return self._location_cache[cache_key]
4396
+
4397
+ # Check cache by normalized name
4398
+ if name_normalized in self._location_cache:
4399
+ return self._location_cache[name_normalized]
4400
+
4401
+ # Check database
4402
+ if source_identifier:
4403
+ cursor = conn.execute(
4404
+ "SELECT id FROM locations WHERE source_identifier = ? AND source_id = ?",
4405
+ (source_identifier, source_id)
4406
+ )
4407
+ else:
4408
+ cursor = conn.execute(
4409
+ "SELECT id FROM locations WHERE name_normalized = ? AND source_id = ?",
4410
+ (name_normalized, source_id)
4411
+ )
4412
+
4413
+ row = cursor.fetchone()
4414
+ if row:
4415
+ location_id = row["id"]
4416
+ self._location_cache[name_normalized] = location_id
4417
+ if source_identifier:
4418
+ self._location_cache[source_identifier.lower()] = location_id
4419
+ return location_id
4420
+
4421
+ # Create new location
4422
+ parent_ids_json = json.dumps(parent_ids) if parent_ids else None
4423
+ cursor = conn.execute(
4424
+ """
4425
+ INSERT INTO locations
4426
+ (name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids)
4427
+ VALUES (?, ?, ?, ?, ?, ?, ?)
4428
+ """,
4429
+ (name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids_json)
4430
+ )
4431
+ location_id = cursor.lastrowid
4432
+ assert location_id is not None
4433
+ conn.commit()
4434
+
4435
+ self._location_cache[name_normalized] = location_id
4436
+ if source_identifier:
4437
+ self._location_cache[source_identifier.lower()] = location_id
4438
+ return location_id
4439
+
4440
+ def get_or_create_by_qid(
4441
+ self,
4442
+ name: str,
4443
+ wikidata_type_qid: int,
4444
+ source_id: int = 4,
4445
+ entity_qid: Optional[int] = None,
4446
+ source_identifier: Optional[str] = None,
4447
+ parent_ids: Optional[list[int]] = None,
4448
+ ) -> int:
4449
+ """
4450
+ Get or create a location using Wikidata P31 type QID.
4451
+
4452
+ Args:
4453
+ name: Location name
4454
+ wikidata_type_qid: Wikidata instance-of QID (e.g., 515 for city)
4455
+ source_id: FK to source_types table
4456
+ entity_qid: Wikidata QID of the entity itself
4457
+ source_identifier: Optional source-specific identifier
4458
+ parent_ids: Optional list of parent location IDs
4459
+
4460
+ Returns:
4461
+ Location ID
4462
+ """
4463
+ location_type_id = self.get_location_type_id_from_qid(wikidata_type_qid)
4464
+ return self.get_or_create(
4465
+ name=name,
4466
+ location_type_id=location_type_id,
4467
+ source_id=source_id,
4468
+ qid=entity_qid,
4469
+ source_identifier=source_identifier,
4470
+ parent_ids=parent_ids,
4471
+ )
4472
+
4473
+ def get_by_id(self, location_id: int) -> Optional[LocationRecord]:
4474
+ """Get a location record by ID."""
4475
+ conn = self._connect()
4476
+
4477
+ cursor = conn.execute(
4478
+ """
4479
+ SELECT id, qid, name, source_id, source_identifier, location_type_id,
4480
+ parent_ids, from_date, to_date, record
4481
+ FROM locations WHERE id = ?
4482
+ """,
4483
+ (location_id,)
4484
+ )
4485
+ row = cursor.fetchone()
4486
+ if row:
4487
+ source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
4488
+ location_type_id = row["location_type_id"]
4489
+ location_type_name = self._get_location_type_name(location_type_id)
4490
+ simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
4491
+ simplified_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
4492
+
4493
+ parent_ids = json.loads(row["parent_ids"]) if row["parent_ids"] else []
4494
+
4495
+ return LocationRecord(
4496
+ name=row["name"],
4497
+ source=source_name,
4498
+ source_id=row["source_identifier"],
4499
+ qid=row["qid"],
4500
+ location_type=location_type_name,
4501
+ simplified_type=SimplifiedLocationType(simplified_name),
4502
+ parent_ids=parent_ids,
4503
+ from_date=row["from_date"],
4504
+ to_date=row["to_date"],
4505
+ record=json.loads(row["record"]) if row["record"] else {},
4506
+ )
4507
+ return None
4508
+
4509
+ def _get_location_type_name(self, type_id: int) -> str:
4510
+ """Get location type name from ID."""
4511
+ # Reverse lookup in cache
4512
+ for name, id_ in self._location_type_cache.items():
4513
+ if id_ == type_id:
4514
+ return name
4515
+ return "other"
4516
+
4517
+ def get_location_type_id(self, type_name: str) -> int:
4518
+ """Get location_type_id for a type name."""
4519
+ return self._location_type_cache.get(type_name, 36) # default to "other"
4520
+
4521
+ def get_location_type_id_from_qid(self, wikidata_qid: int) -> int:
4522
+ """Get location_type_id from Wikidata P31 QID."""
4523
+ return self._location_type_qid_cache.get(wikidata_qid, 36) # default to "other"
4524
+
4525
+ def get_simplified_type(self, location_type_id: int) -> str:
4526
+ """Get simplified type name for a location_type_id."""
4527
+ simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
4528
+ return SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
4529
+
4530
+ def resolve_region_text(self, text: str) -> Optional[int]:
4531
+ """
4532
+ Resolve a region/country text to a location ID.
4533
+
4534
+ Uses pycountry for country resolution, then falls back to search.
4535
+
4536
+ Args:
4537
+ text: Region text (country code, name, or QID)
4538
+
4539
+ Returns:
4540
+ Location ID or None if not resolved
4541
+ """
4542
+ if not text:
4543
+ return None
4544
+
4545
+ text_lower = text.lower().strip()
4546
+
4547
+ # Check cache first
4548
+ if text_lower in self._location_cache:
4549
+ return self._location_cache[text_lower]
4550
+
4551
+ # Try pycountry resolution
4552
+ alpha_2 = self._resolve_via_pycountry(text)
4553
+ if alpha_2:
4554
+ alpha_2_lower = alpha_2.lower()
4555
+ if alpha_2_lower in self._location_cache:
4556
+ location_id = self._location_cache[alpha_2_lower]
4557
+ self._location_cache[text_lower] = location_id # Cache the input too
4558
+ return location_id
4559
+
4560
+ # Country not in database yet, import it
4561
+ try:
4562
+ country = pycountry.countries.get(alpha_2=alpha_2)
4563
+ if country:
4564
+ country_type_id = self._location_type_cache.get("country", 2)
4565
+ location_id = self.get_or_create(
4566
+ name=country.name,
4567
+ location_type_id=country_type_id,
4568
+ source_id=4, # wikidata
4569
+ source_identifier=alpha_2,
4570
+ )
4571
+ self._location_cache[text_lower] = location_id
4572
+ return location_id
4573
+ except Exception:
4574
+ pass
4575
+
4576
+ return None
4577
+
4578
+ def _resolve_via_pycountry(self, region: str) -> Optional[str]:
4579
+ """Try to resolve region via pycountry."""
4580
+ region_clean = region.strip()
4581
+ if not region_clean:
4582
+ return None
4583
+
4584
+ # Try as 2-letter code
4585
+ if len(region_clean) == 2:
4586
+ country = pycountry.countries.get(alpha_2=region_clean.upper())
4587
+ if country:
4588
+ return country.alpha_2
4589
+
4590
+ # Try as 3-letter code
4591
+ if len(region_clean) == 3:
4592
+ country = pycountry.countries.get(alpha_3=region_clean.upper())
4593
+ if country:
4594
+ return country.alpha_2
4595
+
4596
+ # Try fuzzy search
4597
+ try:
4598
+ matches = pycountry.countries.search_fuzzy(region_clean)
4599
+ if matches:
4600
+ return matches[0].alpha_2
4601
+ except LookupError:
4602
+ pass
4603
+
4604
+ return None
4605
+
4606
+ def import_from_pycountry(self) -> int:
4607
+ """
4608
+ Import all countries from pycountry.
4609
+
4610
+ Returns:
4611
+ Number of locations imported
4612
+ """
4613
+ conn = self._connect()
4614
+ country_type_id = self._location_type_cache.get("country", 2)
4615
+ count = 0
4616
+
4617
+ for country in pycountry.countries:
4618
+ name = country.name
4619
+ alpha_2 = country.alpha_2
4620
+ name_normalized = name.lower()
4621
+
4622
+ # Check if already exists
4623
+ if alpha_2.lower() in self._location_cache:
4624
+ continue
4625
+
4626
+ cursor = conn.execute(
4627
+ """
4628
+ INSERT OR IGNORE INTO locations
4629
+ (name, name_normalized, source_id, source_identifier, location_type_id)
4630
+ VALUES (?, ?, 4, ?, ?)
4631
+ """,
4632
+ (name, name_normalized, alpha_2, country_type_id)
4633
+ )
4634
+
4635
+ if cursor.lastrowid:
4636
+ self._location_cache[name_normalized] = cursor.lastrowid
4637
+ self._location_cache[alpha_2.lower()] = cursor.lastrowid
4638
+ count += 1
4639
+
4640
+ conn.commit()
4641
+ logger.info(f"Imported {count} countries from pycountry")
4642
+ return count
4643
+
4644
+ def search(
4645
+ self,
4646
+ query: str,
4647
+ top_k: int = 10,
4648
+ simplified_type: Optional[str] = None,
4649
+ ) -> list[tuple[int, str, float]]:
4650
+ """
4651
+ Search for locations by name.
4652
+
4653
+ Args:
4654
+ query: Search query
4655
+ top_k: Maximum results to return
4656
+ simplified_type: Optional filter by simplified type (e.g., "country", "city")
4657
+
4658
+ Returns:
4659
+ List of (location_id, location_name, score) tuples
4660
+ """
4661
+ conn = self._connect()
4662
+ query_normalized = query.lower().strip()
4663
+
4664
+ # Build query with optional type filter
4665
+ if simplified_type:
4666
+ # Get all location_type_ids for this simplified type
4667
+ simplified_id = {
4668
+ name: id_ for id_, name in SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.items()
4669
+ }.get(simplified_type)
4670
+ if simplified_id:
4671
+ type_ids = [
4672
+ type_id for type_id, simp_id in LOCATION_TYPE_TO_SIMPLIFIED.items()
4673
+ if simp_id == simplified_id
4674
+ ]
4675
+ if type_ids:
4676
+ placeholders = ",".join("?" * len(type_ids))
4677
+ cursor = conn.execute(
4678
+ f"""
4679
+ SELECT id, name FROM locations
4680
+ WHERE name_normalized LIKE ? AND location_type_id IN ({placeholders})
4681
+ ORDER BY length(name)
4682
+ LIMIT ?
4683
+ """,
4684
+ [f"%{query_normalized}%"] + type_ids + [top_k]
4685
+ )
4686
+ else:
4687
+ return []
4688
+ else:
4689
+ return []
4690
+ else:
4691
+ cursor = conn.execute(
4692
+ """
4693
+ SELECT id, name FROM locations
4694
+ WHERE name_normalized LIKE ?
4695
+ ORDER BY length(name)
4696
+ LIMIT ?
4697
+ """,
4698
+ (f"%{query_normalized}%", top_k)
4699
+ )
4700
+
4701
+ results = []
4702
+ for row in cursor:
4703
+ name_normalized = row["name"].lower()
4704
+ if query_normalized == name_normalized:
4705
+ score = 1.0
4706
+ elif name_normalized.startswith(query_normalized):
4707
+ score = 0.9
4708
+ else:
4709
+ score = 0.7
4710
+ results.append((row["id"], row["name"], score))
4711
+
4712
+ return results
4713
+
4714
+ def get_stats(self) -> dict[str, Any]:
4715
+ """Get statistics about the locations table."""
4716
+ conn = self._connect()
4717
+
4718
+ cursor = conn.execute("SELECT COUNT(*) FROM locations")
4719
+ total = cursor.fetchone()[0]
4720
+
4721
+ cursor = conn.execute("SELECT COUNT(*) FROM locations WHERE canon_id IS NOT NULL")
4722
+ canonicalized = cursor.fetchone()[0]
4723
+
4724
+ cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM locations WHERE canon_id IS NOT NULL")
4725
+ groups = cursor.fetchone()[0]
4726
+
4727
+ # Count by simplified type
4728
+ by_type: dict[str, int] = {}
4729
+ cursor = conn.execute("""
4730
+ SELECT lt.simplified_id, COUNT(*) as cnt
4731
+ FROM locations l
4732
+ JOIN location_types lt ON l.location_type_id = lt.id
4733
+ GROUP BY lt.simplified_id
4734
+ """)
4735
+ for row in cursor:
4736
+ type_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(row["simplified_id"], "other")
4737
+ by_type[type_name] = row["cnt"]
4738
+
4739
+ return {
4740
+ "total_locations": total,
4741
+ "canonicalized": canonicalized,
4742
+ "canonical_groups": groups,
4743
+ "by_type": by_type,
4744
+ }
4745
+
4746
+ def insert_batch(self, records: list[LocationRecord]) -> int:
4747
+ """
4748
+ Insert a batch of location records.
4749
+
4750
+ Args:
4751
+ records: List of LocationRecord objects to insert
4752
+
4753
+ Returns:
4754
+ Number of records inserted
4755
+ """
4756
+ if not records:
4757
+ return 0
4758
+
4759
+ conn = self._connect()
4760
+ inserted = 0
4761
+
4762
+ for record in records:
4763
+ name_normalized = record.name.lower().strip()
4764
+ source_identifier = record.source_id # Q code in source_id field
4765
+
4766
+ # Check cache first
4767
+ cache_key = source_identifier.lower() if source_identifier else name_normalized
4768
+ if cache_key in self._location_cache:
4769
+ continue
4770
+
4771
+ # Get location_type_id from type name
4772
+ location_type_id = self._location_type_cache.get(record.location_type, 36) # default "other"
4773
+ source_id = SOURCE_NAME_TO_ID.get(record.source, 4) # default wikidata
4774
+
4775
+ parent_ids_json = json.dumps(record.parent_ids) if record.parent_ids else None
4776
+ record_json = json.dumps(record.record) if record.record else "{}"
4777
+
4778
+ try:
4779
+ cursor = conn.execute(
4780
+ """
4781
+ INSERT OR IGNORE INTO locations
4782
+ (name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids, record, from_date, to_date)
4783
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
4784
+ """,
4785
+ (
4786
+ record.name,
4787
+ name_normalized,
4788
+ source_id,
4789
+ source_identifier,
4790
+ record.qid,
4791
+ location_type_id,
4792
+ parent_ids_json,
4793
+ record_json,
4794
+ record.from_date,
4795
+ record.to_date,
4796
+ )
4797
+ )
4798
+ if cursor.lastrowid:
4799
+ self._location_cache[name_normalized] = cursor.lastrowid
4800
+ if source_identifier:
4801
+ self._location_cache[source_identifier.lower()] = cursor.lastrowid
4802
+ inserted += 1
4803
+ except Exception as e:
4804
+ logger.warning(f"Failed to insert location {record.name}: {e}")
4805
+
4806
+ conn.commit()
4807
+ return inserted
4808
+
4809
+ def get_all_source_ids(self, source: str = "wikidata") -> set[str]:
4810
+ """
4811
+ Get all source_identifiers for a given source.
4812
+
4813
+ Args:
4814
+ source: Source name (e.g., "wikidata")
4815
+
4816
+ Returns:
4817
+ Set of source_identifiers
4818
+ """
4819
+ conn = self._connect()
4820
+ source_id = SOURCE_NAME_TO_ID.get(source, 4)
4821
+ cursor = conn.execute(
4822
+ "SELECT source_identifier FROM locations WHERE source_id = ? AND source_identifier IS NOT NULL",
4823
+ (source_id,)
4824
+ )
4825
+ return {row["source_identifier"] for row in cursor}