corp-extractor 0.5.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +191 -24
- corp_extractor-0.9.0.dist-info/RECORD +76 -0
- statement_extractor/__init__.py +1 -1
- statement_extractor/cli.py +1227 -10
- statement_extractor/data/statement_taxonomy.json +6949 -1159
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +520 -0
- statement_extractor/database/importers/__init__.py +24 -0
- statement_extractor/database/importers/companies_house.py +545 -0
- statement_extractor/database/importers/gleif.py +538 -0
- statement_extractor/database/importers/sec_edgar.py +375 -0
- statement_extractor/database/importers/wikidata.py +1012 -0
- statement_extractor/database/importers/wikidata_people.py +632 -0
- statement_extractor/database/models.py +230 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +1609 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +173 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/models/__init__.py +16 -1
- statement_extractor/models/canonical.py +44 -1
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/labels.py +47 -18
- statement_extractor/models/qualifiers.py +51 -3
- statement_extractor/models/statement.py +26 -0
- statement_extractor/pipeline/config.py +6 -11
- statement_extractor/pipeline/orchestrator.py +80 -111
- statement_extractor/pipeline/registry.py +52 -46
- statement_extractor/plugins/__init__.py +20 -8
- statement_extractor/plugins/base.py +334 -64
- statement_extractor/plugins/extractors/gliner2.py +10 -0
- statement_extractor/plugins/labelers/taxonomy.py +18 -5
- statement_extractor/plugins/labelers/taxonomy_embedding.py +17 -6
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +11 -0
- statement_extractor/plugins/qualifiers/companies_house.py +14 -3
- statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
- statement_extractor/plugins/qualifiers/gleif.py +14 -3
- statement_extractor/plugins/qualifiers/person.py +578 -14
- statement_extractor/plugins/qualifiers/sec_edgar.py +14 -3
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/t5_gemma.py +158 -53
- statement_extractor/plugins/taxonomy/embedding.py +193 -46
- statement_extractor/plugins/taxonomy/mnli.py +16 -4
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.5.0.dist-info/RECORD +0 -55
- statement_extractor/plugins/canonicalizers/__init__.py +0 -17
- statement_extractor/plugins/canonicalizers/base.py +0 -9
- statement_extractor/plugins/canonicalizers/location.py +0 -219
- statement_extractor/plugins/canonicalizers/organization.py +0 -230
- statement_extractor/plugins/canonicalizers/person.py +0 -242
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1609 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Entity/Organization database with sqlite-vec for vector search.
|
|
3
|
+
|
|
4
|
+
Uses a hybrid approach:
|
|
5
|
+
1. Text-based filtering to narrow candidates (Levenshtein-like)
|
|
6
|
+
2. sqlite-vec vector search for semantic ranking
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
import sqlite3
|
|
13
|
+
import time
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Iterator, Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import sqlite_vec
|
|
19
|
+
|
|
20
|
+
from .models import CompanyRecord, DatabaseStats, EntityType, PersonRecord, PersonType
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Default database location
|
|
25
|
+
DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
|
|
26
|
+
|
|
27
|
+
# Module-level singleton for OrganizationDatabase to prevent multiple loads
|
|
28
|
+
_database_instances: dict[str, "OrganizationDatabase"] = {}
|
|
29
|
+
|
|
30
|
+
# Module-level singleton for PersonDatabase
|
|
31
|
+
_person_database_instances: dict[str, "PersonDatabase"] = {}
|
|
32
|
+
|
|
33
|
+
# Comprehensive set of corporate legal suffixes (international)
|
|
34
|
+
COMPANY_SUFFIXES: set[str] = {
|
|
35
|
+
'A/S', 'AB', 'AG', 'AO', 'AG & Co', 'AG &', 'AG & CO.', 'AG & CO. KG', 'AG & CO. KGaA',
|
|
36
|
+
'AG & KG', 'AG & KGaA', 'AG & PARTNER', 'ATE', 'ASA', 'B.V.', 'BV', 'Class A', 'Class B',
|
|
37
|
+
'Class C', 'Class D', 'Class E', 'Class F', 'Class G', 'CO', 'Co', 'Co.', 'Company',
|
|
38
|
+
'Corp', 'Corp.', 'Corporation', 'DAC', 'GmbH', 'Inc', 'Inc.', 'Incorporated', 'KGaA',
|
|
39
|
+
'Limited', 'LLC', 'LLP', 'LP', 'Ltd', 'Ltd.', 'N.V.', 'NV', 'Plc', 'PC', 'plc', 'PLC',
|
|
40
|
+
'Pty Ltd', 'Pty', 'Pty. Ltd.', 'S.A.', 'S.A.B. de C.V.', 'SAB de CV', 'S.A.B.', 'S.A.P.I.',
|
|
41
|
+
'NV/SA', 'SDI', 'SpA', 'S.L.', 'S.p.A.', 'SA', 'SE', 'Tbk PT', 'U.A.',
|
|
42
|
+
# Additional common suffixes
|
|
43
|
+
'Group', 'Holdings', 'Holding', 'Partners', 'Trust', 'Fund', 'Bank', 'N.A.', 'The',
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
# Pre-compile the suffix pattern for performance
|
|
47
|
+
_SUFFIX_PATTERN = re.compile(
|
|
48
|
+
r'\s+(' + '|'.join(re.escape(suffix) for suffix in COMPANY_SUFFIXES) + r')\.?$',
|
|
49
|
+
re.IGNORECASE
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _clean_org_name(name: str | None) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Remove special characters and formatting from organization name.
|
|
56
|
+
|
|
57
|
+
Removes brackets, parentheses, quotes, and other formatting artifacts.
|
|
58
|
+
"""
|
|
59
|
+
if not name:
|
|
60
|
+
return ""
|
|
61
|
+
# Remove special characters, keeping only alphanumeric and spaces
|
|
62
|
+
cleaned = re.sub(r'[•;:\'"\[\](){}<>`~!@#$%^&*\-_=+\\|/?!`~]+', ' ', name)
|
|
63
|
+
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
|
64
|
+
# Recurse if changes were made (handles nested special chars)
|
|
65
|
+
return _clean_org_name(cleaned) if cleaned != name else cleaned
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _remove_suffix(name: str) -> str:
|
|
69
|
+
"""
|
|
70
|
+
Remove corporate legal suffixes from company name.
|
|
71
|
+
|
|
72
|
+
Iteratively removes suffixes until no more are found.
|
|
73
|
+
Also removes possessive 's and trailing punctuation.
|
|
74
|
+
"""
|
|
75
|
+
cleaned = name.strip()
|
|
76
|
+
cleaned = re.sub(r'\s+', ' ', cleaned)
|
|
77
|
+
# Remove possessive 's (e.g., "Amazon's" -> "Amazon")
|
|
78
|
+
cleaned = re.sub(r"'s\b", "", cleaned)
|
|
79
|
+
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
|
80
|
+
|
|
81
|
+
while True:
|
|
82
|
+
new_name = _SUFFIX_PATTERN.sub('', cleaned)
|
|
83
|
+
# Remove trailing punctuation
|
|
84
|
+
new_name = re.sub(r'[ .,;&\n\t/)]$', '', new_name)
|
|
85
|
+
|
|
86
|
+
if new_name == cleaned:
|
|
87
|
+
break
|
|
88
|
+
cleaned = new_name.strip()
|
|
89
|
+
|
|
90
|
+
return cleaned.strip()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _normalize_name(name: str) -> str:
|
|
94
|
+
"""
|
|
95
|
+
Normalize company name for text matching.
|
|
96
|
+
|
|
97
|
+
1. Remove possessive 's (before cleaning removes apostrophe)
|
|
98
|
+
2. Clean special characters
|
|
99
|
+
3. Remove legal suffixes
|
|
100
|
+
4. Lowercase
|
|
101
|
+
5. If result is empty, use cleaned lowercase original
|
|
102
|
+
|
|
103
|
+
Always returns a non-empty string for valid input.
|
|
104
|
+
"""
|
|
105
|
+
if not name:
|
|
106
|
+
return ""
|
|
107
|
+
# Remove possessive 's first (before cleaning removes the apostrophe)
|
|
108
|
+
normalized = re.sub(r"'s\b", "", name)
|
|
109
|
+
# Clean special characters
|
|
110
|
+
cleaned = _clean_org_name(normalized)
|
|
111
|
+
# Remove legal suffixes
|
|
112
|
+
normalized = _remove_suffix(cleaned)
|
|
113
|
+
# Lowercase for matching
|
|
114
|
+
normalized = normalized.lower()
|
|
115
|
+
# If normalized is empty (e.g., name was just "Ltd"), use the cleaned name
|
|
116
|
+
if not normalized:
|
|
117
|
+
normalized = cleaned.lower() if cleaned else name.lower()
|
|
118
|
+
return normalized
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _extract_search_terms(query: str) -> list[str]:
|
|
122
|
+
"""
|
|
123
|
+
Extract search terms from a query for SQL LIKE matching.
|
|
124
|
+
|
|
125
|
+
Returns list of terms to search for, ordered by length (longest first).
|
|
126
|
+
"""
|
|
127
|
+
# Split into words
|
|
128
|
+
words = query.split()
|
|
129
|
+
|
|
130
|
+
# Filter out very short words (< 3 chars) unless it's the only word
|
|
131
|
+
if len(words) > 1:
|
|
132
|
+
words = [w for w in words if len(w) >= 3]
|
|
133
|
+
|
|
134
|
+
# Sort by length descending (longer words are more specific)
|
|
135
|
+
words.sort(key=len, reverse=True)
|
|
136
|
+
|
|
137
|
+
return words[:3] # Limit to top 3 terms
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# Person name normalization patterns
|
|
141
|
+
_PERSON_PREFIXES = {
|
|
142
|
+
"dr.", "dr", "prof.", "prof", "professor",
|
|
143
|
+
"mr.", "mr", "mrs.", "mrs", "ms.", "ms", "miss",
|
|
144
|
+
"sir", "dame", "lord", "lady",
|
|
145
|
+
"rev.", "rev", "reverend",
|
|
146
|
+
"hon.", "hon", "honorable",
|
|
147
|
+
"gen.", "gen", "general",
|
|
148
|
+
"col.", "col", "colonel",
|
|
149
|
+
"capt.", "capt", "captain",
|
|
150
|
+
"lt.", "lt", "lieutenant",
|
|
151
|
+
"sgt.", "sgt", "sergeant",
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
_PERSON_SUFFIXES = {
|
|
155
|
+
"jr.", "jr", "junior",
|
|
156
|
+
"sr.", "sr", "senior",
|
|
157
|
+
"ii", "iii", "iv", "v",
|
|
158
|
+
"2nd", "3rd", "4th", "5th",
|
|
159
|
+
"phd", "ph.d.", "ph.d",
|
|
160
|
+
"md", "m.d.", "m.d",
|
|
161
|
+
"esq", "esq.",
|
|
162
|
+
"mba", "m.b.a.",
|
|
163
|
+
"cpa", "c.p.a.",
|
|
164
|
+
"jd", "j.d.",
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _normalize_person_name(name: str) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Normalize person name for text matching.
|
|
171
|
+
|
|
172
|
+
1. Remove honorific prefixes (Dr., Prof., Mr., etc.)
|
|
173
|
+
2. Remove generational suffixes (Jr., Sr., III, PhD, etc.)
|
|
174
|
+
3. Keep name particles (von, van, de, al-, etc.)
|
|
175
|
+
4. Lowercase and strip
|
|
176
|
+
|
|
177
|
+
Always returns a non-empty string for valid input.
|
|
178
|
+
"""
|
|
179
|
+
if not name:
|
|
180
|
+
return ""
|
|
181
|
+
|
|
182
|
+
# Lowercase for matching
|
|
183
|
+
normalized = name.lower().strip()
|
|
184
|
+
|
|
185
|
+
# Split into words
|
|
186
|
+
words = normalized.split()
|
|
187
|
+
if not words:
|
|
188
|
+
return ""
|
|
189
|
+
|
|
190
|
+
# Remove prefix if first word is a title
|
|
191
|
+
while words and words[0].rstrip(".") in _PERSON_PREFIXES:
|
|
192
|
+
words.pop(0)
|
|
193
|
+
if not words:
|
|
194
|
+
return name.lower().strip() # Fallback if name was just a title
|
|
195
|
+
|
|
196
|
+
# Remove suffix if last word is a suffix
|
|
197
|
+
while words and words[-1].rstrip(".") in _PERSON_SUFFIXES:
|
|
198
|
+
words.pop()
|
|
199
|
+
if not words:
|
|
200
|
+
return name.lower().strip() # Fallback if name was just suffixes
|
|
201
|
+
|
|
202
|
+
# Rejoin remaining words
|
|
203
|
+
normalized = " ".join(words)
|
|
204
|
+
|
|
205
|
+
# Clean up extra spaces
|
|
206
|
+
normalized = re.sub(r'\s+', ' ', normalized).strip()
|
|
207
|
+
|
|
208
|
+
return normalized if normalized else name.lower().strip()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def get_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "OrganizationDatabase":
|
|
212
|
+
"""
|
|
213
|
+
Get a singleton OrganizationDatabase instance for the given path.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
db_path: Path to database file
|
|
217
|
+
embedding_dim: Dimension of embeddings
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Shared OrganizationDatabase instance
|
|
221
|
+
"""
|
|
222
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
223
|
+
if path_key not in _database_instances:
|
|
224
|
+
logger.debug(f"Creating new OrganizationDatabase instance for {path_key}")
|
|
225
|
+
_database_instances[path_key] = OrganizationDatabase(db_path=db_path, embedding_dim=embedding_dim)
|
|
226
|
+
return _database_instances[path_key]
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class OrganizationDatabase:
|
|
230
|
+
"""
|
|
231
|
+
SQLite database with sqlite-vec for organization vector search.
|
|
232
|
+
|
|
233
|
+
Uses hybrid text + vector search:
|
|
234
|
+
1. Text filtering with Levenshtein distance to reduce candidates
|
|
235
|
+
2. sqlite-vec for semantic similarity ranking
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
db_path: Optional[str | Path] = None,
|
|
241
|
+
embedding_dim: int = 768, # Default for embeddinggemma-300m
|
|
242
|
+
):
|
|
243
|
+
"""
|
|
244
|
+
Initialize the organization database.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
db_path: Path to database file (creates if not exists)
|
|
248
|
+
embedding_dim: Dimension of embeddings to store
|
|
249
|
+
"""
|
|
250
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
251
|
+
self._embedding_dim = embedding_dim
|
|
252
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
253
|
+
|
|
254
|
+
def _ensure_dir(self) -> None:
|
|
255
|
+
"""Ensure database directory exists."""
|
|
256
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
257
|
+
|
|
258
|
+
def _connect(self) -> sqlite3.Connection:
|
|
259
|
+
"""Get or create database connection with sqlite-vec loaded."""
|
|
260
|
+
if self._conn is not None:
|
|
261
|
+
return self._conn
|
|
262
|
+
|
|
263
|
+
self._ensure_dir()
|
|
264
|
+
self._conn = sqlite3.connect(str(self._db_path))
|
|
265
|
+
self._conn.row_factory = sqlite3.Row
|
|
266
|
+
|
|
267
|
+
# Load sqlite-vec extension
|
|
268
|
+
self._conn.enable_load_extension(True)
|
|
269
|
+
sqlite_vec.load(self._conn)
|
|
270
|
+
self._conn.enable_load_extension(False)
|
|
271
|
+
|
|
272
|
+
# Create tables
|
|
273
|
+
self._create_tables()
|
|
274
|
+
|
|
275
|
+
return self._conn
|
|
276
|
+
|
|
277
|
+
def _create_tables(self) -> None:
|
|
278
|
+
"""Create database tables including sqlite-vec virtual table."""
|
|
279
|
+
conn = self._conn
|
|
280
|
+
assert conn is not None
|
|
281
|
+
|
|
282
|
+
# Main organization records table
|
|
283
|
+
conn.execute("""
|
|
284
|
+
CREATE TABLE IF NOT EXISTS organizations (
|
|
285
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
286
|
+
name TEXT NOT NULL,
|
|
287
|
+
name_normalized TEXT NOT NULL,
|
|
288
|
+
source TEXT NOT NULL,
|
|
289
|
+
source_id TEXT NOT NULL,
|
|
290
|
+
region TEXT NOT NULL DEFAULT '',
|
|
291
|
+
entity_type TEXT NOT NULL DEFAULT 'unknown',
|
|
292
|
+
record TEXT NOT NULL,
|
|
293
|
+
UNIQUE(source, source_id)
|
|
294
|
+
)
|
|
295
|
+
""")
|
|
296
|
+
|
|
297
|
+
# Add region column if it doesn't exist (migration for existing DBs)
|
|
298
|
+
try:
|
|
299
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN region TEXT NOT NULL DEFAULT ''")
|
|
300
|
+
logger.info("Added region column to organizations table")
|
|
301
|
+
except sqlite3.OperationalError:
|
|
302
|
+
pass # Column already exists
|
|
303
|
+
|
|
304
|
+
# Add entity_type column if it doesn't exist (migration for existing DBs)
|
|
305
|
+
try:
|
|
306
|
+
conn.execute("ALTER TABLE organizations ADD COLUMN entity_type TEXT NOT NULL DEFAULT 'unknown'")
|
|
307
|
+
logger.info("Added entity_type column to organizations table")
|
|
308
|
+
except sqlite3.OperationalError:
|
|
309
|
+
pass # Column already exists
|
|
310
|
+
|
|
311
|
+
# Create indexes on main table
|
|
312
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
|
|
313
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
|
|
314
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source ON organizations(source)")
|
|
315
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source_id ON organizations(source, source_id)")
|
|
316
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
|
|
317
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
|
|
318
|
+
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
|
|
319
|
+
|
|
320
|
+
# Create sqlite-vec virtual table for embeddings
|
|
321
|
+
# vec0 is the recommended virtual table type
|
|
322
|
+
conn.execute(f"""
|
|
323
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings USING vec0(
|
|
324
|
+
org_id INTEGER PRIMARY KEY,
|
|
325
|
+
embedding float[{self._embedding_dim}]
|
|
326
|
+
)
|
|
327
|
+
""")
|
|
328
|
+
|
|
329
|
+
conn.commit()
|
|
330
|
+
|
|
331
|
+
def close(self) -> None:
|
|
332
|
+
"""Close database connection."""
|
|
333
|
+
if self._conn:
|
|
334
|
+
self._conn.close()
|
|
335
|
+
self._conn = None
|
|
336
|
+
|
|
337
|
+
def insert(self, record: CompanyRecord, embedding: np.ndarray) -> int:
|
|
338
|
+
"""
|
|
339
|
+
Insert an organization record with its embedding.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
record: Organization record to insert
|
|
343
|
+
embedding: Embedding vector for the organization name
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
Row ID of inserted record
|
|
347
|
+
"""
|
|
348
|
+
conn = self._connect()
|
|
349
|
+
|
|
350
|
+
# Serialize record
|
|
351
|
+
record_json = json.dumps(record.record)
|
|
352
|
+
name_normalized = _normalize_name(record.name)
|
|
353
|
+
|
|
354
|
+
cursor = conn.execute("""
|
|
355
|
+
INSERT OR REPLACE INTO organizations
|
|
356
|
+
(name, name_normalized, source, source_id, region, entity_type, record)
|
|
357
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
358
|
+
""", (
|
|
359
|
+
record.name,
|
|
360
|
+
name_normalized,
|
|
361
|
+
record.source,
|
|
362
|
+
record.source_id,
|
|
363
|
+
record.region,
|
|
364
|
+
record.entity_type.value,
|
|
365
|
+
record_json,
|
|
366
|
+
))
|
|
367
|
+
|
|
368
|
+
row_id = cursor.lastrowid
|
|
369
|
+
assert row_id is not None
|
|
370
|
+
|
|
371
|
+
# Insert embedding into vec table
|
|
372
|
+
# sqlite-vec expects the embedding as a blob
|
|
373
|
+
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
374
|
+
conn.execute("""
|
|
375
|
+
INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
|
|
376
|
+
VALUES (?, ?)
|
|
377
|
+
""", (row_id, embedding_blob))
|
|
378
|
+
|
|
379
|
+
conn.commit()
|
|
380
|
+
return row_id
|
|
381
|
+
|
|
382
|
+
def insert_batch(
|
|
383
|
+
self,
|
|
384
|
+
records: list[CompanyRecord],
|
|
385
|
+
embeddings: np.ndarray,
|
|
386
|
+
batch_size: int = 1000,
|
|
387
|
+
) -> int:
|
|
388
|
+
"""
|
|
389
|
+
Insert multiple organization records with embeddings.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
records: List of organization records
|
|
393
|
+
embeddings: Matrix of embeddings (N x dim)
|
|
394
|
+
batch_size: Commit batch size
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
Number of records inserted
|
|
398
|
+
"""
|
|
399
|
+
conn = self._connect()
|
|
400
|
+
count = 0
|
|
401
|
+
|
|
402
|
+
for record, embedding in zip(records, embeddings):
|
|
403
|
+
record_json = json.dumps(record.record)
|
|
404
|
+
name_normalized = _normalize_name(record.name)
|
|
405
|
+
|
|
406
|
+
cursor = conn.execute("""
|
|
407
|
+
INSERT OR REPLACE INTO organizations
|
|
408
|
+
(name, name_normalized, source, source_id, region, entity_type, record)
|
|
409
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
410
|
+
""", (
|
|
411
|
+
record.name,
|
|
412
|
+
name_normalized,
|
|
413
|
+
record.source,
|
|
414
|
+
record.source_id,
|
|
415
|
+
record.region,
|
|
416
|
+
record.entity_type.value,
|
|
417
|
+
record_json,
|
|
418
|
+
))
|
|
419
|
+
|
|
420
|
+
row_id = cursor.lastrowid
|
|
421
|
+
assert row_id is not None
|
|
422
|
+
|
|
423
|
+
# Insert embedding
|
|
424
|
+
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
425
|
+
conn.execute("""
|
|
426
|
+
INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
|
|
427
|
+
VALUES (?, ?)
|
|
428
|
+
""", (row_id, embedding_blob))
|
|
429
|
+
|
|
430
|
+
count += 1
|
|
431
|
+
|
|
432
|
+
if count % batch_size == 0:
|
|
433
|
+
conn.commit()
|
|
434
|
+
logger.info(f"Inserted {count} records...")
|
|
435
|
+
|
|
436
|
+
conn.commit()
|
|
437
|
+
return count
|
|
438
|
+
|
|
439
|
+
def search(
|
|
440
|
+
self,
|
|
441
|
+
query_embedding: np.ndarray,
|
|
442
|
+
top_k: int = 20,
|
|
443
|
+
source_filter: Optional[str] = None,
|
|
444
|
+
query_text: Optional[str] = None,
|
|
445
|
+
max_text_candidates: int = 5000,
|
|
446
|
+
) -> list[tuple[CompanyRecord, float]]:
|
|
447
|
+
"""
|
|
448
|
+
Search for similar organizations using hybrid text + vector search.
|
|
449
|
+
|
|
450
|
+
Two-stage approach:
|
|
451
|
+
1. If query_text provided, use SQL LIKE to find candidates containing search terms
|
|
452
|
+
2. Use sqlite-vec for vector similarity ranking on filtered candidates
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
query_embedding: Query embedding vector
|
|
456
|
+
top_k: Number of results to return
|
|
457
|
+
source_filter: Optional filter by source (gleif, sec_edgar, etc.)
|
|
458
|
+
query_text: Optional query text for text-based pre-filtering
|
|
459
|
+
max_text_candidates: Max candidates to keep after text filtering
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
List of (CompanyRecord, similarity_score) tuples
|
|
463
|
+
"""
|
|
464
|
+
start = time.time()
|
|
465
|
+
self._connect()
|
|
466
|
+
|
|
467
|
+
# Normalize query embedding
|
|
468
|
+
query_norm = np.linalg.norm(query_embedding)
|
|
469
|
+
if query_norm == 0:
|
|
470
|
+
return []
|
|
471
|
+
query_normalized = query_embedding / query_norm
|
|
472
|
+
query_blob = query_normalized.astype(np.float32).tobytes()
|
|
473
|
+
|
|
474
|
+
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
475
|
+
candidate_ids: Optional[set[int]] = None
|
|
476
|
+
if query_text:
|
|
477
|
+
query_normalized_text = _normalize_name(query_text)
|
|
478
|
+
if query_normalized_text:
|
|
479
|
+
candidate_ids = self._text_filter_candidates(
|
|
480
|
+
query_normalized_text,
|
|
481
|
+
max_candidates=max_text_candidates,
|
|
482
|
+
source_filter=source_filter,
|
|
483
|
+
)
|
|
484
|
+
logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
|
|
485
|
+
|
|
486
|
+
# Stage 2: Vector search
|
|
487
|
+
if candidate_ids is not None and len(candidate_ids) == 0:
|
|
488
|
+
# No text matches, return empty
|
|
489
|
+
return []
|
|
490
|
+
|
|
491
|
+
if candidate_ids is not None:
|
|
492
|
+
# Search within text-filtered candidates
|
|
493
|
+
results = self._vector_search_filtered(
|
|
494
|
+
query_blob, candidate_ids, top_k, source_filter
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
# Full vector search
|
|
498
|
+
results = self._vector_search_full(query_blob, top_k, source_filter)
|
|
499
|
+
|
|
500
|
+
elapsed = time.time() - start
|
|
501
|
+
logger.debug(f"Hybrid search took {elapsed:.3f}s (results={len(results)})")
|
|
502
|
+
return results
|
|
503
|
+
|
|
504
|
+
def _text_filter_candidates(
|
|
505
|
+
self,
|
|
506
|
+
query_normalized: str,
|
|
507
|
+
max_candidates: int,
|
|
508
|
+
source_filter: Optional[str] = None,
|
|
509
|
+
) -> set[int]:
|
|
510
|
+
"""
|
|
511
|
+
Filter candidates using SQL LIKE for fast text matching.
|
|
512
|
+
|
|
513
|
+
This is a generous pre-filter to reduce the embedding search space.
|
|
514
|
+
Returns set of organization IDs that contain any search term.
|
|
515
|
+
Uses `name_normalized` column for consistent matching.
|
|
516
|
+
"""
|
|
517
|
+
conn = self._conn
|
|
518
|
+
assert conn is not None
|
|
519
|
+
|
|
520
|
+
# Extract search terms from the normalized query
|
|
521
|
+
search_terms = _extract_search_terms(query_normalized)
|
|
522
|
+
if not search_terms:
|
|
523
|
+
return set()
|
|
524
|
+
|
|
525
|
+
logger.debug(f"Text filter search terms: {search_terms}")
|
|
526
|
+
|
|
527
|
+
# Build OR clause for LIKE matching on any term
|
|
528
|
+
# Use name_normalized for consistent matching (already lowercased, suffixes removed)
|
|
529
|
+
like_clauses = []
|
|
530
|
+
params: list = []
|
|
531
|
+
for term in search_terms:
|
|
532
|
+
like_clauses.append("name_normalized LIKE ?")
|
|
533
|
+
params.append(f"%{term}%")
|
|
534
|
+
|
|
535
|
+
where_clause = " OR ".join(like_clauses)
|
|
536
|
+
|
|
537
|
+
# Add source filter if specified
|
|
538
|
+
if source_filter:
|
|
539
|
+
query = f"""
|
|
540
|
+
SELECT id FROM organizations
|
|
541
|
+
WHERE ({where_clause}) AND source = ?
|
|
542
|
+
LIMIT ?
|
|
543
|
+
"""
|
|
544
|
+
params.append(source_filter)
|
|
545
|
+
else:
|
|
546
|
+
query = f"""
|
|
547
|
+
SELECT id FROM organizations
|
|
548
|
+
WHERE {where_clause}
|
|
549
|
+
LIMIT ?
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
params.append(max_candidates)
|
|
553
|
+
|
|
554
|
+
cursor = conn.execute(query, params)
|
|
555
|
+
return set(row["id"] for row in cursor)
|
|
556
|
+
|
|
557
|
+
def _vector_search_filtered(
|
|
558
|
+
self,
|
|
559
|
+
query_blob: bytes,
|
|
560
|
+
candidate_ids: set[int],
|
|
561
|
+
top_k: int,
|
|
562
|
+
source_filter: Optional[str],
|
|
563
|
+
) -> list[tuple[CompanyRecord, float]]:
|
|
564
|
+
"""Vector search within a filtered set of candidates."""
|
|
565
|
+
conn = self._conn
|
|
566
|
+
assert conn is not None
|
|
567
|
+
|
|
568
|
+
if not candidate_ids:
|
|
569
|
+
return []
|
|
570
|
+
|
|
571
|
+
# Build IN clause for candidate IDs
|
|
572
|
+
placeholders = ",".join("?" * len(candidate_ids))
|
|
573
|
+
|
|
574
|
+
# Query sqlite-vec with KNN search, filtered by candidate IDs
|
|
575
|
+
# Using distance function - lower is more similar for L2
|
|
576
|
+
# We'll use cosine distance
|
|
577
|
+
query = f"""
|
|
578
|
+
SELECT
|
|
579
|
+
e.org_id,
|
|
580
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
581
|
+
FROM organization_embeddings e
|
|
582
|
+
WHERE e.org_id IN ({placeholders})
|
|
583
|
+
ORDER BY distance
|
|
584
|
+
LIMIT ?
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
|
|
588
|
+
|
|
589
|
+
results = []
|
|
590
|
+
for row in cursor:
|
|
591
|
+
org_id = row["org_id"]
|
|
592
|
+
distance = row["distance"]
|
|
593
|
+
# Convert cosine distance to similarity (1 - distance)
|
|
594
|
+
similarity = 1.0 - distance
|
|
595
|
+
|
|
596
|
+
# Fetch full record
|
|
597
|
+
record = self._get_record_by_id(org_id)
|
|
598
|
+
if record:
|
|
599
|
+
# Apply source filter if specified
|
|
600
|
+
if source_filter and record.source != source_filter:
|
|
601
|
+
continue
|
|
602
|
+
results.append((record, similarity))
|
|
603
|
+
|
|
604
|
+
return results
|
|
605
|
+
|
|
606
|
+
def _vector_search_full(
|
|
607
|
+
self,
|
|
608
|
+
query_blob: bytes,
|
|
609
|
+
top_k: int,
|
|
610
|
+
source_filter: Optional[str],
|
|
611
|
+
) -> list[tuple[CompanyRecord, float]]:
|
|
612
|
+
"""Full vector search without text pre-filtering."""
|
|
613
|
+
conn = self._conn
|
|
614
|
+
assert conn is not None
|
|
615
|
+
|
|
616
|
+
# KNN search with sqlite-vec
|
|
617
|
+
if source_filter:
|
|
618
|
+
# Need to join with organizations table for source filter
|
|
619
|
+
query = """
|
|
620
|
+
SELECT
|
|
621
|
+
e.org_id,
|
|
622
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
623
|
+
FROM organization_embeddings e
|
|
624
|
+
JOIN organizations c ON e.org_id = c.id
|
|
625
|
+
WHERE c.source = ?
|
|
626
|
+
ORDER BY distance
|
|
627
|
+
LIMIT ?
|
|
628
|
+
"""
|
|
629
|
+
cursor = conn.execute(query, (query_blob, source_filter, top_k))
|
|
630
|
+
else:
|
|
631
|
+
query = """
|
|
632
|
+
SELECT
|
|
633
|
+
org_id,
|
|
634
|
+
vec_distance_cosine(embedding, ?) as distance
|
|
635
|
+
FROM organization_embeddings
|
|
636
|
+
ORDER BY distance
|
|
637
|
+
LIMIT ?
|
|
638
|
+
"""
|
|
639
|
+
cursor = conn.execute(query, (query_blob, top_k))
|
|
640
|
+
|
|
641
|
+
results = []
|
|
642
|
+
for row in cursor:
|
|
643
|
+
org_id = row["org_id"]
|
|
644
|
+
distance = row["distance"]
|
|
645
|
+
similarity = 1.0 - distance
|
|
646
|
+
|
|
647
|
+
record = self._get_record_by_id(org_id)
|
|
648
|
+
if record:
|
|
649
|
+
results.append((record, similarity))
|
|
650
|
+
|
|
651
|
+
return results
|
|
652
|
+
|
|
653
|
+
def _get_record_by_id(self, org_id: int) -> Optional[CompanyRecord]:
|
|
654
|
+
"""Get an organization record by ID."""
|
|
655
|
+
conn = self._conn
|
|
656
|
+
assert conn is not None
|
|
657
|
+
|
|
658
|
+
cursor = conn.execute("""
|
|
659
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
660
|
+
FROM organizations WHERE id = ?
|
|
661
|
+
""", (org_id,))
|
|
662
|
+
|
|
663
|
+
row = cursor.fetchone()
|
|
664
|
+
if row:
|
|
665
|
+
return CompanyRecord(
|
|
666
|
+
name=row["name"],
|
|
667
|
+
source=row["source"],
|
|
668
|
+
source_id=row["source_id"],
|
|
669
|
+
region=row["region"] or "",
|
|
670
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
671
|
+
record=json.loads(row["record"]),
|
|
672
|
+
)
|
|
673
|
+
return None
|
|
674
|
+
|
|
675
|
+
def get_by_source_id(self, source: str, source_id: str) -> Optional[CompanyRecord]:
|
|
676
|
+
"""Get an organization record by source and source_id."""
|
|
677
|
+
conn = self._connect()
|
|
678
|
+
|
|
679
|
+
cursor = conn.execute("""
|
|
680
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
681
|
+
FROM organizations
|
|
682
|
+
WHERE source = ? AND source_id = ?
|
|
683
|
+
""", (source, source_id))
|
|
684
|
+
|
|
685
|
+
row = cursor.fetchone()
|
|
686
|
+
if row:
|
|
687
|
+
return CompanyRecord(
|
|
688
|
+
name=row["name"],
|
|
689
|
+
source=row["source"],
|
|
690
|
+
source_id=row["source_id"],
|
|
691
|
+
region=row["region"] or "",
|
|
692
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
693
|
+
record=json.loads(row["record"]),
|
|
694
|
+
)
|
|
695
|
+
return None
|
|
696
|
+
|
|
697
|
+
def get_stats(self) -> DatabaseStats:
|
|
698
|
+
"""Get database statistics."""
|
|
699
|
+
conn = self._connect()
|
|
700
|
+
|
|
701
|
+
# Total count
|
|
702
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations")
|
|
703
|
+
total = cursor.fetchone()[0]
|
|
704
|
+
|
|
705
|
+
# Count by source
|
|
706
|
+
cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
|
|
707
|
+
by_source = {row["source"]: row["cnt"] for row in cursor}
|
|
708
|
+
|
|
709
|
+
# Database file size
|
|
710
|
+
db_size = self._db_path.stat().st_size if self._db_path.exists() else 0
|
|
711
|
+
|
|
712
|
+
return DatabaseStats(
|
|
713
|
+
total_records=total,
|
|
714
|
+
by_source=by_source,
|
|
715
|
+
embedding_dimension=self._embedding_dim,
|
|
716
|
+
database_size_bytes=db_size,
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
def iter_records(self, source: Optional[str] = None) -> Iterator[CompanyRecord]:
|
|
720
|
+
"""Iterate over all records, optionally filtered by source."""
|
|
721
|
+
conn = self._connect()
|
|
722
|
+
|
|
723
|
+
if source:
|
|
724
|
+
cursor = conn.execute("""
|
|
725
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
726
|
+
FROM organizations
|
|
727
|
+
WHERE source = ?
|
|
728
|
+
""", (source,))
|
|
729
|
+
else:
|
|
730
|
+
cursor = conn.execute("""
|
|
731
|
+
SELECT name, source, source_id, region, entity_type, record
|
|
732
|
+
FROM organizations
|
|
733
|
+
""")
|
|
734
|
+
|
|
735
|
+
for row in cursor:
|
|
736
|
+
yield CompanyRecord(
|
|
737
|
+
name=row["name"],
|
|
738
|
+
source=row["source"],
|
|
739
|
+
source_id=row["source_id"],
|
|
740
|
+
region=row["region"] or "",
|
|
741
|
+
entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
|
|
742
|
+
record=json.loads(row["record"]),
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
def migrate_name_normalized(self, batch_size: int = 50000) -> int:
|
|
746
|
+
"""
|
|
747
|
+
Populate the name_normalized column for all records.
|
|
748
|
+
|
|
749
|
+
This is a one-time migration for databases that don't have
|
|
750
|
+
normalized names populated.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
batch_size: Number of records to process per batch
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
Number of records updated
|
|
757
|
+
"""
|
|
758
|
+
conn = self._connect()
|
|
759
|
+
|
|
760
|
+
# Check how many need migration (empty, null, or placeholder "-")
|
|
761
|
+
cursor = conn.execute(
|
|
762
|
+
"SELECT COUNT(*) FROM organizations WHERE name_normalized = '' OR name_normalized IS NULL OR name_normalized = '-'"
|
|
763
|
+
)
|
|
764
|
+
empty_count = cursor.fetchone()[0]
|
|
765
|
+
|
|
766
|
+
if empty_count == 0:
|
|
767
|
+
logger.info("All records already have name_normalized populated")
|
|
768
|
+
return 0
|
|
769
|
+
|
|
770
|
+
logger.info(f"Populating name_normalized for {empty_count} records...")
|
|
771
|
+
|
|
772
|
+
updated = 0
|
|
773
|
+
last_id = 0
|
|
774
|
+
|
|
775
|
+
while True:
|
|
776
|
+
# Get batch of records that need normalization, ordered by ID
|
|
777
|
+
cursor = conn.execute("""
|
|
778
|
+
SELECT id, name FROM organizations
|
|
779
|
+
WHERE id > ? AND (name_normalized = '' OR name_normalized IS NULL OR name_normalized = '-')
|
|
780
|
+
ORDER BY id
|
|
781
|
+
LIMIT ?
|
|
782
|
+
""", (last_id, batch_size))
|
|
783
|
+
|
|
784
|
+
rows = cursor.fetchall()
|
|
785
|
+
if not rows:
|
|
786
|
+
break
|
|
787
|
+
|
|
788
|
+
# Update each record
|
|
789
|
+
for row in rows:
|
|
790
|
+
# _normalize_name now always returns non-empty for valid input
|
|
791
|
+
normalized = _normalize_name(row["name"])
|
|
792
|
+
conn.execute(
|
|
793
|
+
"UPDATE organizations SET name_normalized = ? WHERE id = ?",
|
|
794
|
+
(normalized, row["id"])
|
|
795
|
+
)
|
|
796
|
+
last_id = row["id"]
|
|
797
|
+
|
|
798
|
+
conn.commit()
|
|
799
|
+
updated += len(rows)
|
|
800
|
+
logger.info(f" Updated {updated}/{empty_count} records...")
|
|
801
|
+
|
|
802
|
+
logger.info(f"Migration complete: {updated} name_normalized values populated")
|
|
803
|
+
return updated
|
|
804
|
+
|
|
805
|
+
def migrate_to_sqlite_vec(self, batch_size: int = 10000) -> int:
|
|
806
|
+
"""
|
|
807
|
+
Migrate embeddings from BLOB column to sqlite-vec virtual table.
|
|
808
|
+
|
|
809
|
+
This is a one-time migration for databases created before sqlite-vec support.
|
|
810
|
+
|
|
811
|
+
Args:
|
|
812
|
+
batch_size: Number of records to process per batch
|
|
813
|
+
|
|
814
|
+
Returns:
|
|
815
|
+
Number of embeddings migrated
|
|
816
|
+
"""
|
|
817
|
+
conn = self._connect()
|
|
818
|
+
|
|
819
|
+
# Check if migration is needed
|
|
820
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings")
|
|
821
|
+
vec_count = cursor.fetchone()[0]
|
|
822
|
+
|
|
823
|
+
cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE embedding IS NOT NULL")
|
|
824
|
+
blob_count = cursor.fetchone()[0]
|
|
825
|
+
|
|
826
|
+
if vec_count >= blob_count:
|
|
827
|
+
logger.info(f"Migration not needed: sqlite-vec has {vec_count} embeddings, BLOB has {blob_count}")
|
|
828
|
+
return 0
|
|
829
|
+
|
|
830
|
+
logger.info(f"Migrating {blob_count} embeddings from BLOB to sqlite-vec...")
|
|
831
|
+
|
|
832
|
+
# Get IDs that need migration (in sqlite-vec but not in organizations)
|
|
833
|
+
cursor = conn.execute("""
|
|
834
|
+
SELECT c.id, c.embedding
|
|
835
|
+
FROM organizations c
|
|
836
|
+
LEFT JOIN organization_embeddings e ON c.id = e.org_id
|
|
837
|
+
WHERE c.embedding IS NOT NULL AND e.org_id IS NULL
|
|
838
|
+
""")
|
|
839
|
+
|
|
840
|
+
migrated = 0
|
|
841
|
+
batch = []
|
|
842
|
+
|
|
843
|
+
for row in cursor:
|
|
844
|
+
org_id = row["id"]
|
|
845
|
+
embedding_blob = row["embedding"]
|
|
846
|
+
|
|
847
|
+
if embedding_blob:
|
|
848
|
+
batch.append((org_id, embedding_blob))
|
|
849
|
+
|
|
850
|
+
if len(batch) >= batch_size:
|
|
851
|
+
self._insert_vec_batch(batch)
|
|
852
|
+
migrated += len(batch)
|
|
853
|
+
logger.info(f" Migrated {migrated}/{blob_count} embeddings...")
|
|
854
|
+
batch = []
|
|
855
|
+
|
|
856
|
+
# Insert remaining batch
|
|
857
|
+
if batch:
|
|
858
|
+
self._insert_vec_batch(batch)
|
|
859
|
+
migrated += len(batch)
|
|
860
|
+
|
|
861
|
+
logger.info(f"Migration complete: {migrated} embeddings migrated to sqlite-vec")
|
|
862
|
+
return migrated
|
|
863
|
+
|
|
864
|
+
def _insert_vec_batch(self, batch: list[tuple[int, bytes]]) -> None:
|
|
865
|
+
"""Insert a batch of embeddings into sqlite-vec table."""
|
|
866
|
+
conn = self._conn
|
|
867
|
+
assert conn is not None
|
|
868
|
+
|
|
869
|
+
for org_id, embedding_blob in batch:
|
|
870
|
+
conn.execute("""
|
|
871
|
+
INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
|
|
872
|
+
VALUES (?, ?)
|
|
873
|
+
""", (org_id, embedding_blob))
|
|
874
|
+
|
|
875
|
+
conn.commit()
|
|
876
|
+
|
|
877
|
+
def delete_source(self, source: str) -> int:
|
|
878
|
+
"""Delete all records from a specific source."""
|
|
879
|
+
conn = self._connect()
|
|
880
|
+
|
|
881
|
+
# First get IDs to delete from vec table
|
|
882
|
+
cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
|
|
883
|
+
ids_to_delete = [row["id"] for row in cursor]
|
|
884
|
+
|
|
885
|
+
# Delete from vec table
|
|
886
|
+
if ids_to_delete:
|
|
887
|
+
placeholders = ",".join("?" * len(ids_to_delete))
|
|
888
|
+
conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
|
|
889
|
+
|
|
890
|
+
# Delete from main table
|
|
891
|
+
cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
|
|
892
|
+
deleted = cursor.rowcount
|
|
893
|
+
|
|
894
|
+
conn.commit()
|
|
895
|
+
|
|
896
|
+
logger.info(f"Deleted {deleted} records from source '{source}'")
|
|
897
|
+
return deleted
|
|
898
|
+
|
|
899
|
+
def migrate_from_legacy_schema(self) -> dict[str, str]:
|
|
900
|
+
"""
|
|
901
|
+
Migrate database from legacy schema (companies/company_embeddings tables)
|
|
902
|
+
to new schema (organizations/organization_embeddings tables).
|
|
903
|
+
|
|
904
|
+
This handles:
|
|
905
|
+
- Renaming 'companies' table to 'organizations'
|
|
906
|
+
- Renaming 'company_embeddings' table to 'organization_embeddings'
|
|
907
|
+
- Renaming 'company_id' column to 'org_id' in embeddings table
|
|
908
|
+
- Updating indexes to use new naming
|
|
909
|
+
|
|
910
|
+
Returns:
|
|
911
|
+
Dict of migrations performed (table_name -> action)
|
|
912
|
+
"""
|
|
913
|
+
conn = self._connect()
|
|
914
|
+
migrations = {}
|
|
915
|
+
|
|
916
|
+
# Check what tables exist
|
|
917
|
+
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
918
|
+
existing_tables = {row[0] for row in cursor}
|
|
919
|
+
|
|
920
|
+
has_companies = "companies" in existing_tables
|
|
921
|
+
has_organizations = "organizations" in existing_tables
|
|
922
|
+
has_company_embeddings = "company_embeddings" in existing_tables
|
|
923
|
+
has_org_embeddings = "organization_embeddings" in existing_tables
|
|
924
|
+
|
|
925
|
+
if not has_companies and not has_company_embeddings:
|
|
926
|
+
if has_organizations and has_org_embeddings:
|
|
927
|
+
logger.info("Database already uses new schema, no migration needed")
|
|
928
|
+
return {}
|
|
929
|
+
else:
|
|
930
|
+
logger.info("No legacy tables found, database will use new schema")
|
|
931
|
+
return {}
|
|
932
|
+
|
|
933
|
+
logger.info("Migrating database from legacy schema...")
|
|
934
|
+
conn.execute("BEGIN")
|
|
935
|
+
|
|
936
|
+
try:
|
|
937
|
+
# Migrate companies -> organizations
|
|
938
|
+
if has_companies:
|
|
939
|
+
if has_organizations:
|
|
940
|
+
# Both exist - merge data from companies into organizations
|
|
941
|
+
logger.info("Merging companies table into organizations...")
|
|
942
|
+
conn.execute("""
|
|
943
|
+
INSERT OR IGNORE INTO organizations
|
|
944
|
+
(name, name_normalized, source, source_id, region, entity_type, record)
|
|
945
|
+
SELECT name, name_normalized, source, source_id,
|
|
946
|
+
COALESCE(region, ''), COALESCE(entity_type, 'unknown'), record
|
|
947
|
+
FROM companies
|
|
948
|
+
""")
|
|
949
|
+
conn.execute("DROP TABLE companies")
|
|
950
|
+
migrations["companies"] = "merged_into_organizations"
|
|
951
|
+
else:
|
|
952
|
+
# Just rename
|
|
953
|
+
logger.info("Renaming companies table to organizations...")
|
|
954
|
+
conn.execute("ALTER TABLE companies RENAME TO organizations")
|
|
955
|
+
migrations["companies"] = "renamed_to_organizations"
|
|
956
|
+
|
|
957
|
+
# Update indexes
|
|
958
|
+
for old_idx in ["idx_companies_name", "idx_companies_name_normalized",
|
|
959
|
+
"idx_companies_source", "idx_companies_source_id",
|
|
960
|
+
"idx_companies_region", "idx_companies_entity_type",
|
|
961
|
+
"idx_companies_name_region_source"]:
|
|
962
|
+
try:
|
|
963
|
+
conn.execute(f"DROP INDEX IF EXISTS {old_idx}")
|
|
964
|
+
except Exception:
|
|
965
|
+
pass
|
|
966
|
+
|
|
967
|
+
# Create new indexes
|
|
968
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
|
|
969
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
|
|
970
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source ON organizations(source)")
|
|
971
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source_id ON organizations(source, source_id)")
|
|
972
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
|
|
973
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
|
|
974
|
+
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
|
|
975
|
+
|
|
976
|
+
# Migrate company_embeddings -> organization_embeddings
|
|
977
|
+
if has_company_embeddings:
|
|
978
|
+
if has_org_embeddings:
|
|
979
|
+
# Both exist - merge
|
|
980
|
+
logger.info("Merging company_embeddings into organization_embeddings...")
|
|
981
|
+
# Get column info to check for company_id vs org_id
|
|
982
|
+
cursor = conn.execute("PRAGMA table_info(company_embeddings)")
|
|
983
|
+
cols = {row[1] for row in cursor}
|
|
984
|
+
id_col = "company_id" if "company_id" in cols else "org_id"
|
|
985
|
+
|
|
986
|
+
conn.execute(f"""
|
|
987
|
+
INSERT OR IGNORE INTO organization_embeddings (org_id, embedding)
|
|
988
|
+
SELECT {id_col}, embedding FROM company_embeddings
|
|
989
|
+
""")
|
|
990
|
+
conn.execute("DROP TABLE company_embeddings")
|
|
991
|
+
migrations["company_embeddings"] = "merged_into_organization_embeddings"
|
|
992
|
+
else:
|
|
993
|
+
# Need to recreate with new column name
|
|
994
|
+
logger.info("Migrating company_embeddings to organization_embeddings...")
|
|
995
|
+
|
|
996
|
+
# Check if it has company_id or org_id column
|
|
997
|
+
cursor = conn.execute("PRAGMA table_info(company_embeddings)")
|
|
998
|
+
cols = {row[1] for row in cursor}
|
|
999
|
+
id_col = "company_id" if "company_id" in cols else "org_id"
|
|
1000
|
+
|
|
1001
|
+
# Create new virtual table
|
|
1002
|
+
conn.execute(f"""
|
|
1003
|
+
CREATE VIRTUAL TABLE organization_embeddings USING vec0(
|
|
1004
|
+
org_id INTEGER PRIMARY KEY,
|
|
1005
|
+
embedding float[{self._embedding_dim}]
|
|
1006
|
+
)
|
|
1007
|
+
""")
|
|
1008
|
+
|
|
1009
|
+
# Copy data
|
|
1010
|
+
conn.execute(f"""
|
|
1011
|
+
INSERT INTO organization_embeddings (org_id, embedding)
|
|
1012
|
+
SELECT {id_col}, embedding FROM company_embeddings
|
|
1013
|
+
""")
|
|
1014
|
+
|
|
1015
|
+
# Drop old table
|
|
1016
|
+
conn.execute("DROP TABLE company_embeddings")
|
|
1017
|
+
migrations["company_embeddings"] = "renamed_to_organization_embeddings"
|
|
1018
|
+
|
|
1019
|
+
conn.execute("COMMIT")
|
|
1020
|
+
logger.info(f"Migration complete: {migrations}")
|
|
1021
|
+
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
conn.execute("ROLLBACK")
|
|
1024
|
+
logger.error(f"Migration failed: {e}")
|
|
1025
|
+
raise
|
|
1026
|
+
|
|
1027
|
+
# Vacuum to clean up - outside try block since COMMIT already succeeded
|
|
1028
|
+
try:
|
|
1029
|
+
conn.execute("VACUUM")
|
|
1030
|
+
except Exception as e:
|
|
1031
|
+
logger.warning(f"VACUUM failed (migration was successful): {e}")
|
|
1032
|
+
|
|
1033
|
+
return migrations
|
|
1034
|
+
|
|
1035
|
+
def get_missing_embedding_count(self) -> int:
|
|
1036
|
+
"""Get count of organizations without embeddings in organization_embeddings table."""
|
|
1037
|
+
conn = self._connect()
|
|
1038
|
+
|
|
1039
|
+
cursor = conn.execute("""
|
|
1040
|
+
SELECT COUNT(*) FROM organizations c
|
|
1041
|
+
LEFT JOIN organization_embeddings e ON c.id = e.org_id
|
|
1042
|
+
WHERE e.org_id IS NULL
|
|
1043
|
+
""")
|
|
1044
|
+
return cursor.fetchone()[0]
|
|
1045
|
+
|
|
1046
|
+
def get_organizations_without_embeddings(
|
|
1047
|
+
self,
|
|
1048
|
+
batch_size: int = 1000,
|
|
1049
|
+
source: Optional[str] = None,
|
|
1050
|
+
) -> Iterator[tuple[int, str]]:
|
|
1051
|
+
"""
|
|
1052
|
+
Iterate over organizations that don't have embeddings.
|
|
1053
|
+
|
|
1054
|
+
Args:
|
|
1055
|
+
batch_size: Number of records per batch
|
|
1056
|
+
source: Optional source filter
|
|
1057
|
+
|
|
1058
|
+
Yields:
|
|
1059
|
+
Tuples of (org_id, name)
|
|
1060
|
+
"""
|
|
1061
|
+
conn = self._connect()
|
|
1062
|
+
|
|
1063
|
+
last_id = 0
|
|
1064
|
+
while True:
|
|
1065
|
+
if source:
|
|
1066
|
+
cursor = conn.execute("""
|
|
1067
|
+
SELECT c.id, c.name FROM organizations c
|
|
1068
|
+
LEFT JOIN organization_embeddings e ON c.id = e.org_id
|
|
1069
|
+
WHERE e.org_id IS NULL AND c.id > ? AND c.source = ?
|
|
1070
|
+
ORDER BY c.id
|
|
1071
|
+
LIMIT ?
|
|
1072
|
+
""", (last_id, source, batch_size))
|
|
1073
|
+
else:
|
|
1074
|
+
cursor = conn.execute("""
|
|
1075
|
+
SELECT c.id, c.name FROM organizations c
|
|
1076
|
+
LEFT JOIN organization_embeddings e ON c.id = e.org_id
|
|
1077
|
+
WHERE e.org_id IS NULL AND c.id > ?
|
|
1078
|
+
ORDER BY c.id
|
|
1079
|
+
LIMIT ?
|
|
1080
|
+
""", (last_id, batch_size))
|
|
1081
|
+
|
|
1082
|
+
rows = cursor.fetchall()
|
|
1083
|
+
if not rows:
|
|
1084
|
+
break
|
|
1085
|
+
|
|
1086
|
+
for row in rows:
|
|
1087
|
+
yield (row[0], row[1])
|
|
1088
|
+
last_id = row[0]
|
|
1089
|
+
|
|
1090
|
+
def insert_embeddings_batch(
|
|
1091
|
+
self,
|
|
1092
|
+
org_ids: list[int],
|
|
1093
|
+
embeddings: np.ndarray,
|
|
1094
|
+
) -> int:
|
|
1095
|
+
"""
|
|
1096
|
+
Insert embeddings for existing organizations.
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
org_ids: List of organization IDs
|
|
1100
|
+
embeddings: Matrix of embeddings (N x dim)
|
|
1101
|
+
|
|
1102
|
+
Returns:
|
|
1103
|
+
Number of embeddings inserted
|
|
1104
|
+
"""
|
|
1105
|
+
conn = self._connect()
|
|
1106
|
+
count = 0
|
|
1107
|
+
|
|
1108
|
+
for org_id, embedding in zip(org_ids, embeddings):
|
|
1109
|
+
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
1110
|
+
conn.execute("""
|
|
1111
|
+
INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
|
|
1112
|
+
VALUES (?, ?)
|
|
1113
|
+
""", (org_id, embedding_blob))
|
|
1114
|
+
count += 1
|
|
1115
|
+
|
|
1116
|
+
conn.commit()
|
|
1117
|
+
return count
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
|
|
1121
|
+
"""
|
|
1122
|
+
Get a singleton PersonDatabase instance for the given path.
|
|
1123
|
+
|
|
1124
|
+
Args:
|
|
1125
|
+
db_path: Path to database file
|
|
1126
|
+
embedding_dim: Dimension of embeddings
|
|
1127
|
+
|
|
1128
|
+
Returns:
|
|
1129
|
+
Shared PersonDatabase instance
|
|
1130
|
+
"""
|
|
1131
|
+
path_key = str(db_path or DEFAULT_DB_PATH)
|
|
1132
|
+
if path_key not in _person_database_instances:
|
|
1133
|
+
logger.debug(f"Creating new PersonDatabase instance for {path_key}")
|
|
1134
|
+
_person_database_instances[path_key] = PersonDatabase(db_path=db_path, embedding_dim=embedding_dim)
|
|
1135
|
+
return _person_database_instances[path_key]
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
class PersonDatabase:
|
|
1139
|
+
"""
|
|
1140
|
+
SQLite database with sqlite-vec for person vector search.
|
|
1141
|
+
|
|
1142
|
+
Uses hybrid text + vector search:
|
|
1143
|
+
1. Text filtering with LIKE to reduce candidates
|
|
1144
|
+
2. sqlite-vec for semantic similarity ranking
|
|
1145
|
+
|
|
1146
|
+
Stores people from sources like Wikidata with role/org context.
|
|
1147
|
+
"""
|
|
1148
|
+
|
|
1149
|
+
def __init__(
|
|
1150
|
+
self,
|
|
1151
|
+
db_path: Optional[str | Path] = None,
|
|
1152
|
+
embedding_dim: int = 768, # Default for embeddinggemma-300m
|
|
1153
|
+
):
|
|
1154
|
+
"""
|
|
1155
|
+
Initialize the person database.
|
|
1156
|
+
|
|
1157
|
+
Args:
|
|
1158
|
+
db_path: Path to database file (creates if not exists)
|
|
1159
|
+
embedding_dim: Dimension of embeddings to store
|
|
1160
|
+
"""
|
|
1161
|
+
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
1162
|
+
self._embedding_dim = embedding_dim
|
|
1163
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
1164
|
+
|
|
1165
|
+
def _ensure_dir(self) -> None:
|
|
1166
|
+
"""Ensure database directory exists."""
|
|
1167
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1168
|
+
|
|
1169
|
+
def _connect(self) -> sqlite3.Connection:
|
|
1170
|
+
"""Get or create database connection with sqlite-vec loaded."""
|
|
1171
|
+
if self._conn is not None:
|
|
1172
|
+
return self._conn
|
|
1173
|
+
|
|
1174
|
+
self._ensure_dir()
|
|
1175
|
+
self._conn = sqlite3.connect(str(self._db_path))
|
|
1176
|
+
self._conn.row_factory = sqlite3.Row
|
|
1177
|
+
|
|
1178
|
+
# Load sqlite-vec extension
|
|
1179
|
+
self._conn.enable_load_extension(True)
|
|
1180
|
+
sqlite_vec.load(self._conn)
|
|
1181
|
+
self._conn.enable_load_extension(False)
|
|
1182
|
+
|
|
1183
|
+
# Create tables
|
|
1184
|
+
self._create_tables()
|
|
1185
|
+
|
|
1186
|
+
return self._conn
|
|
1187
|
+
|
|
1188
|
+
def _create_tables(self) -> None:
|
|
1189
|
+
"""Create database tables including sqlite-vec virtual table."""
|
|
1190
|
+
conn = self._conn
|
|
1191
|
+
assert conn is not None
|
|
1192
|
+
|
|
1193
|
+
# Main people records table
|
|
1194
|
+
conn.execute("""
|
|
1195
|
+
CREATE TABLE IF NOT EXISTS people (
|
|
1196
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
1197
|
+
name TEXT NOT NULL,
|
|
1198
|
+
name_normalized TEXT NOT NULL,
|
|
1199
|
+
source TEXT NOT NULL DEFAULT 'wikidata',
|
|
1200
|
+
source_id TEXT NOT NULL,
|
|
1201
|
+
country TEXT NOT NULL DEFAULT '',
|
|
1202
|
+
person_type TEXT NOT NULL DEFAULT 'unknown',
|
|
1203
|
+
known_for_role TEXT NOT NULL DEFAULT '',
|
|
1204
|
+
known_for_org TEXT NOT NULL DEFAULT '',
|
|
1205
|
+
record TEXT NOT NULL,
|
|
1206
|
+
UNIQUE(source, source_id)
|
|
1207
|
+
)
|
|
1208
|
+
""")
|
|
1209
|
+
|
|
1210
|
+
# Create indexes on main table
|
|
1211
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name ON people(name)")
|
|
1212
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name_normalized ON people(name_normalized)")
|
|
1213
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source ON people(source)")
|
|
1214
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id)")
|
|
1215
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org ON people(known_for_org)")
|
|
1216
|
+
|
|
1217
|
+
# Create sqlite-vec virtual table for embeddings
|
|
1218
|
+
conn.execute(f"""
|
|
1219
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
|
|
1220
|
+
person_id INTEGER PRIMARY KEY,
|
|
1221
|
+
embedding float[{self._embedding_dim}]
|
|
1222
|
+
)
|
|
1223
|
+
""")
|
|
1224
|
+
|
|
1225
|
+
conn.commit()
|
|
1226
|
+
|
|
1227
|
+
def close(self) -> None:
|
|
1228
|
+
"""Close database connection."""
|
|
1229
|
+
if self._conn:
|
|
1230
|
+
self._conn.close()
|
|
1231
|
+
self._conn = None
|
|
1232
|
+
|
|
1233
|
+
def insert(self, record: PersonRecord, embedding: np.ndarray) -> int:
|
|
1234
|
+
"""
|
|
1235
|
+
Insert a person record with its embedding.
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
record: Person record to insert
|
|
1239
|
+
embedding: Embedding vector for the person name
|
|
1240
|
+
|
|
1241
|
+
Returns:
|
|
1242
|
+
Row ID of inserted record
|
|
1243
|
+
"""
|
|
1244
|
+
conn = self._connect()
|
|
1245
|
+
|
|
1246
|
+
# Serialize record
|
|
1247
|
+
record_json = json.dumps(record.record)
|
|
1248
|
+
name_normalized = _normalize_person_name(record.name)
|
|
1249
|
+
|
|
1250
|
+
cursor = conn.execute("""
|
|
1251
|
+
INSERT OR REPLACE INTO people
|
|
1252
|
+
(name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
|
|
1253
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
1254
|
+
""", (
|
|
1255
|
+
record.name,
|
|
1256
|
+
name_normalized,
|
|
1257
|
+
record.source,
|
|
1258
|
+
record.source_id,
|
|
1259
|
+
record.country,
|
|
1260
|
+
record.person_type.value,
|
|
1261
|
+
record.known_for_role,
|
|
1262
|
+
record.known_for_org,
|
|
1263
|
+
record_json,
|
|
1264
|
+
))
|
|
1265
|
+
|
|
1266
|
+
row_id = cursor.lastrowid
|
|
1267
|
+
assert row_id is not None
|
|
1268
|
+
|
|
1269
|
+
# Insert embedding into vec table
|
|
1270
|
+
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
1271
|
+
conn.execute("""
|
|
1272
|
+
INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
|
|
1273
|
+
VALUES (?, ?)
|
|
1274
|
+
""", (row_id, embedding_blob))
|
|
1275
|
+
|
|
1276
|
+
conn.commit()
|
|
1277
|
+
return row_id
|
|
1278
|
+
|
|
1279
|
+
def insert_batch(
|
|
1280
|
+
self,
|
|
1281
|
+
records: list[PersonRecord],
|
|
1282
|
+
embeddings: np.ndarray,
|
|
1283
|
+
batch_size: int = 1000,
|
|
1284
|
+
) -> int:
|
|
1285
|
+
"""
|
|
1286
|
+
Insert multiple person records with embeddings.
|
|
1287
|
+
|
|
1288
|
+
Args:
|
|
1289
|
+
records: List of person records
|
|
1290
|
+
embeddings: Matrix of embeddings (N x dim)
|
|
1291
|
+
batch_size: Commit batch size
|
|
1292
|
+
|
|
1293
|
+
Returns:
|
|
1294
|
+
Number of records inserted
|
|
1295
|
+
"""
|
|
1296
|
+
conn = self._connect()
|
|
1297
|
+
count = 0
|
|
1298
|
+
|
|
1299
|
+
for record, embedding in zip(records, embeddings):
|
|
1300
|
+
record_json = json.dumps(record.record)
|
|
1301
|
+
name_normalized = _normalize_person_name(record.name)
|
|
1302
|
+
|
|
1303
|
+
cursor = conn.execute("""
|
|
1304
|
+
INSERT OR REPLACE INTO people
|
|
1305
|
+
(name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
|
|
1306
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
1307
|
+
""", (
|
|
1308
|
+
record.name,
|
|
1309
|
+
name_normalized,
|
|
1310
|
+
record.source,
|
|
1311
|
+
record.source_id,
|
|
1312
|
+
record.country,
|
|
1313
|
+
record.person_type.value,
|
|
1314
|
+
record.known_for_role,
|
|
1315
|
+
record.known_for_org,
|
|
1316
|
+
record_json,
|
|
1317
|
+
))
|
|
1318
|
+
|
|
1319
|
+
row_id = cursor.lastrowid
|
|
1320
|
+
assert row_id is not None
|
|
1321
|
+
|
|
1322
|
+
# Insert embedding
|
|
1323
|
+
embedding_blob = embedding.astype(np.float32).tobytes()
|
|
1324
|
+
conn.execute("""
|
|
1325
|
+
INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
|
|
1326
|
+
VALUES (?, ?)
|
|
1327
|
+
""", (row_id, embedding_blob))
|
|
1328
|
+
|
|
1329
|
+
count += 1
|
|
1330
|
+
|
|
1331
|
+
if count % batch_size == 0:
|
|
1332
|
+
conn.commit()
|
|
1333
|
+
logger.info(f"Inserted {count} person records...")
|
|
1334
|
+
|
|
1335
|
+
conn.commit()
|
|
1336
|
+
return count
|
|
1337
|
+
|
|
1338
|
+
def search(
|
|
1339
|
+
self,
|
|
1340
|
+
query_embedding: np.ndarray,
|
|
1341
|
+
top_k: int = 20,
|
|
1342
|
+
query_text: Optional[str] = None,
|
|
1343
|
+
max_text_candidates: int = 5000,
|
|
1344
|
+
) -> list[tuple[PersonRecord, float]]:
|
|
1345
|
+
"""
|
|
1346
|
+
Search for similar people using hybrid text + vector search.
|
|
1347
|
+
|
|
1348
|
+
Two-stage approach:
|
|
1349
|
+
1. If query_text provided, use SQL LIKE to find candidates containing search terms
|
|
1350
|
+
2. Use sqlite-vec for vector similarity ranking on filtered candidates
|
|
1351
|
+
|
|
1352
|
+
Args:
|
|
1353
|
+
query_embedding: Query embedding vector
|
|
1354
|
+
top_k: Number of results to return
|
|
1355
|
+
query_text: Optional query text for text-based pre-filtering
|
|
1356
|
+
max_text_candidates: Max candidates to keep after text filtering
|
|
1357
|
+
|
|
1358
|
+
Returns:
|
|
1359
|
+
List of (PersonRecord, similarity_score) tuples
|
|
1360
|
+
"""
|
|
1361
|
+
start = time.time()
|
|
1362
|
+
self._connect()
|
|
1363
|
+
|
|
1364
|
+
# Normalize query embedding
|
|
1365
|
+
query_norm = np.linalg.norm(query_embedding)
|
|
1366
|
+
if query_norm == 0:
|
|
1367
|
+
return []
|
|
1368
|
+
query_normalized = query_embedding / query_norm
|
|
1369
|
+
query_blob = query_normalized.astype(np.float32).tobytes()
|
|
1370
|
+
|
|
1371
|
+
# Stage 1: Text-based pre-filtering (if query_text provided)
|
|
1372
|
+
candidate_ids: Optional[set[int]] = None
|
|
1373
|
+
if query_text:
|
|
1374
|
+
query_normalized_text = _normalize_person_name(query_text)
|
|
1375
|
+
if query_normalized_text:
|
|
1376
|
+
candidate_ids = self._text_filter_candidates(
|
|
1377
|
+
query_normalized_text,
|
|
1378
|
+
max_candidates=max_text_candidates,
|
|
1379
|
+
)
|
|
1380
|
+
logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
|
|
1381
|
+
|
|
1382
|
+
# Stage 2: Vector search
|
|
1383
|
+
if candidate_ids is not None and len(candidate_ids) == 0:
|
|
1384
|
+
# No text matches, return empty
|
|
1385
|
+
return []
|
|
1386
|
+
|
|
1387
|
+
if candidate_ids is not None:
|
|
1388
|
+
# Search within text-filtered candidates
|
|
1389
|
+
results = self._vector_search_filtered(
|
|
1390
|
+
query_blob, candidate_ids, top_k
|
|
1391
|
+
)
|
|
1392
|
+
else:
|
|
1393
|
+
# Full vector search
|
|
1394
|
+
results = self._vector_search_full(query_blob, top_k)
|
|
1395
|
+
|
|
1396
|
+
elapsed = time.time() - start
|
|
1397
|
+
logger.debug(f"Person search took {elapsed:.3f}s (results={len(results)})")
|
|
1398
|
+
return results
|
|
1399
|
+
|
|
1400
|
+
def _text_filter_candidates(
|
|
1401
|
+
self,
|
|
1402
|
+
query_normalized: str,
|
|
1403
|
+
max_candidates: int,
|
|
1404
|
+
) -> set[int]:
|
|
1405
|
+
"""
|
|
1406
|
+
Filter candidates using SQL LIKE for fast text matching.
|
|
1407
|
+
|
|
1408
|
+
Uses `name_normalized` column for consistent matching.
|
|
1409
|
+
"""
|
|
1410
|
+
conn = self._conn
|
|
1411
|
+
assert conn is not None
|
|
1412
|
+
|
|
1413
|
+
# Extract search terms from the normalized query
|
|
1414
|
+
search_terms = _extract_search_terms(query_normalized)
|
|
1415
|
+
if not search_terms:
|
|
1416
|
+
return set()
|
|
1417
|
+
|
|
1418
|
+
logger.debug(f"Person text filter search terms: {search_terms}")
|
|
1419
|
+
|
|
1420
|
+
# Build OR clause for LIKE matching on any term
|
|
1421
|
+
like_clauses = []
|
|
1422
|
+
params: list = []
|
|
1423
|
+
for term in search_terms:
|
|
1424
|
+
like_clauses.append("name_normalized LIKE ?")
|
|
1425
|
+
params.append(f"%{term}%")
|
|
1426
|
+
|
|
1427
|
+
where_clause = " OR ".join(like_clauses)
|
|
1428
|
+
|
|
1429
|
+
query = f"""
|
|
1430
|
+
SELECT id FROM people
|
|
1431
|
+
WHERE {where_clause}
|
|
1432
|
+
LIMIT ?
|
|
1433
|
+
"""
|
|
1434
|
+
|
|
1435
|
+
params.append(max_candidates)
|
|
1436
|
+
|
|
1437
|
+
cursor = conn.execute(query, params)
|
|
1438
|
+
return set(row["id"] for row in cursor)
|
|
1439
|
+
|
|
1440
|
+
def _vector_search_filtered(
|
|
1441
|
+
self,
|
|
1442
|
+
query_blob: bytes,
|
|
1443
|
+
candidate_ids: set[int],
|
|
1444
|
+
top_k: int,
|
|
1445
|
+
) -> list[tuple[PersonRecord, float]]:
|
|
1446
|
+
"""Vector search within a filtered set of candidates."""
|
|
1447
|
+
conn = self._conn
|
|
1448
|
+
assert conn is not None
|
|
1449
|
+
|
|
1450
|
+
if not candidate_ids:
|
|
1451
|
+
return []
|
|
1452
|
+
|
|
1453
|
+
# Build IN clause for candidate IDs
|
|
1454
|
+
placeholders = ",".join("?" * len(candidate_ids))
|
|
1455
|
+
|
|
1456
|
+
query = f"""
|
|
1457
|
+
SELECT
|
|
1458
|
+
e.person_id,
|
|
1459
|
+
vec_distance_cosine(e.embedding, ?) as distance
|
|
1460
|
+
FROM person_embeddings e
|
|
1461
|
+
WHERE e.person_id IN ({placeholders})
|
|
1462
|
+
ORDER BY distance
|
|
1463
|
+
LIMIT ?
|
|
1464
|
+
"""
|
|
1465
|
+
|
|
1466
|
+
cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
|
|
1467
|
+
|
|
1468
|
+
results = []
|
|
1469
|
+
for row in cursor:
|
|
1470
|
+
person_id = row["person_id"]
|
|
1471
|
+
distance = row["distance"]
|
|
1472
|
+
# Convert cosine distance to similarity (1 - distance)
|
|
1473
|
+
similarity = 1.0 - distance
|
|
1474
|
+
|
|
1475
|
+
# Fetch full record
|
|
1476
|
+
record = self._get_record_by_id(person_id)
|
|
1477
|
+
if record:
|
|
1478
|
+
results.append((record, similarity))
|
|
1479
|
+
|
|
1480
|
+
return results
|
|
1481
|
+
|
|
1482
|
+
def _vector_search_full(
|
|
1483
|
+
self,
|
|
1484
|
+
query_blob: bytes,
|
|
1485
|
+
top_k: int,
|
|
1486
|
+
) -> list[tuple[PersonRecord, float]]:
|
|
1487
|
+
"""Full vector search without text pre-filtering."""
|
|
1488
|
+
conn = self._conn
|
|
1489
|
+
assert conn is not None
|
|
1490
|
+
|
|
1491
|
+
query = """
|
|
1492
|
+
SELECT
|
|
1493
|
+
person_id,
|
|
1494
|
+
vec_distance_cosine(embedding, ?) as distance
|
|
1495
|
+
FROM person_embeddings
|
|
1496
|
+
ORDER BY distance
|
|
1497
|
+
LIMIT ?
|
|
1498
|
+
"""
|
|
1499
|
+
cursor = conn.execute(query, (query_blob, top_k))
|
|
1500
|
+
|
|
1501
|
+
results = []
|
|
1502
|
+
for row in cursor:
|
|
1503
|
+
person_id = row["person_id"]
|
|
1504
|
+
distance = row["distance"]
|
|
1505
|
+
similarity = 1.0 - distance
|
|
1506
|
+
|
|
1507
|
+
record = self._get_record_by_id(person_id)
|
|
1508
|
+
if record:
|
|
1509
|
+
results.append((record, similarity))
|
|
1510
|
+
|
|
1511
|
+
return results
|
|
1512
|
+
|
|
1513
|
+
def _get_record_by_id(self, person_id: int) -> Optional[PersonRecord]:
|
|
1514
|
+
"""Get a person record by ID."""
|
|
1515
|
+
conn = self._conn
|
|
1516
|
+
assert conn is not None
|
|
1517
|
+
|
|
1518
|
+
cursor = conn.execute("""
|
|
1519
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
1520
|
+
FROM people WHERE id = ?
|
|
1521
|
+
""", (person_id,))
|
|
1522
|
+
|
|
1523
|
+
row = cursor.fetchone()
|
|
1524
|
+
if row:
|
|
1525
|
+
return PersonRecord(
|
|
1526
|
+
name=row["name"],
|
|
1527
|
+
source=row["source"],
|
|
1528
|
+
source_id=row["source_id"],
|
|
1529
|
+
country=row["country"] or "",
|
|
1530
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1531
|
+
known_for_role=row["known_for_role"] or "",
|
|
1532
|
+
known_for_org=row["known_for_org"] or "",
|
|
1533
|
+
record=json.loads(row["record"]),
|
|
1534
|
+
)
|
|
1535
|
+
return None
|
|
1536
|
+
|
|
1537
|
+
def get_by_source_id(self, source: str, source_id: str) -> Optional[PersonRecord]:
|
|
1538
|
+
"""Get a person record by source and source_id."""
|
|
1539
|
+
conn = self._connect()
|
|
1540
|
+
|
|
1541
|
+
cursor = conn.execute("""
|
|
1542
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
1543
|
+
FROM people
|
|
1544
|
+
WHERE source = ? AND source_id = ?
|
|
1545
|
+
""", (source, source_id))
|
|
1546
|
+
|
|
1547
|
+
row = cursor.fetchone()
|
|
1548
|
+
if row:
|
|
1549
|
+
return PersonRecord(
|
|
1550
|
+
name=row["name"],
|
|
1551
|
+
source=row["source"],
|
|
1552
|
+
source_id=row["source_id"],
|
|
1553
|
+
country=row["country"] or "",
|
|
1554
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1555
|
+
known_for_role=row["known_for_role"] or "",
|
|
1556
|
+
known_for_org=row["known_for_org"] or "",
|
|
1557
|
+
record=json.loads(row["record"]),
|
|
1558
|
+
)
|
|
1559
|
+
return None
|
|
1560
|
+
|
|
1561
|
+
def get_stats(self) -> dict:
|
|
1562
|
+
"""Get database statistics for people table."""
|
|
1563
|
+
conn = self._connect()
|
|
1564
|
+
|
|
1565
|
+
# Total count
|
|
1566
|
+
cursor = conn.execute("SELECT COUNT(*) FROM people")
|
|
1567
|
+
total = cursor.fetchone()[0]
|
|
1568
|
+
|
|
1569
|
+
# Count by person_type
|
|
1570
|
+
cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
|
|
1571
|
+
by_type = {row["person_type"]: row["cnt"] for row in cursor}
|
|
1572
|
+
|
|
1573
|
+
# Count by source
|
|
1574
|
+
cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
|
|
1575
|
+
by_source = {row["source"]: row["cnt"] for row in cursor}
|
|
1576
|
+
|
|
1577
|
+
return {
|
|
1578
|
+
"total_records": total,
|
|
1579
|
+
"by_type": by_type,
|
|
1580
|
+
"by_source": by_source,
|
|
1581
|
+
}
|
|
1582
|
+
|
|
1583
|
+
def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
|
|
1584
|
+
"""Iterate over all person records, optionally filtered by source."""
|
|
1585
|
+
conn = self._connect()
|
|
1586
|
+
|
|
1587
|
+
if source:
|
|
1588
|
+
cursor = conn.execute("""
|
|
1589
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
1590
|
+
FROM people
|
|
1591
|
+
WHERE source = ?
|
|
1592
|
+
""", (source,))
|
|
1593
|
+
else:
|
|
1594
|
+
cursor = conn.execute("""
|
|
1595
|
+
SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
|
|
1596
|
+
FROM people
|
|
1597
|
+
""")
|
|
1598
|
+
|
|
1599
|
+
for row in cursor:
|
|
1600
|
+
yield PersonRecord(
|
|
1601
|
+
name=row["name"],
|
|
1602
|
+
source=row["source"],
|
|
1603
|
+
source_id=row["source_id"],
|
|
1604
|
+
country=row["country"] or "",
|
|
1605
|
+
person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
|
|
1606
|
+
known_for_role=row["known_for_role"] or "",
|
|
1607
|
+
known_for_org=row["known_for_org"] or "",
|
|
1608
|
+
record=json.loads(row["record"]),
|
|
1609
|
+
)
|