hindsight-api 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +38 -0
- hindsight_api/api/__init__.py +105 -0
- hindsight_api/api/http.py +1872 -0
- hindsight_api/api/mcp.py +157 -0
- hindsight_api/engine/__init__.py +47 -0
- hindsight_api/engine/cross_encoder.py +97 -0
- hindsight_api/engine/db_utils.py +93 -0
- hindsight_api/engine/embeddings.py +113 -0
- hindsight_api/engine/entity_resolver.py +575 -0
- hindsight_api/engine/llm_wrapper.py +269 -0
- hindsight_api/engine/memory_engine.py +3095 -0
- hindsight_api/engine/query_analyzer.py +519 -0
- hindsight_api/engine/response_models.py +222 -0
- hindsight_api/engine/retain/__init__.py +50 -0
- hindsight_api/engine/retain/bank_utils.py +423 -0
- hindsight_api/engine/retain/chunk_storage.py +82 -0
- hindsight_api/engine/retain/deduplication.py +104 -0
- hindsight_api/engine/retain/embedding_processing.py +62 -0
- hindsight_api/engine/retain/embedding_utils.py +54 -0
- hindsight_api/engine/retain/entity_processing.py +90 -0
- hindsight_api/engine/retain/fact_extraction.py +1027 -0
- hindsight_api/engine/retain/fact_storage.py +176 -0
- hindsight_api/engine/retain/link_creation.py +121 -0
- hindsight_api/engine/retain/link_utils.py +651 -0
- hindsight_api/engine/retain/orchestrator.py +405 -0
- hindsight_api/engine/retain/types.py +206 -0
- hindsight_api/engine/search/__init__.py +15 -0
- hindsight_api/engine/search/fusion.py +122 -0
- hindsight_api/engine/search/observation_utils.py +132 -0
- hindsight_api/engine/search/reranking.py +103 -0
- hindsight_api/engine/search/retrieval.py +503 -0
- hindsight_api/engine/search/scoring.py +161 -0
- hindsight_api/engine/search/temporal_extraction.py +64 -0
- hindsight_api/engine/search/think_utils.py +255 -0
- hindsight_api/engine/search/trace.py +215 -0
- hindsight_api/engine/search/tracer.py +447 -0
- hindsight_api/engine/search/types.py +160 -0
- hindsight_api/engine/task_backend.py +223 -0
- hindsight_api/engine/utils.py +203 -0
- hindsight_api/metrics.py +227 -0
- hindsight_api/migrations.py +163 -0
- hindsight_api/models.py +309 -0
- hindsight_api/pg0.py +425 -0
- hindsight_api/web/__init__.py +12 -0
- hindsight_api/web/server.py +143 -0
- hindsight_api-0.0.13.dist-info/METADATA +41 -0
- hindsight_api-0.0.13.dist-info/RECORD +48 -0
- hindsight_api-0.0.13.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Entity extraction and resolution for memory system.
|
|
3
|
+
|
|
4
|
+
Uses spaCy for entity extraction and implements resolution logic
|
|
5
|
+
to disambiguate entities across memory units.
|
|
6
|
+
"""
|
|
7
|
+
import asyncpg
|
|
8
|
+
from typing import List, Dict, Optional, Set, Any
|
|
9
|
+
from difflib import SequenceMatcher
|
|
10
|
+
from datetime import datetime, timezone
|
|
11
|
+
from .db_utils import acquire_with_retry
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# Load spaCy model (singleton)
|
|
15
|
+
_nlp = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EntityResolver:
|
|
19
|
+
"""
|
|
20
|
+
Resolves entities to canonical IDs with disambiguation.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, pool: asyncpg.Pool):
|
|
24
|
+
"""
|
|
25
|
+
Initialize entity resolver.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
pool: asyncpg connection pool
|
|
29
|
+
"""
|
|
30
|
+
self.pool = pool
|
|
31
|
+
|
|
32
|
+
async def resolve_entities_batch(
|
|
33
|
+
self,
|
|
34
|
+
bank_id: str,
|
|
35
|
+
entities_data: List[Dict],
|
|
36
|
+
context: str,
|
|
37
|
+
unit_event_date,
|
|
38
|
+
conn=None,
|
|
39
|
+
) -> List[str]:
|
|
40
|
+
"""
|
|
41
|
+
Resolve multiple entities in batch (MUCH faster than sequential).
|
|
42
|
+
|
|
43
|
+
Groups entities by type, queries candidates in bulk, and resolves
|
|
44
|
+
all entities with minimal DB queries.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
bank_id: bank ID
|
|
48
|
+
entities_data: List of dicts with 'text', 'type', 'nearby_entities'
|
|
49
|
+
context: Context where entities appear
|
|
50
|
+
unit_event_date: When this unit was created
|
|
51
|
+
conn: Optional connection to use (if None, acquires from pool)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List of entity IDs in same order as input
|
|
55
|
+
"""
|
|
56
|
+
if not entities_data:
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
if conn is None:
|
|
60
|
+
async with acquire_with_retry(self.pool) as conn:
|
|
61
|
+
return await self._resolve_entities_batch_impl(conn, bank_id, entities_data, context, unit_event_date)
|
|
62
|
+
else:
|
|
63
|
+
return await self._resolve_entities_batch_impl(conn, bank_id, entities_data, context, unit_event_date)
|
|
64
|
+
|
|
65
|
+
async def _resolve_entities_batch_impl(self, conn, bank_id: str, entities_data: List[Dict], context: str, unit_event_date) -> List[str]:
|
|
66
|
+
# Query ALL candidates for this bank
|
|
67
|
+
all_entities = await conn.fetch(
|
|
68
|
+
"""
|
|
69
|
+
SELECT canonical_name, id, metadata, last_seen, mention_count
|
|
70
|
+
FROM entities
|
|
71
|
+
WHERE bank_id = $1
|
|
72
|
+
""",
|
|
73
|
+
bank_id
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Build entity ID to name mapping for co-occurrence lookups
|
|
77
|
+
entity_id_to_name = {row['id']: row['canonical_name'].lower() for row in all_entities}
|
|
78
|
+
|
|
79
|
+
# Query ALL co-occurrences for this bank's entities in one query
|
|
80
|
+
# This builds a map of entity_id -> set of co-occurring entity names
|
|
81
|
+
all_cooccurrences = await conn.fetch(
|
|
82
|
+
"""
|
|
83
|
+
SELECT ec.entity_id_1, ec.entity_id_2, ec.cooccurrence_count
|
|
84
|
+
FROM entity_cooccurrences ec
|
|
85
|
+
WHERE ec.entity_id_1 IN (SELECT id FROM entities WHERE bank_id = $1)
|
|
86
|
+
OR ec.entity_id_2 IN (SELECT id FROM entities WHERE bank_id = $1)
|
|
87
|
+
""",
|
|
88
|
+
bank_id
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Build co-occurrence map: entity_id -> set of co-occurring entity names (lowercase)
|
|
92
|
+
cooccurrence_map: Dict[str, Set[str]] = {}
|
|
93
|
+
for row in all_cooccurrences:
|
|
94
|
+
eid1, eid2 = row['entity_id_1'], row['entity_id_2']
|
|
95
|
+
# Add both directions
|
|
96
|
+
if eid1 not in cooccurrence_map:
|
|
97
|
+
cooccurrence_map[eid1] = set()
|
|
98
|
+
if eid2 not in cooccurrence_map:
|
|
99
|
+
cooccurrence_map[eid2] = set()
|
|
100
|
+
# Map to canonical names for comparison with nearby_entities
|
|
101
|
+
if eid2 in entity_id_to_name:
|
|
102
|
+
cooccurrence_map[eid1].add(entity_id_to_name[eid2])
|
|
103
|
+
if eid1 in entity_id_to_name:
|
|
104
|
+
cooccurrence_map[eid2].add(entity_id_to_name[eid1])
|
|
105
|
+
|
|
106
|
+
# Build candidate map for each entity text
|
|
107
|
+
all_candidates = {} # Maps entity_text -> list of candidates
|
|
108
|
+
entity_texts = list(set(e['text'] for e in entities_data))
|
|
109
|
+
|
|
110
|
+
for entity_text in entity_texts:
|
|
111
|
+
matching = []
|
|
112
|
+
entity_text_lower = entity_text.lower()
|
|
113
|
+
for row in all_entities:
|
|
114
|
+
canonical_name = row['canonical_name']
|
|
115
|
+
ent_id = row['id']
|
|
116
|
+
metadata = row['metadata']
|
|
117
|
+
last_seen = row['last_seen']
|
|
118
|
+
mention_count = row['mention_count']
|
|
119
|
+
canonical_lower = canonical_name.lower()
|
|
120
|
+
# Match if exact or substring match
|
|
121
|
+
if (entity_text_lower == canonical_lower or
|
|
122
|
+
entity_text_lower in canonical_lower or
|
|
123
|
+
canonical_lower in entity_text_lower):
|
|
124
|
+
matching.append((ent_id, canonical_name, metadata, last_seen, mention_count))
|
|
125
|
+
all_candidates[entity_text] = matching
|
|
126
|
+
|
|
127
|
+
# Resolve each entity using pre-fetched candidates
|
|
128
|
+
entity_ids = [None] * len(entities_data)
|
|
129
|
+
entities_to_update = [] # (entity_id, unit_event_date)
|
|
130
|
+
entities_to_create = [] # (idx, entity_data)
|
|
131
|
+
|
|
132
|
+
for idx, entity_data in enumerate(entities_data):
|
|
133
|
+
entity_text = entity_data['text']
|
|
134
|
+
nearby_entities = entity_data.get('nearby_entities', [])
|
|
135
|
+
|
|
136
|
+
candidates = all_candidates.get(entity_text, [])
|
|
137
|
+
|
|
138
|
+
if not candidates:
|
|
139
|
+
# Will create new entity
|
|
140
|
+
entities_to_create.append((idx, entity_data))
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
# Score candidates
|
|
144
|
+
best_candidate = None
|
|
145
|
+
best_score = 0.0
|
|
146
|
+
|
|
147
|
+
nearby_entity_set = {e['text'].lower() for e in nearby_entities if e['text'] != entity_text}
|
|
148
|
+
|
|
149
|
+
for candidate_id, canonical_name, metadata, last_seen, mention_count in candidates:
|
|
150
|
+
score = 0.0
|
|
151
|
+
|
|
152
|
+
# 1. Name similarity (0-0.5)
|
|
153
|
+
name_similarity = SequenceMatcher(
|
|
154
|
+
None,
|
|
155
|
+
entity_text.lower(),
|
|
156
|
+
canonical_name.lower()
|
|
157
|
+
).ratio()
|
|
158
|
+
score += name_similarity * 0.5
|
|
159
|
+
|
|
160
|
+
# 2. Co-occurring entities (0-0.3)
|
|
161
|
+
if nearby_entity_set:
|
|
162
|
+
co_entities = cooccurrence_map.get(candidate_id, set())
|
|
163
|
+
overlap = len(nearby_entity_set & co_entities)
|
|
164
|
+
co_entity_score = overlap / len(nearby_entity_set)
|
|
165
|
+
score += co_entity_score * 0.3
|
|
166
|
+
|
|
167
|
+
# 3. Temporal proximity (0-0.2)
|
|
168
|
+
if last_seen:
|
|
169
|
+
# Normalize timezone awareness for comparison
|
|
170
|
+
event_date_utc = unit_event_date if unit_event_date.tzinfo else unit_event_date.replace(tzinfo=timezone.utc)
|
|
171
|
+
last_seen_utc = last_seen if last_seen.tzinfo else last_seen.replace(tzinfo=timezone.utc)
|
|
172
|
+
days_diff = abs((event_date_utc - last_seen_utc).total_seconds() / 86400)
|
|
173
|
+
if days_diff < 7:
|
|
174
|
+
temporal_score = max(0, 1.0 - (days_diff / 7))
|
|
175
|
+
score += temporal_score * 0.2
|
|
176
|
+
|
|
177
|
+
if score > best_score:
|
|
178
|
+
best_score = score
|
|
179
|
+
best_candidate = candidate_id
|
|
180
|
+
|
|
181
|
+
# Apply unified threshold
|
|
182
|
+
threshold = 0.6
|
|
183
|
+
|
|
184
|
+
if best_score > threshold:
|
|
185
|
+
entity_ids[idx] = best_candidate
|
|
186
|
+
entities_to_update.append((best_candidate, unit_event_date))
|
|
187
|
+
else:
|
|
188
|
+
entities_to_create.append((idx, entity_data))
|
|
189
|
+
|
|
190
|
+
# Batch update existing entities
|
|
191
|
+
if entities_to_update:
|
|
192
|
+
await conn.executemany(
|
|
193
|
+
"""
|
|
194
|
+
UPDATE entities SET
|
|
195
|
+
mention_count = mention_count + 1,
|
|
196
|
+
last_seen = $2
|
|
197
|
+
WHERE id = $1::uuid
|
|
198
|
+
""",
|
|
199
|
+
entities_to_update
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Create new entities using INSERT ... ON CONFLICT to handle race conditions
|
|
203
|
+
# This ensures that if two concurrent transactions try to create the same entity,
|
|
204
|
+
# only one succeeds and the other gets the existing ID
|
|
205
|
+
if entities_to_create:
|
|
206
|
+
for idx, entity_data in entities_to_create:
|
|
207
|
+
# Use INSERT ... ON CONFLICT to atomically get-or-create
|
|
208
|
+
# The unique index is on (bank_id, LOWER(canonical_name))
|
|
209
|
+
row = await conn.fetchrow(
|
|
210
|
+
"""
|
|
211
|
+
INSERT INTO entities (bank_id, canonical_name, first_seen, last_seen, mention_count)
|
|
212
|
+
VALUES ($1, $2, $3, $4, 1)
|
|
213
|
+
ON CONFLICT (bank_id, LOWER(canonical_name))
|
|
214
|
+
DO UPDATE SET
|
|
215
|
+
mention_count = entities.mention_count + 1,
|
|
216
|
+
last_seen = EXCLUDED.last_seen
|
|
217
|
+
RETURNING id
|
|
218
|
+
""",
|
|
219
|
+
bank_id,
|
|
220
|
+
entity_data['text'],
|
|
221
|
+
unit_event_date,
|
|
222
|
+
unit_event_date
|
|
223
|
+
)
|
|
224
|
+
entity_ids[idx] = row['id']
|
|
225
|
+
|
|
226
|
+
return entity_ids
|
|
227
|
+
|
|
228
|
+
async def resolve_entity(
|
|
229
|
+
self,
|
|
230
|
+
bank_id: str,
|
|
231
|
+
entity_text: str,
|
|
232
|
+
context: str,
|
|
233
|
+
nearby_entities: List[Dict],
|
|
234
|
+
unit_event_date,
|
|
235
|
+
) -> str:
|
|
236
|
+
"""
|
|
237
|
+
Resolve an entity to a canonical entity ID.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
bank_id: bank ID (entities are scoped to agents)
|
|
241
|
+
entity_text: Entity text ("Alice", "Google", etc.)
|
|
242
|
+
context: Context where entity appears
|
|
243
|
+
nearby_entities: Other entities in the same unit
|
|
244
|
+
unit_event_date: When this unit was created
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Entity ID (creates new entity if needed)
|
|
248
|
+
"""
|
|
249
|
+
async with acquire_with_retry(self.pool) as conn:
|
|
250
|
+
# Find candidate entities with similar name
|
|
251
|
+
candidates = await conn.fetch(
|
|
252
|
+
"""
|
|
253
|
+
SELECT id, canonical_name, metadata, last_seen
|
|
254
|
+
FROM entities
|
|
255
|
+
WHERE bank_id = $1
|
|
256
|
+
AND (
|
|
257
|
+
canonical_name ILIKE $2
|
|
258
|
+
OR canonical_name ILIKE $3
|
|
259
|
+
OR $2 ILIKE canonical_name || '%%'
|
|
260
|
+
)
|
|
261
|
+
ORDER BY mention_count DESC
|
|
262
|
+
""",
|
|
263
|
+
bank_id, entity_text, f"%{entity_text}%"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if not candidates:
|
|
267
|
+
# New entity - create it
|
|
268
|
+
return await self._create_entity(
|
|
269
|
+
conn, bank_id, entity_text, unit_event_date
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Score candidates based on:
|
|
273
|
+
# 1. Name similarity
|
|
274
|
+
# 2. Context overlap (TODO: could use embeddings)
|
|
275
|
+
# 3. Co-occurring entities
|
|
276
|
+
# 4. Temporal proximity
|
|
277
|
+
|
|
278
|
+
best_candidate = None
|
|
279
|
+
best_score = 0.0
|
|
280
|
+
best_name_similarity = 0.0
|
|
281
|
+
|
|
282
|
+
nearby_entity_set = {e['text'].lower() for e in nearby_entities if e['text'] != entity_text}
|
|
283
|
+
|
|
284
|
+
for row in candidates:
|
|
285
|
+
candidate_id = row['id']
|
|
286
|
+
canonical_name = row['canonical_name']
|
|
287
|
+
metadata = row['metadata']
|
|
288
|
+
last_seen = row['last_seen']
|
|
289
|
+
score = 0.0
|
|
290
|
+
|
|
291
|
+
# 1. Name similarity (0-1)
|
|
292
|
+
name_similarity = SequenceMatcher(
|
|
293
|
+
None,
|
|
294
|
+
entity_text.lower(),
|
|
295
|
+
canonical_name.lower()
|
|
296
|
+
).ratio()
|
|
297
|
+
score += name_similarity * 0.5
|
|
298
|
+
|
|
299
|
+
# 2. Co-occurring entities (0-0.5)
|
|
300
|
+
# Get entities that co-occurred with this candidate before
|
|
301
|
+
# Use the materialized co-occurrence cache for fast lookup
|
|
302
|
+
co_entity_rows = await conn.fetch(
|
|
303
|
+
"""
|
|
304
|
+
SELECT e.canonical_name, ec.cooccurrence_count
|
|
305
|
+
FROM entity_cooccurrences ec
|
|
306
|
+
JOIN entities e ON (
|
|
307
|
+
CASE
|
|
308
|
+
WHEN ec.entity_id_1 = $1 THEN ec.entity_id_2
|
|
309
|
+
WHEN ec.entity_id_2 = $1 THEN ec.entity_id_1
|
|
310
|
+
END = e.id
|
|
311
|
+
)
|
|
312
|
+
WHERE ec.entity_id_1 = $1 OR ec.entity_id_2 = $1
|
|
313
|
+
""",
|
|
314
|
+
candidate_id
|
|
315
|
+
)
|
|
316
|
+
co_entities = {r['canonical_name'].lower() for r in co_entity_rows}
|
|
317
|
+
|
|
318
|
+
# Check overlap with nearby entities
|
|
319
|
+
overlap = len(nearby_entity_set & co_entities)
|
|
320
|
+
if nearby_entity_set:
|
|
321
|
+
co_entity_score = overlap / len(nearby_entity_set)
|
|
322
|
+
score += co_entity_score * 0.3
|
|
323
|
+
|
|
324
|
+
# 3. Temporal proximity (0-0.2)
|
|
325
|
+
if last_seen:
|
|
326
|
+
days_diff = abs((unit_event_date - last_seen).total_seconds() / 86400)
|
|
327
|
+
if days_diff < 7: # Within a week
|
|
328
|
+
temporal_score = max(0, 1.0 - (days_diff / 7))
|
|
329
|
+
score += temporal_score * 0.2
|
|
330
|
+
|
|
331
|
+
if score > best_score:
|
|
332
|
+
best_score = score
|
|
333
|
+
best_candidate = candidate_id
|
|
334
|
+
best_name_similarity = name_similarity
|
|
335
|
+
|
|
336
|
+
# Threshold for considering it the same entity
|
|
337
|
+
threshold = 0.6
|
|
338
|
+
|
|
339
|
+
if best_score > threshold:
|
|
340
|
+
# Update entity
|
|
341
|
+
await conn.execute(
|
|
342
|
+
"""
|
|
343
|
+
UPDATE entities
|
|
344
|
+
SET mention_count = mention_count + 1,
|
|
345
|
+
last_seen = $1
|
|
346
|
+
WHERE id = $2
|
|
347
|
+
""",
|
|
348
|
+
unit_event_date, best_candidate
|
|
349
|
+
)
|
|
350
|
+
return best_candidate
|
|
351
|
+
else:
|
|
352
|
+
# Not confident - create new entity
|
|
353
|
+
return await self._create_entity(
|
|
354
|
+
conn, bank_id, entity_text, unit_event_date
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
async def _create_entity(
|
|
358
|
+
self,
|
|
359
|
+
conn,
|
|
360
|
+
bank_id: str,
|
|
361
|
+
entity_text: str,
|
|
362
|
+
event_date,
|
|
363
|
+
) -> str:
|
|
364
|
+
"""
|
|
365
|
+
Create a new entity or get existing one if it already exists.
|
|
366
|
+
|
|
367
|
+
Uses INSERT ... ON CONFLICT to handle race conditions where
|
|
368
|
+
two concurrent transactions try to create the same entity.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
conn: Database connection
|
|
372
|
+
bank_id: bank ID
|
|
373
|
+
entity_text: Entity text
|
|
374
|
+
event_date: When first seen
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Entity ID
|
|
378
|
+
"""
|
|
379
|
+
entity_id = await conn.fetchval(
|
|
380
|
+
"""
|
|
381
|
+
INSERT INTO entities (bank_id, canonical_name, first_seen, last_seen, mention_count)
|
|
382
|
+
VALUES ($1, $2, $3, $4, 1)
|
|
383
|
+
ON CONFLICT (bank_id, LOWER(canonical_name))
|
|
384
|
+
DO UPDATE SET
|
|
385
|
+
mention_count = entities.mention_count + 1,
|
|
386
|
+
last_seen = EXCLUDED.last_seen
|
|
387
|
+
RETURNING id
|
|
388
|
+
""",
|
|
389
|
+
bank_id, entity_text, event_date, event_date
|
|
390
|
+
)
|
|
391
|
+
return entity_id
|
|
392
|
+
|
|
393
|
+
async def link_unit_to_entity(self, unit_id: str, entity_id: str):
|
|
394
|
+
"""
|
|
395
|
+
Link a memory unit to an entity.
|
|
396
|
+
Also updates co-occurrence cache with other entities in the same unit.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
unit_id: Memory unit ID
|
|
400
|
+
entity_id: Entity ID
|
|
401
|
+
"""
|
|
402
|
+
async with acquire_with_retry(self.pool) as conn:
|
|
403
|
+
# Insert unit-entity link
|
|
404
|
+
await conn.execute(
|
|
405
|
+
"""
|
|
406
|
+
INSERT INTO unit_entities (unit_id, entity_id)
|
|
407
|
+
VALUES ($1, $2)
|
|
408
|
+
ON CONFLICT DO NOTHING
|
|
409
|
+
""",
|
|
410
|
+
unit_id, entity_id
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Update co-occurrence cache: find other entities in this unit
|
|
414
|
+
rows = await conn.fetch(
|
|
415
|
+
"""
|
|
416
|
+
SELECT entity_id
|
|
417
|
+
FROM unit_entities
|
|
418
|
+
WHERE unit_id = $1 AND entity_id != $2
|
|
419
|
+
""",
|
|
420
|
+
unit_id, entity_id
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
other_entities = [row['entity_id'] for row in rows]
|
|
424
|
+
|
|
425
|
+
# Update co-occurrences for each pair
|
|
426
|
+
for other_entity_id in other_entities:
|
|
427
|
+
await self._update_cooccurrence(conn, entity_id, other_entity_id)
|
|
428
|
+
|
|
429
|
+
async def _update_cooccurrence(self, conn, entity_id_1: str, entity_id_2: str):
|
|
430
|
+
"""
|
|
431
|
+
Update the co-occurrence cache for two entities.
|
|
432
|
+
|
|
433
|
+
Uses CHECK constraint ordering (entity_id_1 < entity_id_2) to avoid duplicates.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
conn: Database connection
|
|
437
|
+
entity_id_1: First entity ID
|
|
438
|
+
entity_id_2: Second entity ID
|
|
439
|
+
"""
|
|
440
|
+
# Ensure consistent ordering (smaller UUID first)
|
|
441
|
+
if entity_id_1 > entity_id_2:
|
|
442
|
+
entity_id_1, entity_id_2 = entity_id_2, entity_id_1
|
|
443
|
+
|
|
444
|
+
await conn.execute(
|
|
445
|
+
"""
|
|
446
|
+
INSERT INTO entity_cooccurrences (entity_id_1, entity_id_2, cooccurrence_count, last_cooccurred)
|
|
447
|
+
VALUES ($1, $2, 1, NOW())
|
|
448
|
+
ON CONFLICT (entity_id_1, entity_id_2)
|
|
449
|
+
DO UPDATE SET
|
|
450
|
+
cooccurrence_count = entity_cooccurrences.cooccurrence_count + 1,
|
|
451
|
+
last_cooccurred = NOW()
|
|
452
|
+
""",
|
|
453
|
+
entity_id_1, entity_id_2
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
async def link_units_to_entities_batch(self, unit_entity_pairs: List[tuple[str, str]], conn=None):
|
|
457
|
+
"""
|
|
458
|
+
Link multiple memory units to entities in batch (MUCH faster than sequential).
|
|
459
|
+
|
|
460
|
+
Also updates co-occurrence cache for entities that appear in the same unit.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
unit_entity_pairs: List of (unit_id, entity_id) tuples
|
|
464
|
+
conn: Optional connection to use (if None, acquires from pool)
|
|
465
|
+
"""
|
|
466
|
+
if not unit_entity_pairs:
|
|
467
|
+
return
|
|
468
|
+
|
|
469
|
+
if conn is None:
|
|
470
|
+
async with acquire_with_retry(self.pool) as conn:
|
|
471
|
+
return await self._link_units_to_entities_batch_impl(conn, unit_entity_pairs)
|
|
472
|
+
else:
|
|
473
|
+
return await self._link_units_to_entities_batch_impl(conn, unit_entity_pairs)
|
|
474
|
+
|
|
475
|
+
async def _link_units_to_entities_batch_impl(self, conn, unit_entity_pairs: List[tuple[str, str]]):
|
|
476
|
+
# Batch insert all unit-entity links
|
|
477
|
+
await conn.executemany(
|
|
478
|
+
"""
|
|
479
|
+
INSERT INTO unit_entities (unit_id, entity_id)
|
|
480
|
+
VALUES ($1, $2)
|
|
481
|
+
ON CONFLICT DO NOTHING
|
|
482
|
+
""",
|
|
483
|
+
unit_entity_pairs
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Build map of unit -> entities for co-occurrence calculation
|
|
487
|
+
# Use sets to avoid duplicate entities in the same unit
|
|
488
|
+
unit_to_entities = {}
|
|
489
|
+
for unit_id, entity_id in unit_entity_pairs:
|
|
490
|
+
if unit_id not in unit_to_entities:
|
|
491
|
+
unit_to_entities[unit_id] = set()
|
|
492
|
+
unit_to_entities[unit_id].add(entity_id)
|
|
493
|
+
|
|
494
|
+
# Update co-occurrences for all pairs in each unit
|
|
495
|
+
cooccurrence_pairs = set() # Use set to avoid duplicates
|
|
496
|
+
for unit_id, entity_ids in unit_to_entities.items():
|
|
497
|
+
entity_list = list(entity_ids) # Convert set to list for iteration
|
|
498
|
+
# For each pair of entities in this unit, create co-occurrence
|
|
499
|
+
for i, entity_id_1 in enumerate(entity_list):
|
|
500
|
+
for entity_id_2 in entity_list[i+1:]:
|
|
501
|
+
# Skip if same entity (shouldn't happen with set, but be safe)
|
|
502
|
+
if entity_id_1 == entity_id_2:
|
|
503
|
+
continue
|
|
504
|
+
# Ensure consistent ordering (entity_id_1 < entity_id_2)
|
|
505
|
+
if entity_id_1 > entity_id_2:
|
|
506
|
+
entity_id_1, entity_id_2 = entity_id_2, entity_id_1
|
|
507
|
+
cooccurrence_pairs.add((entity_id_1, entity_id_2))
|
|
508
|
+
|
|
509
|
+
# Batch update co-occurrences
|
|
510
|
+
if cooccurrence_pairs:
|
|
511
|
+
now = datetime.now(timezone.utc)
|
|
512
|
+
await conn.executemany(
|
|
513
|
+
"""
|
|
514
|
+
INSERT INTO entity_cooccurrences (entity_id_1, entity_id_2, cooccurrence_count, last_cooccurred)
|
|
515
|
+
VALUES ($1, $2, $3, $4)
|
|
516
|
+
ON CONFLICT (entity_id_1, entity_id_2)
|
|
517
|
+
DO UPDATE SET
|
|
518
|
+
cooccurrence_count = entity_cooccurrences.cooccurrence_count + 1,
|
|
519
|
+
last_cooccurred = EXCLUDED.last_cooccurred
|
|
520
|
+
""",
|
|
521
|
+
[(e1, e2, 1, now) for e1, e2 in cooccurrence_pairs]
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
async def get_units_by_entity(self, entity_id: str, limit: int = 100) -> List[str]:
|
|
525
|
+
"""
|
|
526
|
+
Get all units that mention an entity.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
entity_id: Entity ID
|
|
530
|
+
limit: Max results
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
List of unit IDs
|
|
534
|
+
"""
|
|
535
|
+
async with acquire_with_retry(self.pool) as conn:
|
|
536
|
+
rows = await conn.fetch(
|
|
537
|
+
"""
|
|
538
|
+
SELECT unit_id
|
|
539
|
+
FROM unit_entities
|
|
540
|
+
WHERE entity_id = $1
|
|
541
|
+
ORDER BY unit_id
|
|
542
|
+
LIMIT $2
|
|
543
|
+
""",
|
|
544
|
+
entity_id, limit
|
|
545
|
+
)
|
|
546
|
+
return [row['unit_id'] for row in rows]
|
|
547
|
+
|
|
548
|
+
async def get_entity_by_text(
|
|
549
|
+
self,
|
|
550
|
+
bank_id: str,
|
|
551
|
+
entity_text: str,
|
|
552
|
+
) -> Optional[str]:
|
|
553
|
+
"""
|
|
554
|
+
Find an entity by text (for query resolution).
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
bank_id: bank ID
|
|
558
|
+
entity_text: Entity text to search for
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
Entity ID if found, None otherwise
|
|
562
|
+
"""
|
|
563
|
+
async with acquire_with_retry(self.pool) as conn:
|
|
564
|
+
row = await conn.fetchrow(
|
|
565
|
+
"""
|
|
566
|
+
SELECT id FROM entities
|
|
567
|
+
WHERE bank_id = $1
|
|
568
|
+
AND canonical_name ILIKE $2
|
|
569
|
+
ORDER BY mention_count DESC
|
|
570
|
+
LIMIT 1
|
|
571
|
+
""",
|
|
572
|
+
bank_id, entity_text
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
return row['id'] if row else None
|