glinker 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1220 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any, Set, Union
3
+ from pathlib import Path
4
+ import redis
5
+ import json
6
+ from elasticsearch import Elasticsearch
7
+ from elasticsearch.helpers import bulk as es_bulk
8
+ import psycopg2
9
+ from psycopg2.extras import RealDictCursor, execute_batch
10
+
11
+ from glinker.core.base import BaseComponent
12
+ from .models import L2Config, LayerConfig, FuzzyConfig, DatabaseRecord
13
+
14
+
15
+ class DatabaseLayer(ABC):
16
+ """Base class for all database layers"""
17
+
18
+ def __init__(self, config: LayerConfig):
19
+ self.config = config
20
+ self.priority = config.priority
21
+ self.ttl = config.ttl
22
+ self.write = config.write
23
+ self.cache_policy = config.cache_policy
24
+ self.field_mapping = config.field_mapping
25
+ self.fuzzy_config = config.fuzzy or FuzzyConfig()
26
+ self._setup()
27
+
28
+ @abstractmethod
29
+ def _setup(self):
30
+ """Initialize layer resources"""
31
+ pass
32
+
33
+ def normalize_query(self, query: str) -> str:
34
+ """Normalize query for search"""
35
+ return query.lower().strip()
36
+
37
+ @abstractmethod
38
+ def search(self, query: str) -> List[DatabaseRecord]:
39
+ """Exact search"""
40
+ pass
41
+
42
+ @abstractmethod
43
+ def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
44
+ """Fuzzy search"""
45
+ pass
46
+
47
+ def supports_fuzzy(self) -> bool:
48
+ """Check if layer supports fuzzy search"""
49
+ return self.fuzzy_config is not None
50
+
51
+ @abstractmethod
52
+ def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
53
+ """Write records to cache"""
54
+ pass
55
+
56
+ @abstractmethod
57
+ def is_available(self) -> bool:
58
+ """Check if layer is available"""
59
+ pass
60
+
61
+ @abstractmethod
62
+ def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
63
+ """Bulk load entities"""
64
+ pass
65
+
66
+ def clear(self):
67
+ """Clear all data in layer"""
68
+ pass
69
+
70
+ def count(self) -> int:
71
+ """Count entities in layer"""
72
+ return 0
73
+
74
+ def get_all_entities(self) -> List[DatabaseRecord]:
75
+ """Get all entities from layer (for precompute)"""
76
+ return []
77
+
78
+ def update_embeddings(
79
+ self,
80
+ entity_ids: List[str],
81
+ embeddings: List[List[float]],
82
+ model_id: str
83
+ ) -> int:
84
+ """Update embeddings for entities"""
85
+ return 0
86
+
87
+ def map_to_record(self, raw_data: Dict[str, Any]) -> DatabaseRecord:
88
+ """Map raw data to DatabaseRecord using field_mapping"""
89
+ mapped = {}
90
+ for standard_field, db_field in self.field_mapping.items():
91
+ if db_field in raw_data:
92
+ mapped[standard_field] = raw_data[db_field]
93
+
94
+ # Handle embedding fields directly (not in field_mapping)
95
+ if 'embedding' in raw_data:
96
+ mapped['embedding'] = raw_data['embedding']
97
+ if 'embedding_model_id' in raw_data:
98
+ mapped['embedding_model_id'] = raw_data['embedding_model_id']
99
+
100
+ mapped['source'] = self.config.type
101
+ return DatabaseRecord(**mapped)
102
+
103
+
104
+ class DictLayer(DatabaseLayer):
105
+ """Simple dict-based storage for small entity sets (<5000)"""
106
+
107
+ def _setup(self):
108
+ self._storage: Dict[str, DatabaseRecord] = {}
109
+ self._label_index: Dict[str, str] = {}
110
+ self._alias_index: Dict[str, Set[str]] = {}
111
+
112
+ def search(self, query: str) -> List[DatabaseRecord]:
113
+ """Fast O(1) exact search using indexes"""
114
+ query_key = self.normalize_query(query)
115
+ results = []
116
+ seen = set()
117
+
118
+ # Label lookup
119
+ if query_key in self._label_index:
120
+ eid = self._label_index[query_key]
121
+ results.append(self._storage[eid])
122
+ seen.add(eid)
123
+
124
+ # Alias lookup
125
+ if query_key in self._alias_index:
126
+ for eid in self._alias_index[query_key]:
127
+ if eid not in seen:
128
+ results.append(self._storage[eid])
129
+ seen.add(eid)
130
+
131
+ return results
132
+
133
+ def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
134
+ """Simple fuzzy search for small datasets (O(n) is fine for <5000 entities)"""
135
+ try:
136
+ from rapidfuzz import fuzz
137
+ except ImportError:
138
+ print("[WARN DictLayer] rapidfuzz not installed, fuzzy search disabled")
139
+ return []
140
+
141
+ query_key = self.normalize_query(query)
142
+ results = []
143
+
144
+ # Check prefix requirement
145
+ if self.fuzzy_config.prefix_length > 0:
146
+ prefix = query_key[:self.fuzzy_config.prefix_length]
147
+
148
+ for entity in self._storage.values():
149
+ # Check label
150
+ label_key = entity.label.lower()
151
+
152
+ if self.fuzzy_config.prefix_length > 0:
153
+ if not label_key.startswith(prefix):
154
+ continue
155
+
156
+ similarity = fuzz.ratio(query_key, label_key) / 100.0
157
+ if similarity >= self.fuzzy_config.min_similarity:
158
+ results.append((entity, similarity))
159
+ continue
160
+
161
+ # Check aliases
162
+ for alias in entity.aliases:
163
+ alias_key = alias.lower()
164
+ if self.fuzzy_config.prefix_length > 0:
165
+ if not alias_key.startswith(prefix):
166
+ continue
167
+
168
+ sim = fuzz.ratio(query_key, alias_key) / 100.0
169
+ if sim >= self.fuzzy_config.min_similarity:
170
+ results.append((entity, sim))
171
+ break
172
+
173
+ # Sort by similarity
174
+ results.sort(key=lambda x: x[1], reverse=True)
175
+ return [r[0] for r in results]
176
+
177
+ def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
178
+ """Write is same as load_bulk for dict layer"""
179
+ self.load_bulk(records, overwrite=True)
180
+
181
+ def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
182
+ """Bulk load entities with indexing"""
183
+ count = 0
184
+ for entity in entities:
185
+ entity_id = entity.entity_id
186
+
187
+ if not overwrite and entity_id in self._storage:
188
+ continue
189
+
190
+ # Store entity
191
+ self._storage[entity_id] = entity
192
+
193
+ # Index by label
194
+ label_key = entity.label.lower()
195
+ self._label_index[label_key] = entity_id
196
+
197
+ # Index by aliases
198
+ for alias in entity.aliases:
199
+ alias_key = alias.lower()
200
+ if alias_key not in self._alias_index:
201
+ self._alias_index[alias_key] = set()
202
+ self._alias_index[alias_key].add(entity_id)
203
+
204
+ count += 1
205
+ return count
206
+
207
+ def clear(self):
208
+ """Clear all data"""
209
+ self._storage.clear()
210
+ self._label_index.clear()
211
+ self._alias_index.clear()
212
+
213
+ def count(self) -> int:
214
+ """Count entities"""
215
+ return len(self._storage)
216
+
217
+ def get_all_entities(self) -> List[DatabaseRecord]:
218
+ """Get all entities from storage"""
219
+ return list(self._storage.values())
220
+
221
+ def update_embeddings(
222
+ self,
223
+ entity_ids: List[str],
224
+ embeddings: List[List[float]],
225
+ model_id: str
226
+ ) -> int:
227
+ """Update embeddings for entities"""
228
+ count = 0
229
+ for eid, emb in zip(entity_ids, embeddings):
230
+ if eid in self._storage:
231
+ self._storage[eid].embedding = emb
232
+ self._storage[eid].embedding_model_id = model_id
233
+ count += 1
234
+ return count
235
+
236
+ def is_available(self) -> bool:
237
+ """Dict layer is always available"""
238
+ return True
239
+
240
+
241
+ class RedisLayer(DatabaseLayer):
242
+ """Redis cache layer"""
243
+
244
+ def _setup(self):
245
+ self.client = redis.Redis(
246
+ host=self.config.config.get('host', 'localhost'),
247
+ port=self.config.config.get('port', 6379),
248
+ db=self.config.config.get('db', 0),
249
+ password=self.config.config.get('password'),
250
+ decode_responses=False
251
+ )
252
+
253
+ def supports_fuzzy(self) -> bool:
254
+ return False
255
+
256
+ def search(self, query: str) -> List[DatabaseRecord]:
257
+ query = self.normalize_query(query)
258
+ key = f"entity:{query}"
259
+
260
+ try:
261
+ data = self.client.get(key)
262
+ if data:
263
+ if isinstance(data, bytes):
264
+ data = data.decode('utf-8')
265
+
266
+ records_data = json.loads(data)
267
+
268
+ if isinstance(records_data, list):
269
+ results = []
270
+ for r in records_data:
271
+ if isinstance(r, dict):
272
+ r['source'] = 'redis'
273
+ results.append(DatabaseRecord(**r))
274
+ else:
275
+ results.append(r)
276
+ return results
277
+
278
+ elif isinstance(records_data, dict):
279
+ records_data['source'] = 'redis'
280
+ return [DatabaseRecord(**records_data)]
281
+
282
+ except Exception as e:
283
+ print(f"[ERROR Redis] Search error: {e}")
284
+
285
+ return []
286
+
287
+ def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
288
+ return []
289
+
290
+ def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
291
+ key = self.normalize_query(key)
292
+ cache_key = f"entity:{key}"
293
+
294
+ try:
295
+ data = json.dumps([r.dict() for r in records])
296
+ self.client.setex(cache_key, ttl, data)
297
+ except Exception as e:
298
+ print(f"[ERROR Redis] Write error: {e}")
299
+
300
+ def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
301
+ """Bulk load to Redis"""
302
+ count = 0
303
+ pipe = self.client.pipeline()
304
+
305
+ for entity in entities:
306
+ # Prepare data
307
+ entity_data = entity.dict()
308
+ data_json = json.dumps(entity_data)
309
+
310
+ # Store by label
311
+ label_key = f"entity:{entity.label.lower()}"
312
+ if overwrite or not self.client.exists(label_key):
313
+ pipe.setex(label_key, self.ttl, data_json)
314
+ count += 1
315
+
316
+ # Store by aliases
317
+ for alias in entity.aliases:
318
+ alias_key = f"entity:{alias.lower()}"
319
+ if overwrite or not self.client.exists(alias_key):
320
+ pipe.setex(alias_key, self.ttl, data_json)
321
+
322
+ # Execute in batches
323
+ if len(pipe) >= batch_size:
324
+ pipe.execute()
325
+ pipe = self.client.pipeline()
326
+
327
+ # Execute remaining
328
+ if len(pipe) > 0:
329
+ pipe.execute()
330
+
331
+ return count
332
+
333
+ def clear(self):
334
+ """Clear all entity keys"""
335
+ for key in self.client.scan_iter(match="entity:*"):
336
+ self.client.delete(key)
337
+
338
+ def count(self) -> int:
339
+ """Count entity keys"""
340
+ return sum(1 for _ in self.client.scan_iter(match="entity:*"))
341
+
342
+ def get_all_entities(self) -> List[DatabaseRecord]:
343
+ """Get all entities from Redis (scans all entity:* keys)"""
344
+ entities = []
345
+ seen_ids = set()
346
+
347
+ for key in self.client.scan_iter(match="entity:*"):
348
+ try:
349
+ data = self.client.get(key)
350
+ if data:
351
+ if isinstance(data, bytes):
352
+ data = data.decode('utf-8')
353
+ record_data = json.loads(data)
354
+
355
+ if isinstance(record_data, dict):
356
+ if record_data.get('entity_id') not in seen_ids:
357
+ record_data['source'] = 'redis'
358
+ entities.append(DatabaseRecord(**record_data))
359
+ seen_ids.add(record_data.get('entity_id'))
360
+ elif isinstance(record_data, list):
361
+ for r in record_data:
362
+ if r.get('entity_id') not in seen_ids:
363
+ r['source'] = 'redis'
364
+ entities.append(DatabaseRecord(**r))
365
+ seen_ids.add(r.get('entity_id'))
366
+ except Exception as e:
367
+ continue
368
+
369
+ return entities
370
+
371
+ def update_embeddings(
372
+ self,
373
+ entity_ids: List[str],
374
+ embeddings: List[List[float]],
375
+ model_id: str
376
+ ) -> int:
377
+ """Update embeddings in Redis entities"""
378
+ count = 0
379
+ id_to_embedding = dict(zip(entity_ids, embeddings))
380
+
381
+ for key in self.client.scan_iter(match="entity:*"):
382
+ try:
383
+ data = self.client.get(key)
384
+ if not data:
385
+ continue
386
+
387
+ if isinstance(data, bytes):
388
+ data = data.decode('utf-8')
389
+
390
+ record_data = json.loads(data)
391
+ updated = False
392
+
393
+ if isinstance(record_data, dict):
394
+ if record_data.get('entity_id') in id_to_embedding:
395
+ record_data['embedding'] = id_to_embedding[record_data['entity_id']]
396
+ record_data['embedding_model_id'] = model_id
397
+ updated = True
398
+ elif isinstance(record_data, list):
399
+ for r in record_data:
400
+ if r.get('entity_id') in id_to_embedding:
401
+ r['embedding'] = id_to_embedding[r['entity_id']]
402
+ r['embedding_model_id'] = model_id
403
+ updated = True
404
+
405
+ if updated:
406
+ self.client.setex(key, self.ttl, json.dumps(record_data))
407
+ count += 1
408
+
409
+ except Exception as e:
410
+ continue
411
+
412
+ return count
413
+
414
+ def is_available(self) -> bool:
415
+ try:
416
+ self.client.ping()
417
+ return True
418
+ except:
419
+ return False
420
+
421
+
422
+ class ElasticsearchLayer(DatabaseLayer):
423
+ """Elasticsearch full-text search layer"""
424
+
425
+ def _setup(self):
426
+ self.client = Elasticsearch(
427
+ self.config.config['hosts'],
428
+ api_key=self.config.config.get('api_key')
429
+ )
430
+ self.index_name = self.config.config['index_name']
431
+
432
+ def search(self, query: str) -> List[DatabaseRecord]:
433
+ query = self.normalize_query(query)
434
+
435
+ try:
436
+ body = {
437
+ "query": {
438
+ "multi_match": {
439
+ "query": query,
440
+ "fields": ["label^2", "aliases^1.5", "description"],
441
+ "type": "best_fields"
442
+ }
443
+ },
444
+ "size": 50
445
+ }
446
+ response = self.client.search(index=self.index_name, body=body)
447
+ return self._process_hits(response['hits']['hits'])
448
+ except Exception as e:
449
+ print(f"[ERROR ES] Search error: {e}")
450
+ return []
451
+
452
+ def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
453
+ query = self.normalize_query(query)
454
+ fuzzy_distance = self.fuzzy_config.max_distance
455
+
456
+ try:
457
+ body = {
458
+ "query": {
459
+ "multi_match": {
460
+ "query": query,
461
+ "fields": ["label^2", "aliases^1.5", "description"],
462
+ "fuzziness": fuzzy_distance,
463
+ "prefix_length": self.fuzzy_config.prefix_length,
464
+ "max_expansions": 50
465
+ }
466
+ },
467
+ "size": 50
468
+ }
469
+ response = self.client.search(index=self.index_name, body=body)
470
+ return self._process_hits(response['hits']['hits'])
471
+ except Exception as e:
472
+ print(f"[ERROR ES] Fuzzy error: {e}")
473
+ return []
474
+
475
+ def _process_hits(self, hits: List[Dict]) -> List[DatabaseRecord]:
476
+ records = []
477
+ for hit in hits:
478
+ source = hit['_source']
479
+ source['_id'] = hit['_id']
480
+ source['source'] = 'elasticsearch'
481
+ record = self.map_to_record(source)
482
+ records.append(record)
483
+ return records
484
+
485
+ def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
486
+ if not records:
487
+ return
488
+
489
+ try:
490
+ actions = []
491
+ for record in records:
492
+ doc = self._map_from_record(record)
493
+ actions.append({
494
+ "_index": self.index_name,
495
+ "_id": record.entity_id,
496
+ "_source": doc
497
+ })
498
+
499
+ if actions:
500
+ es_bulk(self.client, actions)
501
+ self.client.indices.refresh(index=self.index_name)
502
+ except Exception as e:
503
+ print(f"[ERROR ES] Write error: {e}")
504
+
505
+ def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
506
+ """Bulk load to Elasticsearch"""
507
+ actions = []
508
+ for entity in entities:
509
+ doc = self._map_from_record(entity)
510
+
511
+ action = {
512
+ '_index': self.index_name,
513
+ '_id': entity.entity_id,
514
+ '_source': doc
515
+ }
516
+
517
+ if overwrite:
518
+ action['_op_type'] = 'index'
519
+ else:
520
+ action['_op_type'] = 'create'
521
+
522
+ actions.append(action)
523
+
524
+ success, failed = es_bulk(
525
+ self.client,
526
+ actions,
527
+ raise_on_error=False,
528
+ chunk_size=batch_size
529
+ )
530
+
531
+ self.client.indices.refresh(index=self.index_name)
532
+ return success
533
+
534
+ def _map_from_record(self, record: DatabaseRecord) -> dict:
535
+ """Map DatabaseRecord -> ES document using field_mapping"""
536
+ reverse_mapping = {v: k for k, v in self.field_mapping.items()}
537
+
538
+ doc = {}
539
+ for standard_field, value in record.dict().items():
540
+ if standard_field == 'source':
541
+ continue
542
+
543
+ es_field = reverse_mapping.get(standard_field, standard_field)
544
+ doc[es_field] = value
545
+
546
+ return doc
547
+
548
+ def clear(self):
549
+ """Delete all documents in index"""
550
+ try:
551
+ self.client.delete_by_query(
552
+ index=self.index_name,
553
+ body={"query": {"match_all": {}}}
554
+ )
555
+ self.client.indices.refresh(index=self.index_name)
556
+ except Exception as e:
557
+ print(f"[ERROR ES] Clear error: {e}")
558
+
559
+ def count(self) -> int:
560
+ """Count documents in index"""
561
+ try:
562
+ result = self.client.count(index=self.index_name)
563
+ return result['count']
564
+ except:
565
+ return 0
566
+
567
+ def get_all_entities(self) -> List[DatabaseRecord]:
568
+ """Get all entities from Elasticsearch using scroll"""
569
+ entities = []
570
+
571
+ try:
572
+ # Use scroll API for large datasets
573
+ response = self.client.search(
574
+ index=self.index_name,
575
+ body={"query": {"match_all": {}}, "size": 1000},
576
+ scroll='2m'
577
+ )
578
+
579
+ scroll_id = response['_scroll_id']
580
+ hits = response['hits']['hits']
581
+
582
+ while hits:
583
+ entities.extend(self._process_hits(hits))
584
+
585
+ response = self.client.scroll(scroll_id=scroll_id, scroll='2m')
586
+ scroll_id = response['_scroll_id']
587
+ hits = response['hits']['hits']
588
+
589
+ # Clear scroll
590
+ self.client.clear_scroll(scroll_id=scroll_id)
591
+
592
+ except Exception as e:
593
+ print(f"[ERROR ES] get_all_entities error: {e}")
594
+
595
+ return entities
596
+
597
+ def update_embeddings(
598
+ self,
599
+ entity_ids: List[str],
600
+ embeddings: List[List[float]],
601
+ model_id: str
602
+ ) -> int:
603
+ """Update embeddings in Elasticsearch"""
604
+ try:
605
+ actions = []
606
+ for eid, emb in zip(entity_ids, embeddings):
607
+ actions.append({
608
+ "_op_type": "update",
609
+ "_index": self.index_name,
610
+ "_id": eid,
611
+ "doc": {
612
+ "embedding": emb,
613
+ "embedding_model_id": model_id
614
+ }
615
+ })
616
+
617
+ success, failed = es_bulk(
618
+ self.client,
619
+ actions,
620
+ raise_on_error=False,
621
+ chunk_size=500
622
+ )
623
+
624
+ self.client.indices.refresh(index=self.index_name)
625
+ return success
626
+
627
+ except Exception as e:
628
+ print(f"[ERROR ES] update_embeddings error: {e}")
629
+ return 0
630
+
631
+ def is_available(self) -> bool:
632
+ try:
633
+ return self.client.ping()
634
+ except:
635
+ return False
636
+
637
+
638
+ class PostgresLayer(DatabaseLayer):
639
+ """PostgreSQL database layer"""
640
+
641
+ def _setup(self):
642
+ self.conn = psycopg2.connect(
643
+ host=self.config.config['host'],
644
+ port=self.config.config.get('port', 5432),
645
+ database=self.config.config['database'],
646
+ user=self.config.config['user'],
647
+ password=self.config.config['password']
648
+ )
649
+
650
+ cursor = self.conn.cursor()
651
+ try:
652
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
653
+ self.conn.commit()
654
+ except Exception as e:
655
+ print(f"[WARN Postgres] pg_trgm: {e}")
656
+ finally:
657
+ cursor.close()
658
+
659
+ def search(self, query: str) -> List[DatabaseRecord]:
660
+ query = self.normalize_query(query)
661
+
662
+ try:
663
+ cursor = self.conn.cursor(cursor_factory=RealDictCursor)
664
+ sql = """
665
+ SELECT
666
+ e.entity_id,
667
+ e.label,
668
+ e.description,
669
+ e.entity_type,
670
+ e.popularity,
671
+ COALESCE(array_agg(a.alias) FILTER (WHERE a.alias IS NOT NULL), ARRAY[]::text[]) as aliases
672
+ FROM entities e
673
+ LEFT JOIN aliases a ON e.entity_id = a.entity_id
674
+ WHERE LOWER(e.label) LIKE %s
675
+ OR EXISTS (
676
+ SELECT 1 FROM aliases a2
677
+ WHERE a2.entity_id = e.entity_id
678
+ AND LOWER(a2.alias) LIKE %s
679
+ )
680
+ GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity
681
+ ORDER BY e.popularity DESC
682
+ LIMIT 50
683
+ """
684
+ cursor.execute(sql, (f"%{query}%", f"%{query}%"))
685
+ records = self._process_rows(cursor.fetchall())
686
+ cursor.close()
687
+ return records
688
+ except Exception as e:
689
+ print(f"[ERROR Postgres] Search error: {e}")
690
+ return []
691
+
692
+ def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
693
+ query = self.normalize_query(query)
694
+ threshold = self.fuzzy_config.min_similarity
695
+
696
+ try:
697
+ cursor = self.conn.cursor(cursor_factory=RealDictCursor)
698
+ sql = """
699
+ SELECT
700
+ e.entity_id,
701
+ e.label,
702
+ e.description,
703
+ e.entity_type,
704
+ e.popularity,
705
+ COALESCE(array_agg(a.alias) FILTER (WHERE a.alias IS NOT NULL), ARRAY[]::text[]) as aliases,
706
+ similarity(LOWER(e.label), %s) AS sim_score
707
+ FROM entities e
708
+ LEFT JOIN aliases a ON e.entity_id = a.entity_id
709
+ WHERE similarity(LOWER(e.label), %s) >= %s
710
+ GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity
711
+ ORDER BY sim_score DESC, e.popularity DESC
712
+ LIMIT 50
713
+ """
714
+ cursor.execute(sql, (query, query, threshold))
715
+ records = self._process_rows(cursor.fetchall())
716
+ cursor.close()
717
+ return records
718
+ except Exception as e:
719
+ print(f"[ERROR Postgres] Fuzzy error: {e}")
720
+ return self.search(query)
721
+
722
+ def _process_rows(self, rows: List[Dict]) -> List[DatabaseRecord]:
723
+ records = []
724
+ for row in rows:
725
+ row_dict = dict(row)
726
+ row_dict['source'] = 'postgres'
727
+ record = self.map_to_record(row_dict)
728
+ records.append(record)
729
+ return records
730
+
731
+ def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
732
+ pass
733
+
734
+ def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
735
+ """Bulk load to Postgres"""
736
+ cursor = self.conn.cursor()
737
+
738
+ try:
739
+ # Prepare entity data
740
+ entity_values = [
741
+ (e.entity_id, e.label, e.description, e.entity_type, e.popularity)
742
+ for e in entities
743
+ ]
744
+
745
+ # Insert entities
746
+ if overwrite:
747
+ entity_query = """
748
+ INSERT INTO entities (entity_id, label, description, entity_type, popularity)
749
+ VALUES (%s, %s, %s, %s, %s)
750
+ ON CONFLICT (entity_id) DO UPDATE SET
751
+ label = EXCLUDED.label,
752
+ description = EXCLUDED.description,
753
+ entity_type = EXCLUDED.entity_type,
754
+ popularity = EXCLUDED.popularity
755
+ """
756
+ else:
757
+ entity_query = """
758
+ INSERT INTO entities (entity_id, label, description, entity_type, popularity)
759
+ VALUES (%s, %s, %s, %s, %s)
760
+ ON CONFLICT (entity_id) DO NOTHING
761
+ """
762
+
763
+ execute_batch(cursor, entity_query, entity_values, page_size=batch_size)
764
+
765
+ # Prepare alias data
766
+ alias_values = []
767
+ for entity in entities:
768
+ for alias in entity.aliases:
769
+ alias_values.append((entity.entity_id, alias))
770
+
771
+ # Delete old aliases if overwrite
772
+ if overwrite and alias_values:
773
+ entity_ids = [e.entity_id for e in entities]
774
+ cursor.execute(
775
+ "DELETE FROM aliases WHERE entity_id = ANY(%s)",
776
+ (entity_ids,)
777
+ )
778
+
779
+ # Insert aliases
780
+ if alias_values:
781
+ execute_batch(
782
+ cursor,
783
+ "INSERT INTO aliases (entity_id, alias) VALUES (%s, %s) ON CONFLICT DO NOTHING",
784
+ alias_values,
785
+ page_size=batch_size
786
+ )
787
+
788
+ self.conn.commit()
789
+ return len(entities)
790
+
791
+ except Exception as e:
792
+ self.conn.rollback()
793
+ print(f"[ERROR Postgres] Load bulk failed: {e}")
794
+ raise
795
+ finally:
796
+ cursor.close()
797
+
798
+ def clear(self):
799
+ """Clear all data"""
800
+ cursor = self.conn.cursor()
801
+ try:
802
+ cursor.execute("TRUNCATE entities, aliases CASCADE")
803
+ self.conn.commit()
804
+ except Exception as e:
805
+ self.conn.rollback()
806
+ print(f"[ERROR Postgres] Clear error: {e}")
807
+ finally:
808
+ cursor.close()
809
+
810
+ def count(self) -> int:
811
+ """Count entities"""
812
+ cursor = self.conn.cursor()
813
+ try:
814
+ cursor.execute("SELECT COUNT(*) FROM entities")
815
+ return cursor.fetchone()[0]
816
+ except:
817
+ return 0
818
+ finally:
819
+ cursor.close()
820
+
821
+ def get_all_entities(self) -> List[DatabaseRecord]:
822
+ """Get all entities from PostgreSQL"""
823
+ entities = []
824
+
825
+ try:
826
+ cursor = self.conn.cursor(cursor_factory=RealDictCursor)
827
+ sql = """
828
+ SELECT
829
+ e.entity_id,
830
+ e.label,
831
+ e.description,
832
+ e.entity_type,
833
+ e.popularity,
834
+ e.embedding,
835
+ e.embedding_model_id,
836
+ COALESCE(array_agg(a.alias) FILTER (WHERE a.alias IS NOT NULL), ARRAY[]::text[]) as aliases
837
+ FROM entities e
838
+ LEFT JOIN aliases a ON e.entity_id = a.entity_id
839
+ GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity, e.embedding, e.embedding_model_id
840
+ """
841
+ cursor.execute(sql)
842
+
843
+ for row in cursor.fetchall():
844
+ row_dict = dict(row)
845
+ row_dict['source'] = 'postgres'
846
+
847
+ # Deserialize embedding from bytes if needed
848
+ if row_dict.get('embedding'):
849
+ import pickle
850
+ if isinstance(row_dict['embedding'], (bytes, memoryview)):
851
+ row_dict['embedding'] = pickle.loads(bytes(row_dict['embedding']))
852
+
853
+ record = self.map_to_record(row_dict)
854
+ entities.append(record)
855
+
856
+ cursor.close()
857
+
858
+ except Exception as e:
859
+ print(f"[ERROR Postgres] get_all_entities error: {e}")
860
+
861
+ return entities
862
+
863
+ def update_embeddings(
864
+ self,
865
+ entity_ids: List[str],
866
+ embeddings: List[List[float]],
867
+ model_id: str
868
+ ) -> int:
869
+ """Update embeddings in PostgreSQL"""
870
+ cursor = self.conn.cursor()
871
+
872
+ try:
873
+ import pickle
874
+
875
+ # Prepare batch data
876
+ batch_data = []
877
+ for eid, emb in zip(entity_ids, embeddings):
878
+ emb_bytes = pickle.dumps(emb)
879
+ batch_data.append((emb_bytes, model_id, eid))
880
+
881
+ # Batch update
882
+ execute_batch(
883
+ cursor,
884
+ "UPDATE entities SET embedding = %s, embedding_model_id = %s WHERE entity_id = %s",
885
+ batch_data,
886
+ page_size=500
887
+ )
888
+
889
+ self.conn.commit()
890
+ return len(batch_data)
891
+
892
+ except Exception as e:
893
+ self.conn.rollback()
894
+ print(f"[ERROR Postgres] update_embeddings error: {e}")
895
+ return 0
896
+ finally:
897
+ cursor.close()
898
+
899
+ def is_available(self) -> bool:
900
+ try:
901
+ cursor = self.conn.cursor()
902
+ cursor.execute("SELECT 1")
903
+ cursor.close()
904
+ return True
905
+ except:
906
+ return False
907
+
908
+
909
+ class DatabaseChainComponent(BaseComponent[L2Config]):
910
+ """Multi-layer database chain component"""
911
+
912
+ def _setup(self):
913
+ self.layers: List[DatabaseLayer] = []
914
+
915
+ for layer_config in self.config.layers:
916
+ if isinstance(layer_config, dict):
917
+ layer_config = LayerConfig(**layer_config)
918
+
919
+ if layer_config.type == "dict":
920
+ layer = DictLayer(layer_config)
921
+ elif layer_config.type == "redis":
922
+ layer = RedisLayer(layer_config)
923
+ elif layer_config.type == "elasticsearch":
924
+ layer = ElasticsearchLayer(layer_config)
925
+ elif layer_config.type == "postgres":
926
+ layer = PostgresLayer(layer_config)
927
+ else:
928
+ raise ValueError(f"Unknown layer type: {layer_config.type}")
929
+
930
+ self.layers.append(layer)
931
+
932
+ self.layers.sort(key=lambda x: x.priority, reverse=True) # Higher priority checked first
933
+
934
+ def get_available_methods(self) -> List[str]:
935
+ return [
936
+ "search",
937
+ "filter_by_popularity",
938
+ "deduplicate_candidates",
939
+ "limit_candidates",
940
+ "sort_by_popularity"
941
+ ]
942
+
943
+ def search(self, mention: str) -> List[DatabaseRecord]:
944
+ """Search through layers with fallback"""
945
+ found_in_layer = None
946
+ results = []
947
+
948
+ for layer in self.layers:
949
+ if not layer.is_available():
950
+ continue
951
+
952
+ layer_results = []
953
+
954
+ for mode in layer.config.search_mode:
955
+ if mode == "exact":
956
+ layer_results.extend(layer.search(mention))
957
+ elif mode == "fuzzy":
958
+ if layer.supports_fuzzy():
959
+ layer_results.extend(layer.search_fuzzy(mention))
960
+
961
+ if layer_results:
962
+ layer_results = self.deduplicate_candidates(layer_results)
963
+ results = layer_results
964
+ found_in_layer = layer
965
+ break
966
+
967
+ if results and found_in_layer:
968
+ self._cache_write(mention, results, found_in_layer)
969
+
970
+ return results
971
+
972
+ def _cache_write(self, query: str, results: List[DatabaseRecord], source_layer: DatabaseLayer):
973
+ """Write results to upper layers (higher priority = checked earlier)"""
974
+ for layer in self.layers:
975
+ # Skip source layer and all layers with lower priority
976
+ if layer.priority <= source_layer.priority:
977
+ continue
978
+ if not layer.write:
979
+ continue
980
+
981
+ if layer.cache_policy == "always":
982
+ layer.write_cache(query, results, layer.ttl)
983
+ elif layer.cache_policy == "miss":
984
+ existing = layer.search(query)
985
+ if not existing:
986
+ layer.write_cache(query, results, layer.ttl)
987
+ elif layer.cache_policy == "hit":
988
+ existing = layer.search(query)
989
+ if existing:
990
+ layer.write_cache(query, results, layer.ttl)
991
+
992
+ def filter_by_popularity(self, records: List[DatabaseRecord], min_popularity: int = None) -> List[DatabaseRecord]:
993
+ threshold = min_popularity if min_popularity is not None else self.config.min_popularity
994
+ return [r for r in records if r.popularity >= threshold]
995
+
996
+ def deduplicate_candidates(self, records: List[DatabaseRecord]) -> List[DatabaseRecord]:
997
+ seen = set()
998
+ unique = []
999
+ for record in records:
1000
+ if record.entity_id not in seen:
1001
+ unique.append(record)
1002
+ seen.add(record.entity_id)
1003
+ return unique
1004
+
1005
+ def limit_candidates(self, records: List[DatabaseRecord], limit: int = None) -> List[DatabaseRecord]:
1006
+ max_cands = limit if limit is not None else self.config.max_candidates
1007
+ return records[:max_cands]
1008
+
1009
+ def sort_by_popularity(self, records: List[DatabaseRecord]) -> List[DatabaseRecord]:
1010
+ return sorted(records, key=lambda x: x.popularity, reverse=True)
1011
+
1012
+ def load_entities(
1013
+ self,
1014
+ source: Union[str, Path, List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
1015
+ target_layers: List[str] = None,
1016
+ batch_size: int = 1000,
1017
+ overwrite: bool = False
1018
+ ) -> Dict[str, int]:
1019
+ """
1020
+ Load entities from JSONL file, list of dicts, or dict.
1021
+
1022
+ Accepts:
1023
+ - ``str`` / ``Path``: path to a JSONL file (one DatabaseRecord per line)
1024
+ - ``list[dict]``: each dict has at least ``entity_id`` and ``label``
1025
+ - ``dict[str, dict]``: keys are entity_ids, values are entity data
1026
+
1027
+ Args:
1028
+ source: entity data (file path, list, or dict)
1029
+ target_layers: ['dict', 'redis', 'elasticsearch', 'postgres'] or None (all writable)
1030
+ batch_size: batch size for bulk operations
1031
+ overwrite: overwrite existing entities
1032
+
1033
+ Returns:
1034
+ {'redis': 1500, 'elasticsearch': 1500}
1035
+ """
1036
+ if isinstance(source, (str, Path)):
1037
+ entities = self._parse_jsonl(source)
1038
+ elif isinstance(source, dict):
1039
+ entities = [
1040
+ DatabaseRecord(entity_id=eid, **data)
1041
+ if "entity_id" not in data
1042
+ else DatabaseRecord(**data)
1043
+ for eid, data in source.items()
1044
+ ]
1045
+ elif isinstance(source, list):
1046
+ entities = [
1047
+ DatabaseRecord(**e) if isinstance(e, dict) else e
1048
+ for e in source
1049
+ ]
1050
+ else:
1051
+ raise TypeError(f"Expected file path, list, or dict; got {type(source)}")
1052
+
1053
+ return self.load_records(
1054
+ entities,
1055
+ target_layers=target_layers,
1056
+ batch_size=batch_size,
1057
+ overwrite=overwrite,
1058
+ )
1059
+
1060
+ def load_records(
1061
+ self,
1062
+ entities: List[DatabaseRecord],
1063
+ target_layers: List[str] = None,
1064
+ batch_size: int = 1000,
1065
+ overwrite: bool = False,
1066
+ ) -> Dict[str, int]:
1067
+ """
1068
+ Load pre-built DatabaseRecord objects into layers.
1069
+
1070
+ Args:
1071
+ entities: list of DatabaseRecord instances
1072
+ target_layers: layer types to target (None = all writable)
1073
+ batch_size: batch size for bulk operations
1074
+ overwrite: overwrite existing entities
1075
+
1076
+ Returns:
1077
+ Dict of layer_type -> count loaded
1078
+ """
1079
+ # Determine target layers
1080
+ if target_layers is None:
1081
+ target_layers = [l.config.type for l in self.layers if l.write]
1082
+
1083
+ # Load to each layer
1084
+ results = {}
1085
+ for layer in self.layers:
1086
+ if layer.config.type not in target_layers:
1087
+ continue
1088
+
1089
+ if not layer.is_available():
1090
+ continue
1091
+
1092
+ count = layer.load_bulk(entities, overwrite=overwrite, batch_size=batch_size)
1093
+ results[layer.config.type] = count
1094
+
1095
+ return results
1096
+
1097
+ @staticmethod
1098
+ def _parse_jsonl(filepath: Union[str, Path]) -> List[DatabaseRecord]:
1099
+ """Parse JSONL file into DatabaseRecord list."""
1100
+ entities = []
1101
+ with open(filepath, 'r', encoding='utf-8') as f:
1102
+ for line_num, line in enumerate(f, 1):
1103
+ line = line.strip()
1104
+ if not line:
1105
+ continue
1106
+ try:
1107
+ data = json.loads(line)
1108
+ entities.append(DatabaseRecord(**data))
1109
+ except Exception as e:
1110
+ print(f"[WARN] Line {line_num} parse error: {e}")
1111
+ continue
1112
+ return entities
1113
+
1114
+ def clear_layers(self, layer_names: List[str] = None):
1115
+ """Clear all entities in specified layers"""
1116
+ for layer in self.layers:
1117
+ if layer_names and layer.config.type not in layer_names:
1118
+ continue
1119
+
1120
+ print(f"Clearing {layer.config.type}...")
1121
+ layer.clear()
1122
+ print(f"✓ Cleared")
1123
+
1124
+ def get_all_entities(self) -> List[DatabaseRecord]:
1125
+ """Get all entities from all available layers (deduplicated)"""
1126
+ all_entities = []
1127
+ for layer in self.layers:
1128
+ if layer.is_available():
1129
+ all_entities.extend(layer.get_all_entities())
1130
+ return self.deduplicate_candidates(all_entities)
1131
+
1132
+ def count_entities(self) -> Dict[str, int]:
1133
+ """Count entities in each layer"""
1134
+ counts = {}
1135
+ for layer in self.layers:
1136
+ counts[layer.config.type] = layer.count()
1137
+ return counts
1138
+
1139
+ def precompute_embeddings(
1140
+ self,
1141
+ encoder_fn,
1142
+ template: str,
1143
+ model_id: str,
1144
+ target_layers: List[str] = None,
1145
+ batch_size: int = 32
1146
+ ) -> Dict[str, int]:
1147
+ """
1148
+ Precompute embeddings for all entities in specified layers.
1149
+
1150
+ Args:
1151
+ encoder_fn: Callable that takes List[str] and returns embeddings tensor
1152
+ template: Template string for formatting labels (e.g., "{label}: {description}")
1153
+ model_id: Model identifier to store with embeddings
1154
+ target_layers: Layer types to update (None = all)
1155
+ batch_size: Batch size for encoding
1156
+
1157
+ Returns:
1158
+ Dict with count of updated entities per layer
1159
+ """
1160
+ from tqdm import tqdm
1161
+
1162
+ results = {}
1163
+
1164
+ for layer in self.layers:
1165
+ if target_layers and layer.config.type not in target_layers:
1166
+ continue
1167
+
1168
+ if not layer.is_available():
1169
+ print(f"[WARN] {layer.config.type} unavailable, skipping")
1170
+ continue
1171
+
1172
+ print(f"\nPrecomputing embeddings for {layer.config.type}...")
1173
+
1174
+ # Get all entities
1175
+ entities = layer.get_all_entities()
1176
+ if not entities:
1177
+ print(f" No entities found in {layer.config.type}")
1178
+ continue
1179
+
1180
+ print(f" Found {len(entities)} entities")
1181
+
1182
+ # Format labels using template
1183
+ labels = []
1184
+ entity_ids = []
1185
+ for entity in entities:
1186
+ try:
1187
+ formatted = template.format(**entity.dict())
1188
+ labels.append(formatted)
1189
+ entity_ids.append(entity.entity_id)
1190
+ except KeyError as e:
1191
+ print(f" [WARN] Template error for {entity.entity_id}: {e}")
1192
+ continue
1193
+
1194
+ # Encode in batches
1195
+ all_embeddings = []
1196
+ for i in tqdm(range(0, len(labels), batch_size), desc="Encoding"):
1197
+ batch_labels = labels[i:i + batch_size]
1198
+ batch_embeddings = encoder_fn(batch_labels)
1199
+
1200
+ # Convert to list if tensor
1201
+ if hasattr(batch_embeddings, 'tolist'):
1202
+ batch_embeddings = batch_embeddings.tolist()
1203
+ elif hasattr(batch_embeddings, 'cpu'):
1204
+ batch_embeddings = batch_embeddings.cpu().numpy().tolist()
1205
+
1206
+ all_embeddings.extend(batch_embeddings)
1207
+
1208
+ # Update layer
1209
+ updated = layer.update_embeddings(entity_ids, all_embeddings, model_id)
1210
+ results[layer.config.type] = updated
1211
+ print(f" Updated {updated} entities with embeddings")
1212
+
1213
+ return results
1214
+
1215
+ def get_layer(self, layer_type: str) -> DatabaseLayer:
1216
+ """Get layer by type"""
1217
+ for layer in self.layers:
1218
+ if layer.config.type == layer_type:
1219
+ return layer
1220
+ return None