corp-extractor 0.5.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +191 -24
  2. corp_extractor-0.9.0.dist-info/RECORD +76 -0
  3. statement_extractor/__init__.py +1 -1
  4. statement_extractor/cli.py +1227 -10
  5. statement_extractor/data/statement_taxonomy.json +6949 -1159
  6. statement_extractor/database/__init__.py +52 -0
  7. statement_extractor/database/embeddings.py +186 -0
  8. statement_extractor/database/hub.py +520 -0
  9. statement_extractor/database/importers/__init__.py +24 -0
  10. statement_extractor/database/importers/companies_house.py +545 -0
  11. statement_extractor/database/importers/gleif.py +538 -0
  12. statement_extractor/database/importers/sec_edgar.py +375 -0
  13. statement_extractor/database/importers/wikidata.py +1012 -0
  14. statement_extractor/database/importers/wikidata_people.py +632 -0
  15. statement_extractor/database/models.py +230 -0
  16. statement_extractor/database/resolver.py +245 -0
  17. statement_extractor/database/store.py +1609 -0
  18. statement_extractor/document/__init__.py +62 -0
  19. statement_extractor/document/chunker.py +410 -0
  20. statement_extractor/document/context.py +171 -0
  21. statement_extractor/document/deduplicator.py +173 -0
  22. statement_extractor/document/html_extractor.py +246 -0
  23. statement_extractor/document/loader.py +303 -0
  24. statement_extractor/document/pipeline.py +388 -0
  25. statement_extractor/document/summarizer.py +195 -0
  26. statement_extractor/models/__init__.py +16 -1
  27. statement_extractor/models/canonical.py +44 -1
  28. statement_extractor/models/document.py +308 -0
  29. statement_extractor/models/labels.py +47 -18
  30. statement_extractor/models/qualifiers.py +51 -3
  31. statement_extractor/models/statement.py +26 -0
  32. statement_extractor/pipeline/config.py +6 -11
  33. statement_extractor/pipeline/orchestrator.py +80 -111
  34. statement_extractor/pipeline/registry.py +52 -46
  35. statement_extractor/plugins/__init__.py +20 -8
  36. statement_extractor/plugins/base.py +334 -64
  37. statement_extractor/plugins/extractors/gliner2.py +10 -0
  38. statement_extractor/plugins/labelers/taxonomy.py +18 -5
  39. statement_extractor/plugins/labelers/taxonomy_embedding.py +17 -6
  40. statement_extractor/plugins/pdf/__init__.py +10 -0
  41. statement_extractor/plugins/pdf/pypdf.py +291 -0
  42. statement_extractor/plugins/qualifiers/__init__.py +11 -0
  43. statement_extractor/plugins/qualifiers/companies_house.py +14 -3
  44. statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
  45. statement_extractor/plugins/qualifiers/gleif.py +14 -3
  46. statement_extractor/plugins/qualifiers/person.py +578 -14
  47. statement_extractor/plugins/qualifiers/sec_edgar.py +14 -3
  48. statement_extractor/plugins/scrapers/__init__.py +10 -0
  49. statement_extractor/plugins/scrapers/http.py +236 -0
  50. statement_extractor/plugins/splitters/t5_gemma.py +158 -53
  51. statement_extractor/plugins/taxonomy/embedding.py +193 -46
  52. statement_extractor/plugins/taxonomy/mnli.py +16 -4
  53. statement_extractor/scoring.py +8 -8
  54. corp_extractor-0.5.0.dist-info/RECORD +0 -55
  55. statement_extractor/plugins/canonicalizers/__init__.py +0 -17
  56. statement_extractor/plugins/canonicalizers/base.py +0 -9
  57. statement_extractor/plugins/canonicalizers/location.py +0 -219
  58. statement_extractor/plugins/canonicalizers/organization.py +0 -230
  59. statement_extractor/plugins/canonicalizers/person.py +0 -242
  60. {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
  61. {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1609 @@
1
+ """
2
+ Entity/Organization database with sqlite-vec for vector search.
3
+
4
+ Uses a hybrid approach:
5
+ 1. Text-based filtering to narrow candidates (Levenshtein-like)
6
+ 2. sqlite-vec vector search for semantic ranking
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ import re
12
+ import sqlite3
13
+ import time
14
+ from pathlib import Path
15
+ from typing import Iterator, Optional
16
+
17
+ import numpy as np
18
+ import sqlite_vec
19
+
20
+ from .models import CompanyRecord, DatabaseStats, EntityType, PersonRecord, PersonType
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Default database location
25
+ DEFAULT_DB_PATH = Path.home() / ".cache" / "corp-extractor" / "entities.db"
26
+
27
+ # Module-level singleton for OrganizationDatabase to prevent multiple loads
28
+ _database_instances: dict[str, "OrganizationDatabase"] = {}
29
+
30
+ # Module-level singleton for PersonDatabase
31
+ _person_database_instances: dict[str, "PersonDatabase"] = {}
32
+
33
+ # Comprehensive set of corporate legal suffixes (international)
34
+ COMPANY_SUFFIXES: set[str] = {
35
+ 'A/S', 'AB', 'AG', 'AO', 'AG & Co', 'AG &', 'AG & CO.', 'AG & CO. KG', 'AG & CO. KGaA',
36
+ 'AG & KG', 'AG & KGaA', 'AG & PARTNER', 'ATE', 'ASA', 'B.V.', 'BV', 'Class A', 'Class B',
37
+ 'Class C', 'Class D', 'Class E', 'Class F', 'Class G', 'CO', 'Co', 'Co.', 'Company',
38
+ 'Corp', 'Corp.', 'Corporation', 'DAC', 'GmbH', 'Inc', 'Inc.', 'Incorporated', 'KGaA',
39
+ 'Limited', 'LLC', 'LLP', 'LP', 'Ltd', 'Ltd.', 'N.V.', 'NV', 'Plc', 'PC', 'plc', 'PLC',
40
+ 'Pty Ltd', 'Pty', 'Pty. Ltd.', 'S.A.', 'S.A.B. de C.V.', 'SAB de CV', 'S.A.B.', 'S.A.P.I.',
41
+ 'NV/SA', 'SDI', 'SpA', 'S.L.', 'S.p.A.', 'SA', 'SE', 'Tbk PT', 'U.A.',
42
+ # Additional common suffixes
43
+ 'Group', 'Holdings', 'Holding', 'Partners', 'Trust', 'Fund', 'Bank', 'N.A.', 'The',
44
+ }
45
+
46
+ # Pre-compile the suffix pattern for performance
47
+ _SUFFIX_PATTERN = re.compile(
48
+ r'\s+(' + '|'.join(re.escape(suffix) for suffix in COMPANY_SUFFIXES) + r')\.?$',
49
+ re.IGNORECASE
50
+ )
51
+
52
+
53
+ def _clean_org_name(name: str | None) -> str:
54
+ """
55
+ Remove special characters and formatting from organization name.
56
+
57
+ Removes brackets, parentheses, quotes, and other formatting artifacts.
58
+ """
59
+ if not name:
60
+ return ""
61
+ # Remove special characters, keeping only alphanumeric and spaces
62
+ cleaned = re.sub(r'[•;:\'"\[\](){}<>`~!@#$%^&*\-_=+\\|/?!`~]+', ' ', name)
63
+ cleaned = re.sub(r'\s+', ' ', cleaned).strip()
64
+ # Recurse if changes were made (handles nested special chars)
65
+ return _clean_org_name(cleaned) if cleaned != name else cleaned
66
+
67
+
68
+ def _remove_suffix(name: str) -> str:
69
+ """
70
+ Remove corporate legal suffixes from company name.
71
+
72
+ Iteratively removes suffixes until no more are found.
73
+ Also removes possessive 's and trailing punctuation.
74
+ """
75
+ cleaned = name.strip()
76
+ cleaned = re.sub(r'\s+', ' ', cleaned)
77
+ # Remove possessive 's (e.g., "Amazon's" -> "Amazon")
78
+ cleaned = re.sub(r"'s\b", "", cleaned)
79
+ cleaned = re.sub(r'\s+', ' ', cleaned).strip()
80
+
81
+ while True:
82
+ new_name = _SUFFIX_PATTERN.sub('', cleaned)
83
+ # Remove trailing punctuation
84
+ new_name = re.sub(r'[ .,;&\n\t/)]$', '', new_name)
85
+
86
+ if new_name == cleaned:
87
+ break
88
+ cleaned = new_name.strip()
89
+
90
+ return cleaned.strip()
91
+
92
+
93
+ def _normalize_name(name: str) -> str:
94
+ """
95
+ Normalize company name for text matching.
96
+
97
+ 1. Remove possessive 's (before cleaning removes apostrophe)
98
+ 2. Clean special characters
99
+ 3. Remove legal suffixes
100
+ 4. Lowercase
101
+ 5. If result is empty, use cleaned lowercase original
102
+
103
+ Always returns a non-empty string for valid input.
104
+ """
105
+ if not name:
106
+ return ""
107
+ # Remove possessive 's first (before cleaning removes the apostrophe)
108
+ normalized = re.sub(r"'s\b", "", name)
109
+ # Clean special characters
110
+ cleaned = _clean_org_name(normalized)
111
+ # Remove legal suffixes
112
+ normalized = _remove_suffix(cleaned)
113
+ # Lowercase for matching
114
+ normalized = normalized.lower()
115
+ # If normalized is empty (e.g., name was just "Ltd"), use the cleaned name
116
+ if not normalized:
117
+ normalized = cleaned.lower() if cleaned else name.lower()
118
+ return normalized
119
+
120
+
121
+ def _extract_search_terms(query: str) -> list[str]:
122
+ """
123
+ Extract search terms from a query for SQL LIKE matching.
124
+
125
+ Returns list of terms to search for, ordered by length (longest first).
126
+ """
127
+ # Split into words
128
+ words = query.split()
129
+
130
+ # Filter out very short words (< 3 chars) unless it's the only word
131
+ if len(words) > 1:
132
+ words = [w for w in words if len(w) >= 3]
133
+
134
+ # Sort by length descending (longer words are more specific)
135
+ words.sort(key=len, reverse=True)
136
+
137
+ return words[:3] # Limit to top 3 terms
138
+
139
+
140
+ # Person name normalization patterns
141
+ _PERSON_PREFIXES = {
142
+ "dr.", "dr", "prof.", "prof", "professor",
143
+ "mr.", "mr", "mrs.", "mrs", "ms.", "ms", "miss",
144
+ "sir", "dame", "lord", "lady",
145
+ "rev.", "rev", "reverend",
146
+ "hon.", "hon", "honorable",
147
+ "gen.", "gen", "general",
148
+ "col.", "col", "colonel",
149
+ "capt.", "capt", "captain",
150
+ "lt.", "lt", "lieutenant",
151
+ "sgt.", "sgt", "sergeant",
152
+ }
153
+
154
+ _PERSON_SUFFIXES = {
155
+ "jr.", "jr", "junior",
156
+ "sr.", "sr", "senior",
157
+ "ii", "iii", "iv", "v",
158
+ "2nd", "3rd", "4th", "5th",
159
+ "phd", "ph.d.", "ph.d",
160
+ "md", "m.d.", "m.d",
161
+ "esq", "esq.",
162
+ "mba", "m.b.a.",
163
+ "cpa", "c.p.a.",
164
+ "jd", "j.d.",
165
+ }
166
+
167
+
168
+ def _normalize_person_name(name: str) -> str:
169
+ """
170
+ Normalize person name for text matching.
171
+
172
+ 1. Remove honorific prefixes (Dr., Prof., Mr., etc.)
173
+ 2. Remove generational suffixes (Jr., Sr., III, PhD, etc.)
174
+ 3. Keep name particles (von, van, de, al-, etc.)
175
+ 4. Lowercase and strip
176
+
177
+ Always returns a non-empty string for valid input.
178
+ """
179
+ if not name:
180
+ return ""
181
+
182
+ # Lowercase for matching
183
+ normalized = name.lower().strip()
184
+
185
+ # Split into words
186
+ words = normalized.split()
187
+ if not words:
188
+ return ""
189
+
190
+ # Remove prefix if first word is a title
191
+ while words and words[0].rstrip(".") in _PERSON_PREFIXES:
192
+ words.pop(0)
193
+ if not words:
194
+ return name.lower().strip() # Fallback if name was just a title
195
+
196
+ # Remove suffix if last word is a suffix
197
+ while words and words[-1].rstrip(".") in _PERSON_SUFFIXES:
198
+ words.pop()
199
+ if not words:
200
+ return name.lower().strip() # Fallback if name was just suffixes
201
+
202
+ # Rejoin remaining words
203
+ normalized = " ".join(words)
204
+
205
+ # Clean up extra spaces
206
+ normalized = re.sub(r'\s+', ' ', normalized).strip()
207
+
208
+ return normalized if normalized else name.lower().strip()
209
+
210
+
211
+ def get_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "OrganizationDatabase":
212
+ """
213
+ Get a singleton OrganizationDatabase instance for the given path.
214
+
215
+ Args:
216
+ db_path: Path to database file
217
+ embedding_dim: Dimension of embeddings
218
+
219
+ Returns:
220
+ Shared OrganizationDatabase instance
221
+ """
222
+ path_key = str(db_path or DEFAULT_DB_PATH)
223
+ if path_key not in _database_instances:
224
+ logger.debug(f"Creating new OrganizationDatabase instance for {path_key}")
225
+ _database_instances[path_key] = OrganizationDatabase(db_path=db_path, embedding_dim=embedding_dim)
226
+ return _database_instances[path_key]
227
+
228
+
229
+ class OrganizationDatabase:
230
+ """
231
+ SQLite database with sqlite-vec for organization vector search.
232
+
233
+ Uses hybrid text + vector search:
234
+ 1. Text filtering with Levenshtein distance to reduce candidates
235
+ 2. sqlite-vec for semantic similarity ranking
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ db_path: Optional[str | Path] = None,
241
+ embedding_dim: int = 768, # Default for embeddinggemma-300m
242
+ ):
243
+ """
244
+ Initialize the organization database.
245
+
246
+ Args:
247
+ db_path: Path to database file (creates if not exists)
248
+ embedding_dim: Dimension of embeddings to store
249
+ """
250
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
251
+ self._embedding_dim = embedding_dim
252
+ self._conn: Optional[sqlite3.Connection] = None
253
+
254
+ def _ensure_dir(self) -> None:
255
+ """Ensure database directory exists."""
256
+ self._db_path.parent.mkdir(parents=True, exist_ok=True)
257
+
258
+ def _connect(self) -> sqlite3.Connection:
259
+ """Get or create database connection with sqlite-vec loaded."""
260
+ if self._conn is not None:
261
+ return self._conn
262
+
263
+ self._ensure_dir()
264
+ self._conn = sqlite3.connect(str(self._db_path))
265
+ self._conn.row_factory = sqlite3.Row
266
+
267
+ # Load sqlite-vec extension
268
+ self._conn.enable_load_extension(True)
269
+ sqlite_vec.load(self._conn)
270
+ self._conn.enable_load_extension(False)
271
+
272
+ # Create tables
273
+ self._create_tables()
274
+
275
+ return self._conn
276
+
277
+ def _create_tables(self) -> None:
278
+ """Create database tables including sqlite-vec virtual table."""
279
+ conn = self._conn
280
+ assert conn is not None
281
+
282
+ # Main organization records table
283
+ conn.execute("""
284
+ CREATE TABLE IF NOT EXISTS organizations (
285
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
286
+ name TEXT NOT NULL,
287
+ name_normalized TEXT NOT NULL,
288
+ source TEXT NOT NULL,
289
+ source_id TEXT NOT NULL,
290
+ region TEXT NOT NULL DEFAULT '',
291
+ entity_type TEXT NOT NULL DEFAULT 'unknown',
292
+ record TEXT NOT NULL,
293
+ UNIQUE(source, source_id)
294
+ )
295
+ """)
296
+
297
+ # Add region column if it doesn't exist (migration for existing DBs)
298
+ try:
299
+ conn.execute("ALTER TABLE organizations ADD COLUMN region TEXT NOT NULL DEFAULT ''")
300
+ logger.info("Added region column to organizations table")
301
+ except sqlite3.OperationalError:
302
+ pass # Column already exists
303
+
304
+ # Add entity_type column if it doesn't exist (migration for existing DBs)
305
+ try:
306
+ conn.execute("ALTER TABLE organizations ADD COLUMN entity_type TEXT NOT NULL DEFAULT 'unknown'")
307
+ logger.info("Added entity_type column to organizations table")
308
+ except sqlite3.OperationalError:
309
+ pass # Column already exists
310
+
311
+ # Create indexes on main table
312
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
313
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
314
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source ON organizations(source)")
315
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source_id ON organizations(source, source_id)")
316
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
317
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
318
+ conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
319
+
320
+ # Create sqlite-vec virtual table for embeddings
321
+ # vec0 is the recommended virtual table type
322
+ conn.execute(f"""
323
+ CREATE VIRTUAL TABLE IF NOT EXISTS organization_embeddings USING vec0(
324
+ org_id INTEGER PRIMARY KEY,
325
+ embedding float[{self._embedding_dim}]
326
+ )
327
+ """)
328
+
329
+ conn.commit()
330
+
331
+ def close(self) -> None:
332
+ """Close database connection."""
333
+ if self._conn:
334
+ self._conn.close()
335
+ self._conn = None
336
+
337
+ def insert(self, record: CompanyRecord, embedding: np.ndarray) -> int:
338
+ """
339
+ Insert an organization record with its embedding.
340
+
341
+ Args:
342
+ record: Organization record to insert
343
+ embedding: Embedding vector for the organization name
344
+
345
+ Returns:
346
+ Row ID of inserted record
347
+ """
348
+ conn = self._connect()
349
+
350
+ # Serialize record
351
+ record_json = json.dumps(record.record)
352
+ name_normalized = _normalize_name(record.name)
353
+
354
+ cursor = conn.execute("""
355
+ INSERT OR REPLACE INTO organizations
356
+ (name, name_normalized, source, source_id, region, entity_type, record)
357
+ VALUES (?, ?, ?, ?, ?, ?, ?)
358
+ """, (
359
+ record.name,
360
+ name_normalized,
361
+ record.source,
362
+ record.source_id,
363
+ record.region,
364
+ record.entity_type.value,
365
+ record_json,
366
+ ))
367
+
368
+ row_id = cursor.lastrowid
369
+ assert row_id is not None
370
+
371
+ # Insert embedding into vec table
372
+ # sqlite-vec expects the embedding as a blob
373
+ embedding_blob = embedding.astype(np.float32).tobytes()
374
+ conn.execute("""
375
+ INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
376
+ VALUES (?, ?)
377
+ """, (row_id, embedding_blob))
378
+
379
+ conn.commit()
380
+ return row_id
381
+
382
+ def insert_batch(
383
+ self,
384
+ records: list[CompanyRecord],
385
+ embeddings: np.ndarray,
386
+ batch_size: int = 1000,
387
+ ) -> int:
388
+ """
389
+ Insert multiple organization records with embeddings.
390
+
391
+ Args:
392
+ records: List of organization records
393
+ embeddings: Matrix of embeddings (N x dim)
394
+ batch_size: Commit batch size
395
+
396
+ Returns:
397
+ Number of records inserted
398
+ """
399
+ conn = self._connect()
400
+ count = 0
401
+
402
+ for record, embedding in zip(records, embeddings):
403
+ record_json = json.dumps(record.record)
404
+ name_normalized = _normalize_name(record.name)
405
+
406
+ cursor = conn.execute("""
407
+ INSERT OR REPLACE INTO organizations
408
+ (name, name_normalized, source, source_id, region, entity_type, record)
409
+ VALUES (?, ?, ?, ?, ?, ?, ?)
410
+ """, (
411
+ record.name,
412
+ name_normalized,
413
+ record.source,
414
+ record.source_id,
415
+ record.region,
416
+ record.entity_type.value,
417
+ record_json,
418
+ ))
419
+
420
+ row_id = cursor.lastrowid
421
+ assert row_id is not None
422
+
423
+ # Insert embedding
424
+ embedding_blob = embedding.astype(np.float32).tobytes()
425
+ conn.execute("""
426
+ INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
427
+ VALUES (?, ?)
428
+ """, (row_id, embedding_blob))
429
+
430
+ count += 1
431
+
432
+ if count % batch_size == 0:
433
+ conn.commit()
434
+ logger.info(f"Inserted {count} records...")
435
+
436
+ conn.commit()
437
+ return count
438
+
439
+ def search(
440
+ self,
441
+ query_embedding: np.ndarray,
442
+ top_k: int = 20,
443
+ source_filter: Optional[str] = None,
444
+ query_text: Optional[str] = None,
445
+ max_text_candidates: int = 5000,
446
+ ) -> list[tuple[CompanyRecord, float]]:
447
+ """
448
+ Search for similar organizations using hybrid text + vector search.
449
+
450
+ Two-stage approach:
451
+ 1. If query_text provided, use SQL LIKE to find candidates containing search terms
452
+ 2. Use sqlite-vec for vector similarity ranking on filtered candidates
453
+
454
+ Args:
455
+ query_embedding: Query embedding vector
456
+ top_k: Number of results to return
457
+ source_filter: Optional filter by source (gleif, sec_edgar, etc.)
458
+ query_text: Optional query text for text-based pre-filtering
459
+ max_text_candidates: Max candidates to keep after text filtering
460
+
461
+ Returns:
462
+ List of (CompanyRecord, similarity_score) tuples
463
+ """
464
+ start = time.time()
465
+ self._connect()
466
+
467
+ # Normalize query embedding
468
+ query_norm = np.linalg.norm(query_embedding)
469
+ if query_norm == 0:
470
+ return []
471
+ query_normalized = query_embedding / query_norm
472
+ query_blob = query_normalized.astype(np.float32).tobytes()
473
+
474
+ # Stage 1: Text-based pre-filtering (if query_text provided)
475
+ candidate_ids: Optional[set[int]] = None
476
+ if query_text:
477
+ query_normalized_text = _normalize_name(query_text)
478
+ if query_normalized_text:
479
+ candidate_ids = self._text_filter_candidates(
480
+ query_normalized_text,
481
+ max_candidates=max_text_candidates,
482
+ source_filter=source_filter,
483
+ )
484
+ logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
485
+
486
+ # Stage 2: Vector search
487
+ if candidate_ids is not None and len(candidate_ids) == 0:
488
+ # No text matches, return empty
489
+ return []
490
+
491
+ if candidate_ids is not None:
492
+ # Search within text-filtered candidates
493
+ results = self._vector_search_filtered(
494
+ query_blob, candidate_ids, top_k, source_filter
495
+ )
496
+ else:
497
+ # Full vector search
498
+ results = self._vector_search_full(query_blob, top_k, source_filter)
499
+
500
+ elapsed = time.time() - start
501
+ logger.debug(f"Hybrid search took {elapsed:.3f}s (results={len(results)})")
502
+ return results
503
+
504
+ def _text_filter_candidates(
505
+ self,
506
+ query_normalized: str,
507
+ max_candidates: int,
508
+ source_filter: Optional[str] = None,
509
+ ) -> set[int]:
510
+ """
511
+ Filter candidates using SQL LIKE for fast text matching.
512
+
513
+ This is a generous pre-filter to reduce the embedding search space.
514
+ Returns set of organization IDs that contain any search term.
515
+ Uses `name_normalized` column for consistent matching.
516
+ """
517
+ conn = self._conn
518
+ assert conn is not None
519
+
520
+ # Extract search terms from the normalized query
521
+ search_terms = _extract_search_terms(query_normalized)
522
+ if not search_terms:
523
+ return set()
524
+
525
+ logger.debug(f"Text filter search terms: {search_terms}")
526
+
527
+ # Build OR clause for LIKE matching on any term
528
+ # Use name_normalized for consistent matching (already lowercased, suffixes removed)
529
+ like_clauses = []
530
+ params: list = []
531
+ for term in search_terms:
532
+ like_clauses.append("name_normalized LIKE ?")
533
+ params.append(f"%{term}%")
534
+
535
+ where_clause = " OR ".join(like_clauses)
536
+
537
+ # Add source filter if specified
538
+ if source_filter:
539
+ query = f"""
540
+ SELECT id FROM organizations
541
+ WHERE ({where_clause}) AND source = ?
542
+ LIMIT ?
543
+ """
544
+ params.append(source_filter)
545
+ else:
546
+ query = f"""
547
+ SELECT id FROM organizations
548
+ WHERE {where_clause}
549
+ LIMIT ?
550
+ """
551
+
552
+ params.append(max_candidates)
553
+
554
+ cursor = conn.execute(query, params)
555
+ return set(row["id"] for row in cursor)
556
+
557
+ def _vector_search_filtered(
558
+ self,
559
+ query_blob: bytes,
560
+ candidate_ids: set[int],
561
+ top_k: int,
562
+ source_filter: Optional[str],
563
+ ) -> list[tuple[CompanyRecord, float]]:
564
+ """Vector search within a filtered set of candidates."""
565
+ conn = self._conn
566
+ assert conn is not None
567
+
568
+ if not candidate_ids:
569
+ return []
570
+
571
+ # Build IN clause for candidate IDs
572
+ placeholders = ",".join("?" * len(candidate_ids))
573
+
574
+ # Query sqlite-vec with KNN search, filtered by candidate IDs
575
+ # Using distance function - lower is more similar for L2
576
+ # We'll use cosine distance
577
+ query = f"""
578
+ SELECT
579
+ e.org_id,
580
+ vec_distance_cosine(e.embedding, ?) as distance
581
+ FROM organization_embeddings e
582
+ WHERE e.org_id IN ({placeholders})
583
+ ORDER BY distance
584
+ LIMIT ?
585
+ """
586
+
587
+ cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
588
+
589
+ results = []
590
+ for row in cursor:
591
+ org_id = row["org_id"]
592
+ distance = row["distance"]
593
+ # Convert cosine distance to similarity (1 - distance)
594
+ similarity = 1.0 - distance
595
+
596
+ # Fetch full record
597
+ record = self._get_record_by_id(org_id)
598
+ if record:
599
+ # Apply source filter if specified
600
+ if source_filter and record.source != source_filter:
601
+ continue
602
+ results.append((record, similarity))
603
+
604
+ return results
605
+
606
+ def _vector_search_full(
607
+ self,
608
+ query_blob: bytes,
609
+ top_k: int,
610
+ source_filter: Optional[str],
611
+ ) -> list[tuple[CompanyRecord, float]]:
612
+ """Full vector search without text pre-filtering."""
613
+ conn = self._conn
614
+ assert conn is not None
615
+
616
+ # KNN search with sqlite-vec
617
+ if source_filter:
618
+ # Need to join with organizations table for source filter
619
+ query = """
620
+ SELECT
621
+ e.org_id,
622
+ vec_distance_cosine(e.embedding, ?) as distance
623
+ FROM organization_embeddings e
624
+ JOIN organizations c ON e.org_id = c.id
625
+ WHERE c.source = ?
626
+ ORDER BY distance
627
+ LIMIT ?
628
+ """
629
+ cursor = conn.execute(query, (query_blob, source_filter, top_k))
630
+ else:
631
+ query = """
632
+ SELECT
633
+ org_id,
634
+ vec_distance_cosine(embedding, ?) as distance
635
+ FROM organization_embeddings
636
+ ORDER BY distance
637
+ LIMIT ?
638
+ """
639
+ cursor = conn.execute(query, (query_blob, top_k))
640
+
641
+ results = []
642
+ for row in cursor:
643
+ org_id = row["org_id"]
644
+ distance = row["distance"]
645
+ similarity = 1.0 - distance
646
+
647
+ record = self._get_record_by_id(org_id)
648
+ if record:
649
+ results.append((record, similarity))
650
+
651
+ return results
652
+
653
+ def _get_record_by_id(self, org_id: int) -> Optional[CompanyRecord]:
654
+ """Get an organization record by ID."""
655
+ conn = self._conn
656
+ assert conn is not None
657
+
658
+ cursor = conn.execute("""
659
+ SELECT name, source, source_id, region, entity_type, record
660
+ FROM organizations WHERE id = ?
661
+ """, (org_id,))
662
+
663
+ row = cursor.fetchone()
664
+ if row:
665
+ return CompanyRecord(
666
+ name=row["name"],
667
+ source=row["source"],
668
+ source_id=row["source_id"],
669
+ region=row["region"] or "",
670
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
671
+ record=json.loads(row["record"]),
672
+ )
673
+ return None
674
+
675
+ def get_by_source_id(self, source: str, source_id: str) -> Optional[CompanyRecord]:
676
+ """Get an organization record by source and source_id."""
677
+ conn = self._connect()
678
+
679
+ cursor = conn.execute("""
680
+ SELECT name, source, source_id, region, entity_type, record
681
+ FROM organizations
682
+ WHERE source = ? AND source_id = ?
683
+ """, (source, source_id))
684
+
685
+ row = cursor.fetchone()
686
+ if row:
687
+ return CompanyRecord(
688
+ name=row["name"],
689
+ source=row["source"],
690
+ source_id=row["source_id"],
691
+ region=row["region"] or "",
692
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
693
+ record=json.loads(row["record"]),
694
+ )
695
+ return None
696
+
697
+ def get_stats(self) -> DatabaseStats:
698
+ """Get database statistics."""
699
+ conn = self._connect()
700
+
701
+ # Total count
702
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations")
703
+ total = cursor.fetchone()[0]
704
+
705
+ # Count by source
706
+ cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM organizations GROUP BY source")
707
+ by_source = {row["source"]: row["cnt"] for row in cursor}
708
+
709
+ # Database file size
710
+ db_size = self._db_path.stat().st_size if self._db_path.exists() else 0
711
+
712
+ return DatabaseStats(
713
+ total_records=total,
714
+ by_source=by_source,
715
+ embedding_dimension=self._embedding_dim,
716
+ database_size_bytes=db_size,
717
+ )
718
+
719
+ def iter_records(self, source: Optional[str] = None) -> Iterator[CompanyRecord]:
720
+ """Iterate over all records, optionally filtered by source."""
721
+ conn = self._connect()
722
+
723
+ if source:
724
+ cursor = conn.execute("""
725
+ SELECT name, source, source_id, region, entity_type, record
726
+ FROM organizations
727
+ WHERE source = ?
728
+ """, (source,))
729
+ else:
730
+ cursor = conn.execute("""
731
+ SELECT name, source, source_id, region, entity_type, record
732
+ FROM organizations
733
+ """)
734
+
735
+ for row in cursor:
736
+ yield CompanyRecord(
737
+ name=row["name"],
738
+ source=row["source"],
739
+ source_id=row["source_id"],
740
+ region=row["region"] or "",
741
+ entity_type=EntityType(row["entity_type"]) if row["entity_type"] else EntityType.UNKNOWN,
742
+ record=json.loads(row["record"]),
743
+ )
744
+
745
+ def migrate_name_normalized(self, batch_size: int = 50000) -> int:
746
+ """
747
+ Populate the name_normalized column for all records.
748
+
749
+ This is a one-time migration for databases that don't have
750
+ normalized names populated.
751
+
752
+ Args:
753
+ batch_size: Number of records to process per batch
754
+
755
+ Returns:
756
+ Number of records updated
757
+ """
758
+ conn = self._connect()
759
+
760
+ # Check how many need migration (empty, null, or placeholder "-")
761
+ cursor = conn.execute(
762
+ "SELECT COUNT(*) FROM organizations WHERE name_normalized = '' OR name_normalized IS NULL OR name_normalized = '-'"
763
+ )
764
+ empty_count = cursor.fetchone()[0]
765
+
766
+ if empty_count == 0:
767
+ logger.info("All records already have name_normalized populated")
768
+ return 0
769
+
770
+ logger.info(f"Populating name_normalized for {empty_count} records...")
771
+
772
+ updated = 0
773
+ last_id = 0
774
+
775
+ while True:
776
+ # Get batch of records that need normalization, ordered by ID
777
+ cursor = conn.execute("""
778
+ SELECT id, name FROM organizations
779
+ WHERE id > ? AND (name_normalized = '' OR name_normalized IS NULL OR name_normalized = '-')
780
+ ORDER BY id
781
+ LIMIT ?
782
+ """, (last_id, batch_size))
783
+
784
+ rows = cursor.fetchall()
785
+ if not rows:
786
+ break
787
+
788
+ # Update each record
789
+ for row in rows:
790
+ # _normalize_name now always returns non-empty for valid input
791
+ normalized = _normalize_name(row["name"])
792
+ conn.execute(
793
+ "UPDATE organizations SET name_normalized = ? WHERE id = ?",
794
+ (normalized, row["id"])
795
+ )
796
+ last_id = row["id"]
797
+
798
+ conn.commit()
799
+ updated += len(rows)
800
+ logger.info(f" Updated {updated}/{empty_count} records...")
801
+
802
+ logger.info(f"Migration complete: {updated} name_normalized values populated")
803
+ return updated
804
+
805
+ def migrate_to_sqlite_vec(self, batch_size: int = 10000) -> int:
806
+ """
807
+ Migrate embeddings from BLOB column to sqlite-vec virtual table.
808
+
809
+ This is a one-time migration for databases created before sqlite-vec support.
810
+
811
+ Args:
812
+ batch_size: Number of records to process per batch
813
+
814
+ Returns:
815
+ Number of embeddings migrated
816
+ """
817
+ conn = self._connect()
818
+
819
+ # Check if migration is needed
820
+ cursor = conn.execute("SELECT COUNT(*) FROM organization_embeddings")
821
+ vec_count = cursor.fetchone()[0]
822
+
823
+ cursor = conn.execute("SELECT COUNT(*) FROM organizations WHERE embedding IS NOT NULL")
824
+ blob_count = cursor.fetchone()[0]
825
+
826
+ if vec_count >= blob_count:
827
+ logger.info(f"Migration not needed: sqlite-vec has {vec_count} embeddings, BLOB has {blob_count}")
828
+ return 0
829
+
830
+ logger.info(f"Migrating {blob_count} embeddings from BLOB to sqlite-vec...")
831
+
832
+ # Get IDs that need migration (in sqlite-vec but not in organizations)
833
+ cursor = conn.execute("""
834
+ SELECT c.id, c.embedding
835
+ FROM organizations c
836
+ LEFT JOIN organization_embeddings e ON c.id = e.org_id
837
+ WHERE c.embedding IS NOT NULL AND e.org_id IS NULL
838
+ """)
839
+
840
+ migrated = 0
841
+ batch = []
842
+
843
+ for row in cursor:
844
+ org_id = row["id"]
845
+ embedding_blob = row["embedding"]
846
+
847
+ if embedding_blob:
848
+ batch.append((org_id, embedding_blob))
849
+
850
+ if len(batch) >= batch_size:
851
+ self._insert_vec_batch(batch)
852
+ migrated += len(batch)
853
+ logger.info(f" Migrated {migrated}/{blob_count} embeddings...")
854
+ batch = []
855
+
856
+ # Insert remaining batch
857
+ if batch:
858
+ self._insert_vec_batch(batch)
859
+ migrated += len(batch)
860
+
861
+ logger.info(f"Migration complete: {migrated} embeddings migrated to sqlite-vec")
862
+ return migrated
863
+
864
+ def _insert_vec_batch(self, batch: list[tuple[int, bytes]]) -> None:
865
+ """Insert a batch of embeddings into sqlite-vec table."""
866
+ conn = self._conn
867
+ assert conn is not None
868
+
869
+ for org_id, embedding_blob in batch:
870
+ conn.execute("""
871
+ INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
872
+ VALUES (?, ?)
873
+ """, (org_id, embedding_blob))
874
+
875
+ conn.commit()
876
+
877
+ def delete_source(self, source: str) -> int:
878
+ """Delete all records from a specific source."""
879
+ conn = self._connect()
880
+
881
+ # First get IDs to delete from vec table
882
+ cursor = conn.execute("SELECT id FROM organizations WHERE source = ?", (source,))
883
+ ids_to_delete = [row["id"] for row in cursor]
884
+
885
+ # Delete from vec table
886
+ if ids_to_delete:
887
+ placeholders = ",".join("?" * len(ids_to_delete))
888
+ conn.execute(f"DELETE FROM organization_embeddings WHERE org_id IN ({placeholders})", ids_to_delete)
889
+
890
+ # Delete from main table
891
+ cursor = conn.execute("DELETE FROM organizations WHERE source = ?", (source,))
892
+ deleted = cursor.rowcount
893
+
894
+ conn.commit()
895
+
896
+ logger.info(f"Deleted {deleted} records from source '{source}'")
897
+ return deleted
898
+
899
+ def migrate_from_legacy_schema(self) -> dict[str, str]:
900
+ """
901
+ Migrate database from legacy schema (companies/company_embeddings tables)
902
+ to new schema (organizations/organization_embeddings tables).
903
+
904
+ This handles:
905
+ - Renaming 'companies' table to 'organizations'
906
+ - Renaming 'company_embeddings' table to 'organization_embeddings'
907
+ - Renaming 'company_id' column to 'org_id' in embeddings table
908
+ - Updating indexes to use new naming
909
+
910
+ Returns:
911
+ Dict of migrations performed (table_name -> action)
912
+ """
913
+ conn = self._connect()
914
+ migrations = {}
915
+
916
+ # Check what tables exist
917
+ cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
918
+ existing_tables = {row[0] for row in cursor}
919
+
920
+ has_companies = "companies" in existing_tables
921
+ has_organizations = "organizations" in existing_tables
922
+ has_company_embeddings = "company_embeddings" in existing_tables
923
+ has_org_embeddings = "organization_embeddings" in existing_tables
924
+
925
+ if not has_companies and not has_company_embeddings:
926
+ if has_organizations and has_org_embeddings:
927
+ logger.info("Database already uses new schema, no migration needed")
928
+ return {}
929
+ else:
930
+ logger.info("No legacy tables found, database will use new schema")
931
+ return {}
932
+
933
+ logger.info("Migrating database from legacy schema...")
934
+ conn.execute("BEGIN")
935
+
936
+ try:
937
+ # Migrate companies -> organizations
938
+ if has_companies:
939
+ if has_organizations:
940
+ # Both exist - merge data from companies into organizations
941
+ logger.info("Merging companies table into organizations...")
942
+ conn.execute("""
943
+ INSERT OR IGNORE INTO organizations
944
+ (name, name_normalized, source, source_id, region, entity_type, record)
945
+ SELECT name, name_normalized, source, source_id,
946
+ COALESCE(region, ''), COALESCE(entity_type, 'unknown'), record
947
+ FROM companies
948
+ """)
949
+ conn.execute("DROP TABLE companies")
950
+ migrations["companies"] = "merged_into_organizations"
951
+ else:
952
+ # Just rename
953
+ logger.info("Renaming companies table to organizations...")
954
+ conn.execute("ALTER TABLE companies RENAME TO organizations")
955
+ migrations["companies"] = "renamed_to_organizations"
956
+
957
+ # Update indexes
958
+ for old_idx in ["idx_companies_name", "idx_companies_name_normalized",
959
+ "idx_companies_source", "idx_companies_source_id",
960
+ "idx_companies_region", "idx_companies_entity_type",
961
+ "idx_companies_name_region_source"]:
962
+ try:
963
+ conn.execute(f"DROP INDEX IF EXISTS {old_idx}")
964
+ except Exception:
965
+ pass
966
+
967
+ # Create new indexes
968
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name ON organizations(name)")
969
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_name_normalized ON organizations(name_normalized)")
970
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source ON organizations(source)")
971
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_source_id ON organizations(source, source_id)")
972
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_region ON organizations(region)")
973
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_orgs_entity_type ON organizations(entity_type)")
974
+ conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_orgs_name_region_source ON organizations(name, region, source)")
975
+
976
+ # Migrate company_embeddings -> organization_embeddings
977
+ if has_company_embeddings:
978
+ if has_org_embeddings:
979
+ # Both exist - merge
980
+ logger.info("Merging company_embeddings into organization_embeddings...")
981
+ # Get column info to check for company_id vs org_id
982
+ cursor = conn.execute("PRAGMA table_info(company_embeddings)")
983
+ cols = {row[1] for row in cursor}
984
+ id_col = "company_id" if "company_id" in cols else "org_id"
985
+
986
+ conn.execute(f"""
987
+ INSERT OR IGNORE INTO organization_embeddings (org_id, embedding)
988
+ SELECT {id_col}, embedding FROM company_embeddings
989
+ """)
990
+ conn.execute("DROP TABLE company_embeddings")
991
+ migrations["company_embeddings"] = "merged_into_organization_embeddings"
992
+ else:
993
+ # Need to recreate with new column name
994
+ logger.info("Migrating company_embeddings to organization_embeddings...")
995
+
996
+ # Check if it has company_id or org_id column
997
+ cursor = conn.execute("PRAGMA table_info(company_embeddings)")
998
+ cols = {row[1] for row in cursor}
999
+ id_col = "company_id" if "company_id" in cols else "org_id"
1000
+
1001
+ # Create new virtual table
1002
+ conn.execute(f"""
1003
+ CREATE VIRTUAL TABLE organization_embeddings USING vec0(
1004
+ org_id INTEGER PRIMARY KEY,
1005
+ embedding float[{self._embedding_dim}]
1006
+ )
1007
+ """)
1008
+
1009
+ # Copy data
1010
+ conn.execute(f"""
1011
+ INSERT INTO organization_embeddings (org_id, embedding)
1012
+ SELECT {id_col}, embedding FROM company_embeddings
1013
+ """)
1014
+
1015
+ # Drop old table
1016
+ conn.execute("DROP TABLE company_embeddings")
1017
+ migrations["company_embeddings"] = "renamed_to_organization_embeddings"
1018
+
1019
+ conn.execute("COMMIT")
1020
+ logger.info(f"Migration complete: {migrations}")
1021
+
1022
+ except Exception as e:
1023
+ conn.execute("ROLLBACK")
1024
+ logger.error(f"Migration failed: {e}")
1025
+ raise
1026
+
1027
+ # Vacuum to clean up - outside try block since COMMIT already succeeded
1028
+ try:
1029
+ conn.execute("VACUUM")
1030
+ except Exception as e:
1031
+ logger.warning(f"VACUUM failed (migration was successful): {e}")
1032
+
1033
+ return migrations
1034
+
1035
+ def get_missing_embedding_count(self) -> int:
1036
+ """Get count of organizations without embeddings in organization_embeddings table."""
1037
+ conn = self._connect()
1038
+
1039
+ cursor = conn.execute("""
1040
+ SELECT COUNT(*) FROM organizations c
1041
+ LEFT JOIN organization_embeddings e ON c.id = e.org_id
1042
+ WHERE e.org_id IS NULL
1043
+ """)
1044
+ return cursor.fetchone()[0]
1045
+
1046
+ def get_organizations_without_embeddings(
1047
+ self,
1048
+ batch_size: int = 1000,
1049
+ source: Optional[str] = None,
1050
+ ) -> Iterator[tuple[int, str]]:
1051
+ """
1052
+ Iterate over organizations that don't have embeddings.
1053
+
1054
+ Args:
1055
+ batch_size: Number of records per batch
1056
+ source: Optional source filter
1057
+
1058
+ Yields:
1059
+ Tuples of (org_id, name)
1060
+ """
1061
+ conn = self._connect()
1062
+
1063
+ last_id = 0
1064
+ while True:
1065
+ if source:
1066
+ cursor = conn.execute("""
1067
+ SELECT c.id, c.name FROM organizations c
1068
+ LEFT JOIN organization_embeddings e ON c.id = e.org_id
1069
+ WHERE e.org_id IS NULL AND c.id > ? AND c.source = ?
1070
+ ORDER BY c.id
1071
+ LIMIT ?
1072
+ """, (last_id, source, batch_size))
1073
+ else:
1074
+ cursor = conn.execute("""
1075
+ SELECT c.id, c.name FROM organizations c
1076
+ LEFT JOIN organization_embeddings e ON c.id = e.org_id
1077
+ WHERE e.org_id IS NULL AND c.id > ?
1078
+ ORDER BY c.id
1079
+ LIMIT ?
1080
+ """, (last_id, batch_size))
1081
+
1082
+ rows = cursor.fetchall()
1083
+ if not rows:
1084
+ break
1085
+
1086
+ for row in rows:
1087
+ yield (row[0], row[1])
1088
+ last_id = row[0]
1089
+
1090
+ def insert_embeddings_batch(
1091
+ self,
1092
+ org_ids: list[int],
1093
+ embeddings: np.ndarray,
1094
+ ) -> int:
1095
+ """
1096
+ Insert embeddings for existing organizations.
1097
+
1098
+ Args:
1099
+ org_ids: List of organization IDs
1100
+ embeddings: Matrix of embeddings (N x dim)
1101
+
1102
+ Returns:
1103
+ Number of embeddings inserted
1104
+ """
1105
+ conn = self._connect()
1106
+ count = 0
1107
+
1108
+ for org_id, embedding in zip(org_ids, embeddings):
1109
+ embedding_blob = embedding.astype(np.float32).tobytes()
1110
+ conn.execute("""
1111
+ INSERT OR REPLACE INTO organization_embeddings (org_id, embedding)
1112
+ VALUES (?, ?)
1113
+ """, (org_id, embedding_blob))
1114
+ count += 1
1115
+
1116
+ conn.commit()
1117
+ return count
1118
+
1119
+
1120
+ def get_person_database(db_path: Optional[str | Path] = None, embedding_dim: int = 768) -> "PersonDatabase":
1121
+ """
1122
+ Get a singleton PersonDatabase instance for the given path.
1123
+
1124
+ Args:
1125
+ db_path: Path to database file
1126
+ embedding_dim: Dimension of embeddings
1127
+
1128
+ Returns:
1129
+ Shared PersonDatabase instance
1130
+ """
1131
+ path_key = str(db_path or DEFAULT_DB_PATH)
1132
+ if path_key not in _person_database_instances:
1133
+ logger.debug(f"Creating new PersonDatabase instance for {path_key}")
1134
+ _person_database_instances[path_key] = PersonDatabase(db_path=db_path, embedding_dim=embedding_dim)
1135
+ return _person_database_instances[path_key]
1136
+
1137
+
1138
+ class PersonDatabase:
1139
+ """
1140
+ SQLite database with sqlite-vec for person vector search.
1141
+
1142
+ Uses hybrid text + vector search:
1143
+ 1. Text filtering with LIKE to reduce candidates
1144
+ 2. sqlite-vec for semantic similarity ranking
1145
+
1146
+ Stores people from sources like Wikidata with role/org context.
1147
+ """
1148
+
1149
+ def __init__(
1150
+ self,
1151
+ db_path: Optional[str | Path] = None,
1152
+ embedding_dim: int = 768, # Default for embeddinggemma-300m
1153
+ ):
1154
+ """
1155
+ Initialize the person database.
1156
+
1157
+ Args:
1158
+ db_path: Path to database file (creates if not exists)
1159
+ embedding_dim: Dimension of embeddings to store
1160
+ """
1161
+ self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
1162
+ self._embedding_dim = embedding_dim
1163
+ self._conn: Optional[sqlite3.Connection] = None
1164
+
1165
+ def _ensure_dir(self) -> None:
1166
+ """Ensure database directory exists."""
1167
+ self._db_path.parent.mkdir(parents=True, exist_ok=True)
1168
+
1169
+ def _connect(self) -> sqlite3.Connection:
1170
+ """Get or create database connection with sqlite-vec loaded."""
1171
+ if self._conn is not None:
1172
+ return self._conn
1173
+
1174
+ self._ensure_dir()
1175
+ self._conn = sqlite3.connect(str(self._db_path))
1176
+ self._conn.row_factory = sqlite3.Row
1177
+
1178
+ # Load sqlite-vec extension
1179
+ self._conn.enable_load_extension(True)
1180
+ sqlite_vec.load(self._conn)
1181
+ self._conn.enable_load_extension(False)
1182
+
1183
+ # Create tables
1184
+ self._create_tables()
1185
+
1186
+ return self._conn
1187
+
1188
+ def _create_tables(self) -> None:
1189
+ """Create database tables including sqlite-vec virtual table."""
1190
+ conn = self._conn
1191
+ assert conn is not None
1192
+
1193
+ # Main people records table
1194
+ conn.execute("""
1195
+ CREATE TABLE IF NOT EXISTS people (
1196
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
1197
+ name TEXT NOT NULL,
1198
+ name_normalized TEXT NOT NULL,
1199
+ source TEXT NOT NULL DEFAULT 'wikidata',
1200
+ source_id TEXT NOT NULL,
1201
+ country TEXT NOT NULL DEFAULT '',
1202
+ person_type TEXT NOT NULL DEFAULT 'unknown',
1203
+ known_for_role TEXT NOT NULL DEFAULT '',
1204
+ known_for_org TEXT NOT NULL DEFAULT '',
1205
+ record TEXT NOT NULL,
1206
+ UNIQUE(source, source_id)
1207
+ )
1208
+ """)
1209
+
1210
+ # Create indexes on main table
1211
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name ON people(name)")
1212
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_name_normalized ON people(name_normalized)")
1213
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source ON people(source)")
1214
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_source_id ON people(source, source_id)")
1215
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_people_known_for_org ON people(known_for_org)")
1216
+
1217
+ # Create sqlite-vec virtual table for embeddings
1218
+ conn.execute(f"""
1219
+ CREATE VIRTUAL TABLE IF NOT EXISTS person_embeddings USING vec0(
1220
+ person_id INTEGER PRIMARY KEY,
1221
+ embedding float[{self._embedding_dim}]
1222
+ )
1223
+ """)
1224
+
1225
+ conn.commit()
1226
+
1227
+ def close(self) -> None:
1228
+ """Close database connection."""
1229
+ if self._conn:
1230
+ self._conn.close()
1231
+ self._conn = None
1232
+
1233
+ def insert(self, record: PersonRecord, embedding: np.ndarray) -> int:
1234
+ """
1235
+ Insert a person record with its embedding.
1236
+
1237
+ Args:
1238
+ record: Person record to insert
1239
+ embedding: Embedding vector for the person name
1240
+
1241
+ Returns:
1242
+ Row ID of inserted record
1243
+ """
1244
+ conn = self._connect()
1245
+
1246
+ # Serialize record
1247
+ record_json = json.dumps(record.record)
1248
+ name_normalized = _normalize_person_name(record.name)
1249
+
1250
+ cursor = conn.execute("""
1251
+ INSERT OR REPLACE INTO people
1252
+ (name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
1253
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
1254
+ """, (
1255
+ record.name,
1256
+ name_normalized,
1257
+ record.source,
1258
+ record.source_id,
1259
+ record.country,
1260
+ record.person_type.value,
1261
+ record.known_for_role,
1262
+ record.known_for_org,
1263
+ record_json,
1264
+ ))
1265
+
1266
+ row_id = cursor.lastrowid
1267
+ assert row_id is not None
1268
+
1269
+ # Insert embedding into vec table
1270
+ embedding_blob = embedding.astype(np.float32).tobytes()
1271
+ conn.execute("""
1272
+ INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
1273
+ VALUES (?, ?)
1274
+ """, (row_id, embedding_blob))
1275
+
1276
+ conn.commit()
1277
+ return row_id
1278
+
1279
+ def insert_batch(
1280
+ self,
1281
+ records: list[PersonRecord],
1282
+ embeddings: np.ndarray,
1283
+ batch_size: int = 1000,
1284
+ ) -> int:
1285
+ """
1286
+ Insert multiple person records with embeddings.
1287
+
1288
+ Args:
1289
+ records: List of person records
1290
+ embeddings: Matrix of embeddings (N x dim)
1291
+ batch_size: Commit batch size
1292
+
1293
+ Returns:
1294
+ Number of records inserted
1295
+ """
1296
+ conn = self._connect()
1297
+ count = 0
1298
+
1299
+ for record, embedding in zip(records, embeddings):
1300
+ record_json = json.dumps(record.record)
1301
+ name_normalized = _normalize_person_name(record.name)
1302
+
1303
+ cursor = conn.execute("""
1304
+ INSERT OR REPLACE INTO people
1305
+ (name, name_normalized, source, source_id, country, person_type, known_for_role, known_for_org, record)
1306
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
1307
+ """, (
1308
+ record.name,
1309
+ name_normalized,
1310
+ record.source,
1311
+ record.source_id,
1312
+ record.country,
1313
+ record.person_type.value,
1314
+ record.known_for_role,
1315
+ record.known_for_org,
1316
+ record_json,
1317
+ ))
1318
+
1319
+ row_id = cursor.lastrowid
1320
+ assert row_id is not None
1321
+
1322
+ # Insert embedding
1323
+ embedding_blob = embedding.astype(np.float32).tobytes()
1324
+ conn.execute("""
1325
+ INSERT OR REPLACE INTO person_embeddings (person_id, embedding)
1326
+ VALUES (?, ?)
1327
+ """, (row_id, embedding_blob))
1328
+
1329
+ count += 1
1330
+
1331
+ if count % batch_size == 0:
1332
+ conn.commit()
1333
+ logger.info(f"Inserted {count} person records...")
1334
+
1335
+ conn.commit()
1336
+ return count
1337
+
1338
+ def search(
1339
+ self,
1340
+ query_embedding: np.ndarray,
1341
+ top_k: int = 20,
1342
+ query_text: Optional[str] = None,
1343
+ max_text_candidates: int = 5000,
1344
+ ) -> list[tuple[PersonRecord, float]]:
1345
+ """
1346
+ Search for similar people using hybrid text + vector search.
1347
+
1348
+ Two-stage approach:
1349
+ 1. If query_text provided, use SQL LIKE to find candidates containing search terms
1350
+ 2. Use sqlite-vec for vector similarity ranking on filtered candidates
1351
+
1352
+ Args:
1353
+ query_embedding: Query embedding vector
1354
+ top_k: Number of results to return
1355
+ query_text: Optional query text for text-based pre-filtering
1356
+ max_text_candidates: Max candidates to keep after text filtering
1357
+
1358
+ Returns:
1359
+ List of (PersonRecord, similarity_score) tuples
1360
+ """
1361
+ start = time.time()
1362
+ self._connect()
1363
+
1364
+ # Normalize query embedding
1365
+ query_norm = np.linalg.norm(query_embedding)
1366
+ if query_norm == 0:
1367
+ return []
1368
+ query_normalized = query_embedding / query_norm
1369
+ query_blob = query_normalized.astype(np.float32).tobytes()
1370
+
1371
+ # Stage 1: Text-based pre-filtering (if query_text provided)
1372
+ candidate_ids: Optional[set[int]] = None
1373
+ if query_text:
1374
+ query_normalized_text = _normalize_person_name(query_text)
1375
+ if query_normalized_text:
1376
+ candidate_ids = self._text_filter_candidates(
1377
+ query_normalized_text,
1378
+ max_candidates=max_text_candidates,
1379
+ )
1380
+ logger.info(f"Text filter: {len(candidate_ids)} candidates for '{query_text}'")
1381
+
1382
+ # Stage 2: Vector search
1383
+ if candidate_ids is not None and len(candidate_ids) == 0:
1384
+ # No text matches, return empty
1385
+ return []
1386
+
1387
+ if candidate_ids is not None:
1388
+ # Search within text-filtered candidates
1389
+ results = self._vector_search_filtered(
1390
+ query_blob, candidate_ids, top_k
1391
+ )
1392
+ else:
1393
+ # Full vector search
1394
+ results = self._vector_search_full(query_blob, top_k)
1395
+
1396
+ elapsed = time.time() - start
1397
+ logger.debug(f"Person search took {elapsed:.3f}s (results={len(results)})")
1398
+ return results
1399
+
1400
+ def _text_filter_candidates(
1401
+ self,
1402
+ query_normalized: str,
1403
+ max_candidates: int,
1404
+ ) -> set[int]:
1405
+ """
1406
+ Filter candidates using SQL LIKE for fast text matching.
1407
+
1408
+ Uses `name_normalized` column for consistent matching.
1409
+ """
1410
+ conn = self._conn
1411
+ assert conn is not None
1412
+
1413
+ # Extract search terms from the normalized query
1414
+ search_terms = _extract_search_terms(query_normalized)
1415
+ if not search_terms:
1416
+ return set()
1417
+
1418
+ logger.debug(f"Person text filter search terms: {search_terms}")
1419
+
1420
+ # Build OR clause for LIKE matching on any term
1421
+ like_clauses = []
1422
+ params: list = []
1423
+ for term in search_terms:
1424
+ like_clauses.append("name_normalized LIKE ?")
1425
+ params.append(f"%{term}%")
1426
+
1427
+ where_clause = " OR ".join(like_clauses)
1428
+
1429
+ query = f"""
1430
+ SELECT id FROM people
1431
+ WHERE {where_clause}
1432
+ LIMIT ?
1433
+ """
1434
+
1435
+ params.append(max_candidates)
1436
+
1437
+ cursor = conn.execute(query, params)
1438
+ return set(row["id"] for row in cursor)
1439
+
1440
+ def _vector_search_filtered(
1441
+ self,
1442
+ query_blob: bytes,
1443
+ candidate_ids: set[int],
1444
+ top_k: int,
1445
+ ) -> list[tuple[PersonRecord, float]]:
1446
+ """Vector search within a filtered set of candidates."""
1447
+ conn = self._conn
1448
+ assert conn is not None
1449
+
1450
+ if not candidate_ids:
1451
+ return []
1452
+
1453
+ # Build IN clause for candidate IDs
1454
+ placeholders = ",".join("?" * len(candidate_ids))
1455
+
1456
+ query = f"""
1457
+ SELECT
1458
+ e.person_id,
1459
+ vec_distance_cosine(e.embedding, ?) as distance
1460
+ FROM person_embeddings e
1461
+ WHERE e.person_id IN ({placeholders})
1462
+ ORDER BY distance
1463
+ LIMIT ?
1464
+ """
1465
+
1466
+ cursor = conn.execute(query, [query_blob] + list(candidate_ids) + [top_k])
1467
+
1468
+ results = []
1469
+ for row in cursor:
1470
+ person_id = row["person_id"]
1471
+ distance = row["distance"]
1472
+ # Convert cosine distance to similarity (1 - distance)
1473
+ similarity = 1.0 - distance
1474
+
1475
+ # Fetch full record
1476
+ record = self._get_record_by_id(person_id)
1477
+ if record:
1478
+ results.append((record, similarity))
1479
+
1480
+ return results
1481
+
1482
+ def _vector_search_full(
1483
+ self,
1484
+ query_blob: bytes,
1485
+ top_k: int,
1486
+ ) -> list[tuple[PersonRecord, float]]:
1487
+ """Full vector search without text pre-filtering."""
1488
+ conn = self._conn
1489
+ assert conn is not None
1490
+
1491
+ query = """
1492
+ SELECT
1493
+ person_id,
1494
+ vec_distance_cosine(embedding, ?) as distance
1495
+ FROM person_embeddings
1496
+ ORDER BY distance
1497
+ LIMIT ?
1498
+ """
1499
+ cursor = conn.execute(query, (query_blob, top_k))
1500
+
1501
+ results = []
1502
+ for row in cursor:
1503
+ person_id = row["person_id"]
1504
+ distance = row["distance"]
1505
+ similarity = 1.0 - distance
1506
+
1507
+ record = self._get_record_by_id(person_id)
1508
+ if record:
1509
+ results.append((record, similarity))
1510
+
1511
+ return results
1512
+
1513
+ def _get_record_by_id(self, person_id: int) -> Optional[PersonRecord]:
1514
+ """Get a person record by ID."""
1515
+ conn = self._conn
1516
+ assert conn is not None
1517
+
1518
+ cursor = conn.execute("""
1519
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1520
+ FROM people WHERE id = ?
1521
+ """, (person_id,))
1522
+
1523
+ row = cursor.fetchone()
1524
+ if row:
1525
+ return PersonRecord(
1526
+ name=row["name"],
1527
+ source=row["source"],
1528
+ source_id=row["source_id"],
1529
+ country=row["country"] or "",
1530
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1531
+ known_for_role=row["known_for_role"] or "",
1532
+ known_for_org=row["known_for_org"] or "",
1533
+ record=json.loads(row["record"]),
1534
+ )
1535
+ return None
1536
+
1537
+ def get_by_source_id(self, source: str, source_id: str) -> Optional[PersonRecord]:
1538
+ """Get a person record by source and source_id."""
1539
+ conn = self._connect()
1540
+
1541
+ cursor = conn.execute("""
1542
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1543
+ FROM people
1544
+ WHERE source = ? AND source_id = ?
1545
+ """, (source, source_id))
1546
+
1547
+ row = cursor.fetchone()
1548
+ if row:
1549
+ return PersonRecord(
1550
+ name=row["name"],
1551
+ source=row["source"],
1552
+ source_id=row["source_id"],
1553
+ country=row["country"] or "",
1554
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1555
+ known_for_role=row["known_for_role"] or "",
1556
+ known_for_org=row["known_for_org"] or "",
1557
+ record=json.loads(row["record"]),
1558
+ )
1559
+ return None
1560
+
1561
+ def get_stats(self) -> dict:
1562
+ """Get database statistics for people table."""
1563
+ conn = self._connect()
1564
+
1565
+ # Total count
1566
+ cursor = conn.execute("SELECT COUNT(*) FROM people")
1567
+ total = cursor.fetchone()[0]
1568
+
1569
+ # Count by person_type
1570
+ cursor = conn.execute("SELECT person_type, COUNT(*) as cnt FROM people GROUP BY person_type")
1571
+ by_type = {row["person_type"]: row["cnt"] for row in cursor}
1572
+
1573
+ # Count by source
1574
+ cursor = conn.execute("SELECT source, COUNT(*) as cnt FROM people GROUP BY source")
1575
+ by_source = {row["source"]: row["cnt"] for row in cursor}
1576
+
1577
+ return {
1578
+ "total_records": total,
1579
+ "by_type": by_type,
1580
+ "by_source": by_source,
1581
+ }
1582
+
1583
+ def iter_records(self, source: Optional[str] = None) -> Iterator[PersonRecord]:
1584
+ """Iterate over all person records, optionally filtered by source."""
1585
+ conn = self._connect()
1586
+
1587
+ if source:
1588
+ cursor = conn.execute("""
1589
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1590
+ FROM people
1591
+ WHERE source = ?
1592
+ """, (source,))
1593
+ else:
1594
+ cursor = conn.execute("""
1595
+ SELECT name, source, source_id, country, person_type, known_for_role, known_for_org, record
1596
+ FROM people
1597
+ """)
1598
+
1599
+ for row in cursor:
1600
+ yield PersonRecord(
1601
+ name=row["name"],
1602
+ source=row["source"],
1603
+ source_id=row["source_id"],
1604
+ country=row["country"] or "",
1605
+ person_type=PersonType(row["person_type"]) if row["person_type"] else PersonType.UNKNOWN,
1606
+ known_for_role=row["known_for_role"] or "",
1607
+ known_for_org=row["known_for_org"] or "",
1608
+ record=json.loads(row["record"]),
1609
+ )