glinker 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glinker/__init__.py +54 -0
- glinker/core/__init__.py +56 -0
- glinker/core/base.py +103 -0
- glinker/core/builders.py +547 -0
- glinker/core/dag.py +898 -0
- glinker/core/factory.py +261 -0
- glinker/core/registry.py +31 -0
- glinker/l0/__init__.py +21 -0
- glinker/l0/component.py +472 -0
- glinker/l0/models.py +90 -0
- glinker/l0/processor.py +108 -0
- glinker/l1/__init__.py +15 -0
- glinker/l1/component.py +284 -0
- glinker/l1/models.py +47 -0
- glinker/l1/processor.py +152 -0
- glinker/l2/__init__.py +19 -0
- glinker/l2/component.py +1220 -0
- glinker/l2/models.py +99 -0
- glinker/l2/processor.py +170 -0
- glinker/l3/__init__.py +12 -0
- glinker/l3/component.py +184 -0
- glinker/l3/models.py +48 -0
- glinker/l3/processor.py +350 -0
- glinker/l4/__init__.py +9 -0
- glinker/l4/component.py +121 -0
- glinker/l4/models.py +21 -0
- glinker/l4/processor.py +156 -0
- glinker/py.typed +1 -0
- glinker-0.1.0.dist-info/METADATA +994 -0
- glinker-0.1.0.dist-info/RECORD +33 -0
- glinker-0.1.0.dist-info/WHEEL +5 -0
- glinker-0.1.0.dist-info/licenses/LICENSE +201 -0
- glinker-0.1.0.dist-info/top_level.txt +1 -0
glinker/l2/component.py
ADDED
|
@@ -0,0 +1,1220 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Dict, Any, Set, Union
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import redis
|
|
5
|
+
import json
|
|
6
|
+
from elasticsearch import Elasticsearch
|
|
7
|
+
from elasticsearch.helpers import bulk as es_bulk
|
|
8
|
+
import psycopg2
|
|
9
|
+
from psycopg2.extras import RealDictCursor, execute_batch
|
|
10
|
+
|
|
11
|
+
from glinker.core.base import BaseComponent
|
|
12
|
+
from .models import L2Config, LayerConfig, FuzzyConfig, DatabaseRecord
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatabaseLayer(ABC):
|
|
16
|
+
"""Base class for all database layers"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: LayerConfig):
|
|
19
|
+
self.config = config
|
|
20
|
+
self.priority = config.priority
|
|
21
|
+
self.ttl = config.ttl
|
|
22
|
+
self.write = config.write
|
|
23
|
+
self.cache_policy = config.cache_policy
|
|
24
|
+
self.field_mapping = config.field_mapping
|
|
25
|
+
self.fuzzy_config = config.fuzzy or FuzzyConfig()
|
|
26
|
+
self._setup()
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def _setup(self):
|
|
30
|
+
"""Initialize layer resources"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def normalize_query(self, query: str) -> str:
|
|
34
|
+
"""Normalize query for search"""
|
|
35
|
+
return query.lower().strip()
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def search(self, query: str) -> List[DatabaseRecord]:
|
|
39
|
+
"""Exact search"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
|
|
44
|
+
"""Fuzzy search"""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def supports_fuzzy(self) -> bool:
|
|
48
|
+
"""Check if layer supports fuzzy search"""
|
|
49
|
+
return self.fuzzy_config is not None
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
|
|
53
|
+
"""Write records to cache"""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def is_available(self) -> bool:
|
|
58
|
+
"""Check if layer is available"""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
|
|
63
|
+
"""Bulk load entities"""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
def clear(self):
|
|
67
|
+
"""Clear all data in layer"""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def count(self) -> int:
|
|
71
|
+
"""Count entities in layer"""
|
|
72
|
+
return 0
|
|
73
|
+
|
|
74
|
+
def get_all_entities(self) -> List[DatabaseRecord]:
|
|
75
|
+
"""Get all entities from layer (for precompute)"""
|
|
76
|
+
return []
|
|
77
|
+
|
|
78
|
+
def update_embeddings(
|
|
79
|
+
self,
|
|
80
|
+
entity_ids: List[str],
|
|
81
|
+
embeddings: List[List[float]],
|
|
82
|
+
model_id: str
|
|
83
|
+
) -> int:
|
|
84
|
+
"""Update embeddings for entities"""
|
|
85
|
+
return 0
|
|
86
|
+
|
|
87
|
+
def map_to_record(self, raw_data: Dict[str, Any]) -> DatabaseRecord:
|
|
88
|
+
"""Map raw data to DatabaseRecord using field_mapping"""
|
|
89
|
+
mapped = {}
|
|
90
|
+
for standard_field, db_field in self.field_mapping.items():
|
|
91
|
+
if db_field in raw_data:
|
|
92
|
+
mapped[standard_field] = raw_data[db_field]
|
|
93
|
+
|
|
94
|
+
# Handle embedding fields directly (not in field_mapping)
|
|
95
|
+
if 'embedding' in raw_data:
|
|
96
|
+
mapped['embedding'] = raw_data['embedding']
|
|
97
|
+
if 'embedding_model_id' in raw_data:
|
|
98
|
+
mapped['embedding_model_id'] = raw_data['embedding_model_id']
|
|
99
|
+
|
|
100
|
+
mapped['source'] = self.config.type
|
|
101
|
+
return DatabaseRecord(**mapped)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class DictLayer(DatabaseLayer):
|
|
105
|
+
"""Simple dict-based storage for small entity sets (<5000)"""
|
|
106
|
+
|
|
107
|
+
def _setup(self):
|
|
108
|
+
self._storage: Dict[str, DatabaseRecord] = {}
|
|
109
|
+
self._label_index: Dict[str, str] = {}
|
|
110
|
+
self._alias_index: Dict[str, Set[str]] = {}
|
|
111
|
+
|
|
112
|
+
def search(self, query: str) -> List[DatabaseRecord]:
|
|
113
|
+
"""Fast O(1) exact search using indexes"""
|
|
114
|
+
query_key = self.normalize_query(query)
|
|
115
|
+
results = []
|
|
116
|
+
seen = set()
|
|
117
|
+
|
|
118
|
+
# Label lookup
|
|
119
|
+
if query_key in self._label_index:
|
|
120
|
+
eid = self._label_index[query_key]
|
|
121
|
+
results.append(self._storage[eid])
|
|
122
|
+
seen.add(eid)
|
|
123
|
+
|
|
124
|
+
# Alias lookup
|
|
125
|
+
if query_key in self._alias_index:
|
|
126
|
+
for eid in self._alias_index[query_key]:
|
|
127
|
+
if eid not in seen:
|
|
128
|
+
results.append(self._storage[eid])
|
|
129
|
+
seen.add(eid)
|
|
130
|
+
|
|
131
|
+
return results
|
|
132
|
+
|
|
133
|
+
def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
|
|
134
|
+
"""Simple fuzzy search for small datasets (O(n) is fine for <5000 entities)"""
|
|
135
|
+
try:
|
|
136
|
+
from rapidfuzz import fuzz
|
|
137
|
+
except ImportError:
|
|
138
|
+
print("[WARN DictLayer] rapidfuzz not installed, fuzzy search disabled")
|
|
139
|
+
return []
|
|
140
|
+
|
|
141
|
+
query_key = self.normalize_query(query)
|
|
142
|
+
results = []
|
|
143
|
+
|
|
144
|
+
# Check prefix requirement
|
|
145
|
+
if self.fuzzy_config.prefix_length > 0:
|
|
146
|
+
prefix = query_key[:self.fuzzy_config.prefix_length]
|
|
147
|
+
|
|
148
|
+
for entity in self._storage.values():
|
|
149
|
+
# Check label
|
|
150
|
+
label_key = entity.label.lower()
|
|
151
|
+
|
|
152
|
+
if self.fuzzy_config.prefix_length > 0:
|
|
153
|
+
if not label_key.startswith(prefix):
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
similarity = fuzz.ratio(query_key, label_key) / 100.0
|
|
157
|
+
if similarity >= self.fuzzy_config.min_similarity:
|
|
158
|
+
results.append((entity, similarity))
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
# Check aliases
|
|
162
|
+
for alias in entity.aliases:
|
|
163
|
+
alias_key = alias.lower()
|
|
164
|
+
if self.fuzzy_config.prefix_length > 0:
|
|
165
|
+
if not alias_key.startswith(prefix):
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
sim = fuzz.ratio(query_key, alias_key) / 100.0
|
|
169
|
+
if sim >= self.fuzzy_config.min_similarity:
|
|
170
|
+
results.append((entity, sim))
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
# Sort by similarity
|
|
174
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
175
|
+
return [r[0] for r in results]
|
|
176
|
+
|
|
177
|
+
def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
|
|
178
|
+
"""Write is same as load_bulk for dict layer"""
|
|
179
|
+
self.load_bulk(records, overwrite=True)
|
|
180
|
+
|
|
181
|
+
def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
|
|
182
|
+
"""Bulk load entities with indexing"""
|
|
183
|
+
count = 0
|
|
184
|
+
for entity in entities:
|
|
185
|
+
entity_id = entity.entity_id
|
|
186
|
+
|
|
187
|
+
if not overwrite and entity_id in self._storage:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
# Store entity
|
|
191
|
+
self._storage[entity_id] = entity
|
|
192
|
+
|
|
193
|
+
# Index by label
|
|
194
|
+
label_key = entity.label.lower()
|
|
195
|
+
self._label_index[label_key] = entity_id
|
|
196
|
+
|
|
197
|
+
# Index by aliases
|
|
198
|
+
for alias in entity.aliases:
|
|
199
|
+
alias_key = alias.lower()
|
|
200
|
+
if alias_key not in self._alias_index:
|
|
201
|
+
self._alias_index[alias_key] = set()
|
|
202
|
+
self._alias_index[alias_key].add(entity_id)
|
|
203
|
+
|
|
204
|
+
count += 1
|
|
205
|
+
return count
|
|
206
|
+
|
|
207
|
+
def clear(self):
|
|
208
|
+
"""Clear all data"""
|
|
209
|
+
self._storage.clear()
|
|
210
|
+
self._label_index.clear()
|
|
211
|
+
self._alias_index.clear()
|
|
212
|
+
|
|
213
|
+
def count(self) -> int:
|
|
214
|
+
"""Count entities"""
|
|
215
|
+
return len(self._storage)
|
|
216
|
+
|
|
217
|
+
def get_all_entities(self) -> List[DatabaseRecord]:
|
|
218
|
+
"""Get all entities from storage"""
|
|
219
|
+
return list(self._storage.values())
|
|
220
|
+
|
|
221
|
+
def update_embeddings(
|
|
222
|
+
self,
|
|
223
|
+
entity_ids: List[str],
|
|
224
|
+
embeddings: List[List[float]],
|
|
225
|
+
model_id: str
|
|
226
|
+
) -> int:
|
|
227
|
+
"""Update embeddings for entities"""
|
|
228
|
+
count = 0
|
|
229
|
+
for eid, emb in zip(entity_ids, embeddings):
|
|
230
|
+
if eid in self._storage:
|
|
231
|
+
self._storage[eid].embedding = emb
|
|
232
|
+
self._storage[eid].embedding_model_id = model_id
|
|
233
|
+
count += 1
|
|
234
|
+
return count
|
|
235
|
+
|
|
236
|
+
def is_available(self) -> bool:
|
|
237
|
+
"""Dict layer is always available"""
|
|
238
|
+
return True
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class RedisLayer(DatabaseLayer):
|
|
242
|
+
"""Redis cache layer"""
|
|
243
|
+
|
|
244
|
+
def _setup(self):
|
|
245
|
+
self.client = redis.Redis(
|
|
246
|
+
host=self.config.config.get('host', 'localhost'),
|
|
247
|
+
port=self.config.config.get('port', 6379),
|
|
248
|
+
db=self.config.config.get('db', 0),
|
|
249
|
+
password=self.config.config.get('password'),
|
|
250
|
+
decode_responses=False
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def supports_fuzzy(self) -> bool:
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
def search(self, query: str) -> List[DatabaseRecord]:
|
|
257
|
+
query = self.normalize_query(query)
|
|
258
|
+
key = f"entity:{query}"
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
data = self.client.get(key)
|
|
262
|
+
if data:
|
|
263
|
+
if isinstance(data, bytes):
|
|
264
|
+
data = data.decode('utf-8')
|
|
265
|
+
|
|
266
|
+
records_data = json.loads(data)
|
|
267
|
+
|
|
268
|
+
if isinstance(records_data, list):
|
|
269
|
+
results = []
|
|
270
|
+
for r in records_data:
|
|
271
|
+
if isinstance(r, dict):
|
|
272
|
+
r['source'] = 'redis'
|
|
273
|
+
results.append(DatabaseRecord(**r))
|
|
274
|
+
else:
|
|
275
|
+
results.append(r)
|
|
276
|
+
return results
|
|
277
|
+
|
|
278
|
+
elif isinstance(records_data, dict):
|
|
279
|
+
records_data['source'] = 'redis'
|
|
280
|
+
return [DatabaseRecord(**records_data)]
|
|
281
|
+
|
|
282
|
+
except Exception as e:
|
|
283
|
+
print(f"[ERROR Redis] Search error: {e}")
|
|
284
|
+
|
|
285
|
+
return []
|
|
286
|
+
|
|
287
|
+
def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
|
|
288
|
+
return []
|
|
289
|
+
|
|
290
|
+
def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
|
|
291
|
+
key = self.normalize_query(key)
|
|
292
|
+
cache_key = f"entity:{key}"
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
data = json.dumps([r.dict() for r in records])
|
|
296
|
+
self.client.setex(cache_key, ttl, data)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
print(f"[ERROR Redis] Write error: {e}")
|
|
299
|
+
|
|
300
|
+
def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
|
|
301
|
+
"""Bulk load to Redis"""
|
|
302
|
+
count = 0
|
|
303
|
+
pipe = self.client.pipeline()
|
|
304
|
+
|
|
305
|
+
for entity in entities:
|
|
306
|
+
# Prepare data
|
|
307
|
+
entity_data = entity.dict()
|
|
308
|
+
data_json = json.dumps(entity_data)
|
|
309
|
+
|
|
310
|
+
# Store by label
|
|
311
|
+
label_key = f"entity:{entity.label.lower()}"
|
|
312
|
+
if overwrite or not self.client.exists(label_key):
|
|
313
|
+
pipe.setex(label_key, self.ttl, data_json)
|
|
314
|
+
count += 1
|
|
315
|
+
|
|
316
|
+
# Store by aliases
|
|
317
|
+
for alias in entity.aliases:
|
|
318
|
+
alias_key = f"entity:{alias.lower()}"
|
|
319
|
+
if overwrite or not self.client.exists(alias_key):
|
|
320
|
+
pipe.setex(alias_key, self.ttl, data_json)
|
|
321
|
+
|
|
322
|
+
# Execute in batches
|
|
323
|
+
if len(pipe) >= batch_size:
|
|
324
|
+
pipe.execute()
|
|
325
|
+
pipe = self.client.pipeline()
|
|
326
|
+
|
|
327
|
+
# Execute remaining
|
|
328
|
+
if len(pipe) > 0:
|
|
329
|
+
pipe.execute()
|
|
330
|
+
|
|
331
|
+
return count
|
|
332
|
+
|
|
333
|
+
def clear(self):
|
|
334
|
+
"""Clear all entity keys"""
|
|
335
|
+
for key in self.client.scan_iter(match="entity:*"):
|
|
336
|
+
self.client.delete(key)
|
|
337
|
+
|
|
338
|
+
def count(self) -> int:
|
|
339
|
+
"""Count entity keys"""
|
|
340
|
+
return sum(1 for _ in self.client.scan_iter(match="entity:*"))
|
|
341
|
+
|
|
342
|
+
def get_all_entities(self) -> List[DatabaseRecord]:
|
|
343
|
+
"""Get all entities from Redis (scans all entity:* keys)"""
|
|
344
|
+
entities = []
|
|
345
|
+
seen_ids = set()
|
|
346
|
+
|
|
347
|
+
for key in self.client.scan_iter(match="entity:*"):
|
|
348
|
+
try:
|
|
349
|
+
data = self.client.get(key)
|
|
350
|
+
if data:
|
|
351
|
+
if isinstance(data, bytes):
|
|
352
|
+
data = data.decode('utf-8')
|
|
353
|
+
record_data = json.loads(data)
|
|
354
|
+
|
|
355
|
+
if isinstance(record_data, dict):
|
|
356
|
+
if record_data.get('entity_id') not in seen_ids:
|
|
357
|
+
record_data['source'] = 'redis'
|
|
358
|
+
entities.append(DatabaseRecord(**record_data))
|
|
359
|
+
seen_ids.add(record_data.get('entity_id'))
|
|
360
|
+
elif isinstance(record_data, list):
|
|
361
|
+
for r in record_data:
|
|
362
|
+
if r.get('entity_id') not in seen_ids:
|
|
363
|
+
r['source'] = 'redis'
|
|
364
|
+
entities.append(DatabaseRecord(**r))
|
|
365
|
+
seen_ids.add(r.get('entity_id'))
|
|
366
|
+
except Exception as e:
|
|
367
|
+
continue
|
|
368
|
+
|
|
369
|
+
return entities
|
|
370
|
+
|
|
371
|
+
def update_embeddings(
|
|
372
|
+
self,
|
|
373
|
+
entity_ids: List[str],
|
|
374
|
+
embeddings: List[List[float]],
|
|
375
|
+
model_id: str
|
|
376
|
+
) -> int:
|
|
377
|
+
"""Update embeddings in Redis entities"""
|
|
378
|
+
count = 0
|
|
379
|
+
id_to_embedding = dict(zip(entity_ids, embeddings))
|
|
380
|
+
|
|
381
|
+
for key in self.client.scan_iter(match="entity:*"):
|
|
382
|
+
try:
|
|
383
|
+
data = self.client.get(key)
|
|
384
|
+
if not data:
|
|
385
|
+
continue
|
|
386
|
+
|
|
387
|
+
if isinstance(data, bytes):
|
|
388
|
+
data = data.decode('utf-8')
|
|
389
|
+
|
|
390
|
+
record_data = json.loads(data)
|
|
391
|
+
updated = False
|
|
392
|
+
|
|
393
|
+
if isinstance(record_data, dict):
|
|
394
|
+
if record_data.get('entity_id') in id_to_embedding:
|
|
395
|
+
record_data['embedding'] = id_to_embedding[record_data['entity_id']]
|
|
396
|
+
record_data['embedding_model_id'] = model_id
|
|
397
|
+
updated = True
|
|
398
|
+
elif isinstance(record_data, list):
|
|
399
|
+
for r in record_data:
|
|
400
|
+
if r.get('entity_id') in id_to_embedding:
|
|
401
|
+
r['embedding'] = id_to_embedding[r['entity_id']]
|
|
402
|
+
r['embedding_model_id'] = model_id
|
|
403
|
+
updated = True
|
|
404
|
+
|
|
405
|
+
if updated:
|
|
406
|
+
self.client.setex(key, self.ttl, json.dumps(record_data))
|
|
407
|
+
count += 1
|
|
408
|
+
|
|
409
|
+
except Exception as e:
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
return count
|
|
413
|
+
|
|
414
|
+
def is_available(self) -> bool:
|
|
415
|
+
try:
|
|
416
|
+
self.client.ping()
|
|
417
|
+
return True
|
|
418
|
+
except:
|
|
419
|
+
return False
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class ElasticsearchLayer(DatabaseLayer):
|
|
423
|
+
"""Elasticsearch full-text search layer"""
|
|
424
|
+
|
|
425
|
+
def _setup(self):
|
|
426
|
+
self.client = Elasticsearch(
|
|
427
|
+
self.config.config['hosts'],
|
|
428
|
+
api_key=self.config.config.get('api_key')
|
|
429
|
+
)
|
|
430
|
+
self.index_name = self.config.config['index_name']
|
|
431
|
+
|
|
432
|
+
def search(self, query: str) -> List[DatabaseRecord]:
|
|
433
|
+
query = self.normalize_query(query)
|
|
434
|
+
|
|
435
|
+
try:
|
|
436
|
+
body = {
|
|
437
|
+
"query": {
|
|
438
|
+
"multi_match": {
|
|
439
|
+
"query": query,
|
|
440
|
+
"fields": ["label^2", "aliases^1.5", "description"],
|
|
441
|
+
"type": "best_fields"
|
|
442
|
+
}
|
|
443
|
+
},
|
|
444
|
+
"size": 50
|
|
445
|
+
}
|
|
446
|
+
response = self.client.search(index=self.index_name, body=body)
|
|
447
|
+
return self._process_hits(response['hits']['hits'])
|
|
448
|
+
except Exception as e:
|
|
449
|
+
print(f"[ERROR ES] Search error: {e}")
|
|
450
|
+
return []
|
|
451
|
+
|
|
452
|
+
def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
|
|
453
|
+
query = self.normalize_query(query)
|
|
454
|
+
fuzzy_distance = self.fuzzy_config.max_distance
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
body = {
|
|
458
|
+
"query": {
|
|
459
|
+
"multi_match": {
|
|
460
|
+
"query": query,
|
|
461
|
+
"fields": ["label^2", "aliases^1.5", "description"],
|
|
462
|
+
"fuzziness": fuzzy_distance,
|
|
463
|
+
"prefix_length": self.fuzzy_config.prefix_length,
|
|
464
|
+
"max_expansions": 50
|
|
465
|
+
}
|
|
466
|
+
},
|
|
467
|
+
"size": 50
|
|
468
|
+
}
|
|
469
|
+
response = self.client.search(index=self.index_name, body=body)
|
|
470
|
+
return self._process_hits(response['hits']['hits'])
|
|
471
|
+
except Exception as e:
|
|
472
|
+
print(f"[ERROR ES] Fuzzy error: {e}")
|
|
473
|
+
return []
|
|
474
|
+
|
|
475
|
+
def _process_hits(self, hits: List[Dict]) -> List[DatabaseRecord]:
|
|
476
|
+
records = []
|
|
477
|
+
for hit in hits:
|
|
478
|
+
source = hit['_source']
|
|
479
|
+
source['_id'] = hit['_id']
|
|
480
|
+
source['source'] = 'elasticsearch'
|
|
481
|
+
record = self.map_to_record(source)
|
|
482
|
+
records.append(record)
|
|
483
|
+
return records
|
|
484
|
+
|
|
485
|
+
def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
|
|
486
|
+
if not records:
|
|
487
|
+
return
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
actions = []
|
|
491
|
+
for record in records:
|
|
492
|
+
doc = self._map_from_record(record)
|
|
493
|
+
actions.append({
|
|
494
|
+
"_index": self.index_name,
|
|
495
|
+
"_id": record.entity_id,
|
|
496
|
+
"_source": doc
|
|
497
|
+
})
|
|
498
|
+
|
|
499
|
+
if actions:
|
|
500
|
+
es_bulk(self.client, actions)
|
|
501
|
+
self.client.indices.refresh(index=self.index_name)
|
|
502
|
+
except Exception as e:
|
|
503
|
+
print(f"[ERROR ES] Write error: {e}")
|
|
504
|
+
|
|
505
|
+
def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
|
|
506
|
+
"""Bulk load to Elasticsearch"""
|
|
507
|
+
actions = []
|
|
508
|
+
for entity in entities:
|
|
509
|
+
doc = self._map_from_record(entity)
|
|
510
|
+
|
|
511
|
+
action = {
|
|
512
|
+
'_index': self.index_name,
|
|
513
|
+
'_id': entity.entity_id,
|
|
514
|
+
'_source': doc
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
if overwrite:
|
|
518
|
+
action['_op_type'] = 'index'
|
|
519
|
+
else:
|
|
520
|
+
action['_op_type'] = 'create'
|
|
521
|
+
|
|
522
|
+
actions.append(action)
|
|
523
|
+
|
|
524
|
+
success, failed = es_bulk(
|
|
525
|
+
self.client,
|
|
526
|
+
actions,
|
|
527
|
+
raise_on_error=False,
|
|
528
|
+
chunk_size=batch_size
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
self.client.indices.refresh(index=self.index_name)
|
|
532
|
+
return success
|
|
533
|
+
|
|
534
|
+
def _map_from_record(self, record: DatabaseRecord) -> dict:
|
|
535
|
+
"""Map DatabaseRecord -> ES document using field_mapping"""
|
|
536
|
+
reverse_mapping = {v: k for k, v in self.field_mapping.items()}
|
|
537
|
+
|
|
538
|
+
doc = {}
|
|
539
|
+
for standard_field, value in record.dict().items():
|
|
540
|
+
if standard_field == 'source':
|
|
541
|
+
continue
|
|
542
|
+
|
|
543
|
+
es_field = reverse_mapping.get(standard_field, standard_field)
|
|
544
|
+
doc[es_field] = value
|
|
545
|
+
|
|
546
|
+
return doc
|
|
547
|
+
|
|
548
|
+
def clear(self):
|
|
549
|
+
"""Delete all documents in index"""
|
|
550
|
+
try:
|
|
551
|
+
self.client.delete_by_query(
|
|
552
|
+
index=self.index_name,
|
|
553
|
+
body={"query": {"match_all": {}}}
|
|
554
|
+
)
|
|
555
|
+
self.client.indices.refresh(index=self.index_name)
|
|
556
|
+
except Exception as e:
|
|
557
|
+
print(f"[ERROR ES] Clear error: {e}")
|
|
558
|
+
|
|
559
|
+
def count(self) -> int:
|
|
560
|
+
"""Count documents in index"""
|
|
561
|
+
try:
|
|
562
|
+
result = self.client.count(index=self.index_name)
|
|
563
|
+
return result['count']
|
|
564
|
+
except:
|
|
565
|
+
return 0
|
|
566
|
+
|
|
567
|
+
def get_all_entities(self) -> List[DatabaseRecord]:
|
|
568
|
+
"""Get all entities from Elasticsearch using scroll"""
|
|
569
|
+
entities = []
|
|
570
|
+
|
|
571
|
+
try:
|
|
572
|
+
# Use scroll API for large datasets
|
|
573
|
+
response = self.client.search(
|
|
574
|
+
index=self.index_name,
|
|
575
|
+
body={"query": {"match_all": {}}, "size": 1000},
|
|
576
|
+
scroll='2m'
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
scroll_id = response['_scroll_id']
|
|
580
|
+
hits = response['hits']['hits']
|
|
581
|
+
|
|
582
|
+
while hits:
|
|
583
|
+
entities.extend(self._process_hits(hits))
|
|
584
|
+
|
|
585
|
+
response = self.client.scroll(scroll_id=scroll_id, scroll='2m')
|
|
586
|
+
scroll_id = response['_scroll_id']
|
|
587
|
+
hits = response['hits']['hits']
|
|
588
|
+
|
|
589
|
+
# Clear scroll
|
|
590
|
+
self.client.clear_scroll(scroll_id=scroll_id)
|
|
591
|
+
|
|
592
|
+
except Exception as e:
|
|
593
|
+
print(f"[ERROR ES] get_all_entities error: {e}")
|
|
594
|
+
|
|
595
|
+
return entities
|
|
596
|
+
|
|
597
|
+
def update_embeddings(
|
|
598
|
+
self,
|
|
599
|
+
entity_ids: List[str],
|
|
600
|
+
embeddings: List[List[float]],
|
|
601
|
+
model_id: str
|
|
602
|
+
) -> int:
|
|
603
|
+
"""Update embeddings in Elasticsearch"""
|
|
604
|
+
try:
|
|
605
|
+
actions = []
|
|
606
|
+
for eid, emb in zip(entity_ids, embeddings):
|
|
607
|
+
actions.append({
|
|
608
|
+
"_op_type": "update",
|
|
609
|
+
"_index": self.index_name,
|
|
610
|
+
"_id": eid,
|
|
611
|
+
"doc": {
|
|
612
|
+
"embedding": emb,
|
|
613
|
+
"embedding_model_id": model_id
|
|
614
|
+
}
|
|
615
|
+
})
|
|
616
|
+
|
|
617
|
+
success, failed = es_bulk(
|
|
618
|
+
self.client,
|
|
619
|
+
actions,
|
|
620
|
+
raise_on_error=False,
|
|
621
|
+
chunk_size=500
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
self.client.indices.refresh(index=self.index_name)
|
|
625
|
+
return success
|
|
626
|
+
|
|
627
|
+
except Exception as e:
|
|
628
|
+
print(f"[ERROR ES] update_embeddings error: {e}")
|
|
629
|
+
return 0
|
|
630
|
+
|
|
631
|
+
def is_available(self) -> bool:
|
|
632
|
+
try:
|
|
633
|
+
return self.client.ping()
|
|
634
|
+
except:
|
|
635
|
+
return False
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
class PostgresLayer(DatabaseLayer):
|
|
639
|
+
"""PostgreSQL database layer"""
|
|
640
|
+
|
|
641
|
+
def _setup(self):
|
|
642
|
+
self.conn = psycopg2.connect(
|
|
643
|
+
host=self.config.config['host'],
|
|
644
|
+
port=self.config.config.get('port', 5432),
|
|
645
|
+
database=self.config.config['database'],
|
|
646
|
+
user=self.config.config['user'],
|
|
647
|
+
password=self.config.config['password']
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
cursor = self.conn.cursor()
|
|
651
|
+
try:
|
|
652
|
+
cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
|
653
|
+
self.conn.commit()
|
|
654
|
+
except Exception as e:
|
|
655
|
+
print(f"[WARN Postgres] pg_trgm: {e}")
|
|
656
|
+
finally:
|
|
657
|
+
cursor.close()
|
|
658
|
+
|
|
659
|
+
def search(self, query: str) -> List[DatabaseRecord]:
|
|
660
|
+
query = self.normalize_query(query)
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
cursor = self.conn.cursor(cursor_factory=RealDictCursor)
|
|
664
|
+
sql = """
|
|
665
|
+
SELECT
|
|
666
|
+
e.entity_id,
|
|
667
|
+
e.label,
|
|
668
|
+
e.description,
|
|
669
|
+
e.entity_type,
|
|
670
|
+
e.popularity,
|
|
671
|
+
COALESCE(array_agg(a.alias) FILTER (WHERE a.alias IS NOT NULL), ARRAY[]::text[]) as aliases
|
|
672
|
+
FROM entities e
|
|
673
|
+
LEFT JOIN aliases a ON e.entity_id = a.entity_id
|
|
674
|
+
WHERE LOWER(e.label) LIKE %s
|
|
675
|
+
OR EXISTS (
|
|
676
|
+
SELECT 1 FROM aliases a2
|
|
677
|
+
WHERE a2.entity_id = e.entity_id
|
|
678
|
+
AND LOWER(a2.alias) LIKE %s
|
|
679
|
+
)
|
|
680
|
+
GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity
|
|
681
|
+
ORDER BY e.popularity DESC
|
|
682
|
+
LIMIT 50
|
|
683
|
+
"""
|
|
684
|
+
cursor.execute(sql, (f"%{query}%", f"%{query}%"))
|
|
685
|
+
records = self._process_rows(cursor.fetchall())
|
|
686
|
+
cursor.close()
|
|
687
|
+
return records
|
|
688
|
+
except Exception as e:
|
|
689
|
+
print(f"[ERROR Postgres] Search error: {e}")
|
|
690
|
+
return []
|
|
691
|
+
|
|
692
|
+
def search_fuzzy(self, query: str) -> List[DatabaseRecord]:
|
|
693
|
+
query = self.normalize_query(query)
|
|
694
|
+
threshold = self.fuzzy_config.min_similarity
|
|
695
|
+
|
|
696
|
+
try:
|
|
697
|
+
cursor = self.conn.cursor(cursor_factory=RealDictCursor)
|
|
698
|
+
sql = """
|
|
699
|
+
SELECT
|
|
700
|
+
e.entity_id,
|
|
701
|
+
e.label,
|
|
702
|
+
e.description,
|
|
703
|
+
e.entity_type,
|
|
704
|
+
e.popularity,
|
|
705
|
+
COALESCE(array_agg(a.alias) FILTER (WHERE a.alias IS NOT NULL), ARRAY[]::text[]) as aliases,
|
|
706
|
+
similarity(LOWER(e.label), %s) AS sim_score
|
|
707
|
+
FROM entities e
|
|
708
|
+
LEFT JOIN aliases a ON e.entity_id = a.entity_id
|
|
709
|
+
WHERE similarity(LOWER(e.label), %s) >= %s
|
|
710
|
+
GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity
|
|
711
|
+
ORDER BY sim_score DESC, e.popularity DESC
|
|
712
|
+
LIMIT 50
|
|
713
|
+
"""
|
|
714
|
+
cursor.execute(sql, (query, query, threshold))
|
|
715
|
+
records = self._process_rows(cursor.fetchall())
|
|
716
|
+
cursor.close()
|
|
717
|
+
return records
|
|
718
|
+
except Exception as e:
|
|
719
|
+
print(f"[ERROR Postgres] Fuzzy error: {e}")
|
|
720
|
+
return self.search(query)
|
|
721
|
+
|
|
722
|
+
def _process_rows(self, rows: List[Dict]) -> List[DatabaseRecord]:
|
|
723
|
+
records = []
|
|
724
|
+
for row in rows:
|
|
725
|
+
row_dict = dict(row)
|
|
726
|
+
row_dict['source'] = 'postgres'
|
|
727
|
+
record = self.map_to_record(row_dict)
|
|
728
|
+
records.append(record)
|
|
729
|
+
return records
|
|
730
|
+
|
|
731
|
+
def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int):
|
|
732
|
+
pass
|
|
733
|
+
|
|
734
|
+
def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int:
|
|
735
|
+
"""Bulk load to Postgres"""
|
|
736
|
+
cursor = self.conn.cursor()
|
|
737
|
+
|
|
738
|
+
try:
|
|
739
|
+
# Prepare entity data
|
|
740
|
+
entity_values = [
|
|
741
|
+
(e.entity_id, e.label, e.description, e.entity_type, e.popularity)
|
|
742
|
+
for e in entities
|
|
743
|
+
]
|
|
744
|
+
|
|
745
|
+
# Insert entities
|
|
746
|
+
if overwrite:
|
|
747
|
+
entity_query = """
|
|
748
|
+
INSERT INTO entities (entity_id, label, description, entity_type, popularity)
|
|
749
|
+
VALUES (%s, %s, %s, %s, %s)
|
|
750
|
+
ON CONFLICT (entity_id) DO UPDATE SET
|
|
751
|
+
label = EXCLUDED.label,
|
|
752
|
+
description = EXCLUDED.description,
|
|
753
|
+
entity_type = EXCLUDED.entity_type,
|
|
754
|
+
popularity = EXCLUDED.popularity
|
|
755
|
+
"""
|
|
756
|
+
else:
|
|
757
|
+
entity_query = """
|
|
758
|
+
INSERT INTO entities (entity_id, label, description, entity_type, popularity)
|
|
759
|
+
VALUES (%s, %s, %s, %s, %s)
|
|
760
|
+
ON CONFLICT (entity_id) DO NOTHING
|
|
761
|
+
"""
|
|
762
|
+
|
|
763
|
+
execute_batch(cursor, entity_query, entity_values, page_size=batch_size)
|
|
764
|
+
|
|
765
|
+
# Prepare alias data
|
|
766
|
+
alias_values = []
|
|
767
|
+
for entity in entities:
|
|
768
|
+
for alias in entity.aliases:
|
|
769
|
+
alias_values.append((entity.entity_id, alias))
|
|
770
|
+
|
|
771
|
+
# Delete old aliases if overwrite
|
|
772
|
+
if overwrite and alias_values:
|
|
773
|
+
entity_ids = [e.entity_id for e in entities]
|
|
774
|
+
cursor.execute(
|
|
775
|
+
"DELETE FROM aliases WHERE entity_id = ANY(%s)",
|
|
776
|
+
(entity_ids,)
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# Insert aliases
|
|
780
|
+
if alias_values:
|
|
781
|
+
execute_batch(
|
|
782
|
+
cursor,
|
|
783
|
+
"INSERT INTO aliases (entity_id, alias) VALUES (%s, %s) ON CONFLICT DO NOTHING",
|
|
784
|
+
alias_values,
|
|
785
|
+
page_size=batch_size
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
self.conn.commit()
|
|
789
|
+
return len(entities)
|
|
790
|
+
|
|
791
|
+
except Exception as e:
|
|
792
|
+
self.conn.rollback()
|
|
793
|
+
print(f"[ERROR Postgres] Load bulk failed: {e}")
|
|
794
|
+
raise
|
|
795
|
+
finally:
|
|
796
|
+
cursor.close()
|
|
797
|
+
|
|
798
|
+
def clear(self):
|
|
799
|
+
"""Clear all data"""
|
|
800
|
+
cursor = self.conn.cursor()
|
|
801
|
+
try:
|
|
802
|
+
cursor.execute("TRUNCATE entities, aliases CASCADE")
|
|
803
|
+
self.conn.commit()
|
|
804
|
+
except Exception as e:
|
|
805
|
+
self.conn.rollback()
|
|
806
|
+
print(f"[ERROR Postgres] Clear error: {e}")
|
|
807
|
+
finally:
|
|
808
|
+
cursor.close()
|
|
809
|
+
|
|
810
|
+
def count(self) -> int:
|
|
811
|
+
"""Count entities"""
|
|
812
|
+
cursor = self.conn.cursor()
|
|
813
|
+
try:
|
|
814
|
+
cursor.execute("SELECT COUNT(*) FROM entities")
|
|
815
|
+
return cursor.fetchone()[0]
|
|
816
|
+
except:
|
|
817
|
+
return 0
|
|
818
|
+
finally:
|
|
819
|
+
cursor.close()
|
|
820
|
+
|
|
821
|
+
def get_all_entities(self) -> List[DatabaseRecord]:
|
|
822
|
+
"""Get all entities from PostgreSQL"""
|
|
823
|
+
entities = []
|
|
824
|
+
|
|
825
|
+
try:
|
|
826
|
+
cursor = self.conn.cursor(cursor_factory=RealDictCursor)
|
|
827
|
+
sql = """
|
|
828
|
+
SELECT
|
|
829
|
+
e.entity_id,
|
|
830
|
+
e.label,
|
|
831
|
+
e.description,
|
|
832
|
+
e.entity_type,
|
|
833
|
+
e.popularity,
|
|
834
|
+
e.embedding,
|
|
835
|
+
e.embedding_model_id,
|
|
836
|
+
COALESCE(array_agg(a.alias) FILTER (WHERE a.alias IS NOT NULL), ARRAY[]::text[]) as aliases
|
|
837
|
+
FROM entities e
|
|
838
|
+
LEFT JOIN aliases a ON e.entity_id = a.entity_id
|
|
839
|
+
GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity, e.embedding, e.embedding_model_id
|
|
840
|
+
"""
|
|
841
|
+
cursor.execute(sql)
|
|
842
|
+
|
|
843
|
+
for row in cursor.fetchall():
|
|
844
|
+
row_dict = dict(row)
|
|
845
|
+
row_dict['source'] = 'postgres'
|
|
846
|
+
|
|
847
|
+
# Deserialize embedding from bytes if needed
|
|
848
|
+
if row_dict.get('embedding'):
|
|
849
|
+
import pickle
|
|
850
|
+
if isinstance(row_dict['embedding'], (bytes, memoryview)):
|
|
851
|
+
row_dict['embedding'] = pickle.loads(bytes(row_dict['embedding']))
|
|
852
|
+
|
|
853
|
+
record = self.map_to_record(row_dict)
|
|
854
|
+
entities.append(record)
|
|
855
|
+
|
|
856
|
+
cursor.close()
|
|
857
|
+
|
|
858
|
+
except Exception as e:
|
|
859
|
+
print(f"[ERROR Postgres] get_all_entities error: {e}")
|
|
860
|
+
|
|
861
|
+
return entities
|
|
862
|
+
|
|
863
|
+
def update_embeddings(
|
|
864
|
+
self,
|
|
865
|
+
entity_ids: List[str],
|
|
866
|
+
embeddings: List[List[float]],
|
|
867
|
+
model_id: str
|
|
868
|
+
) -> int:
|
|
869
|
+
"""Update embeddings in PostgreSQL"""
|
|
870
|
+
cursor = self.conn.cursor()
|
|
871
|
+
|
|
872
|
+
try:
|
|
873
|
+
import pickle
|
|
874
|
+
|
|
875
|
+
# Prepare batch data
|
|
876
|
+
batch_data = []
|
|
877
|
+
for eid, emb in zip(entity_ids, embeddings):
|
|
878
|
+
emb_bytes = pickle.dumps(emb)
|
|
879
|
+
batch_data.append((emb_bytes, model_id, eid))
|
|
880
|
+
|
|
881
|
+
# Batch update
|
|
882
|
+
execute_batch(
|
|
883
|
+
cursor,
|
|
884
|
+
"UPDATE entities SET embedding = %s, embedding_model_id = %s WHERE entity_id = %s",
|
|
885
|
+
batch_data,
|
|
886
|
+
page_size=500
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
self.conn.commit()
|
|
890
|
+
return len(batch_data)
|
|
891
|
+
|
|
892
|
+
except Exception as e:
|
|
893
|
+
self.conn.rollback()
|
|
894
|
+
print(f"[ERROR Postgres] update_embeddings error: {e}")
|
|
895
|
+
return 0
|
|
896
|
+
finally:
|
|
897
|
+
cursor.close()
|
|
898
|
+
|
|
899
|
+
def is_available(self) -> bool:
|
|
900
|
+
try:
|
|
901
|
+
cursor = self.conn.cursor()
|
|
902
|
+
cursor.execute("SELECT 1")
|
|
903
|
+
cursor.close()
|
|
904
|
+
return True
|
|
905
|
+
except:
|
|
906
|
+
return False
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
class DatabaseChainComponent(BaseComponent[L2Config]):
|
|
910
|
+
"""Multi-layer database chain component"""
|
|
911
|
+
|
|
912
|
+
def _setup(self):
|
|
913
|
+
self.layers: List[DatabaseLayer] = []
|
|
914
|
+
|
|
915
|
+
for layer_config in self.config.layers:
|
|
916
|
+
if isinstance(layer_config, dict):
|
|
917
|
+
layer_config = LayerConfig(**layer_config)
|
|
918
|
+
|
|
919
|
+
if layer_config.type == "dict":
|
|
920
|
+
layer = DictLayer(layer_config)
|
|
921
|
+
elif layer_config.type == "redis":
|
|
922
|
+
layer = RedisLayer(layer_config)
|
|
923
|
+
elif layer_config.type == "elasticsearch":
|
|
924
|
+
layer = ElasticsearchLayer(layer_config)
|
|
925
|
+
elif layer_config.type == "postgres":
|
|
926
|
+
layer = PostgresLayer(layer_config)
|
|
927
|
+
else:
|
|
928
|
+
raise ValueError(f"Unknown layer type: {layer_config.type}")
|
|
929
|
+
|
|
930
|
+
self.layers.append(layer)
|
|
931
|
+
|
|
932
|
+
self.layers.sort(key=lambda x: x.priority, reverse=True) # Higher priority checked first
|
|
933
|
+
|
|
934
|
+
def get_available_methods(self) -> List[str]:
|
|
935
|
+
return [
|
|
936
|
+
"search",
|
|
937
|
+
"filter_by_popularity",
|
|
938
|
+
"deduplicate_candidates",
|
|
939
|
+
"limit_candidates",
|
|
940
|
+
"sort_by_popularity"
|
|
941
|
+
]
|
|
942
|
+
|
|
943
|
+
def search(self, mention: str) -> List[DatabaseRecord]:
|
|
944
|
+
"""Search through layers with fallback"""
|
|
945
|
+
found_in_layer = None
|
|
946
|
+
results = []
|
|
947
|
+
|
|
948
|
+
for layer in self.layers:
|
|
949
|
+
if not layer.is_available():
|
|
950
|
+
continue
|
|
951
|
+
|
|
952
|
+
layer_results = []
|
|
953
|
+
|
|
954
|
+
for mode in layer.config.search_mode:
|
|
955
|
+
if mode == "exact":
|
|
956
|
+
layer_results.extend(layer.search(mention))
|
|
957
|
+
elif mode == "fuzzy":
|
|
958
|
+
if layer.supports_fuzzy():
|
|
959
|
+
layer_results.extend(layer.search_fuzzy(mention))
|
|
960
|
+
|
|
961
|
+
if layer_results:
|
|
962
|
+
layer_results = self.deduplicate_candidates(layer_results)
|
|
963
|
+
results = layer_results
|
|
964
|
+
found_in_layer = layer
|
|
965
|
+
break
|
|
966
|
+
|
|
967
|
+
if results and found_in_layer:
|
|
968
|
+
self._cache_write(mention, results, found_in_layer)
|
|
969
|
+
|
|
970
|
+
return results
|
|
971
|
+
|
|
972
|
+
def _cache_write(self, query: str, results: List[DatabaseRecord], source_layer: DatabaseLayer):
|
|
973
|
+
"""Write results to upper layers (higher priority = checked earlier)"""
|
|
974
|
+
for layer in self.layers:
|
|
975
|
+
# Skip source layer and all layers with lower priority
|
|
976
|
+
if layer.priority <= source_layer.priority:
|
|
977
|
+
continue
|
|
978
|
+
if not layer.write:
|
|
979
|
+
continue
|
|
980
|
+
|
|
981
|
+
if layer.cache_policy == "always":
|
|
982
|
+
layer.write_cache(query, results, layer.ttl)
|
|
983
|
+
elif layer.cache_policy == "miss":
|
|
984
|
+
existing = layer.search(query)
|
|
985
|
+
if not existing:
|
|
986
|
+
layer.write_cache(query, results, layer.ttl)
|
|
987
|
+
elif layer.cache_policy == "hit":
|
|
988
|
+
existing = layer.search(query)
|
|
989
|
+
if existing:
|
|
990
|
+
layer.write_cache(query, results, layer.ttl)
|
|
991
|
+
|
|
992
|
+
def filter_by_popularity(self, records: List[DatabaseRecord], min_popularity: int = None) -> List[DatabaseRecord]:
|
|
993
|
+
threshold = min_popularity if min_popularity is not None else self.config.min_popularity
|
|
994
|
+
return [r for r in records if r.popularity >= threshold]
|
|
995
|
+
|
|
996
|
+
def deduplicate_candidates(self, records: List[DatabaseRecord]) -> List[DatabaseRecord]:
|
|
997
|
+
seen = set()
|
|
998
|
+
unique = []
|
|
999
|
+
for record in records:
|
|
1000
|
+
if record.entity_id not in seen:
|
|
1001
|
+
unique.append(record)
|
|
1002
|
+
seen.add(record.entity_id)
|
|
1003
|
+
return unique
|
|
1004
|
+
|
|
1005
|
+
def limit_candidates(self, records: List[DatabaseRecord], limit: int = None) -> List[DatabaseRecord]:
|
|
1006
|
+
max_cands = limit if limit is not None else self.config.max_candidates
|
|
1007
|
+
return records[:max_cands]
|
|
1008
|
+
|
|
1009
|
+
def sort_by_popularity(self, records: List[DatabaseRecord]) -> List[DatabaseRecord]:
|
|
1010
|
+
return sorted(records, key=lambda x: x.popularity, reverse=True)
|
|
1011
|
+
|
|
1012
|
+
def load_entities(
|
|
1013
|
+
self,
|
|
1014
|
+
source: Union[str, Path, List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
|
|
1015
|
+
target_layers: List[str] = None,
|
|
1016
|
+
batch_size: int = 1000,
|
|
1017
|
+
overwrite: bool = False
|
|
1018
|
+
) -> Dict[str, int]:
|
|
1019
|
+
"""
|
|
1020
|
+
Load entities from JSONL file, list of dicts, or dict.
|
|
1021
|
+
|
|
1022
|
+
Accepts:
|
|
1023
|
+
- ``str`` / ``Path``: path to a JSONL file (one DatabaseRecord per line)
|
|
1024
|
+
- ``list[dict]``: each dict has at least ``entity_id`` and ``label``
|
|
1025
|
+
- ``dict[str, dict]``: keys are entity_ids, values are entity data
|
|
1026
|
+
|
|
1027
|
+
Args:
|
|
1028
|
+
source: entity data (file path, list, or dict)
|
|
1029
|
+
target_layers: ['dict', 'redis', 'elasticsearch', 'postgres'] or None (all writable)
|
|
1030
|
+
batch_size: batch size for bulk operations
|
|
1031
|
+
overwrite: overwrite existing entities
|
|
1032
|
+
|
|
1033
|
+
Returns:
|
|
1034
|
+
{'redis': 1500, 'elasticsearch': 1500}
|
|
1035
|
+
"""
|
|
1036
|
+
if isinstance(source, (str, Path)):
|
|
1037
|
+
entities = self._parse_jsonl(source)
|
|
1038
|
+
elif isinstance(source, dict):
|
|
1039
|
+
entities = [
|
|
1040
|
+
DatabaseRecord(entity_id=eid, **data)
|
|
1041
|
+
if "entity_id" not in data
|
|
1042
|
+
else DatabaseRecord(**data)
|
|
1043
|
+
for eid, data in source.items()
|
|
1044
|
+
]
|
|
1045
|
+
elif isinstance(source, list):
|
|
1046
|
+
entities = [
|
|
1047
|
+
DatabaseRecord(**e) if isinstance(e, dict) else e
|
|
1048
|
+
for e in source
|
|
1049
|
+
]
|
|
1050
|
+
else:
|
|
1051
|
+
raise TypeError(f"Expected file path, list, or dict; got {type(source)}")
|
|
1052
|
+
|
|
1053
|
+
return self.load_records(
|
|
1054
|
+
entities,
|
|
1055
|
+
target_layers=target_layers,
|
|
1056
|
+
batch_size=batch_size,
|
|
1057
|
+
overwrite=overwrite,
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
def load_records(
|
|
1061
|
+
self,
|
|
1062
|
+
entities: List[DatabaseRecord],
|
|
1063
|
+
target_layers: List[str] = None,
|
|
1064
|
+
batch_size: int = 1000,
|
|
1065
|
+
overwrite: bool = False,
|
|
1066
|
+
) -> Dict[str, int]:
|
|
1067
|
+
"""
|
|
1068
|
+
Load pre-built DatabaseRecord objects into layers.
|
|
1069
|
+
|
|
1070
|
+
Args:
|
|
1071
|
+
entities: list of DatabaseRecord instances
|
|
1072
|
+
target_layers: layer types to target (None = all writable)
|
|
1073
|
+
batch_size: batch size for bulk operations
|
|
1074
|
+
overwrite: overwrite existing entities
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
Dict of layer_type -> count loaded
|
|
1078
|
+
"""
|
|
1079
|
+
# Determine target layers
|
|
1080
|
+
if target_layers is None:
|
|
1081
|
+
target_layers = [l.config.type for l in self.layers if l.write]
|
|
1082
|
+
|
|
1083
|
+
# Load to each layer
|
|
1084
|
+
results = {}
|
|
1085
|
+
for layer in self.layers:
|
|
1086
|
+
if layer.config.type not in target_layers:
|
|
1087
|
+
continue
|
|
1088
|
+
|
|
1089
|
+
if not layer.is_available():
|
|
1090
|
+
continue
|
|
1091
|
+
|
|
1092
|
+
count = layer.load_bulk(entities, overwrite=overwrite, batch_size=batch_size)
|
|
1093
|
+
results[layer.config.type] = count
|
|
1094
|
+
|
|
1095
|
+
return results
|
|
1096
|
+
|
|
1097
|
+
@staticmethod
|
|
1098
|
+
def _parse_jsonl(filepath: Union[str, Path]) -> List[DatabaseRecord]:
|
|
1099
|
+
"""Parse JSONL file into DatabaseRecord list."""
|
|
1100
|
+
entities = []
|
|
1101
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
|
1102
|
+
for line_num, line in enumerate(f, 1):
|
|
1103
|
+
line = line.strip()
|
|
1104
|
+
if not line:
|
|
1105
|
+
continue
|
|
1106
|
+
try:
|
|
1107
|
+
data = json.loads(line)
|
|
1108
|
+
entities.append(DatabaseRecord(**data))
|
|
1109
|
+
except Exception as e:
|
|
1110
|
+
print(f"[WARN] Line {line_num} parse error: {e}")
|
|
1111
|
+
continue
|
|
1112
|
+
return entities
|
|
1113
|
+
|
|
1114
|
+
def clear_layers(self, layer_names: List[str] = None):
|
|
1115
|
+
"""Clear all entities in specified layers"""
|
|
1116
|
+
for layer in self.layers:
|
|
1117
|
+
if layer_names and layer.config.type not in layer_names:
|
|
1118
|
+
continue
|
|
1119
|
+
|
|
1120
|
+
print(f"Clearing {layer.config.type}...")
|
|
1121
|
+
layer.clear()
|
|
1122
|
+
print(f"✓ Cleared")
|
|
1123
|
+
|
|
1124
|
+
def get_all_entities(self) -> List[DatabaseRecord]:
|
|
1125
|
+
"""Get all entities from all available layers (deduplicated)"""
|
|
1126
|
+
all_entities = []
|
|
1127
|
+
for layer in self.layers:
|
|
1128
|
+
if layer.is_available():
|
|
1129
|
+
all_entities.extend(layer.get_all_entities())
|
|
1130
|
+
return self.deduplicate_candidates(all_entities)
|
|
1131
|
+
|
|
1132
|
+
def count_entities(self) -> Dict[str, int]:
|
|
1133
|
+
"""Count entities in each layer"""
|
|
1134
|
+
counts = {}
|
|
1135
|
+
for layer in self.layers:
|
|
1136
|
+
counts[layer.config.type] = layer.count()
|
|
1137
|
+
return counts
|
|
1138
|
+
|
|
1139
|
+
def precompute_embeddings(
|
|
1140
|
+
self,
|
|
1141
|
+
encoder_fn,
|
|
1142
|
+
template: str,
|
|
1143
|
+
model_id: str,
|
|
1144
|
+
target_layers: List[str] = None,
|
|
1145
|
+
batch_size: int = 32
|
|
1146
|
+
) -> Dict[str, int]:
|
|
1147
|
+
"""
|
|
1148
|
+
Precompute embeddings for all entities in specified layers.
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
encoder_fn: Callable that takes List[str] and returns embeddings tensor
|
|
1152
|
+
template: Template string for formatting labels (e.g., "{label}: {description}")
|
|
1153
|
+
model_id: Model identifier to store with embeddings
|
|
1154
|
+
target_layers: Layer types to update (None = all)
|
|
1155
|
+
batch_size: Batch size for encoding
|
|
1156
|
+
|
|
1157
|
+
Returns:
|
|
1158
|
+
Dict with count of updated entities per layer
|
|
1159
|
+
"""
|
|
1160
|
+
from tqdm import tqdm
|
|
1161
|
+
|
|
1162
|
+
results = {}
|
|
1163
|
+
|
|
1164
|
+
for layer in self.layers:
|
|
1165
|
+
if target_layers and layer.config.type not in target_layers:
|
|
1166
|
+
continue
|
|
1167
|
+
|
|
1168
|
+
if not layer.is_available():
|
|
1169
|
+
print(f"[WARN] {layer.config.type} unavailable, skipping")
|
|
1170
|
+
continue
|
|
1171
|
+
|
|
1172
|
+
print(f"\nPrecomputing embeddings for {layer.config.type}...")
|
|
1173
|
+
|
|
1174
|
+
# Get all entities
|
|
1175
|
+
entities = layer.get_all_entities()
|
|
1176
|
+
if not entities:
|
|
1177
|
+
print(f" No entities found in {layer.config.type}")
|
|
1178
|
+
continue
|
|
1179
|
+
|
|
1180
|
+
print(f" Found {len(entities)} entities")
|
|
1181
|
+
|
|
1182
|
+
# Format labels using template
|
|
1183
|
+
labels = []
|
|
1184
|
+
entity_ids = []
|
|
1185
|
+
for entity in entities:
|
|
1186
|
+
try:
|
|
1187
|
+
formatted = template.format(**entity.dict())
|
|
1188
|
+
labels.append(formatted)
|
|
1189
|
+
entity_ids.append(entity.entity_id)
|
|
1190
|
+
except KeyError as e:
|
|
1191
|
+
print(f" [WARN] Template error for {entity.entity_id}: {e}")
|
|
1192
|
+
continue
|
|
1193
|
+
|
|
1194
|
+
# Encode in batches
|
|
1195
|
+
all_embeddings = []
|
|
1196
|
+
for i in tqdm(range(0, len(labels), batch_size), desc="Encoding"):
|
|
1197
|
+
batch_labels = labels[i:i + batch_size]
|
|
1198
|
+
batch_embeddings = encoder_fn(batch_labels)
|
|
1199
|
+
|
|
1200
|
+
# Convert to list if tensor
|
|
1201
|
+
if hasattr(batch_embeddings, 'tolist'):
|
|
1202
|
+
batch_embeddings = batch_embeddings.tolist()
|
|
1203
|
+
elif hasattr(batch_embeddings, 'cpu'):
|
|
1204
|
+
batch_embeddings = batch_embeddings.cpu().numpy().tolist()
|
|
1205
|
+
|
|
1206
|
+
all_embeddings.extend(batch_embeddings)
|
|
1207
|
+
|
|
1208
|
+
# Update layer
|
|
1209
|
+
updated = layer.update_embeddings(entity_ids, all_embeddings, model_id)
|
|
1210
|
+
results[layer.config.type] = updated
|
|
1211
|
+
print(f" Updated {updated} entities with embeddings")
|
|
1212
|
+
|
|
1213
|
+
return results
|
|
1214
|
+
|
|
1215
|
+
def get_layer(self, layer_type: str) -> DatabaseLayer:
|
|
1216
|
+
"""Get layer by type"""
|
|
1217
|
+
for layer in self.layers:
|
|
1218
|
+
if layer.config.type == layer_type:
|
|
1219
|
+
return layer
|
|
1220
|
+
return None
|