corp-extractor 0.9.3__py3-none-any.whl → 0.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.9.3.dist-info → corp_extractor-0.9.4.dist-info}/METADATA +33 -3
- {corp_extractor-0.9.3.dist-info → corp_extractor-0.9.4.dist-info}/RECORD +16 -12
- statement_extractor/cli.py +472 -45
- statement_extractor/database/embeddings.py +45 -0
- statement_extractor/database/hub.py +51 -9
- statement_extractor/database/importers/import_utils.py +264 -0
- statement_extractor/database/importers/wikidata_dump.py +334 -3
- statement_extractor/database/importers/wikidata_people.py +44 -0
- statement_extractor/database/migrate_v2.py +852 -0
- statement_extractor/database/models.py +125 -1
- statement_extractor/database/schema_v2.py +409 -0
- statement_extractor/database/seed_data.py +359 -0
- statement_extractor/database/store.py +2113 -322
- statement_extractor/plugins/qualifiers/person.py +109 -52
- {corp_extractor-0.9.3.dist-info → corp_extractor-0.9.4.dist-info}/WHEEL +0 -0
- {corp_extractor-0.9.3.dist-info → corp_extractor-0.9.4.dist-info}/entry_points.txt +0 -0
|
@@ -12,18 +12,41 @@ import re
|
|
|
12
12
|
import sqlite3
|
|
13
13
|
import time
|
|
14
14
|
from pathlib import Path
|
|
15
|
-
from typing import Iterator, Optional
|
|
15
|
+
from typing import Any, Iterator, Optional
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pycountry
|
|
19
19
|
import sqlite_vec
|
|
20
20
|
|
|
21
|
-
from .models import
|
|
21
|
+
from .models import (
|
|
22
|
+
CompanyRecord,
|
|
23
|
+
DatabaseStats,
|
|
24
|
+
EntityType,
|
|
25
|
+
LocationRecord,
|
|
26
|
+
PersonRecord,
|
|
27
|
+
PersonType,
|
|
28
|
+
RoleRecord,
|
|
29
|
+
SimplifiedLocationType,
|
|
30
|
+
)
|
|
31
|
+
from .seed_data import (
|
|
32
|
+
LOCATION_TYPE_NAME_TO_ID,
|
|
33
|
+
LOCATION_TYPE_QID_TO_ID,
|
|
34
|
+
LOCATION_TYPE_TO_SIMPLIFIED,
|
|
35
|
+
ORG_TYPE_ID_TO_NAME,
|
|
36
|
+
ORG_TYPE_NAME_TO_ID,
|
|
37
|
+
PEOPLE_TYPE_ID_TO_NAME,
|
|
38
|
+
PEOPLE_TYPE_NAME_TO_ID,
|
|
39
|
+
SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME,
|
|
40
|
+
SOURCE_ID_TO_NAME,
|
|
41
|
+
SOURCE_NAME_TO_ID,
|
|
42
|
+
seed_all_enums,
|
|
43
|
+
seed_pycountry_locations,
|
|
44
|
+
)
|
|
22
45
|
|
|
23
46
|
logger = logging.getLogger(__name__)
|
|
24
47
|
|
|
25
48
|
# Default database location
|
|
26
|
-
DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
|
|
49
|
+
DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities-v2.db"
|
|
27
50
|
|
|
28
51
|
# Module-level shared connections by path (both databases share the same connection)
|
|
29
52
|
_shared_connections: dict[str, sqlite3.Connection] = {}
|
|
@@ -500,6 +523,7 @@ class OrganizationDatabase:
|
|
|
500
523
|
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
501
524
|
self._embedding_dim = embedding_dim
|
|
502
525
|
self._conn: Optional[sqlite3.Connection] = None
|
|
526
|
+
self._is_v2: Optional[bool] = None # Detected on first connect
|
|
503
527
|
|
|
504
528
|
def _ensure_dir(self) -> None:
|
|
505
529
|
"""Ensure database directory exists."""
|
|
@@ -512,11 +536,27 @@ class OrganizationDatabase:
|
|
|
512
536
|
|
|
513
537
|
self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
|
|
514
538
|
|
|
515
|
-
#
|
|
516
|
-
|
|
539
|
+
# Detect schema version BEFORE creating tables
|
|
540
|
+
# v2 has entity_type_id (FK) instead of entity_type (TEXT)
|
|
541
|
+
if self._is_v2 is None:
|
|
542
|
+
cursor = self._conn.execute("PRAGMA table_info(organizations)")
|
|
543
|
+
columns = {row["name"] for row in cursor}
|
|
544
|
+
self._is_v2 = "entity_type_id" in columns
|
|
545
|
+
if self._is_v2:
|
|
546
|
+
logger.debug("Detected v2 schema for organizations")
|
|
547
|
+
|
|
548
|
+
# Create tables (idempotent) - only for v1 schema or fresh databases
|
|
549
|
+
# v2 databases already have their schema from migration
|
|
550
|
+
if not self._is_v2:
|
|
551
|
+
self._create_tables()
|
|
517
552
|
|
|
518
553
|
return self._conn
|
|
519
554
|
|
|
555
|
+
@property
|
|
556
|
+
def _org_table(self) -> str:
|
|
557
|
+
"""Return table/view name for organization queries needing text fields."""
|
|
558
|
+
return "organizations_view" if self._is_v2 else "organizations"
|
|
559
|
+
|
|
520
560
|
def _create_tables(self) -> None:
|
|
521
561
|
"""Create database tables including sqlite-vec virtual table."""
|
|
522
562
|
conn = self._conn
|
|
@@ -591,7 +631,7 @@ class OrganizationDatabase:
|
|
|
591
631
|
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
|
|
592
632
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_canon_id ON organizations(canon_id)")
|
|
593
633
|
|
|
594
|
-
# Create sqlite-vec virtual table for embeddings
|
|
634
|
+
# Create sqlite-vec virtual table for embeddings (float32)
|
|
595
635
|
# vec0 is the recommended virtual table type
|
|
596
636
|
conn.execute(f"""
|
|
597
637
|
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings USING vec0(
|
|
@@ -600,19 +640,34 @@ class OrganizationDatabase:
|
|
|
600
640
|
)
|
|
601
641
|
""")
|
|
602
642
|
|
|
643
|
+
# Create sqlite-vec virtual table for scalar embeddings (int8)
|
|
644
|
+
# Provides 75% storage reduction with ~92% recall at top-100
|
|
645
|
+
conn.execute(f"""
|
|
646
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
|
|
647
|
+
org_id INTEGER PRIMARY KEY,
|
|
648
|
+
embedding int8[{self._embedding_dim}]
|
|
649
|
+
)
|
|
650
|
+
""")
|
|
651
|
+
|
|
603
652
|
conn.commit()
|
|
604
653
|
|
|
605
654
|
def close(self) -> None:
|
|
606
655
|
"""Clear connection reference (shared connection remains open)."""
|
|
607
656
|
self._conn = None
|
|
608
657
|
|
|
609
|
-
def insert(
|
|
658
|
+
def insert(
|
|
659
|
+
self,
|
|
660
|
+
record: CompanyRecord,
|
|
661
|
+
embedding: np.ndarray,
|
|
662
|
+
scalar_embedding: Optional[np.ndarray] = None,
|
|
663
|
+
) -> int:
|
|
610
664
|
"""
|
|
611
665
|
Insert an organization record with its embedding.
|
|
612
666
|
|
|
613
667
|
Args:
|
|
614
668
|
record: Organization record to insert
|
|
615
|
-
embedding: Embedding vector for the organization name
|
|
669
|
+
embedding: Embedding vector for the organization name (float32)
|
|
670
|
+
scalar_embedding: Optional int8 scalar embedding for compact storage
|
|
616
671
|
|
|
617
672
|
Returns:
|
|
618
673
|
Row ID of inserted record
|
|
@@ -642,7 +697,7 @@ class OrganizationDatabase:
|
|
|
642
697
|
row_id = cursor.lastrowid
|
|
643
698
|
assert row_id is not None
|
|
644
699
|
|
|
645
|
-
# Insert embedding into vec table
|
|
700
|
+
# Insert embedding into vec table (float32)
|
|
646
701
|
# sqlite-vec virtual tables don't support INSERT OR REPLACE, so delete first
|
|
647
702
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
648
703
|
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (row_id,))
|
|
@@ -651,6 +706,15 @@ class OrganizationDatabase:
|
|
|
651
706
|
VALUES (?, ?)
|
|
652
707
|
""", (row_id, embedding_blob))
|
|
653
708
|
|
|
709
|
+
# Insert scalar embedding if provided (int8)
|
|
710
|
+
if scalar_embedding is not None:
|
|
711
|
+
scalar_blob = scalar_embedding.astype(np.int8).tobytes()
|
|
712
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
|
|
713
|
+
conn.execute("""
|
|
714
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
715
|
+
VALUES (?, vec_int8(?))
|
|
716
|
+
""", (row_id, scalar_blob))
|
|
717
|
+
|
|
654
718
|
conn.commit()
|
|
655
719
|
return row_id
|
|
656
720
|
|
|
@@ -659,14 +723,16 @@ class OrganizationDatabase:
|
|
|
659
723
|
records: list[CompanyRecord],
|
|
660
724
|
embeddings: np.ndarray,
|
|
661
725
|
batch_size: int = 1000,
|
|
726
|
+
scalar_embeddings: Optional[np.ndarray] = None,
|
|
662
727
|
) -> int:
|
|
663
728
|
"""
|
|
664
729
|
Insert multiple organization records with embeddings.
|
|
665
730
|
|
|
666
731
|
Args:
|
|
667
732
|
records: List of organization records
|
|
668
|
-
embeddings: Matrix of embeddings (N x dim)
|
|
733
|
+
embeddings: Matrix of embeddings (N x dim) - float32
|
|
669
734
|
batch_size: Commit batch size
|
|
735
|
+
scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
|
|
670
736
|
|
|
671
737
|
Returns:
|
|
672
738
|
Number of records inserted
|
|
@@ -674,25 +740,54 @@ class OrganizationDatabase:
|
|
|
674
740
|
conn = self._connect()
|
|
675
741
|
count = 0
|
|
676
742
|
|
|
677
|
-
for record, embedding in zip(records, embeddings):
|
|
743
|
+
for i, (record, embedding) in enumerate(zip(records, embeddings)):
|
|
678
744
|
record_json = json.dumps(record.record)
|
|
679
745
|
name_normalized = _normalize_name(record.name)
|
|
680
746
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
record.
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
747
|
+
if self._is_v2:
|
|
748
|
+
# v2 schema: use FK IDs instead of TEXT columns
|
|
749
|
+
source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
|
|
750
|
+
entity_type_id = ORG_TYPE_NAME_TO_ID.get(record.entity_type.value, 17) # 17 = unknown
|
|
751
|
+
|
|
752
|
+
# Resolve region to location_id if provided
|
|
753
|
+
region_id = None
|
|
754
|
+
if record.region:
|
|
755
|
+
# Use locations database to resolve region
|
|
756
|
+
locations_db = get_locations_database(db_path=self._db_path)
|
|
757
|
+
region_id = locations_db.resolve_region_text(record.region)
|
|
758
|
+
|
|
759
|
+
cursor = conn.execute("""
|
|
760
|
+
INSERT OR REPLACE INTO organizations
|
|
761
|
+
(name, name_normalized, source_id, source_identifier, region_id, entity_type_id, from_date, to_date, record)
|
|
762
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
763
|
+
""", (
|
|
764
|
+
record.name,
|
|
765
|
+
name_normalized,
|
|
766
|
+
source_type_id,
|
|
767
|
+
record.source_id,
|
|
768
|
+
region_id,
|
|
769
|
+
entity_type_id,
|
|
770
|
+
record.from_date or "",
|
|
771
|
+
record.to_date or "",
|
|
772
|
+
record_json,
|
|
773
|
+
))
|
|
774
|
+
else:
|
|
775
|
+
# v1 schema: use TEXT columns
|
|
776
|
+
cursor = conn.execute("""
|
|
777
|
+
INSERT OR REPLACE INTO organizations
|
|
778
|
+
(name, name_normalized, source, source_id, region, entity_type, from_date, to_date, record)
|
|
779
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
780
|
+
""", (
|
|
781
|
+
record.name,
|
|
782
|
+
name_normalized,
|
|
783
|
+
record.source,
|
|
784
|
+
record.source_id,
|
|
785
|
+
record.region,
|
|
786
|
+
record.entity_type.value,
|
|
787
|
+
record.from_date or "",
|
|
788
|
+
record.to_date or "",
|
|
789
|
+
record_json,
|
|
790
|
+
))
|
|
696
791
|
|
|
697
792
|
row_id = cursor.lastrowid
|
|
698
793
|
assert row_id is not None
|
|
@@ -705,6 +800,15 @@ class OrganizationDatabase:
|
|
|
705
800
|
VALUES (?, ?)
|
|
706
801
|
""", (row_id, embedding_blob))
|
|
707
802
|
|
|
803
|
+
# Insert scalar embedding if provided (int8)
|
|
804
|
+
if scalar_embeddings is not None:
|
|
805
|
+
scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
|
|
806
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (row_id,))
|
|
807
|
+
conn.execute("""
|
|
808
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
809
|
+
VALUES (?, vec_int8(?))
|
|
810
|
+
""", (row_id, scalar_blob))
|
|
811
|
+
|
|
708
812
|
count += 1
|
|
709
813
|
|
|
710
814
|
if count % batch_size == 0:
|
|
@@ -750,7 +854,13 @@ class OrganizationDatabase:
|
|
|
750
854
|
if query_norm == 0:
|
|
751
855
|
return []
|
|
752
856
|
query_normalized = query_embedding / query_norm
|
|
753
|
-
|
|
857
|
+
|
|
858
|
+
# Use int8 quantized query if scalar table is available (75% storage savings)
|
|
859
|
+
if self._has_scalar_table():
|
|
860
|
+
query_int8 = self._quantize_query(query_normalized)
|
|
861
|
+
query_blob = query_int8.tobytes()
|
|
862
|
+
else:
|
|
863
|
+
query_blob = query_normalized.astype(np.float32).tobytes()
|
|
754
864
|
|
|
755
865
|
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
756
866
|
candidate_ids: Optional[set[int]] = None
|
|
@@ -980,6 +1090,19 @@ class OrganizationDatabase:
|
|
|
980
1090
|
cursor = conn.execute(query, params)
|
|
981
1091
|
return set(row["id"] for row in cursor)
|
|
982
1092
|
|
|
1093
|
+
def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
|
|
1094
|
+
"""Quantize query embedding to int8 for scalar search."""
|
|
1095
|
+
return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
|
|
1096
|
+
|
|
1097
|
+
def _has_scalar_table(self) -> bool:
|
|
1098
|
+
"""Check if scalar embedding table exists."""
|
|
1099
|
+
conn = self._conn
|
|
1100
|
+
assert conn is not None
|
|
1101
|
+
cursor = conn.execute(
|
|
1102
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='organization_embeddings_scalar'"
|
|
1103
|
+
)
|
|
1104
|
+
return cursor.fetchone() is not None
|
|
1105
|
+
|
|
983
1106
|
def _vector_search_filtered(
|
|
984
1107
|
self,
|
|
985
1108
|
query_blob: bytes,
|
|
@@ -987,7 +1110,7 @@ class OrganizationDatabase:
|
|
|
987
1110
|
top_k: int,
|
|
988
1111
|
source_filter: Optional[str],
|
|
989
1112
|
) -> list[tuple[CompanyRecord, float]]:
|
|
990
|
-
"""Vector search within a filtered set of candidates."""
|
|
1113
|
+
"""Vector search within a filtered set of candidates using scalar (int8) embeddings."""
|
|
991
1114
|
conn = self._conn
|
|
992
1115
|
assert conn is not None
|
|
993
1116
|
|
|
@@ -997,18 +1120,29 @@ class OrganizationDatabase:
|
|
|
997
1120
|
# Build IN clause for candidate IDs
|
|
998
1121
|
placeholders = ",".join("?" * len(candidate_ids))
|
|
999
1122
|
|
|
1000
|
-
#
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1123
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
1124
|
+
if self._has_scalar_table():
|
|
1125
|
+
# Query uses int8 embeddings with vec_int8() wrapper
|
|
1126
|
+
query = f"""
|
|
1127
|
+
SELECT
|
|
1128
|
+
e.org_id,
|
|
1129
|
+
vec_distance_cosine(e.embedding, vec_int8(?)) as distance
|
|
1130
|
+
FROM organization_embeddings_scalar e
|
|
1131
|
+
WHERE e.org_id IN ({placeholders})
|
|
1132
|
+
ORDER BY distance
|
|
1133
|
+
LIMIT ?
|
|
1134
|
+
"""
|
|
1135
|
+
else:
|
|
1136
|
+
# Fall back to float32 embeddings
|
|
1137
|
+
query = f"""
|
|
1138
|
+
SELECT
|
|
1139
|
+
e.org_id,
|
|
1140
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
1141
|
+
FROM organization_embeddings e
|
|
1142
|
+
WHERE e.org_id IN ({placeholders})
|
|
1143
|
+
ORDER BY distance
|
|
1144
|
+
LIMIT ?
|
|
1145
|
+
"""
|
|
1012
1146
|
|
|
1013
1147
|
cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
|
|
1014
1148
|
|
|
@@ -1035,33 +1169,58 @@ class OrganizationDatabase:
|
|
|
1035
1169
|
top_k: int,
|
|
1036
1170
|
source_filter: Optional[str],
|
|
1037
1171
|
) -> list[tuple[CompanyRecord, float]]:
|
|
1038
|
-
"""Full vector search without text pre-filtering."""
|
|
1172
|
+
"""Full vector search without text pre-filtering using scalar (int8) embeddings."""
|
|
1039
1173
|
conn = self._conn
|
|
1040
1174
|
assert conn is not None
|
|
1041
1175
|
|
|
1176
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
1177
|
+
use_scalar = self._has_scalar_table()
|
|
1178
|
+
|
|
1042
1179
|
# KNN search with sqlite-vec
|
|
1043
1180
|
if source_filter:
|
|
1044
1181
|
# Need to join with organizations table for source filter
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1182
|
+
if use_scalar:
|
|
1183
|
+
query = """
|
|
1184
|
+
SELECT
|
|
1185
|
+
e.org_id,
|
|
1186
|
+
vec_distance_cosine(e.embedding, vec_int8(?)) as distance
|
|
1187
|
+
FROM organization_embeddings_scalar e
|
|
1188
|
+
JOIN organizations c ON e.org_id = c.id
|
|
1189
|
+
WHERE c.source = ?
|
|
1190
|
+
ORDER BY distance
|
|
1191
|
+
LIMIT ?
|
|
1192
|
+
"""
|
|
1193
|
+
else:
|
|
1194
|
+
query = """
|
|
1195
|
+
SELECT
|
|
1196
|
+
e.org_id,
|
|
1197
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
1198
|
+
FROM organization_embeddings e
|
|
1199
|
+
JOIN organizations c ON e.org_id = c.id
|
|
1200
|
+
WHERE c.source = ?
|
|
1201
|
+
ORDER BY distance
|
|
1202
|
+
LIMIT ?
|
|
1203
|
+
"""
|
|
1055
1204
|
cursor = conn.execute(query, (query_blob, source_filter, top_k))
|
|
1056
1205
|
else:
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1206
|
+
if use_scalar:
|
|
1207
|
+
query = """
|
|
1208
|
+
SELECT
|
|
1209
|
+
org_id,
|
|
1210
|
+
vec_distance_cosine(embedding, vec_int8(?)) as distance
|
|
1211
|
+
FROM organization_embeddings_scalar
|
|
1212
|
+
ORDER BY distance
|
|
1213
|
+
LIMIT ?
|
|
1214
|
+
"""
|
|
1215
|
+
else:
|
|
1216
|
+
query = """
|
|
1217
|
+
SELECT
|
|
1218
|
+
org_id,
|
|
1219
|
+
vec_distance_cosine(embedding, ?) as distance
|
|
1220
|
+
FROM organization_embeddings
|
|
1221
|
+
ORDER BY distance
|
|
1222
|
+
LIMIT ?
|
|
1223
|
+
"""
|
|
1065
1224
|
cursor = conn.execute(query, (query_blob, top_k))
|
|
1066
1225
|
|
|
1067
1226
|
results = []
|
|
@@ -1081,10 +1240,19 @@ class OrganizationDatabase:
|
|
|
1081
1240
|
conn = self._conn
|
|
1082
1241
|
assert conn is not None
|
|
1083
1242
|
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1243
|
+
if self._is_v2:
|
|
1244
|
+
# v2 schema: use view for text fields, but need record from base table
|
|
1245
|
+
cursor = conn.execute("""
|
|
1246
|
+
SELECT v.id, v.name, v.source, v.source_identifier, v.region, v.entity_type, v.canon_id, o.record
|
|
1247
|
+
FROM organizations_view v
|
|
1248
|
+
JOIN organizations o ON v.id = o.id
|
|
1249
|
+
WHERE v.id = ?
|
|
1250
|
+
""", (org_id,))
|
|
1251
|
+
else:
|
|
1252
|
+
cursor = conn.execute("""
|
|
1253
|
+
SELECT id, name, source, source_id, region, entity_type, record, canon_id
|
|
1254
|
+
FROM organizations WHERE id = ?
|
|
1255
|
+
""", (org_id,))
|
|
1088
1256
|
|
|
1089
1257
|
row = cursor.fetchone()
|
|
1090
1258
|
if row:
|
|
@@ -1092,10 +1260,11 @@ class OrganizationDatabase:
|
|
|
1092
1260
|
# Add db_id and canon_id to record dict for canon-aware search
|
|
1093
1261
|
record_data["db_id"] = row["id"]
|
|
1094
1262
|
record_data["canon_id"] = row["canon_id"]
|
|
1263
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
1095
1264
|
return CompanyRecord(
|
|
1096
1265
|
name=row["name"],
|
|
1097
1266
|
source=row["source"],
|
|
1098
|
-
source_id=row[
|
|
1267
|
+
source_id=row[source_id_field],
|
|
1099
1268
|
region=row["region"] or "",
|
|
1100
1269
|
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
1101
1270
|
record=record_data,
|
|
@@ -1106,18 +1275,29 @@ class OrganizationDatabase:
|
|
|
1106
1275
|
"""Get an organization record by source and source_id."""
|
|
1107
1276
|
conn = self._connect()
|
|
1108
1277
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1278
|
+
if self._is_v2:
|
|
1279
|
+
# v2 schema: join view with base table for record
|
|
1280
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1281
|
+
cursor = conn.execute("""
|
|
1282
|
+
SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
|
|
1283
|
+
FROM organizations_view v
|
|
1284
|
+
JOIN organizations o ON v.id = o.id
|
|
1285
|
+
WHERE o.source_id = ? AND o.source_identifier = ?
|
|
1286
|
+
""", (source_type_id, source_id))
|
|
1287
|
+
else:
|
|
1288
|
+
cursor = conn.execute("""
|
|
1289
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
1290
|
+
FROM organizations
|
|
1291
|
+
WHERE source = ? AND source_id = ?
|
|
1292
|
+
""", (source, source_id))
|
|
1114
1293
|
|
|
1115
1294
|
row = cursor.fetchone()
|
|
1116
1295
|
if row:
|
|
1296
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
1117
1297
|
return CompanyRecord(
|
|
1118
1298
|
name=row["name"],
|
|
1119
1299
|
source=row["source"],
|
|
1120
|
-
source_id=row[
|
|
1300
|
+
source_id=row[source_id_field],
|
|
1121
1301
|
region=row["region"] or "",
|
|
1122
1302
|
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
1123
1303
|
record=json.loads(row["record"]),
|
|
@@ -1128,10 +1308,17 @@ class OrganizationDatabase:
|
|
|
1128
1308
|
"""Get the internal database ID for an organization by source and source_id."""
|
|
1129
1309
|
conn = self._connect()
|
|
1130
1310
|
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1311
|
+
if self._is_v2:
|
|
1312
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1313
|
+
cursor = conn.execute("""
|
|
1314
|
+
SELECT id FROM organizations
|
|
1315
|
+
WHERE source_id = ? AND source_identifier = ?
|
|
1316
|
+
""", (source_type_id, source_id))
|
|
1317
|
+
else:
|
|
1318
|
+
cursor = conn.execute("""
|
|
1319
|
+
SELECT id FROM organizations
|
|
1320
|
+
WHERE source = ? AND source_id = ?
|
|
1321
|
+
""", (source, source_id))
|
|
1135
1322
|
|
|
1136
1323
|
row = cursor.fetchone()
|
|
1137
1324
|
if row:
|
|
@@ -1146,8 +1333,18 @@ class OrganizationDatabase:
|
|
|
1146
1333
|
cursor = conn.execute("SELECT COUNT(*) FROM organizations")
|
|
1147
1334
|
total = cursor.fetchone()[0]
|
|
1148
1335
|
|
|
1149
|
-
# Count by source
|
|
1150
|
-
|
|
1336
|
+
# Count by source - handle both v1 and v2 schema
|
|
1337
|
+
if self._is_v2:
|
|
1338
|
+
# v2 schema - join with source_types
|
|
1339
|
+
cursor = conn.execute("""
|
|
1340
|
+
SELECT st.name as source, COUNT(*) as cnt
|
|
1341
|
+
FROM organizations o
|
|
1342
|
+
JOIN source_types st ON o.source_id = st.id
|
|
1343
|
+
GROUP BY o.source_id
|
|
1344
|
+
""")
|
|
1345
|
+
else:
|
|
1346
|
+
# v1 schema
|
|
1347
|
+
cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
|
|
1151
1348
|
by_source = {row["source"]: row["cnt"] for row in cursor}
|
|
1152
1349
|
|
|
1153
1350
|
# Database file size
|
|
@@ -1167,20 +1364,31 @@ class OrganizationDatabase:
|
|
|
1167
1364
|
Useful for resume operations to skip already-imported records.
|
|
1168
1365
|
|
|
1169
1366
|
Args:
|
|
1170
|
-
source: Optional source filter (e.g., "
|
|
1367
|
+
source: Optional source filter (e.g., "wikidata" for Wikidata orgs)
|
|
1171
1368
|
|
|
1172
1369
|
Returns:
|
|
1173
1370
|
Set of source_id strings (e.g., Q codes for Wikidata)
|
|
1174
1371
|
"""
|
|
1175
1372
|
conn = self._connect()
|
|
1176
1373
|
|
|
1177
|
-
if
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
(source,)
|
|
1181
|
-
|
|
1374
|
+
if self._is_v2:
|
|
1375
|
+
id_col = "source_identifier"
|
|
1376
|
+
if source:
|
|
1377
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1378
|
+
cursor = conn.execute(
|
|
1379
|
+
f"SELECT DISTINCT {id_col} FROM organizations WHERE source_id = ?",
|
|
1380
|
+
(source_type_id,)
|
|
1381
|
+
)
|
|
1382
|
+
else:
|
|
1383
|
+
cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM organizations")
|
|
1182
1384
|
else:
|
|
1183
|
-
|
|
1385
|
+
if source:
|
|
1386
|
+
cursor = conn.execute(
|
|
1387
|
+
"SELECT DISTINCT source_id FROM organizations WHERE source = ?",
|
|
1388
|
+
(source,)
|
|
1389
|
+
)
|
|
1390
|
+
else:
|
|
1391
|
+
cursor = conn.execute("SELECT DISTINCT source_id FROM organizations")
|
|
1184
1392
|
|
|
1185
1393
|
return {row[0] for row in cursor}
|
|
1186
1394
|
|
|
@@ -1188,27 +1396,51 @@ class OrganizationDatabase:
|
|
|
1188
1396
|
"""Iterate over all records, optionally filtered by source."""
|
|
1189
1397
|
conn = self._connect()
|
|
1190
1398
|
|
|
1191
|
-
if
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1399
|
+
if self._is_v2:
|
|
1400
|
+
if source:
|
|
1401
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1402
|
+
cursor = conn.execute("""
|
|
1403
|
+
SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
|
|
1404
|
+
FROM organizations_view v
|
|
1405
|
+
JOIN organizations o ON v.id = o.id
|
|
1406
|
+
WHERE o.source_id = ?
|
|
1407
|
+
""", (source_type_id,))
|
|
1408
|
+
else:
|
|
1409
|
+
cursor = conn.execute("""
|
|
1410
|
+
SELECT v.name, v.source, v.source_identifier, v.region, v.entity_type, o.record
|
|
1411
|
+
FROM organizations_view v
|
|
1412
|
+
JOIN organizations o ON v.id = o.id
|
|
1413
|
+
""")
|
|
1414
|
+
for row in cursor:
|
|
1415
|
+
yield CompanyRecord(
|
|
1416
|
+
name=row["name"],
|
|
1417
|
+
source=row["source"],
|
|
1418
|
+
source_id=row["source_identifier"],
|
|
1419
|
+
region=row["region"] or "",
|
|
1420
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
1421
|
+
record=json.loads(row["record"]),
|
|
1422
|
+
)
|
|
1197
1423
|
else:
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1424
|
+
if source:
|
|
1425
|
+
cursor = conn.execute("""
|
|
1426
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
1427
|
+
FROM organizations
|
|
1428
|
+
WHERE source = ?
|
|
1429
|
+
""", (source,))
|
|
1430
|
+
else:
|
|
1431
|
+
cursor = conn.execute("""
|
|
1432
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
1433
|
+
FROM organizations
|
|
1434
|
+
""")
|
|
1435
|
+
for row in cursor:
|
|
1436
|
+
yield CompanyRecord(
|
|
1437
|
+
name=row["name"],
|
|
1438
|
+
source=row["source"],
|
|
1439
|
+
source_id=row["source_id"],
|
|
1440
|
+
region=row["region"] or "",
|
|
1441
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
1442
|
+
record=json.loads(row["record"]),
|
|
1443
|
+
)
|
|
1212
1444
|
|
|
1213
1445
|
def canonicalize(self, batch_size: int = 10000) -> dict[str, int]:
|
|
1214
1446
|
"""
|
|
@@ -1252,10 +1484,18 @@ class OrganizationDatabase:
|
|
|
1252
1484
|
sources: dict[int, str] = {} # org_id -> source
|
|
1253
1485
|
all_org_ids: list[int] = []
|
|
1254
1486
|
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1487
|
+
if self._is_v2:
|
|
1488
|
+
cursor = conn.execute("""
|
|
1489
|
+
SELECT o.id, s.name as source, o.source_identifier as source_id, o.name, l.name as region, o.record
|
|
1490
|
+
FROM organizations o
|
|
1491
|
+
JOIN source_types s ON o.source_id = s.id
|
|
1492
|
+
LEFT JOIN locations l ON o.region_id = l.id
|
|
1493
|
+
""")
|
|
1494
|
+
else:
|
|
1495
|
+
cursor = conn.execute("""
|
|
1496
|
+
SELECT id, source, source_id, name, region, record
|
|
1497
|
+
FROM organizations
|
|
1498
|
+
""")
|
|
1259
1499
|
|
|
1260
1500
|
count = 0
|
|
1261
1501
|
for row in cursor:
|
|
@@ -1586,17 +1826,32 @@ class OrganizationDatabase:
|
|
|
1586
1826
|
"""Delete all records from a specific source."""
|
|
1587
1827
|
conn = self._connect()
|
|
1588
1828
|
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1829
|
+
if self._is_v2:
|
|
1830
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
1831
|
+
# First get IDs to delete from vec table
|
|
1832
|
+
cursor = conn.execute("SELECT id FROM organizations WHERE source_id = ?", (source_type_id,))
|
|
1833
|
+
ids_to_delete = [row["id"] for row in cursor]
|
|
1834
|
+
|
|
1835
|
+
# Delete from vec table
|
|
1836
|
+
if ids_to_delete:
|
|
1837
|
+
placeholders = ",".join("?" * len(ids_to_delete))
|
|
1838
|
+
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
1839
|
+
|
|
1840
|
+
# Delete from main table
|
|
1841
|
+
cursor = conn.execute("DELETE FROM organizations WHERE source_id = ?", (source_type_id,))
|
|
1842
|
+
else:
|
|
1843
|
+
# First get IDs to delete from vec table
|
|
1844
|
+
cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
|
|
1845
|
+
ids_to_delete = [row["id"] for row in cursor]
|
|
1846
|
+
|
|
1847
|
+
# Delete from vec table
|
|
1848
|
+
if ids_to_delete:
|
|
1849
|
+
placeholders = ",".join("?" * len(ids_to_delete))
|
|
1850
|
+
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
1592
1851
|
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
placeholders = ",".join("?" * len(ids_to_delete))
|
|
1596
|
-
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
1852
|
+
# Delete from main table
|
|
1853
|
+
cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
|
|
1597
1854
|
|
|
1598
|
-
# Delete from main table
|
|
1599
|
-
cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
|
|
1600
1855
|
deleted = cursor.rowcount
|
|
1601
1856
|
|
|
1602
1857
|
conn.commit()
|
|
@@ -1825,130 +2080,366 @@ class OrganizationDatabase:
|
|
|
1825
2080
|
conn.commit()
|
|
1826
2081
|
return count
|
|
1827
2082
|
|
|
1828
|
-
def
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
2083
|
+
def ensure_scalar_table_exists(self) -> None:
|
|
2084
|
+
"""Create scalar embedding table if it doesn't exist."""
|
|
2085
|
+
conn = self._connect()
|
|
2086
|
+
conn.execute(f"""
|
|
2087
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings_scalar USING vec0(
|
|
2088
|
+
org_id INTEGER PRIMARY KEY,
|
|
2089
|
+
embedding int8[{self._embedding_dim}]
|
|
2090
|
+
)
|
|
2091
|
+
""")
|
|
2092
|
+
conn.commit()
|
|
2093
|
+
logger.info("Ensured organization_embeddings_scalar table exists")
|
|
2094
|
+
|
|
2095
|
+
def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
|
|
1833
2096
|
"""
|
|
1834
|
-
|
|
2097
|
+
Yield batches of org IDs that have float32 but missing scalar embeddings.
|
|
1835
2098
|
|
|
1836
2099
|
Args:
|
|
1837
|
-
|
|
1838
|
-
batch_size: Commit batch size
|
|
2100
|
+
batch_size: Number of IDs per batch
|
|
1839
2101
|
|
|
1840
|
-
|
|
1841
|
-
|
|
2102
|
+
Yields:
|
|
2103
|
+
Lists of org_ids needing scalar embeddings
|
|
1842
2104
|
"""
|
|
1843
2105
|
conn = self._connect()
|
|
1844
2106
|
|
|
1845
|
-
#
|
|
1846
|
-
|
|
1847
|
-
cursor = conn.execute("""
|
|
1848
|
-
SELECT id, region FROM organizations
|
|
1849
|
-
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
1850
|
-
""")
|
|
1851
|
-
rows = cursor.fetchall()
|
|
2107
|
+
# Ensure scalar table exists before querying
|
|
2108
|
+
self.ensure_scalar_table_exists()
|
|
1852
2109
|
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
2110
|
+
last_id = 0
|
|
2111
|
+
while True:
|
|
2112
|
+
cursor = conn.execute("""
|
|
2113
|
+
SELECT e.org_id FROM organization_embeddings e
|
|
2114
|
+
LEFT JOIN organization_embeddings_scalar s ON e.org_id = s.org_id
|
|
2115
|
+
WHERE s.org_id IS NULL AND e.org_id > ?
|
|
2116
|
+
ORDER BY e.org_id
|
|
2117
|
+
LIMIT ?
|
|
2118
|
+
""", (last_id, batch_size))
|
|
1862
2119
|
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
2120
|
+
rows = cursor.fetchall()
|
|
2121
|
+
if not rows:
|
|
2122
|
+
break
|
|
1866
2123
|
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
2124
|
+
ids = [row["org_id"] for row in rows]
|
|
2125
|
+
yield ids
|
|
2126
|
+
last_id = ids[-1]
|
|
1870
2127
|
|
|
1871
|
-
def
|
|
2128
|
+
def get_embeddings_by_ids(self, org_ids: list[int]) -> dict[int, np.ndarray]:
|
|
1872
2129
|
"""
|
|
1873
|
-
|
|
2130
|
+
Fetch float32 embeddings for given org IDs.
|
|
2131
|
+
|
|
2132
|
+
Args:
|
|
2133
|
+
org_ids: List of organization IDs
|
|
1874
2134
|
|
|
1875
2135
|
Returns:
|
|
1876
|
-
|
|
2136
|
+
Dict mapping org_id to float32 embedding array
|
|
1877
2137
|
"""
|
|
1878
2138
|
conn = self._connect()
|
|
1879
|
-
qids: set[str] = set()
|
|
1880
2139
|
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
1884
|
-
""")
|
|
1885
|
-
for row in cursor:
|
|
1886
|
-
qids.add(row["region"])
|
|
2140
|
+
if not org_ids:
|
|
2141
|
+
return {}
|
|
1887
2142
|
|
|
1888
|
-
|
|
2143
|
+
placeholders = ",".join("?" * len(org_ids))
|
|
2144
|
+
cursor = conn.execute(f"""
|
|
2145
|
+
SELECT org_id, embedding FROM organization_embeddings
|
|
2146
|
+
WHERE org_id IN ({placeholders})
|
|
2147
|
+
""", org_ids)
|
|
1889
2148
|
|
|
2149
|
+
result = {}
|
|
2150
|
+
for row in cursor:
|
|
2151
|
+
embedding_blob = row["embedding"]
|
|
2152
|
+
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
2153
|
+
result[row["org_id"]] = embedding
|
|
2154
|
+
return result
|
|
1890
2155
|
|
|
1891
|
-
def
|
|
1892
|
-
|
|
1893
|
-
|
|
2156
|
+
def insert_scalar_embeddings_batch(self, org_ids: list[int], embeddings: np.ndarray) -> int:
|
|
2157
|
+
"""
|
|
2158
|
+
Insert scalar (int8) embeddings for existing orgs.
|
|
1894
2159
|
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
|
|
2160
|
+
Args:
|
|
2161
|
+
org_ids: List of organization IDs
|
|
2162
|
+
embeddings: Matrix of int8 embeddings (N x dim)
|
|
1898
2163
|
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
logger.debug(f"Creating new PersonDatabase instance for {path_key}")
|
|
1905
|
-
_person_database_instances[path_key] = PersonDatabase(db_path=db_path, embedding_dim=embedding_dim)
|
|
1906
|
-
return _person_database_instances[path_key]
|
|
2164
|
+
Returns:
|
|
2165
|
+
Number of embeddings inserted
|
|
2166
|
+
"""
|
|
2167
|
+
conn = self._connect()
|
|
2168
|
+
count = 0
|
|
1907
2169
|
|
|
2170
|
+
for org_id, embedding in zip(org_ids, embeddings):
|
|
2171
|
+
scalar_blob = embedding.astype(np.int8).tobytes()
|
|
2172
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
|
|
2173
|
+
conn.execute("""
|
|
2174
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
2175
|
+
VALUES (?, vec_int8(?))
|
|
2176
|
+
""", (org_id, scalar_blob))
|
|
2177
|
+
count += 1
|
|
1908
2178
|
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
SQLite database with sqlite-vec for person vector search.
|
|
2179
|
+
conn.commit()
|
|
2180
|
+
return count
|
|
1912
2181
|
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
2182
|
+
def get_scalar_embedding_count(self) -> int:
|
|
2183
|
+
"""Get count of scalar embeddings."""
|
|
2184
|
+
conn = self._connect()
|
|
2185
|
+
if not self._has_scalar_table():
|
|
2186
|
+
return 0
|
|
2187
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings_scalar")
|
|
2188
|
+
return cursor.fetchone()[0]
|
|
1916
2189
|
|
|
1917
|
-
|
|
1918
|
-
|
|
2190
|
+
def get_float32_embedding_count(self) -> int:
|
|
2191
|
+
"""Get count of float32 embeddings."""
|
|
2192
|
+
conn = self._connect()
|
|
2193
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings")
|
|
2194
|
+
return cursor.fetchone()[0]
|
|
1919
2195
|
|
|
1920
|
-
def
|
|
1921
|
-
self,
|
|
1922
|
-
db_path: Optional[str | Path] = None,
|
|
1923
|
-
embedding_dim: int = 768, # Default for embeddinggemma-300m
|
|
1924
|
-
):
|
|
2196
|
+
def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
|
|
1925
2197
|
"""
|
|
1926
|
-
|
|
2198
|
+
Yield batches of (org_id, name) tuples for records missing both float32 and scalar embeddings.
|
|
1927
2199
|
|
|
1928
2200
|
Args:
|
|
1929
|
-
|
|
1930
|
-
embedding_dim: Dimension of embeddings to store
|
|
1931
|
-
"""
|
|
1932
|
-
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
1933
|
-
self._embedding_dim = embedding_dim
|
|
1934
|
-
self._conn: Optional[sqlite3.Connection] = None
|
|
2201
|
+
batch_size: Number of IDs per batch
|
|
1935
2202
|
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
2203
|
+
Yields:
|
|
2204
|
+
Lists of (org_id, name) tuples needing embeddings generated from scratch
|
|
2205
|
+
"""
|
|
2206
|
+
conn = self._connect()
|
|
1939
2207
|
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
if self._conn is not None:
|
|
1943
|
-
return self._conn
|
|
2208
|
+
# Ensure scalar table exists
|
|
2209
|
+
self.ensure_scalar_table_exists()
|
|
1944
2210
|
|
|
1945
|
-
|
|
2211
|
+
last_id = 0
|
|
2212
|
+
while True:
|
|
2213
|
+
cursor = conn.execute("""
|
|
2214
|
+
SELECT o.id, o.name FROM organizations o
|
|
2215
|
+
LEFT JOIN organization_embeddings e ON o.id = e.org_id
|
|
2216
|
+
WHERE e.org_id IS NULL AND o.id > ?
|
|
2217
|
+
ORDER BY o.id
|
|
2218
|
+
LIMIT ?
|
|
2219
|
+
""", (last_id, batch_size))
|
|
1946
2220
|
|
|
1947
|
-
|
|
1948
|
-
|
|
2221
|
+
rows = cursor.fetchall()
|
|
2222
|
+
if not rows:
|
|
2223
|
+
break
|
|
2224
|
+
|
|
2225
|
+
results = [(row["id"], row["name"]) for row in rows]
|
|
2226
|
+
yield results
|
|
2227
|
+
last_id = results[-1][0]
|
|
2228
|
+
|
|
2229
|
+
def insert_both_embeddings_batch(
|
|
2230
|
+
self,
|
|
2231
|
+
org_ids: list[int],
|
|
2232
|
+
fp32_embeddings: np.ndarray,
|
|
2233
|
+
int8_embeddings: np.ndarray,
|
|
2234
|
+
) -> int:
|
|
2235
|
+
"""
|
|
2236
|
+
Insert both float32 and int8 embeddings for existing orgs.
|
|
2237
|
+
|
|
2238
|
+
Args:
|
|
2239
|
+
org_ids: List of organization IDs
|
|
2240
|
+
fp32_embeddings: Matrix of float32 embeddings (N x dim)
|
|
2241
|
+
int8_embeddings: Matrix of int8 embeddings (N x dim)
|
|
2242
|
+
|
|
2243
|
+
Returns:
|
|
2244
|
+
Number of embeddings inserted
|
|
2245
|
+
"""
|
|
2246
|
+
conn = self._connect()
|
|
2247
|
+
count = 0
|
|
2248
|
+
|
|
2249
|
+
for org_id, fp32, int8 in zip(org_ids, fp32_embeddings, int8_embeddings):
|
|
2250
|
+
# Insert float32
|
|
2251
|
+
fp32_blob = fp32.astype(np.float32).tobytes()
|
|
2252
|
+
conn.execute("DELETE FROM organization_embeddings WHERE org_id = ?", (org_id,))
|
|
2253
|
+
conn.execute("""
|
|
2254
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
2255
|
+
VALUES (?, ?)
|
|
2256
|
+
""", (org_id, fp32_blob))
|
|
2257
|
+
|
|
2258
|
+
# Insert int8
|
|
2259
|
+
int8_blob = int8.astype(np.int8).tobytes()
|
|
2260
|
+
conn.execute("DELETE FROM organization_embeddings_scalar WHERE org_id = ?", (org_id,))
|
|
2261
|
+
conn.execute("""
|
|
2262
|
+
INSERT INTO organization_embeddings_scalar (org_id, embedding)
|
|
2263
|
+
VALUES (?, vec_int8(?))
|
|
2264
|
+
""", (org_id, int8_blob))
|
|
2265
|
+
|
|
2266
|
+
count += 1
|
|
2267
|
+
|
|
2268
|
+
conn.commit()
|
|
2269
|
+
return count
|
|
2270
|
+
|
|
2271
|
+
def resolve_qid_labels(
|
|
2272
|
+
self,
|
|
2273
|
+
label_map: dict[str, str],
|
|
2274
|
+
batch_size: int = 1000,
|
|
2275
|
+
) -> tuple[int, int]:
|
|
2276
|
+
"""
|
|
2277
|
+
Update organization records that have QIDs instead of labels in region field.
|
|
2278
|
+
|
|
2279
|
+
If resolving would create a duplicate of an existing record with
|
|
2280
|
+
resolved labels, the QID version is deleted instead.
|
|
2281
|
+
|
|
2282
|
+
Args:
|
|
2283
|
+
label_map: Mapping of QID -> label for resolution
|
|
2284
|
+
batch_size: Commit batch size
|
|
2285
|
+
|
|
2286
|
+
Returns:
|
|
2287
|
+
Tuple of (records updated, duplicates deleted)
|
|
2288
|
+
"""
|
|
2289
|
+
conn = self._connect()
|
|
2290
|
+
|
|
2291
|
+
# Find records with QIDs in region field (starts with 'Q' followed by digits)
|
|
2292
|
+
region_updates = 0
|
|
2293
|
+
cursor = conn.execute("""
|
|
2294
|
+
SELECT id, region FROM organizations
|
|
2295
|
+
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
2296
|
+
""")
|
|
2297
|
+
rows = cursor.fetchall()
|
|
2298
|
+
|
|
2299
|
+
duplicates_deleted = 0
|
|
2300
|
+
for row in rows:
|
|
2301
|
+
org_id = row["id"]
|
|
2302
|
+
qid = row["region"]
|
|
2303
|
+
if qid in label_map:
|
|
2304
|
+
resolved_region = label_map[qid]
|
|
2305
|
+
# Check if this update would create a duplicate
|
|
2306
|
+
# Get the name and source of the current record
|
|
2307
|
+
org_cursor = conn.execute(
|
|
2308
|
+
"SELECT name, source FROM organizations WHERE id = ?",
|
|
2309
|
+
(org_id,)
|
|
2310
|
+
)
|
|
2311
|
+
org_row = org_cursor.fetchone()
|
|
2312
|
+
if org_row is None:
|
|
2313
|
+
continue
|
|
2314
|
+
|
|
2315
|
+
org_name = org_row["name"]
|
|
2316
|
+
org_source = org_row["source"]
|
|
2317
|
+
|
|
2318
|
+
# Check if a record with the resolved region already exists
|
|
2319
|
+
existing_cursor = conn.execute(
|
|
2320
|
+
"SELECT id FROM organizations WHERE name = ? AND region = ? AND source = ? AND id != ?",
|
|
2321
|
+
(org_name, resolved_region, org_source, org_id)
|
|
2322
|
+
)
|
|
2323
|
+
existing = existing_cursor.fetchone()
|
|
2324
|
+
|
|
2325
|
+
if existing is not None:
|
|
2326
|
+
# Duplicate would be created - delete the QID-based record
|
|
2327
|
+
conn.execute("DELETE FROM organizations WHERE id = ?", (org_id,))
|
|
2328
|
+
duplicates_deleted += 1
|
|
2329
|
+
else:
|
|
2330
|
+
# Safe to update
|
|
2331
|
+
conn.execute(
|
|
2332
|
+
"UPDATE organizations SET region = ? WHERE id = ?",
|
|
2333
|
+
(resolved_region, org_id)
|
|
2334
|
+
)
|
|
2335
|
+
region_updates += 1
|
|
2336
|
+
|
|
2337
|
+
if (region_updates + duplicates_deleted) % batch_size == 0:
|
|
2338
|
+
conn.commit()
|
|
2339
|
+
logger.info(f"Resolved QID labels: {region_updates} updates, {duplicates_deleted} deletes...")
|
|
2340
|
+
|
|
2341
|
+
conn.commit()
|
|
2342
|
+
logger.info(f"Resolved QID labels: {region_updates} organization regions, {duplicates_deleted} duplicates deleted")
|
|
2343
|
+
return region_updates, duplicates_deleted
|
|
2344
|
+
|
|
2345
|
+
def get_unresolved_qids(self) -> set[str]:
|
|
2346
|
+
"""
|
|
2347
|
+
Get all QIDs that still need resolution in the organizations table.
|
|
2348
|
+
|
|
2349
|
+
Returns:
|
|
2350
|
+
Set of QIDs (starting with 'Q') found in region field
|
|
2351
|
+
"""
|
|
2352
|
+
conn = self._connect()
|
|
2353
|
+
qids: set[str] = set()
|
|
2354
|
+
|
|
2355
|
+
cursor = conn.execute("""
|
|
2356
|
+
SELECT DISTINCT region FROM organizations
|
|
2357
|
+
WHERE region LIKE 'Q%' AND region GLOB 'Q[0-9]*'
|
|
2358
|
+
""")
|
|
2359
|
+
for row in cursor:
|
|
2360
|
+
qids.add(row["region"])
|
|
2361
|
+
|
|
2362
|
+
return qids
|
|
2363
|
+
|
|
2364
|
+
|
|
2365
|
+
def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
|
|
2366
|
+
"""
|
|
2367
|
+
Get a singleton PersonDatabase instance for the given path.
|
|
2368
|
+
|
|
2369
|
+
Args:
|
|
2370
|
+
db_path: Path to database file
|
|
2371
|
+
embedding_dim: Dimension of embeddings
|
|
2372
|
+
|
|
2373
|
+
Returns:
|
|
2374
|
+
Shared PersonDatabase instance
|
|
2375
|
+
"""
|
|
2376
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
2377
|
+
if path_key not in _person_database_instances:
|
|
2378
|
+
logger.debug(f"Creating new PersonDatabase instance for {path_key}")
|
|
2379
|
+
_person_database_instances[path_key] = PersonDatabase(db_path=db_path, embedding_dim=embedding_dim)
|
|
2380
|
+
return _person_database_instances[path_key]
|
|
2381
|
+
|
|
2382
|
+
|
|
2383
|
+
class PersonDatabase:
|
|
2384
|
+
"""
|
|
2385
|
+
SQLite database with sqlite-vec for person vector search.
|
|
2386
|
+
|
|
2387
|
+
Uses hybrid text + vector search:
|
|
2388
|
+
1. Text filtering with LIKE to reduce candidates
|
|
2389
|
+
2. sqlite-vec for semantic similarity ranking
|
|
2390
|
+
|
|
2391
|
+
Stores people from sources like Wikidata with role/org context.
|
|
2392
|
+
"""
|
|
2393
|
+
|
|
2394
|
+
def __init__(
|
|
2395
|
+
self,
|
|
2396
|
+
db_path: Optional[str | Path] = None,
|
|
2397
|
+
embedding_dim: int = 768, # Default for embeddinggemma-300m
|
|
2398
|
+
):
|
|
2399
|
+
"""
|
|
2400
|
+
Initialize the person database.
|
|
2401
|
+
|
|
2402
|
+
Args:
|
|
2403
|
+
db_path: Path to database file (creates if not exists)
|
|
2404
|
+
embedding_dim: Dimension of embeddings to store
|
|
2405
|
+
"""
|
|
2406
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
2407
|
+
self._embedding_dim = embedding_dim
|
|
2408
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
2409
|
+
self._is_v2: Optional[bool] = None # Detected on first connect
|
|
2410
|
+
|
|
2411
|
+
def _ensure_dir(self) -> None:
|
|
2412
|
+
"""Ensure database directory exists."""
|
|
2413
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2414
|
+
|
|
2415
|
+
def _connect(self) -> sqlite3.Connection:
|
|
2416
|
+
"""Get or create database connection using shared connection pool."""
|
|
2417
|
+
if self._conn is not None:
|
|
2418
|
+
return self._conn
|
|
2419
|
+
|
|
2420
|
+
self._conn = _get_shared_connection(self._db_path, self._embedding_dim)
|
|
2421
|
+
|
|
2422
|
+
# Detect schema version BEFORE creating tables
|
|
2423
|
+
# v2 has person_type_id (FK) instead of person_type (TEXT)
|
|
2424
|
+
if self._is_v2 is None:
|
|
2425
|
+
cursor = self._conn.execute("PRAGMA table_info(people)")
|
|
2426
|
+
columns = {row["name"] for row in cursor}
|
|
2427
|
+
self._is_v2 = "person_type_id" in columns
|
|
2428
|
+
if self._is_v2:
|
|
2429
|
+
logger.debug("Detected v2 schema for people")
|
|
2430
|
+
|
|
2431
|
+
# Create tables (idempotent) - only for v1 schema or fresh databases
|
|
2432
|
+
# v2 databases already have their schema from migration
|
|
2433
|
+
if not self._is_v2:
|
|
2434
|
+
self._create_tables()
|
|
1949
2435
|
|
|
1950
2436
|
return self._conn
|
|
1951
2437
|
|
|
2438
|
+
@property
|
|
2439
|
+
def _people_table(self) -> str:
|
|
2440
|
+
"""Return table/view name for people queries needing text fields."""
|
|
2441
|
+
return "people_view" if self._is_v2 else "people"
|
|
2442
|
+
|
|
1952
2443
|
def _create_tables(self) -> None:
|
|
1953
2444
|
"""Create database tables including sqlite-vec virtual table."""
|
|
1954
2445
|
conn = self._conn
|
|
@@ -2051,7 +2542,7 @@ class PersonDatabase:
|
|
|
2051
2542
|
except sqlite3.OperationalError:
|
|
2052
2543
|
pass # Column doesn't exist yet
|
|
2053
2544
|
|
|
2054
|
-
# Create sqlite-vec virtual table for embeddings
|
|
2545
|
+
# Create sqlite-vec virtual table for embeddings (float32)
|
|
2055
2546
|
conn.execute(f"""
|
|
2056
2547
|
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
|
|
2057
2548
|
person_id INTEGER PRIMARY KEY,
|
|
@@ -2059,6 +2550,15 @@ class PersonDatabase:
|
|
|
2059
2550
|
)
|
|
2060
2551
|
""")
|
|
2061
2552
|
|
|
2553
|
+
# Create sqlite-vec virtual table for scalar embeddings (int8)
|
|
2554
|
+
# Provides 75% storage reduction with ~92% recall at top-100
|
|
2555
|
+
conn.execute(f"""
|
|
2556
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
|
|
2557
|
+
person_id INTEGER PRIMARY KEY,
|
|
2558
|
+
embedding int8[{self._embedding_dim}]
|
|
2559
|
+
)
|
|
2560
|
+
""")
|
|
2561
|
+
|
|
2062
2562
|
# Create QID labels lookup table for Wikidata QID -> label mappings
|
|
2063
2563
|
conn.execute("""
|
|
2064
2564
|
CREATE TABLE IF NOT EXISTS qid_labels (
|
|
@@ -2148,13 +2648,19 @@ class PersonDatabase:
|
|
|
2148
2648
|
"""Clear connection reference (shared connection remains open)."""
|
|
2149
2649
|
self._conn = None
|
|
2150
2650
|
|
|
2151
|
-
def insert(
|
|
2651
|
+
def insert(
|
|
2652
|
+
self,
|
|
2653
|
+
record: PersonRecord,
|
|
2654
|
+
embedding: np.ndarray,
|
|
2655
|
+
scalar_embedding: Optional[np.ndarray] = None,
|
|
2656
|
+
) -> int:
|
|
2152
2657
|
"""
|
|
2153
2658
|
Insert a person record with its embedding.
|
|
2154
2659
|
|
|
2155
2660
|
Args:
|
|
2156
2661
|
record: Person record to insert
|
|
2157
|
-
embedding: Embedding vector for the person name
|
|
2662
|
+
embedding: Embedding vector for the person name (float32)
|
|
2663
|
+
scalar_embedding: Optional int8 scalar embedding for compact storage
|
|
2158
2664
|
|
|
2159
2665
|
Returns:
|
|
2160
2666
|
Row ID of inserted record
|
|
@@ -2191,7 +2697,7 @@ class PersonDatabase:
|
|
|
2191
2697
|
row_id = cursor.lastrowid
|
|
2192
2698
|
assert row_id is not None
|
|
2193
2699
|
|
|
2194
|
-
# Insert embedding into vec table (
|
|
2700
|
+
# Insert embedding into vec table (float32)
|
|
2195
2701
|
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
2196
2702
|
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (row_id,))
|
|
2197
2703
|
conn.execute("""
|
|
@@ -2199,6 +2705,15 @@ class PersonDatabase:
|
|
|
2199
2705
|
VALUES (?, ?)
|
|
2200
2706
|
""", (row_id, embedding_blob))
|
|
2201
2707
|
|
|
2708
|
+
# Insert scalar embedding if provided (int8)
|
|
2709
|
+
if scalar_embedding is not None:
|
|
2710
|
+
scalar_blob = scalar_embedding.astype(np.int8).tobytes()
|
|
2711
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
|
|
2712
|
+
conn.execute("""
|
|
2713
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
2714
|
+
VALUES (?, vec_int8(?))
|
|
2715
|
+
""", (row_id, scalar_blob))
|
|
2716
|
+
|
|
2202
2717
|
conn.commit()
|
|
2203
2718
|
return row_id
|
|
2204
2719
|
|
|
@@ -2207,14 +2722,16 @@ class PersonDatabase:
|
|
|
2207
2722
|
records: list[PersonRecord],
|
|
2208
2723
|
embeddings: np.ndarray,
|
|
2209
2724
|
batch_size: int = 1000,
|
|
2725
|
+
scalar_embeddings: Optional[np.ndarray] = None,
|
|
2210
2726
|
) -> int:
|
|
2211
2727
|
"""
|
|
2212
2728
|
Insert multiple person records with embeddings.
|
|
2213
2729
|
|
|
2214
2730
|
Args:
|
|
2215
2731
|
records: List of person records
|
|
2216
|
-
embeddings: Matrix of embeddings (N x dim)
|
|
2732
|
+
embeddings: Matrix of embeddings (N x dim) - float32
|
|
2217
2733
|
batch_size: Commit batch size
|
|
2734
|
+
scalar_embeddings: Optional matrix of int8 scalar embeddings (N x dim)
|
|
2218
2735
|
|
|
2219
2736
|
Returns:
|
|
2220
2737
|
Number of records inserted
|
|
@@ -2222,32 +2739,66 @@ class PersonDatabase:
|
|
|
2222
2739
|
conn = self._connect()
|
|
2223
2740
|
count = 0
|
|
2224
2741
|
|
|
2225
|
-
for record, embedding in zip(records, embeddings):
|
|
2742
|
+
for i, (record, embedding) in enumerate(zip(records, embeddings)):
|
|
2226
2743
|
record_json = json.dumps(record.record)
|
|
2227
2744
|
name_normalized = _normalize_person_name(record.name)
|
|
2228
2745
|
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
record.
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
|
|
2242
|
-
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
|
|
2746
|
+
if self._is_v2:
|
|
2747
|
+
# v2 schema: use FK IDs instead of TEXT columns
|
|
2748
|
+
source_type_id = SOURCE_NAME_TO_ID.get(record.source, 4)
|
|
2749
|
+
person_type_id = PEOPLE_TYPE_NAME_TO_ID.get(record.person_type.value, 15) # 15 = unknown
|
|
2750
|
+
|
|
2751
|
+
# Resolve country to location_id if provided
|
|
2752
|
+
country_id = None
|
|
2753
|
+
if record.country:
|
|
2754
|
+
locations_db = get_locations_database(db_path=self._db_path)
|
|
2755
|
+
country_id = locations_db.resolve_region_text(record.country)
|
|
2756
|
+
|
|
2757
|
+
cursor = conn.execute("""
|
|
2758
|
+
INSERT OR REPLACE INTO people
|
|
2759
|
+
(name, name_normalized, source_id, source_identifier, country_id, person_type_id,
|
|
2760
|
+
known_for_org, known_for_org_id, from_date, to_date,
|
|
2761
|
+
birth_date, death_date, record)
|
|
2762
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
2763
|
+
""", (
|
|
2764
|
+
record.name,
|
|
2765
|
+
name_normalized,
|
|
2766
|
+
source_type_id,
|
|
2767
|
+
record.source_id,
|
|
2768
|
+
country_id,
|
|
2769
|
+
person_type_id,
|
|
2770
|
+
record.known_for_org,
|
|
2771
|
+
record.known_for_org_id, # Can be None
|
|
2772
|
+
record.from_date or "",
|
|
2773
|
+
record.to_date or "",
|
|
2774
|
+
record.birth_date or "",
|
|
2775
|
+
record.death_date or "",
|
|
2776
|
+
record_json,
|
|
2777
|
+
))
|
|
2778
|
+
else:
|
|
2779
|
+
# v1 schema: use TEXT columns
|
|
2780
|
+
cursor = conn.execute("""
|
|
2781
|
+
INSERT OR REPLACE INTO people
|
|
2782
|
+
(name, name_normalized, source, source_id, country, person_type,
|
|
2783
|
+
known_for_role, known_for_org, known_for_org_id, from_date, to_date,
|
|
2784
|
+
birth_date, death_date, record)
|
|
2785
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
2786
|
+
""", (
|
|
2787
|
+
record.name,
|
|
2788
|
+
name_normalized,
|
|
2789
|
+
record.source,
|
|
2790
|
+
record.source_id,
|
|
2791
|
+
record.country,
|
|
2792
|
+
record.person_type.value,
|
|
2793
|
+
record.known_for_role,
|
|
2794
|
+
record.known_for_org,
|
|
2795
|
+
record.known_for_org_id, # Can be None
|
|
2796
|
+
record.from_date or "",
|
|
2797
|
+
record.to_date or "",
|
|
2798
|
+
record.birth_date or "",
|
|
2799
|
+
record.death_date or "",
|
|
2800
|
+
record_json,
|
|
2801
|
+
))
|
|
2251
2802
|
|
|
2252
2803
|
row_id = cursor.lastrowid
|
|
2253
2804
|
assert row_id is not None
|
|
@@ -2260,6 +2811,15 @@ class PersonDatabase:
|
|
|
2260
2811
|
VALUES (?, ?)
|
|
2261
2812
|
""", (row_id, embedding_blob))
|
|
2262
2813
|
|
|
2814
|
+
# Insert scalar embedding if provided (int8)
|
|
2815
|
+
if scalar_embeddings is not None:
|
|
2816
|
+
scalar_blob = scalar_embeddings[i].astype(np.int8).tobytes()
|
|
2817
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (row_id,))
|
|
2818
|
+
conn.execute("""
|
|
2819
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
2820
|
+
VALUES (?, vec_int8(?))
|
|
2821
|
+
""", (row_id, scalar_blob))
|
|
2822
|
+
|
|
2263
2823
|
count += 1
|
|
2264
2824
|
|
|
2265
2825
|
if count % batch_size == 0:
|
|
@@ -2284,10 +2844,17 @@ class PersonDatabase:
|
|
|
2284
2844
|
"""
|
|
2285
2845
|
conn = self._connect()
|
|
2286
2846
|
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
|
|
2847
|
+
if self._is_v2:
|
|
2848
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
2849
|
+
cursor = conn.execute("""
|
|
2850
|
+
UPDATE people SET from_date = ?, to_date = ?
|
|
2851
|
+
WHERE source_id = ? AND source_identifier = ?
|
|
2852
|
+
""", (from_date or "", to_date or "", source_type_id, source_id))
|
|
2853
|
+
else:
|
|
2854
|
+
cursor = conn.execute("""
|
|
2855
|
+
UPDATE people SET from_date = ?, to_date = ?
|
|
2856
|
+
WHERE source = ? AND source_id = ?
|
|
2857
|
+
""", (from_date or "", to_date or "", source, source_id))
|
|
2291
2858
|
|
|
2292
2859
|
conn.commit()
|
|
2293
2860
|
return cursor.rowcount > 0
|
|
@@ -2322,10 +2889,17 @@ class PersonDatabase:
|
|
|
2322
2889
|
conn = self._connect()
|
|
2323
2890
|
|
|
2324
2891
|
# First get the person's internal ID
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
(
|
|
2328
|
-
|
|
2892
|
+
if self._is_v2:
|
|
2893
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
2894
|
+
row = conn.execute(
|
|
2895
|
+
"SELECT id FROM people WHERE source_id = ? AND source_identifier = ?",
|
|
2896
|
+
(source_type_id, source_id)
|
|
2897
|
+
).fetchone()
|
|
2898
|
+
else:
|
|
2899
|
+
row = conn.execute(
|
|
2900
|
+
"SELECT id FROM people WHERE source = ? AND source_id = ?",
|
|
2901
|
+
(source, source_id)
|
|
2902
|
+
).fetchone()
|
|
2329
2903
|
|
|
2330
2904
|
if not row:
|
|
2331
2905
|
return False
|
|
@@ -2382,7 +2956,13 @@ class PersonDatabase:
|
|
|
2382
2956
|
if query_norm == 0:
|
|
2383
2957
|
return []
|
|
2384
2958
|
query_normalized = query_embedding / query_norm
|
|
2385
|
-
|
|
2959
|
+
|
|
2960
|
+
# Use int8 quantized query if scalar table is available (75% storage savings)
|
|
2961
|
+
if self._has_scalar_table():
|
|
2962
|
+
query_int8 = self._quantize_query(query_normalized)
|
|
2963
|
+
query_blob = query_int8.tobytes()
|
|
2964
|
+
else:
|
|
2965
|
+
query_blob = query_normalized.astype(np.float32).tobytes()
|
|
2386
2966
|
|
|
2387
2967
|
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
2388
2968
|
candidate_ids: Optional[set[int]] = None
|
|
@@ -2453,13 +3033,26 @@ class PersonDatabase:
|
|
|
2453
3033
|
cursor = conn.execute(query, params)
|
|
2454
3034
|
return set(row["id"] for row in cursor)
|
|
2455
3035
|
|
|
3036
|
+
def _quantize_query(self, embedding: np.ndarray) -> np.ndarray:
|
|
3037
|
+
"""Quantize query embedding to int8 for scalar search."""
|
|
3038
|
+
return np.clip(np.round(embedding * 127), -127, 127).astype(np.int8)
|
|
3039
|
+
|
|
3040
|
+
def _has_scalar_table(self) -> bool:
|
|
3041
|
+
"""Check if scalar embedding table exists."""
|
|
3042
|
+
conn = self._conn
|
|
3043
|
+
assert conn is not None
|
|
3044
|
+
cursor = conn.execute(
|
|
3045
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='person_embeddings_scalar'"
|
|
3046
|
+
)
|
|
3047
|
+
return cursor.fetchone() is not None
|
|
3048
|
+
|
|
2456
3049
|
def _vector_search_filtered(
|
|
2457
3050
|
self,
|
|
2458
3051
|
query_blob: bytes,
|
|
2459
3052
|
candidate_ids: set[int],
|
|
2460
3053
|
top_k: int,
|
|
2461
3054
|
) -> list[tuple[PersonRecord, float]]:
|
|
2462
|
-
"""Vector search within a filtered set of candidates."""
|
|
3055
|
+
"""Vector search within a filtered set of candidates using scalar (int8) embeddings."""
|
|
2463
3056
|
conn = self._conn
|
|
2464
3057
|
assert conn is not None
|
|
2465
3058
|
|
|
@@ -2469,15 +3062,27 @@ class PersonDatabase:
|
|
|
2469
3062
|
# Build IN clause for candidate IDs
|
|
2470
3063
|
placeholders = ",".join("?" * len(candidate_ids))
|
|
2471
3064
|
|
|
2472
|
-
|
|
2473
|
-
|
|
2474
|
-
|
|
2475
|
-
|
|
2476
|
-
|
|
2477
|
-
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
|
|
3065
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
3066
|
+
if self._has_scalar_table():
|
|
3067
|
+
query = f"""
|
|
3068
|
+
SELECT
|
|
3069
|
+
e.person_id,
|
|
3070
|
+
vec_distance_cosine(e.embedding, vec_int8(?)) as distance
|
|
3071
|
+
FROM person_embeddings_scalar e
|
|
3072
|
+
WHERE e.person_id IN ({placeholders})
|
|
3073
|
+
ORDER BY distance
|
|
3074
|
+
LIMIT ?
|
|
3075
|
+
"""
|
|
3076
|
+
else:
|
|
3077
|
+
query = f"""
|
|
3078
|
+
SELECT
|
|
3079
|
+
e.person_id,
|
|
3080
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
3081
|
+
FROM person_embeddings e
|
|
3082
|
+
WHERE e.person_id IN ({placeholders})
|
|
3083
|
+
ORDER BY distance
|
|
3084
|
+
LIMIT ?
|
|
3085
|
+
"""
|
|
2481
3086
|
|
|
2482
3087
|
cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
|
|
2483
3088
|
|
|
@@ -2500,18 +3105,29 @@ class PersonDatabase:
|
|
|
2500
3105
|
query_blob: bytes,
|
|
2501
3106
|
top_k: int,
|
|
2502
3107
|
) -> list[tuple[PersonRecord, float]]:
|
|
2503
|
-
"""Full vector search without text pre-filtering."""
|
|
3108
|
+
"""Full vector search without text pre-filtering using scalar (int8) embeddings."""
|
|
2504
3109
|
conn = self._conn
|
|
2505
3110
|
assert conn is not None
|
|
2506
3111
|
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
|
-
|
|
2511
|
-
|
|
2512
|
-
|
|
2513
|
-
|
|
2514
|
-
|
|
3112
|
+
# Use scalar embedding table if available (75% storage reduction)
|
|
3113
|
+
if self._has_scalar_table():
|
|
3114
|
+
query = """
|
|
3115
|
+
SELECT
|
|
3116
|
+
person_id,
|
|
3117
|
+
vec_distance_cosine(embedding, vec_int8(?)) as distance
|
|
3118
|
+
FROM person_embeddings_scalar
|
|
3119
|
+
ORDER BY distance
|
|
3120
|
+
LIMIT ?
|
|
3121
|
+
"""
|
|
3122
|
+
else:
|
|
3123
|
+
query = """
|
|
3124
|
+
SELECT
|
|
3125
|
+
person_id,
|
|
3126
|
+
vec_distance_cosine(embedding, ?) as distance
|
|
3127
|
+
FROM person_embeddings
|
|
3128
|
+
ORDER BY distance
|
|
3129
|
+
LIMIT ?
|
|
3130
|
+
"""
|
|
2515
3131
|
cursor = conn.execute(query, (query_blob, top_k))
|
|
2516
3132
|
|
|
2517
3133
|
results = []
|
|
@@ -2531,17 +3147,29 @@ class PersonDatabase:
|
|
|
2531
3147
|
conn = self._conn
|
|
2532
3148
|
assert conn is not None
|
|
2533
3149
|
|
|
2534
|
-
|
|
2535
|
-
|
|
2536
|
-
|
|
2537
|
-
|
|
3150
|
+
if self._is_v2:
|
|
3151
|
+
# v2 schema: join view with base table for record
|
|
3152
|
+
cursor = conn.execute("""
|
|
3153
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3154
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3155
|
+
v.birth_date, v.death_date, p.record
|
|
3156
|
+
FROM people_view v
|
|
3157
|
+
JOIN people p ON v.id = p.id
|
|
3158
|
+
WHERE v.id = ?
|
|
3159
|
+
""", (person_id,))
|
|
3160
|
+
else:
|
|
3161
|
+
cursor = conn.execute("""
|
|
3162
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3163
|
+
FROM people WHERE id = ?
|
|
3164
|
+
""", (person_id,))
|
|
2538
3165
|
|
|
2539
3166
|
row = cursor.fetchone()
|
|
2540
3167
|
if row:
|
|
3168
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
2541
3169
|
return PersonRecord(
|
|
2542
3170
|
name=row["name"],
|
|
2543
3171
|
source=row["source"],
|
|
2544
|
-
source_id=row[
|
|
3172
|
+
source_id=row[source_id_field],
|
|
2545
3173
|
country=row["country"] or "",
|
|
2546
3174
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
2547
3175
|
known_for_role=row["known_for_role"] or "",
|
|
@@ -2557,18 +3185,30 @@ class PersonDatabase:
|
|
|
2557
3185
|
"""Get a person record by source and source_id."""
|
|
2558
3186
|
conn = self._connect()
|
|
2559
3187
|
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
3188
|
+
if self._is_v2:
|
|
3189
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
3190
|
+
cursor = conn.execute("""
|
|
3191
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3192
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3193
|
+
v.birth_date, v.death_date, p.record
|
|
3194
|
+
FROM people_view v
|
|
3195
|
+
JOIN people p ON v.id = p.id
|
|
3196
|
+
WHERE p.source_id = ? AND p.source_identifier = ?
|
|
3197
|
+
""", (source_type_id, source_id))
|
|
3198
|
+
else:
|
|
3199
|
+
cursor = conn.execute("""
|
|
3200
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3201
|
+
FROM people
|
|
3202
|
+
WHERE source = ? AND source_id = ?
|
|
3203
|
+
""", (source, source_id))
|
|
2565
3204
|
|
|
2566
3205
|
row = cursor.fetchone()
|
|
2567
3206
|
if row:
|
|
3207
|
+
source_id_field = "source_identifier" if self._is_v2 else "source_id"
|
|
2568
3208
|
return PersonRecord(
|
|
2569
3209
|
name=row["name"],
|
|
2570
3210
|
source=row["source"],
|
|
2571
|
-
source_id=row[
|
|
3211
|
+
source_id=row[source_id_field],
|
|
2572
3212
|
country=row["country"] or "",
|
|
2573
3213
|
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
2574
3214
|
known_for_role=row["known_for_role"] or "",
|
|
@@ -2588,12 +3228,32 @@ class PersonDatabase:
|
|
|
2588
3228
|
cursor = conn.execute("SELECT COUNT(*) FROM people")
|
|
2589
3229
|
total = cursor.fetchone()[0]
|
|
2590
3230
|
|
|
2591
|
-
# Count by person_type
|
|
2592
|
-
|
|
3231
|
+
# Count by person_type - handle both v1 and v2 schema
|
|
3232
|
+
if self._is_v2:
|
|
3233
|
+
# v2 schema - join with people_types
|
|
3234
|
+
cursor = conn.execute("""
|
|
3235
|
+
SELECT pt.name as person_type, COUNT(*) as cnt
|
|
3236
|
+
FROM people p
|
|
3237
|
+
JOIN people_types pt ON p.person_type_id = pt.id
|
|
3238
|
+
GROUP BY p.person_type_id
|
|
3239
|
+
""")
|
|
3240
|
+
else:
|
|
3241
|
+
# v1 schema
|
|
3242
|
+
cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
|
|
2593
3243
|
by_type = {row["person_type"]: row["cnt"] for row in cursor}
|
|
2594
3244
|
|
|
2595
|
-
# Count by source
|
|
2596
|
-
|
|
3245
|
+
# Count by source - handle both v1 and v2 schema
|
|
3246
|
+
if self._is_v2:
|
|
3247
|
+
# v2 schema - join with source_types
|
|
3248
|
+
cursor = conn.execute("""
|
|
3249
|
+
SELECT st.name as source, COUNT(*) as cnt
|
|
3250
|
+
FROM people p
|
|
3251
|
+
JOIN source_types st ON p.source_id = st.id
|
|
3252
|
+
GROUP BY p.source_id
|
|
3253
|
+
""")
|
|
3254
|
+
else:
|
|
3255
|
+
# v1 schema
|
|
3256
|
+
cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
|
|
2597
3257
|
by_source = {row["source"]: row["cnt"] for row in cursor}
|
|
2598
3258
|
|
|
2599
3259
|
return {
|
|
@@ -2602,60 +3262,293 @@ class PersonDatabase:
|
|
|
2602
3262
|
"by_source": by_source,
|
|
2603
3263
|
}
|
|
2604
3264
|
|
|
2605
|
-
def
|
|
2606
|
-
"""
|
|
2607
|
-
|
|
3265
|
+
def ensure_scalar_table_exists(self) -> None:
|
|
3266
|
+
"""Create scalar embedding table if it doesn't exist."""
|
|
3267
|
+
conn = self._connect()
|
|
3268
|
+
conn.execute(f"""
|
|
3269
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings_scalar USING vec0(
|
|
3270
|
+
person_id INTEGER PRIMARY KEY,
|
|
3271
|
+
embedding int8[{self._embedding_dim}]
|
|
3272
|
+
)
|
|
3273
|
+
""")
|
|
3274
|
+
conn.commit()
|
|
3275
|
+
logger.info("Ensured person_embeddings_scalar table exists")
|
|
2608
3276
|
|
|
2609
|
-
|
|
3277
|
+
def get_missing_scalar_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[int]]:
|
|
3278
|
+
"""
|
|
3279
|
+
Yield batches of person IDs that have float32 but missing scalar embeddings.
|
|
2610
3280
|
|
|
2611
3281
|
Args:
|
|
2612
|
-
|
|
3282
|
+
batch_size: Number of IDs per batch
|
|
2613
3283
|
|
|
2614
|
-
|
|
2615
|
-
|
|
3284
|
+
Yields:
|
|
3285
|
+
Lists of person_ids needing scalar embeddings
|
|
2616
3286
|
"""
|
|
2617
3287
|
conn = self._connect()
|
|
2618
3288
|
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
"SELECT DISTINCT source_id FROM people WHERE source = ?",
|
|
2622
|
-
(source,)
|
|
2623
|
-
)
|
|
2624
|
-
else:
|
|
2625
|
-
cursor = conn.execute("SELECT DISTINCT source_id FROM people")
|
|
2626
|
-
|
|
2627
|
-
return {row[0] for row in cursor}
|
|
3289
|
+
# Ensure scalar table exists before querying
|
|
3290
|
+
self.ensure_scalar_table_exists()
|
|
2628
3291
|
|
|
2629
|
-
|
|
2630
|
-
|
|
3292
|
+
last_id = 0
|
|
3293
|
+
while True:
|
|
3294
|
+
cursor = conn.execute("""
|
|
3295
|
+
SELECT e.person_id FROM person_embeddings e
|
|
3296
|
+
LEFT JOIN person_embeddings_scalar s ON e.person_id = s.person_id
|
|
3297
|
+
WHERE s.person_id IS NULL AND e.person_id > ?
|
|
3298
|
+
ORDER BY e.person_id
|
|
3299
|
+
LIMIT ?
|
|
3300
|
+
""", (last_id, batch_size))
|
|
3301
|
+
|
|
3302
|
+
rows = cursor.fetchall()
|
|
3303
|
+
if not rows:
|
|
3304
|
+
break
|
|
3305
|
+
|
|
3306
|
+
ids = [row["person_id"] for row in rows]
|
|
3307
|
+
yield ids
|
|
3308
|
+
last_id = ids[-1]
|
|
3309
|
+
|
|
3310
|
+
def get_embeddings_by_ids(self, person_ids: list[int]) -> dict[int, np.ndarray]:
|
|
3311
|
+
"""
|
|
3312
|
+
Fetch float32 embeddings for given person IDs.
|
|
3313
|
+
|
|
3314
|
+
Args:
|
|
3315
|
+
person_ids: List of person IDs
|
|
3316
|
+
|
|
3317
|
+
Returns:
|
|
3318
|
+
Dict mapping person_id to float32 embedding array
|
|
3319
|
+
"""
|
|
2631
3320
|
conn = self._connect()
|
|
2632
3321
|
|
|
2633
|
-
if
|
|
3322
|
+
if not person_ids:
|
|
3323
|
+
return {}
|
|
3324
|
+
|
|
3325
|
+
placeholders = ",".join("?" * len(person_ids))
|
|
3326
|
+
cursor = conn.execute(f"""
|
|
3327
|
+
SELECT person_id, embedding FROM person_embeddings
|
|
3328
|
+
WHERE person_id IN ({placeholders})
|
|
3329
|
+
""", person_ids)
|
|
3330
|
+
|
|
3331
|
+
result = {}
|
|
3332
|
+
for row in cursor:
|
|
3333
|
+
embedding_blob = row["embedding"]
|
|
3334
|
+
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
3335
|
+
result[row["person_id"]] = embedding
|
|
3336
|
+
return result
|
|
3337
|
+
|
|
3338
|
+
def insert_scalar_embeddings_batch(self, person_ids: list[int], embeddings: np.ndarray) -> int:
|
|
3339
|
+
"""
|
|
3340
|
+
Insert scalar (int8) embeddings for existing people.
|
|
3341
|
+
|
|
3342
|
+
Args:
|
|
3343
|
+
person_ids: List of person IDs
|
|
3344
|
+
embeddings: Matrix of int8 embeddings (N x dim)
|
|
3345
|
+
|
|
3346
|
+
Returns:
|
|
3347
|
+
Number of embeddings inserted
|
|
3348
|
+
"""
|
|
3349
|
+
conn = self._connect()
|
|
3350
|
+
count = 0
|
|
3351
|
+
|
|
3352
|
+
for person_id, embedding in zip(person_ids, embeddings):
|
|
3353
|
+
scalar_blob = embedding.astype(np.int8).tobytes()
|
|
3354
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
|
|
3355
|
+
conn.execute("""
|
|
3356
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
3357
|
+
VALUES (?, vec_int8(?))
|
|
3358
|
+
""", (person_id, scalar_blob))
|
|
3359
|
+
count += 1
|
|
3360
|
+
|
|
3361
|
+
conn.commit()
|
|
3362
|
+
return count
|
|
3363
|
+
|
|
3364
|
+
def get_scalar_embedding_count(self) -> int:
|
|
3365
|
+
"""Get count of scalar embeddings."""
|
|
3366
|
+
conn = self._connect()
|
|
3367
|
+
if not self._has_scalar_table():
|
|
3368
|
+
return 0
|
|
3369
|
+
cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings_scalar")
|
|
3370
|
+
return cursor.fetchone()[0]
|
|
3371
|
+
|
|
3372
|
+
def get_float32_embedding_count(self) -> int:
|
|
3373
|
+
"""Get count of float32 embeddings."""
|
|
3374
|
+
conn = self._connect()
|
|
3375
|
+
cursor = conn.execute("SELECT COUNT(*) FROM person_embeddings")
|
|
3376
|
+
return cursor.fetchone()[0]
|
|
3377
|
+
|
|
3378
|
+
def get_missing_all_embedding_ids(self, batch_size: int = 1000) -> Iterator[list[tuple[int, str]]]:
|
|
3379
|
+
"""
|
|
3380
|
+
Yield batches of (person_id, name) tuples for records missing both float32 and scalar embeddings.
|
|
3381
|
+
|
|
3382
|
+
Args:
|
|
3383
|
+
batch_size: Number of IDs per batch
|
|
3384
|
+
|
|
3385
|
+
Yields:
|
|
3386
|
+
Lists of (person_id, name) tuples needing embeddings generated from scratch
|
|
3387
|
+
"""
|
|
3388
|
+
conn = self._connect()
|
|
3389
|
+
|
|
3390
|
+
# Ensure scalar table exists
|
|
3391
|
+
self.ensure_scalar_table_exists()
|
|
3392
|
+
|
|
3393
|
+
last_id = 0
|
|
3394
|
+
while True:
|
|
2634
3395
|
cursor = conn.execute("""
|
|
2635
|
-
SELECT
|
|
2636
|
-
|
|
2637
|
-
WHERE
|
|
2638
|
-
|
|
3396
|
+
SELECT p.id, p.name FROM people p
|
|
3397
|
+
LEFT JOIN person_embeddings e ON p.id = e.person_id
|
|
3398
|
+
WHERE e.person_id IS NULL AND p.id > ?
|
|
3399
|
+
ORDER BY p.id
|
|
3400
|
+
LIMIT ?
|
|
3401
|
+
""", (last_id, batch_size))
|
|
3402
|
+
|
|
3403
|
+
rows = cursor.fetchall()
|
|
3404
|
+
if not rows:
|
|
3405
|
+
break
|
|
3406
|
+
|
|
3407
|
+
results = [(row["id"], row["name"]) for row in rows]
|
|
3408
|
+
yield results
|
|
3409
|
+
last_id = results[-1][0]
|
|
3410
|
+
|
|
3411
|
+
def insert_both_embeddings_batch(
|
|
3412
|
+
self,
|
|
3413
|
+
person_ids: list[int],
|
|
3414
|
+
fp32_embeddings: np.ndarray,
|
|
3415
|
+
int8_embeddings: np.ndarray,
|
|
3416
|
+
) -> int:
|
|
3417
|
+
"""
|
|
3418
|
+
Insert both float32 and int8 embeddings for existing people.
|
|
3419
|
+
|
|
3420
|
+
Args:
|
|
3421
|
+
person_ids: List of person IDs
|
|
3422
|
+
fp32_embeddings: Matrix of float32 embeddings (N x dim)
|
|
3423
|
+
int8_embeddings: Matrix of int8 embeddings (N x dim)
|
|
3424
|
+
|
|
3425
|
+
Returns:
|
|
3426
|
+
Number of embeddings inserted
|
|
3427
|
+
"""
|
|
3428
|
+
conn = self._connect()
|
|
3429
|
+
count = 0
|
|
3430
|
+
|
|
3431
|
+
for person_id, fp32, int8 in zip(person_ids, fp32_embeddings, int8_embeddings):
|
|
3432
|
+
# Insert float32
|
|
3433
|
+
fp32_blob = fp32.astype(np.float32).tobytes()
|
|
3434
|
+
conn.execute("DELETE FROM person_embeddings WHERE person_id = ?", (person_id,))
|
|
3435
|
+
conn.execute("""
|
|
3436
|
+
INSERT INTO person_embeddings (person_id, embedding)
|
|
3437
|
+
VALUES (?, ?)
|
|
3438
|
+
""", (person_id, fp32_blob))
|
|
3439
|
+
|
|
3440
|
+
# Insert int8
|
|
3441
|
+
int8_blob = int8.astype(np.int8).tobytes()
|
|
3442
|
+
conn.execute("DELETE FROM person_embeddings_scalar WHERE person_id = ?", (person_id,))
|
|
3443
|
+
conn.execute("""
|
|
3444
|
+
INSERT INTO person_embeddings_scalar (person_id, embedding)
|
|
3445
|
+
VALUES (?, vec_int8(?))
|
|
3446
|
+
""", (person_id, int8_blob))
|
|
3447
|
+
|
|
3448
|
+
count += 1
|
|
3449
|
+
|
|
3450
|
+
conn.commit()
|
|
3451
|
+
return count
|
|
3452
|
+
|
|
3453
|
+
def get_all_source_ids(self, source: Optional[str] = None) -> set[str]:
|
|
3454
|
+
"""
|
|
3455
|
+
Get all source_ids from the people table.
|
|
3456
|
+
|
|
3457
|
+
Useful for resume operations to skip already-imported records.
|
|
3458
|
+
|
|
3459
|
+
Args:
|
|
3460
|
+
source: Optional source filter (e.g., "wikidata")
|
|
3461
|
+
|
|
3462
|
+
Returns:
|
|
3463
|
+
Set of source_id strings (e.g., Q codes for Wikidata)
|
|
3464
|
+
"""
|
|
3465
|
+
conn = self._connect()
|
|
3466
|
+
|
|
3467
|
+
if self._is_v2:
|
|
3468
|
+
id_col = "source_identifier"
|
|
3469
|
+
if source:
|
|
3470
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
3471
|
+
cursor = conn.execute(
|
|
3472
|
+
f"SELECT DISTINCT {id_col} FROM people WHERE source_id = ?",
|
|
3473
|
+
(source_type_id,)
|
|
3474
|
+
)
|
|
3475
|
+
else:
|
|
3476
|
+
cursor = conn.execute(f"SELECT DISTINCT {id_col} FROM people")
|
|
2639
3477
|
else:
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
|
|
3478
|
+
if source:
|
|
3479
|
+
cursor = conn.execute(
|
|
3480
|
+
"SELECT DISTINCT source_id FROM people WHERE source = ?",
|
|
3481
|
+
(source,)
|
|
3482
|
+
)
|
|
3483
|
+
else:
|
|
3484
|
+
cursor = conn.execute("SELECT DISTINCT source_id FROM people")
|
|
2644
3485
|
|
|
2645
|
-
for row in cursor
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
3486
|
+
return {row[0] for row in cursor}
|
|
3487
|
+
|
|
3488
|
+
def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
|
|
3489
|
+
"""Iterate over all person records, optionally filtered by source."""
|
|
3490
|
+
conn = self._connect()
|
|
3491
|
+
|
|
3492
|
+
if self._is_v2:
|
|
3493
|
+
if source:
|
|
3494
|
+
source_type_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
3495
|
+
cursor = conn.execute("""
|
|
3496
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3497
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3498
|
+
v.birth_date, v.death_date, p.record
|
|
3499
|
+
FROM people_view v
|
|
3500
|
+
JOIN people p ON v.id = p.id
|
|
3501
|
+
WHERE p.source_id = ?
|
|
3502
|
+
""", (source_type_id,))
|
|
3503
|
+
else:
|
|
3504
|
+
cursor = conn.execute("""
|
|
3505
|
+
SELECT v.name, v.source, v.source_identifier, v.country, v.person_type,
|
|
3506
|
+
v.known_for_role, v.known_for_org, v.known_for_org_id,
|
|
3507
|
+
v.birth_date, v.death_date, p.record
|
|
3508
|
+
FROM people_view v
|
|
3509
|
+
JOIN people p ON v.id = p.id
|
|
3510
|
+
""")
|
|
3511
|
+
for row in cursor:
|
|
3512
|
+
yield PersonRecord(
|
|
3513
|
+
name=row["name"],
|
|
3514
|
+
source=row["source"],
|
|
3515
|
+
source_id=row["source_identifier"],
|
|
3516
|
+
country=row["country"] or "",
|
|
3517
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
3518
|
+
known_for_role=row["known_for_role"] or "",
|
|
3519
|
+
known_for_org=row["known_for_org"] or "",
|
|
3520
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
3521
|
+
birth_date=row["birth_date"] or "",
|
|
3522
|
+
death_date=row["death_date"] or "",
|
|
3523
|
+
record=json.loads(row["record"]),
|
|
3524
|
+
)
|
|
3525
|
+
else:
|
|
3526
|
+
if source:
|
|
3527
|
+
cursor = conn.execute("""
|
|
3528
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3529
|
+
FROM people
|
|
3530
|
+
WHERE source = ?
|
|
3531
|
+
""", (source,))
|
|
3532
|
+
else:
|
|
3533
|
+
cursor = conn.execute("""
|
|
3534
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, known_for_org_id, birth_date, death_date, record
|
|
3535
|
+
FROM people
|
|
3536
|
+
""")
|
|
3537
|
+
|
|
3538
|
+
for row in cursor:
|
|
3539
|
+
yield PersonRecord(
|
|
3540
|
+
name=row["name"],
|
|
3541
|
+
source=row["source"],
|
|
3542
|
+
source_id=row["source_id"],
|
|
3543
|
+
country=row["country"] or "",
|
|
3544
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
3545
|
+
known_for_role=row["known_for_role"] or "",
|
|
3546
|
+
known_for_org=row["known_for_org"] or "",
|
|
3547
|
+
known_for_org_id=row["known_for_org_id"], # Can be None
|
|
3548
|
+
birth_date=row["birth_date"] or "",
|
|
3549
|
+
death_date=row["death_date"] or "",
|
|
3550
|
+
record=json.loads(row["record"]),
|
|
3551
|
+
)
|
|
2659
3552
|
|
|
2660
3553
|
def resolve_qid_labels(
|
|
2661
3554
|
self,
|
|
@@ -2680,6 +3573,11 @@ class PersonDatabase:
|
|
|
2680
3573
|
"""
|
|
2681
3574
|
conn = self._connect()
|
|
2682
3575
|
|
|
3576
|
+
# v2 schema stores QIDs as integers, not text - this method doesn't apply
|
|
3577
|
+
if self._is_v2:
|
|
3578
|
+
logger.debug("Skipping resolve_qid_labels for v2 schema (QIDs stored as integers)")
|
|
3579
|
+
return 0, 0
|
|
3580
|
+
|
|
2683
3581
|
# Find all records with QIDs in any field (role or org - these are in unique constraint)
|
|
2684
3582
|
# Country is not part of unique constraint so can be updated directly
|
|
2685
3583
|
cursor = conn.execute("""
|
|
@@ -2752,6 +3650,11 @@ class PersonDatabase:
|
|
|
2752
3650
|
Set of QIDs (starting with 'Q') found in country, role, or org fields
|
|
2753
3651
|
"""
|
|
2754
3652
|
conn = self._connect()
|
|
3653
|
+
|
|
3654
|
+
# v2 schema stores QIDs as integers, not text - this method doesn't apply
|
|
3655
|
+
if self._is_v2:
|
|
3656
|
+
return set()
|
|
3657
|
+
|
|
2755
3658
|
qids: set[str] = set()
|
|
2756
3659
|
|
|
2757
3660
|
# Get QIDs from country field
|
|
@@ -2797,11 +3700,27 @@ class PersonDatabase:
|
|
|
2797
3700
|
"""
|
|
2798
3701
|
conn = self._connect()
|
|
2799
3702
|
count = 0
|
|
3703
|
+
skipped = 0
|
|
2800
3704
|
|
|
2801
3705
|
for qid, label in label_map.items():
|
|
3706
|
+
# Skip non-Q IDs (e.g., property IDs like P19)
|
|
3707
|
+
if not qid.startswith("Q"):
|
|
3708
|
+
skipped += 1
|
|
3709
|
+
continue
|
|
3710
|
+
|
|
3711
|
+
# v2 schema stores QID as integer without Q prefix
|
|
3712
|
+
if self._is_v2:
|
|
3713
|
+
try:
|
|
3714
|
+
qid_val: str | int = int(qid[1:])
|
|
3715
|
+
except ValueError:
|
|
3716
|
+
skipped += 1
|
|
3717
|
+
continue
|
|
3718
|
+
else:
|
|
3719
|
+
qid_val = qid
|
|
3720
|
+
|
|
2802
3721
|
conn.execute(
|
|
2803
3722
|
"INSERT OR REPLACE INTO qid_labels (qid, label) VALUES (?, ?)",
|
|
2804
|
-
(
|
|
3723
|
+
(qid_val, label)
|
|
2805
3724
|
)
|
|
2806
3725
|
count += 1
|
|
2807
3726
|
|
|
@@ -2824,9 +3743,16 @@ class PersonDatabase:
|
|
|
2824
3743
|
Label string or None if not found
|
|
2825
3744
|
"""
|
|
2826
3745
|
conn = self._connect()
|
|
3746
|
+
|
|
3747
|
+
# v2 schema stores QID as integer without Q prefix
|
|
3748
|
+
if self._is_v2:
|
|
3749
|
+
qid_val: str | int = int(qid[1:]) if qid.startswith("Q") else int(qid)
|
|
3750
|
+
else:
|
|
3751
|
+
qid_val = qid
|
|
3752
|
+
|
|
2827
3753
|
cursor = conn.execute(
|
|
2828
3754
|
"SELECT label FROM qid_labels WHERE qid = ?",
|
|
2829
|
-
(
|
|
3755
|
+
(qid_val,)
|
|
2830
3756
|
)
|
|
2831
3757
|
row = cursor.fetchone()
|
|
2832
3758
|
return row["label"] if row else None
|
|
@@ -2879,11 +3805,19 @@ class PersonDatabase:
|
|
|
2879
3805
|
logger.info("Phase 1: Building person index...")
|
|
2880
3806
|
|
|
2881
3807
|
# Load all people with their normalized names and org info
|
|
2882
|
-
|
|
2883
|
-
|
|
2884
|
-
|
|
2885
|
-
|
|
2886
|
-
|
|
3808
|
+
if self._is_v2:
|
|
3809
|
+
cursor = conn.execute("""
|
|
3810
|
+
SELECT p.id, p.name, p.name_normalized, s.name as source, p.source_identifier as source_id,
|
|
3811
|
+
p.known_for_org, p.known_for_org_id, p.from_date, p.to_date
|
|
3812
|
+
FROM people p
|
|
3813
|
+
JOIN source_types s ON p.source_id = s.id
|
|
3814
|
+
""")
|
|
3815
|
+
else:
|
|
3816
|
+
cursor = conn.execute("""
|
|
3817
|
+
SELECT id, name, name_normalized, source, source_id,
|
|
3818
|
+
known_for_org, known_for_org_id, from_date, to_date
|
|
3819
|
+
FROM people
|
|
3820
|
+
""")
|
|
2887
3821
|
|
|
2888
3822
|
people: list[dict] = []
|
|
2889
3823
|
for row in cursor:
|
|
@@ -3032,3 +3966,860 @@ class PersonDatabase:
|
|
|
3032
3966
|
"UPDATE people SET canon_id = ?, canon_size = ? WHERE id = ?",
|
|
3033
3967
|
(canon_id, canon_size, person_id)
|
|
3034
3968
|
)
|
|
3969
|
+
|
|
3970
|
+
|
|
3971
|
+
# =============================================================================
|
|
3972
|
+
# Module-level singletons for new v2 databases
|
|
3973
|
+
# =============================================================================
|
|
3974
|
+
|
|
3975
|
+
_roles_database_instances: dict[str, "RolesDatabase"] = {}
|
|
3976
|
+
_locations_database_instances: dict[str, "LocationsDatabase"] = {}
|
|
3977
|
+
|
|
3978
|
+
|
|
3979
|
+
def get_roles_database(db_path: Optional[str | Path] = None) -> "RolesDatabase":
|
|
3980
|
+
"""
|
|
3981
|
+
Get a singleton RolesDatabase instance for the given path.
|
|
3982
|
+
|
|
3983
|
+
Args:
|
|
3984
|
+
db_path: Path to database file
|
|
3985
|
+
|
|
3986
|
+
Returns:
|
|
3987
|
+
Shared RolesDatabase instance
|
|
3988
|
+
"""
|
|
3989
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
3990
|
+
if path_key not in _roles_database_instances:
|
|
3991
|
+
logger.debug(f"Creating new RolesDatabase instance for {path_key}")
|
|
3992
|
+
_roles_database_instances[path_key] = RolesDatabase(db_path=db_path)
|
|
3993
|
+
return _roles_database_instances[path_key]
|
|
3994
|
+
|
|
3995
|
+
|
|
3996
|
+
def get_locations_database(db_path: Optional[str | Path] = None) -> "LocationsDatabase":
|
|
3997
|
+
"""
|
|
3998
|
+
Get a singleton LocationsDatabase instance for the given path.
|
|
3999
|
+
|
|
4000
|
+
Args:
|
|
4001
|
+
db_path: Path to database file
|
|
4002
|
+
|
|
4003
|
+
Returns:
|
|
4004
|
+
Shared LocationsDatabase instance
|
|
4005
|
+
"""
|
|
4006
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
4007
|
+
if path_key not in _locations_database_instances:
|
|
4008
|
+
logger.debug(f"Creating new LocationsDatabase instance for {path_key}")
|
|
4009
|
+
_locations_database_instances[path_key] = LocationsDatabase(db_path=db_path)
|
|
4010
|
+
return _locations_database_instances[path_key]
|
|
4011
|
+
|
|
4012
|
+
|
|
4013
|
+
# =============================================================================
|
|
4014
|
+
# ROLES DATABASE (v2)
|
|
4015
|
+
# =============================================================================
|
|
4016
|
+
|
|
4017
|
+
|
|
4018
|
+
class RolesDatabase:
|
|
4019
|
+
"""
|
|
4020
|
+
SQLite database for job titles/roles.
|
|
4021
|
+
|
|
4022
|
+
Stores normalized role records with source tracking and supports
|
|
4023
|
+
canonicalization to group equivalent roles (e.g., CEO, Chief Executive).
|
|
4024
|
+
"""
|
|
4025
|
+
|
|
4026
|
+
def __init__(self, db_path: Optional[str | Path] = None):
|
|
4027
|
+
"""
|
|
4028
|
+
Initialize the roles database.
|
|
4029
|
+
|
|
4030
|
+
Args:
|
|
4031
|
+
db_path: Path to database file (creates if not exists)
|
|
4032
|
+
"""
|
|
4033
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
4034
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
4035
|
+
self._role_cache: dict[str, int] = {} # name_normalized -> role_id
|
|
4036
|
+
|
|
4037
|
+
def _connect(self) -> sqlite3.Connection:
|
|
4038
|
+
"""Get or create database connection using shared connection pool."""
|
|
4039
|
+
if self._conn is not None:
|
|
4040
|
+
return self._conn
|
|
4041
|
+
|
|
4042
|
+
self._conn = _get_shared_connection(self._db_path)
|
|
4043
|
+
self._create_tables()
|
|
4044
|
+
return self._conn
|
|
4045
|
+
|
|
4046
|
+
def _create_tables(self) -> None:
|
|
4047
|
+
"""Create roles table and indexes."""
|
|
4048
|
+
conn = self._conn
|
|
4049
|
+
assert conn is not None
|
|
4050
|
+
|
|
4051
|
+
# Check if enum tables exist, create and seed if not
|
|
4052
|
+
cursor = conn.execute(
|
|
4053
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
|
|
4054
|
+
)
|
|
4055
|
+
if not cursor.fetchone():
|
|
4056
|
+
logger.info("Creating enum tables for v2 schema...")
|
|
4057
|
+
from .schema_v2 import (
|
|
4058
|
+
CREATE_SOURCE_TYPES,
|
|
4059
|
+
CREATE_PEOPLE_TYPES,
|
|
4060
|
+
CREATE_ORGANIZATION_TYPES,
|
|
4061
|
+
CREATE_SIMPLIFIED_LOCATION_TYPES,
|
|
4062
|
+
CREATE_LOCATION_TYPES,
|
|
4063
|
+
)
|
|
4064
|
+
conn.execute(CREATE_SOURCE_TYPES)
|
|
4065
|
+
conn.execute(CREATE_PEOPLE_TYPES)
|
|
4066
|
+
conn.execute(CREATE_ORGANIZATION_TYPES)
|
|
4067
|
+
conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
|
|
4068
|
+
conn.execute(CREATE_LOCATION_TYPES)
|
|
4069
|
+
seed_all_enums(conn)
|
|
4070
|
+
|
|
4071
|
+
# Create roles table
|
|
4072
|
+
conn.execute("""
|
|
4073
|
+
CREATE TABLE IF NOT EXISTS roles (
|
|
4074
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
4075
|
+
qid INTEGER,
|
|
4076
|
+
name TEXT NOT NULL,
|
|
4077
|
+
name_normalized TEXT NOT NULL,
|
|
4078
|
+
source_id INTEGER NOT NULL DEFAULT 4,
|
|
4079
|
+
source_identifier TEXT,
|
|
4080
|
+
record TEXT NOT NULL DEFAULT '{}',
|
|
4081
|
+
canon_id INTEGER DEFAULT NULL,
|
|
4082
|
+
canon_size INTEGER DEFAULT 1,
|
|
4083
|
+
UNIQUE(name_normalized, source_id)
|
|
4084
|
+
)
|
|
4085
|
+
""")
|
|
4086
|
+
|
|
4087
|
+
# Create indexes
|
|
4088
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name ON roles(name)")
|
|
4089
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_name_normalized ON roles(name_normalized)")
|
|
4090
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_qid ON roles(qid)")
|
|
4091
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_source_id ON roles(source_id)")
|
|
4092
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_roles_canon_id ON roles(canon_id)")
|
|
4093
|
+
|
|
4094
|
+
conn.commit()
|
|
4095
|
+
|
|
4096
|
+
def close(self) -> None:
|
|
4097
|
+
"""Clear connection reference."""
|
|
4098
|
+
self._conn = None
|
|
4099
|
+
|
|
4100
|
+
def get_or_create(
|
|
4101
|
+
self,
|
|
4102
|
+
name: str,
|
|
4103
|
+
source_id: int = 4, # wikidata
|
|
4104
|
+
qid: Optional[int] = None,
|
|
4105
|
+
source_identifier: Optional[str] = None,
|
|
4106
|
+
) -> int:
|
|
4107
|
+
"""
|
|
4108
|
+
Get or create a role record.
|
|
4109
|
+
|
|
4110
|
+
Args:
|
|
4111
|
+
name: Role/title name
|
|
4112
|
+
source_id: FK to source_types table
|
|
4113
|
+
qid: Optional Wikidata QID as integer
|
|
4114
|
+
source_identifier: Optional source-specific identifier
|
|
4115
|
+
|
|
4116
|
+
Returns:
|
|
4117
|
+
Role ID
|
|
4118
|
+
"""
|
|
4119
|
+
if not name:
|
|
4120
|
+
raise ValueError("Role name cannot be empty")
|
|
4121
|
+
|
|
4122
|
+
conn = self._connect()
|
|
4123
|
+
name_normalized = name.lower().strip()
|
|
4124
|
+
|
|
4125
|
+
# Check cache
|
|
4126
|
+
cache_key = f"{name_normalized}:{source_id}"
|
|
4127
|
+
if cache_key in self._role_cache:
|
|
4128
|
+
return self._role_cache[cache_key]
|
|
4129
|
+
|
|
4130
|
+
# Check database
|
|
4131
|
+
cursor = conn.execute(
|
|
4132
|
+
"SELECT id FROM roles WHERE name_normalized = ? AND source_id = ?",
|
|
4133
|
+
(name_normalized, source_id)
|
|
4134
|
+
)
|
|
4135
|
+
row = cursor.fetchone()
|
|
4136
|
+
if row:
|
|
4137
|
+
role_id = row["id"]
|
|
4138
|
+
self._role_cache[cache_key] = role_id
|
|
4139
|
+
return role_id
|
|
4140
|
+
|
|
4141
|
+
# Create new role
|
|
4142
|
+
cursor = conn.execute(
|
|
4143
|
+
"""
|
|
4144
|
+
INSERT INTO roles (name, name_normalized, source_id, qid, source_identifier)
|
|
4145
|
+
VALUES (?, ?, ?, ?, ?)
|
|
4146
|
+
""",
|
|
4147
|
+
(name, name_normalized, source_id, qid, source_identifier)
|
|
4148
|
+
)
|
|
4149
|
+
role_id = cursor.lastrowid
|
|
4150
|
+
assert role_id is not None
|
|
4151
|
+
conn.commit()
|
|
4152
|
+
|
|
4153
|
+
self._role_cache[cache_key] = role_id
|
|
4154
|
+
return role_id
|
|
4155
|
+
|
|
4156
|
+
def get_by_id(self, role_id: int) -> Optional[RoleRecord]:
|
|
4157
|
+
"""Get a role record by ID."""
|
|
4158
|
+
conn = self._connect()
|
|
4159
|
+
|
|
4160
|
+
cursor = conn.execute(
|
|
4161
|
+
"SELECT id, qid, name, source_id, source_identifier, record FROM roles WHERE id = ?",
|
|
4162
|
+
(role_id,)
|
|
4163
|
+
)
|
|
4164
|
+
row = cursor.fetchone()
|
|
4165
|
+
if row:
|
|
4166
|
+
source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
|
|
4167
|
+
return RoleRecord(
|
|
4168
|
+
name=row["name"],
|
|
4169
|
+
source=source_name,
|
|
4170
|
+
source_id=row["source_identifier"],
|
|
4171
|
+
qid=row["qid"],
|
|
4172
|
+
record=json.loads(row["record"]) if row["record"] else {},
|
|
4173
|
+
)
|
|
4174
|
+
return None
|
|
4175
|
+
|
|
4176
|
+
def search(
|
|
4177
|
+
self,
|
|
4178
|
+
query: str,
|
|
4179
|
+
top_k: int = 10,
|
|
4180
|
+
) -> list[tuple[int, str, float]]:
|
|
4181
|
+
"""
|
|
4182
|
+
Search for roles by name.
|
|
4183
|
+
|
|
4184
|
+
Args:
|
|
4185
|
+
query: Search query
|
|
4186
|
+
top_k: Maximum results to return
|
|
4187
|
+
|
|
4188
|
+
Returns:
|
|
4189
|
+
List of (role_id, role_name, score) tuples
|
|
4190
|
+
"""
|
|
4191
|
+
conn = self._connect()
|
|
4192
|
+
query_normalized = query.lower().strip()
|
|
4193
|
+
|
|
4194
|
+
# Exact match first
|
|
4195
|
+
cursor = conn.execute(
|
|
4196
|
+
"SELECT id, name FROM roles WHERE name_normalized = ? LIMIT 1",
|
|
4197
|
+
(query_normalized,)
|
|
4198
|
+
)
|
|
4199
|
+
row = cursor.fetchone()
|
|
4200
|
+
if row:
|
|
4201
|
+
return [(row["id"], row["name"], 1.0)]
|
|
4202
|
+
|
|
4203
|
+
# LIKE match
|
|
4204
|
+
cursor = conn.execute(
|
|
4205
|
+
"""
|
|
4206
|
+
SELECT id, name FROM roles
|
|
4207
|
+
WHERE name_normalized LIKE ?
|
|
4208
|
+
ORDER BY length(name)
|
|
4209
|
+
LIMIT ?
|
|
4210
|
+
""",
|
|
4211
|
+
(f"%{query_normalized}%", top_k)
|
|
4212
|
+
)
|
|
4213
|
+
|
|
4214
|
+
results = []
|
|
4215
|
+
for row in cursor:
|
|
4216
|
+
# Simple score based on match quality
|
|
4217
|
+
name_normalized = row["name"].lower()
|
|
4218
|
+
if query_normalized == name_normalized:
|
|
4219
|
+
score = 1.0
|
|
4220
|
+
elif name_normalized.startswith(query_normalized):
|
|
4221
|
+
score = 0.9
|
|
4222
|
+
else:
|
|
4223
|
+
score = 0.7
|
|
4224
|
+
results.append((row["id"], row["name"], score))
|
|
4225
|
+
|
|
4226
|
+
return results
|
|
4227
|
+
|
|
4228
|
+
def get_stats(self) -> dict[str, int]:
|
|
4229
|
+
"""Get statistics about the roles table."""
|
|
4230
|
+
conn = self._connect()
|
|
4231
|
+
|
|
4232
|
+
cursor = conn.execute("SELECT COUNT(*) FROM roles")
|
|
4233
|
+
total = cursor.fetchone()[0]
|
|
4234
|
+
|
|
4235
|
+
cursor = conn.execute("SELECT COUNT(*) FROM roles WHERE canon_id IS NOT NULL")
|
|
4236
|
+
canonicalized = cursor.fetchone()[0]
|
|
4237
|
+
|
|
4238
|
+
cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM roles WHERE canon_id IS NOT NULL")
|
|
4239
|
+
groups = cursor.fetchone()[0]
|
|
4240
|
+
|
|
4241
|
+
return {
|
|
4242
|
+
"total_roles": total,
|
|
4243
|
+
"canonicalized": canonicalized,
|
|
4244
|
+
"canonical_groups": groups,
|
|
4245
|
+
}
|
|
4246
|
+
|
|
4247
|
+
|
|
4248
|
+
# =============================================================================
|
|
4249
|
+
# LOCATIONS DATABASE (v2)
|
|
4250
|
+
# =============================================================================
|
|
4251
|
+
|
|
4252
|
+
|
|
4253
|
+
class LocationsDatabase:
|
|
4254
|
+
"""
|
|
4255
|
+
SQLite database for geopolitical locations.
|
|
4256
|
+
|
|
4257
|
+
Stores countries, states, cities with hierarchical relationships
|
|
4258
|
+
and type classification. Supports pycountry integration.
|
|
4259
|
+
"""
|
|
4260
|
+
|
|
4261
|
+
def __init__(self, db_path: Optional[str | Path] = None):
|
|
4262
|
+
"""
|
|
4263
|
+
Initialize the locations database.
|
|
4264
|
+
|
|
4265
|
+
Args:
|
|
4266
|
+
db_path: Path to database file (creates if not exists)
|
|
4267
|
+
"""
|
|
4268
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
4269
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
4270
|
+
self._location_cache: dict[str, int] = {} # lookup_key -> location_id
|
|
4271
|
+
self._location_type_cache: dict[str, int] = {} # type_name -> type_id
|
|
4272
|
+
self._location_type_qid_cache: dict[int, int] = {} # qid -> type_id
|
|
4273
|
+
|
|
4274
|
+
def _connect(self) -> sqlite3.Connection:
|
|
4275
|
+
"""Get or create database connection using shared connection pool."""
|
|
4276
|
+
if self._conn is not None:
|
|
4277
|
+
return self._conn
|
|
4278
|
+
|
|
4279
|
+
self._conn = _get_shared_connection(self._db_path)
|
|
4280
|
+
self._create_tables()
|
|
4281
|
+
self._build_caches()
|
|
4282
|
+
return self._conn
|
|
4283
|
+
|
|
4284
|
+
def _create_tables(self) -> None:
|
|
4285
|
+
"""Create locations table and indexes."""
|
|
4286
|
+
conn = self._conn
|
|
4287
|
+
assert conn is not None
|
|
4288
|
+
|
|
4289
|
+
# Check if enum tables exist, create and seed if not
|
|
4290
|
+
cursor = conn.execute(
|
|
4291
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='source_types'"
|
|
4292
|
+
)
|
|
4293
|
+
if not cursor.fetchone():
|
|
4294
|
+
logger.info("Creating enum tables for v2 schema...")
|
|
4295
|
+
from .schema_v2 import (
|
|
4296
|
+
CREATE_SOURCE_TYPES,
|
|
4297
|
+
CREATE_PEOPLE_TYPES,
|
|
4298
|
+
CREATE_ORGANIZATION_TYPES,
|
|
4299
|
+
CREATE_SIMPLIFIED_LOCATION_TYPES,
|
|
4300
|
+
CREATE_LOCATION_TYPES,
|
|
4301
|
+
)
|
|
4302
|
+
conn.execute(CREATE_SOURCE_TYPES)
|
|
4303
|
+
conn.execute(CREATE_PEOPLE_TYPES)
|
|
4304
|
+
conn.execute(CREATE_ORGANIZATION_TYPES)
|
|
4305
|
+
conn.execute(CREATE_SIMPLIFIED_LOCATION_TYPES)
|
|
4306
|
+
conn.execute(CREATE_LOCATION_TYPES)
|
|
4307
|
+
seed_all_enums(conn)
|
|
4308
|
+
|
|
4309
|
+
# Create locations table
|
|
4310
|
+
conn.execute("""
|
|
4311
|
+
CREATE TABLE IF NOT EXISTS locations (
|
|
4312
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
4313
|
+
qid INTEGER,
|
|
4314
|
+
name TEXT NOT NULL,
|
|
4315
|
+
name_normalized TEXT NOT NULL,
|
|
4316
|
+
source_id INTEGER NOT NULL DEFAULT 4,
|
|
4317
|
+
source_identifier TEXT,
|
|
4318
|
+
parent_ids TEXT,
|
|
4319
|
+
location_type_id INTEGER NOT NULL DEFAULT 2,
|
|
4320
|
+
record TEXT NOT NULL DEFAULT '{}',
|
|
4321
|
+
from_date TEXT DEFAULT NULL,
|
|
4322
|
+
to_date TEXT DEFAULT NULL,
|
|
4323
|
+
canon_id INTEGER DEFAULT NULL,
|
|
4324
|
+
canon_size INTEGER DEFAULT 1,
|
|
4325
|
+
UNIQUE(source_identifier, source_id)
|
|
4326
|
+
)
|
|
4327
|
+
""")
|
|
4328
|
+
|
|
4329
|
+
# Create indexes
|
|
4330
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name ON locations(name)")
|
|
4331
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_name_normalized ON locations(name_normalized)")
|
|
4332
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_qid ON locations(qid)")
|
|
4333
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_source_id ON locations(source_id)")
|
|
4334
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_location_type_id ON locations(location_type_id)")
|
|
4335
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_locations_canon_id ON locations(canon_id)")
|
|
4336
|
+
|
|
4337
|
+
conn.commit()
|
|
4338
|
+
|
|
4339
|
+
def _build_caches(self) -> None:
|
|
4340
|
+
"""Build lookup caches from database and seed data."""
|
|
4341
|
+
# Load location type caches from seed data
|
|
4342
|
+
self._location_type_cache = dict(LOCATION_TYPE_NAME_TO_ID)
|
|
4343
|
+
self._location_type_qid_cache = dict(LOCATION_TYPE_QID_TO_ID)
|
|
4344
|
+
|
|
4345
|
+
# Load existing locations into cache
|
|
4346
|
+
conn = self._conn
|
|
4347
|
+
if conn:
|
|
4348
|
+
cursor = conn.execute(
|
|
4349
|
+
"SELECT id, name_normalized, source_identifier FROM locations"
|
|
4350
|
+
)
|
|
4351
|
+
for row in cursor:
|
|
4352
|
+
# Cache by normalized name
|
|
4353
|
+
self._location_cache[row["name_normalized"]] = row["id"]
|
|
4354
|
+
# Also cache by source_identifier
|
|
4355
|
+
if row["source_identifier"]:
|
|
4356
|
+
self._location_cache[row["source_identifier"].lower()] = row["id"]
|
|
4357
|
+
|
|
4358
|
+
def close(self) -> None:
|
|
4359
|
+
"""Clear connection reference."""
|
|
4360
|
+
self._conn = None
|
|
4361
|
+
|
|
4362
|
+
def get_or_create(
|
|
4363
|
+
self,
|
|
4364
|
+
name: str,
|
|
4365
|
+
location_type_id: int,
|
|
4366
|
+
source_id: int = 4, # wikidata
|
|
4367
|
+
qid: Optional[int] = None,
|
|
4368
|
+
source_identifier: Optional[str] = None,
|
|
4369
|
+
parent_ids: Optional[list[int]] = None,
|
|
4370
|
+
) -> int:
|
|
4371
|
+
"""
|
|
4372
|
+
Get or create a location record.
|
|
4373
|
+
|
|
4374
|
+
Args:
|
|
4375
|
+
name: Location name
|
|
4376
|
+
location_type_id: FK to location_types table
|
|
4377
|
+
source_id: FK to source_types table
|
|
4378
|
+
qid: Optional Wikidata QID as integer
|
|
4379
|
+
source_identifier: Optional source-specific identifier (e.g., "US", "CA")
|
|
4380
|
+
parent_ids: Optional list of parent location IDs
|
|
4381
|
+
|
|
4382
|
+
Returns:
|
|
4383
|
+
Location ID
|
|
4384
|
+
"""
|
|
4385
|
+
if not name:
|
|
4386
|
+
raise ValueError("Location name cannot be empty")
|
|
4387
|
+
|
|
4388
|
+
conn = self._connect()
|
|
4389
|
+
name_normalized = name.lower().strip()
|
|
4390
|
+
|
|
4391
|
+
# Check cache by source_identifier first (more specific)
|
|
4392
|
+
if source_identifier:
|
|
4393
|
+
cache_key = source_identifier.lower()
|
|
4394
|
+
if cache_key in self._location_cache:
|
|
4395
|
+
return self._location_cache[cache_key]
|
|
4396
|
+
|
|
4397
|
+
# Check cache by normalized name
|
|
4398
|
+
if name_normalized in self._location_cache:
|
|
4399
|
+
return self._location_cache[name_normalized]
|
|
4400
|
+
|
|
4401
|
+
# Check database
|
|
4402
|
+
if source_identifier:
|
|
4403
|
+
cursor = conn.execute(
|
|
4404
|
+
"SELECT id FROM locations WHERE source_identifier = ? AND source_id = ?",
|
|
4405
|
+
(source_identifier, source_id)
|
|
4406
|
+
)
|
|
4407
|
+
else:
|
|
4408
|
+
cursor = conn.execute(
|
|
4409
|
+
"SELECT id FROM locations WHERE name_normalized = ? AND source_id = ?",
|
|
4410
|
+
(name_normalized, source_id)
|
|
4411
|
+
)
|
|
4412
|
+
|
|
4413
|
+
row = cursor.fetchone()
|
|
4414
|
+
if row:
|
|
4415
|
+
location_id = row["id"]
|
|
4416
|
+
self._location_cache[name_normalized] = location_id
|
|
4417
|
+
if source_identifier:
|
|
4418
|
+
self._location_cache[source_identifier.lower()] = location_id
|
|
4419
|
+
return location_id
|
|
4420
|
+
|
|
4421
|
+
# Create new location
|
|
4422
|
+
parent_ids_json = json.dumps(parent_ids) if parent_ids else None
|
|
4423
|
+
cursor = conn.execute(
|
|
4424
|
+
"""
|
|
4425
|
+
INSERT INTO locations
|
|
4426
|
+
(name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids)
|
|
4427
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
4428
|
+
""",
|
|
4429
|
+
(name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids_json)
|
|
4430
|
+
)
|
|
4431
|
+
location_id = cursor.lastrowid
|
|
4432
|
+
assert location_id is not None
|
|
4433
|
+
conn.commit()
|
|
4434
|
+
|
|
4435
|
+
self._location_cache[name_normalized] = location_id
|
|
4436
|
+
if source_identifier:
|
|
4437
|
+
self._location_cache[source_identifier.lower()] = location_id
|
|
4438
|
+
return location_id
|
|
4439
|
+
|
|
4440
|
+
def get_or_create_by_qid(
|
|
4441
|
+
self,
|
|
4442
|
+
name: str,
|
|
4443
|
+
wikidata_type_qid: int,
|
|
4444
|
+
source_id: int = 4,
|
|
4445
|
+
entity_qid: Optional[int] = None,
|
|
4446
|
+
source_identifier: Optional[str] = None,
|
|
4447
|
+
parent_ids: Optional[list[int]] = None,
|
|
4448
|
+
) -> int:
|
|
4449
|
+
"""
|
|
4450
|
+
Get or create a location using Wikidata P31 type QID.
|
|
4451
|
+
|
|
4452
|
+
Args:
|
|
4453
|
+
name: Location name
|
|
4454
|
+
wikidata_type_qid: Wikidata instance-of QID (e.g., 515 for city)
|
|
4455
|
+
source_id: FK to source_types table
|
|
4456
|
+
entity_qid: Wikidata QID of the entity itself
|
|
4457
|
+
source_identifier: Optional source-specific identifier
|
|
4458
|
+
parent_ids: Optional list of parent location IDs
|
|
4459
|
+
|
|
4460
|
+
Returns:
|
|
4461
|
+
Location ID
|
|
4462
|
+
"""
|
|
4463
|
+
location_type_id = self.get_location_type_id_from_qid(wikidata_type_qid)
|
|
4464
|
+
return self.get_or_create(
|
|
4465
|
+
name=name,
|
|
4466
|
+
location_type_id=location_type_id,
|
|
4467
|
+
source_id=source_id,
|
|
4468
|
+
qid=entity_qid,
|
|
4469
|
+
source_identifier=source_identifier,
|
|
4470
|
+
parent_ids=parent_ids,
|
|
4471
|
+
)
|
|
4472
|
+
|
|
4473
|
+
def get_by_id(self, location_id: int) -> Optional[LocationRecord]:
|
|
4474
|
+
"""Get a location record by ID."""
|
|
4475
|
+
conn = self._connect()
|
|
4476
|
+
|
|
4477
|
+
cursor = conn.execute(
|
|
4478
|
+
"""
|
|
4479
|
+
SELECT id, qid, name, source_id, source_identifier, location_type_id,
|
|
4480
|
+
parent_ids, from_date, to_date, record
|
|
4481
|
+
FROM locations WHERE id = ?
|
|
4482
|
+
""",
|
|
4483
|
+
(location_id,)
|
|
4484
|
+
)
|
|
4485
|
+
row = cursor.fetchone()
|
|
4486
|
+
if row:
|
|
4487
|
+
source_name = SOURCE_ID_TO_NAME.get(row["source_id"], "wikidata")
|
|
4488
|
+
location_type_id = row["location_type_id"]
|
|
4489
|
+
location_type_name = self._get_location_type_name(location_type_id)
|
|
4490
|
+
simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
|
|
4491
|
+
simplified_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
|
|
4492
|
+
|
|
4493
|
+
parent_ids = json.loads(row["parent_ids"]) if row["parent_ids"] else []
|
|
4494
|
+
|
|
4495
|
+
return LocationRecord(
|
|
4496
|
+
name=row["name"],
|
|
4497
|
+
source=source_name,
|
|
4498
|
+
source_id=row["source_identifier"],
|
|
4499
|
+
qid=row["qid"],
|
|
4500
|
+
location_type=location_type_name,
|
|
4501
|
+
simplified_type=SimplifiedLocationType(simplified_name),
|
|
4502
|
+
parent_ids=parent_ids,
|
|
4503
|
+
from_date=row["from_date"],
|
|
4504
|
+
to_date=row["to_date"],
|
|
4505
|
+
record=json.loads(row["record"]) if row["record"] else {},
|
|
4506
|
+
)
|
|
4507
|
+
return None
|
|
4508
|
+
|
|
4509
|
+
def _get_location_type_name(self, type_id: int) -> str:
|
|
4510
|
+
"""Get location type name from ID."""
|
|
4511
|
+
# Reverse lookup in cache
|
|
4512
|
+
for name, id_ in self._location_type_cache.items():
|
|
4513
|
+
if id_ == type_id:
|
|
4514
|
+
return name
|
|
4515
|
+
return "other"
|
|
4516
|
+
|
|
4517
|
+
def get_location_type_id(self, type_name: str) -> int:
|
|
4518
|
+
"""Get location_type_id for a type name."""
|
|
4519
|
+
return self._location_type_cache.get(type_name, 36) # default to "other"
|
|
4520
|
+
|
|
4521
|
+
def get_location_type_id_from_qid(self, wikidata_qid: int) -> int:
|
|
4522
|
+
"""Get location_type_id from Wikidata P31 QID."""
|
|
4523
|
+
return self._location_type_qid_cache.get(wikidata_qid, 36) # default to "other"
|
|
4524
|
+
|
|
4525
|
+
def get_simplified_type(self, location_type_id: int) -> str:
|
|
4526
|
+
"""Get simplified type name for a location_type_id."""
|
|
4527
|
+
simplified_id = LOCATION_TYPE_TO_SIMPLIFIED.get(location_type_id, 7)
|
|
4528
|
+
return SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(simplified_id, "other")
|
|
4529
|
+
|
|
4530
|
+
def resolve_region_text(self, text: str) -> Optional[int]:
|
|
4531
|
+
"""
|
|
4532
|
+
Resolve a region/country text to a location ID.
|
|
4533
|
+
|
|
4534
|
+
Uses pycountry for country resolution, then falls back to search.
|
|
4535
|
+
|
|
4536
|
+
Args:
|
|
4537
|
+
text: Region text (country code, name, or QID)
|
|
4538
|
+
|
|
4539
|
+
Returns:
|
|
4540
|
+
Location ID or None if not resolved
|
|
4541
|
+
"""
|
|
4542
|
+
if not text:
|
|
4543
|
+
return None
|
|
4544
|
+
|
|
4545
|
+
text_lower = text.lower().strip()
|
|
4546
|
+
|
|
4547
|
+
# Check cache first
|
|
4548
|
+
if text_lower in self._location_cache:
|
|
4549
|
+
return self._location_cache[text_lower]
|
|
4550
|
+
|
|
4551
|
+
# Try pycountry resolution
|
|
4552
|
+
alpha_2 = self._resolve_via_pycountry(text)
|
|
4553
|
+
if alpha_2:
|
|
4554
|
+
alpha_2_lower = alpha_2.lower()
|
|
4555
|
+
if alpha_2_lower in self._location_cache:
|
|
4556
|
+
location_id = self._location_cache[alpha_2_lower]
|
|
4557
|
+
self._location_cache[text_lower] = location_id # Cache the input too
|
|
4558
|
+
return location_id
|
|
4559
|
+
|
|
4560
|
+
# Country not in database yet, import it
|
|
4561
|
+
try:
|
|
4562
|
+
country = pycountry.countries.get(alpha_2=alpha_2)
|
|
4563
|
+
if country:
|
|
4564
|
+
country_type_id = self._location_type_cache.get("country", 2)
|
|
4565
|
+
location_id = self.get_or_create(
|
|
4566
|
+
name=country.name,
|
|
4567
|
+
location_type_id=country_type_id,
|
|
4568
|
+
source_id=4, # wikidata
|
|
4569
|
+
source_identifier=alpha_2,
|
|
4570
|
+
)
|
|
4571
|
+
self._location_cache[text_lower] = location_id
|
|
4572
|
+
return location_id
|
|
4573
|
+
except Exception:
|
|
4574
|
+
pass
|
|
4575
|
+
|
|
4576
|
+
return None
|
|
4577
|
+
|
|
4578
|
+
def _resolve_via_pycountry(self, region: str) -> Optional[str]:
|
|
4579
|
+
"""Try to resolve region via pycountry."""
|
|
4580
|
+
region_clean = region.strip()
|
|
4581
|
+
if not region_clean:
|
|
4582
|
+
return None
|
|
4583
|
+
|
|
4584
|
+
# Try as 2-letter code
|
|
4585
|
+
if len(region_clean) == 2:
|
|
4586
|
+
country = pycountry.countries.get(alpha_2=region_clean.upper())
|
|
4587
|
+
if country:
|
|
4588
|
+
return country.alpha_2
|
|
4589
|
+
|
|
4590
|
+
# Try as 3-letter code
|
|
4591
|
+
if len(region_clean) == 3:
|
|
4592
|
+
country = pycountry.countries.get(alpha_3=region_clean.upper())
|
|
4593
|
+
if country:
|
|
4594
|
+
return country.alpha_2
|
|
4595
|
+
|
|
4596
|
+
# Try fuzzy search
|
|
4597
|
+
try:
|
|
4598
|
+
matches = pycountry.countries.search_fuzzy(region_clean)
|
|
4599
|
+
if matches:
|
|
4600
|
+
return matches[0].alpha_2
|
|
4601
|
+
except LookupError:
|
|
4602
|
+
pass
|
|
4603
|
+
|
|
4604
|
+
return None
|
|
4605
|
+
|
|
4606
|
+
def import_from_pycountry(self) -> int:
|
|
4607
|
+
"""
|
|
4608
|
+
Import all countries from pycountry.
|
|
4609
|
+
|
|
4610
|
+
Returns:
|
|
4611
|
+
Number of locations imported
|
|
4612
|
+
"""
|
|
4613
|
+
conn = self._connect()
|
|
4614
|
+
country_type_id = self._location_type_cache.get("country", 2)
|
|
4615
|
+
count = 0
|
|
4616
|
+
|
|
4617
|
+
for country in pycountry.countries:
|
|
4618
|
+
name = country.name
|
|
4619
|
+
alpha_2 = country.alpha_2
|
|
4620
|
+
name_normalized = name.lower()
|
|
4621
|
+
|
|
4622
|
+
# Check if already exists
|
|
4623
|
+
if alpha_2.lower() in self._location_cache:
|
|
4624
|
+
continue
|
|
4625
|
+
|
|
4626
|
+
cursor = conn.execute(
|
|
4627
|
+
"""
|
|
4628
|
+
INSERT OR IGNORE INTO locations
|
|
4629
|
+
(name, name_normalized, source_id, source_identifier, location_type_id)
|
|
4630
|
+
VALUES (?, ?, 4, ?, ?)
|
|
4631
|
+
""",
|
|
4632
|
+
(name, name_normalized, alpha_2, country_type_id)
|
|
4633
|
+
)
|
|
4634
|
+
|
|
4635
|
+
if cursor.lastrowid:
|
|
4636
|
+
self._location_cache[name_normalized] = cursor.lastrowid
|
|
4637
|
+
self._location_cache[alpha_2.lower()] = cursor.lastrowid
|
|
4638
|
+
count += 1
|
|
4639
|
+
|
|
4640
|
+
conn.commit()
|
|
4641
|
+
logger.info(f"Imported {count} countries from pycountry")
|
|
4642
|
+
return count
|
|
4643
|
+
|
|
4644
|
+
def search(
|
|
4645
|
+
self,
|
|
4646
|
+
query: str,
|
|
4647
|
+
top_k: int = 10,
|
|
4648
|
+
simplified_type: Optional[str] = None,
|
|
4649
|
+
) -> list[tuple[int, str, float]]:
|
|
4650
|
+
"""
|
|
4651
|
+
Search for locations by name.
|
|
4652
|
+
|
|
4653
|
+
Args:
|
|
4654
|
+
query: Search query
|
|
4655
|
+
top_k: Maximum results to return
|
|
4656
|
+
simplified_type: Optional filter by simplified type (e.g., "country", "city")
|
|
4657
|
+
|
|
4658
|
+
Returns:
|
|
4659
|
+
List of (location_id, location_name, score) tuples
|
|
4660
|
+
"""
|
|
4661
|
+
conn = self._connect()
|
|
4662
|
+
query_normalized = query.lower().strip()
|
|
4663
|
+
|
|
4664
|
+
# Build query with optional type filter
|
|
4665
|
+
if simplified_type:
|
|
4666
|
+
# Get all location_type_ids for this simplified type
|
|
4667
|
+
simplified_id = {
|
|
4668
|
+
name: id_ for id_, name in SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.items()
|
|
4669
|
+
}.get(simplified_type)
|
|
4670
|
+
if simplified_id:
|
|
4671
|
+
type_ids = [
|
|
4672
|
+
type_id for type_id, simp_id in LOCATION_TYPE_TO_SIMPLIFIED.items()
|
|
4673
|
+
if simp_id == simplified_id
|
|
4674
|
+
]
|
|
4675
|
+
if type_ids:
|
|
4676
|
+
placeholders = ",".join("?" * len(type_ids))
|
|
4677
|
+
cursor = conn.execute(
|
|
4678
|
+
f"""
|
|
4679
|
+
SELECT id, name FROM locations
|
|
4680
|
+
WHERE name_normalized LIKE ? AND location_type_id IN ({placeholders})
|
|
4681
|
+
ORDER BY length(name)
|
|
4682
|
+
LIMIT ?
|
|
4683
|
+
""",
|
|
4684
|
+
[f"%{query_normalized}%"] + type_ids + [top_k]
|
|
4685
|
+
)
|
|
4686
|
+
else:
|
|
4687
|
+
return []
|
|
4688
|
+
else:
|
|
4689
|
+
return []
|
|
4690
|
+
else:
|
|
4691
|
+
cursor = conn.execute(
|
|
4692
|
+
"""
|
|
4693
|
+
SELECT id, name FROM locations
|
|
4694
|
+
WHERE name_normalized LIKE ?
|
|
4695
|
+
ORDER BY length(name)
|
|
4696
|
+
LIMIT ?
|
|
4697
|
+
""",
|
|
4698
|
+
(f"%{query_normalized}%", top_k)
|
|
4699
|
+
)
|
|
4700
|
+
|
|
4701
|
+
results = []
|
|
4702
|
+
for row in cursor:
|
|
4703
|
+
name_normalized = row["name"].lower()
|
|
4704
|
+
if query_normalized == name_normalized:
|
|
4705
|
+
score = 1.0
|
|
4706
|
+
elif name_normalized.startswith(query_normalized):
|
|
4707
|
+
score = 0.9
|
|
4708
|
+
else:
|
|
4709
|
+
score = 0.7
|
|
4710
|
+
results.append((row["id"], row["name"], score))
|
|
4711
|
+
|
|
4712
|
+
return results
|
|
4713
|
+
|
|
4714
|
+
def get_stats(self) -> dict[str, Any]:
|
|
4715
|
+
"""Get statistics about the locations table."""
|
|
4716
|
+
conn = self._connect()
|
|
4717
|
+
|
|
4718
|
+
cursor = conn.execute("SELECT COUNT(*) FROM locations")
|
|
4719
|
+
total = cursor.fetchone()[0]
|
|
4720
|
+
|
|
4721
|
+
cursor = conn.execute("SELECT COUNT(*) FROM locations WHERE canon_id IS NOT NULL")
|
|
4722
|
+
canonicalized = cursor.fetchone()[0]
|
|
4723
|
+
|
|
4724
|
+
cursor = conn.execute("SELECT COUNT(DISTINCT canon_id) FROM locations WHERE canon_id IS NOT NULL")
|
|
4725
|
+
groups = cursor.fetchone()[0]
|
|
4726
|
+
|
|
4727
|
+
# Count by simplified type
|
|
4728
|
+
by_type: dict[str, int] = {}
|
|
4729
|
+
cursor = conn.execute("""
|
|
4730
|
+
SELECT lt.simplified_id, COUNT(*) as cnt
|
|
4731
|
+
FROM locations l
|
|
4732
|
+
JOIN location_types lt ON l.location_type_id = lt.id
|
|
4733
|
+
GROUP BY lt.simplified_id
|
|
4734
|
+
""")
|
|
4735
|
+
for row in cursor:
|
|
4736
|
+
type_name = SIMPLIFIED_LOCATION_TYPE_ID_TO_NAME.get(row["simplified_id"], "other")
|
|
4737
|
+
by_type[type_name] = row["cnt"]
|
|
4738
|
+
|
|
4739
|
+
return {
|
|
4740
|
+
"total_locations": total,
|
|
4741
|
+
"canonicalized": canonicalized,
|
|
4742
|
+
"canonical_groups": groups,
|
|
4743
|
+
"by_type": by_type,
|
|
4744
|
+
}
|
|
4745
|
+
|
|
4746
|
+
def insert_batch(self, records: list[LocationRecord]) -> int:
|
|
4747
|
+
"""
|
|
4748
|
+
Insert a batch of location records.
|
|
4749
|
+
|
|
4750
|
+
Args:
|
|
4751
|
+
records: List of LocationRecord objects to insert
|
|
4752
|
+
|
|
4753
|
+
Returns:
|
|
4754
|
+
Number of records inserted
|
|
4755
|
+
"""
|
|
4756
|
+
if not records:
|
|
4757
|
+
return 0
|
|
4758
|
+
|
|
4759
|
+
conn = self._connect()
|
|
4760
|
+
inserted = 0
|
|
4761
|
+
|
|
4762
|
+
for record in records:
|
|
4763
|
+
name_normalized = record.name.lower().strip()
|
|
4764
|
+
source_identifier = record.source_id # Q code in source_id field
|
|
4765
|
+
|
|
4766
|
+
# Check cache first
|
|
4767
|
+
cache_key = source_identifier.lower() if source_identifier else name_normalized
|
|
4768
|
+
if cache_key in self._location_cache:
|
|
4769
|
+
continue
|
|
4770
|
+
|
|
4771
|
+
# Get location_type_id from type name
|
|
4772
|
+
location_type_id = self._location_type_cache.get(record.location_type, 36) # default "other"
|
|
4773
|
+
source_id = SOURCE_NAME_TO_ID.get(record.source, 4) # default wikidata
|
|
4774
|
+
|
|
4775
|
+
parent_ids_json = json.dumps(record.parent_ids) if record.parent_ids else None
|
|
4776
|
+
record_json = json.dumps(record.record) if record.record else "{}"
|
|
4777
|
+
|
|
4778
|
+
try:
|
|
4779
|
+
cursor = conn.execute(
|
|
4780
|
+
"""
|
|
4781
|
+
INSERT OR IGNORE INTO locations
|
|
4782
|
+
(name, name_normalized, source_id, source_identifier, qid, location_type_id, parent_ids, record, from_date, to_date)
|
|
4783
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
4784
|
+
""",
|
|
4785
|
+
(
|
|
4786
|
+
record.name,
|
|
4787
|
+
name_normalized,
|
|
4788
|
+
source_id,
|
|
4789
|
+
source_identifier,
|
|
4790
|
+
record.qid,
|
|
4791
|
+
location_type_id,
|
|
4792
|
+
parent_ids_json,
|
|
4793
|
+
record_json,
|
|
4794
|
+
record.from_date,
|
|
4795
|
+
record.to_date,
|
|
4796
|
+
)
|
|
4797
|
+
)
|
|
4798
|
+
if cursor.lastrowid:
|
|
4799
|
+
self._location_cache[name_normalized] = cursor.lastrowid
|
|
4800
|
+
if source_identifier:
|
|
4801
|
+
self._location_cache[source_identifier.lower()] = cursor.lastrowid
|
|
4802
|
+
inserted += 1
|
|
4803
|
+
except Exception as e:
|
|
4804
|
+
logger.warning(f"Failed to insert location {record.name}: {e}")
|
|
4805
|
+
|
|
4806
|
+
conn.commit()
|
|
4807
|
+
return inserted
|
|
4808
|
+
|
|
4809
|
+
def get_all_source_ids(self, source: str = "wikidata") -> set[str]:
|
|
4810
|
+
"""
|
|
4811
|
+
Get all source_identifiers for a given source.
|
|
4812
|
+
|
|
4813
|
+
Args:
|
|
4814
|
+
source: Source name (e.g., "wikidata")
|
|
4815
|
+
|
|
4816
|
+
Returns:
|
|
4817
|
+
Set of source_identifiers
|
|
4818
|
+
"""
|
|
4819
|
+
conn = self._connect()
|
|
4820
|
+
source_id = SOURCE_NAME_TO_ID.get(source, 4)
|
|
4821
|
+
cursor = conn.execute(
|
|
4822
|
+
"SELECT source_identifier FROM locations WHERE source_id = ? AND source_identifier IS NOT NULL",
|
|
4823
|
+
(source_id,)
|
|
4824
|
+
)
|
|
4825
|
+
return {row["source_identifier"] for row in cursor}
|