hindsight-api 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +38 -0
- hindsight_api/api/__init__.py +105 -0
- hindsight_api/api/http.py +1872 -0
- hindsight_api/api/mcp.py +157 -0
- hindsight_api/engine/__init__.py +47 -0
- hindsight_api/engine/cross_encoder.py +97 -0
- hindsight_api/engine/db_utils.py +93 -0
- hindsight_api/engine/embeddings.py +113 -0
- hindsight_api/engine/entity_resolver.py +575 -0
- hindsight_api/engine/llm_wrapper.py +269 -0
- hindsight_api/engine/memory_engine.py +3095 -0
- hindsight_api/engine/query_analyzer.py +519 -0
- hindsight_api/engine/response_models.py +222 -0
- hindsight_api/engine/retain/__init__.py +50 -0
- hindsight_api/engine/retain/bank_utils.py +423 -0
- hindsight_api/engine/retain/chunk_storage.py +82 -0
- hindsight_api/engine/retain/deduplication.py +104 -0
- hindsight_api/engine/retain/embedding_processing.py +62 -0
- hindsight_api/engine/retain/embedding_utils.py +54 -0
- hindsight_api/engine/retain/entity_processing.py +90 -0
- hindsight_api/engine/retain/fact_extraction.py +1027 -0
- hindsight_api/engine/retain/fact_storage.py +176 -0
- hindsight_api/engine/retain/link_creation.py +121 -0
- hindsight_api/engine/retain/link_utils.py +651 -0
- hindsight_api/engine/retain/orchestrator.py +405 -0
- hindsight_api/engine/retain/types.py +206 -0
- hindsight_api/engine/search/__init__.py +15 -0
- hindsight_api/engine/search/fusion.py +122 -0
- hindsight_api/engine/search/observation_utils.py +132 -0
- hindsight_api/engine/search/reranking.py +103 -0
- hindsight_api/engine/search/retrieval.py +503 -0
- hindsight_api/engine/search/scoring.py +161 -0
- hindsight_api/engine/search/temporal_extraction.py +64 -0
- hindsight_api/engine/search/think_utils.py +255 -0
- hindsight_api/engine/search/trace.py +215 -0
- hindsight_api/engine/search/tracer.py +447 -0
- hindsight_api/engine/search/types.py +160 -0
- hindsight_api/engine/task_backend.py +223 -0
- hindsight_api/engine/utils.py +203 -0
- hindsight_api/metrics.py +227 -0
- hindsight_api/migrations.py +163 -0
- hindsight_api/models.py +309 -0
- hindsight_api/pg0.py +425 -0
- hindsight_api/web/__init__.py +12 -0
- hindsight_api/web/server.py +143 -0
- hindsight_api-0.0.13.dist-info/METADATA +41 -0
- hindsight_api-0.0.13.dist-info/RECORD +48 -0
- hindsight_api-0.0.13.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Link creation utilities for temporal, semantic, and entity links.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
import logging
|
|
7
|
+
from typing import List
|
|
8
|
+
from datetime import timedelta, datetime, timezone
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _normalize_datetime(dt):
|
|
14
|
+
"""Normalize datetime to be timezone-aware (UTC) for consistent comparison."""
|
|
15
|
+
if dt is None:
|
|
16
|
+
return None
|
|
17
|
+
if dt.tzinfo is None:
|
|
18
|
+
# Naive datetime - assume UTC
|
|
19
|
+
return dt.replace(tzinfo=timezone.utc)
|
|
20
|
+
return dt
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def compute_temporal_links(
|
|
24
|
+
new_units: dict,
|
|
25
|
+
candidates: list,
|
|
26
|
+
time_window_hours: int = 24,
|
|
27
|
+
) -> list:
|
|
28
|
+
"""
|
|
29
|
+
Compute temporal links between new units and candidate neighbors.
|
|
30
|
+
|
|
31
|
+
This is a pure function that takes query results and returns link tuples,
|
|
32
|
+
making it easy to test without database access.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
new_units: Dict mapping unit_id (str) to event_date (datetime)
|
|
36
|
+
candidates: List of dicts with 'id' and 'event_date' keys (candidate neighbors)
|
|
37
|
+
time_window_hours: Time window in hours for temporal links
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
List of tuples: (from_unit_id, to_unit_id, 'temporal', weight, None)
|
|
41
|
+
"""
|
|
42
|
+
if not new_units:
|
|
43
|
+
return []
|
|
44
|
+
|
|
45
|
+
links = []
|
|
46
|
+
for unit_id, unit_event_date in new_units.items():
|
|
47
|
+
# Normalize unit_event_date for consistent comparison
|
|
48
|
+
unit_event_date_norm = _normalize_datetime(unit_event_date)
|
|
49
|
+
|
|
50
|
+
# Calculate time window bounds with overflow protection
|
|
51
|
+
try:
|
|
52
|
+
time_lower = unit_event_date_norm - timedelta(hours=time_window_hours)
|
|
53
|
+
except OverflowError:
|
|
54
|
+
time_lower = datetime.min.replace(tzinfo=timezone.utc)
|
|
55
|
+
try:
|
|
56
|
+
time_upper = unit_event_date_norm + timedelta(hours=time_window_hours)
|
|
57
|
+
except OverflowError:
|
|
58
|
+
time_upper = datetime.max.replace(tzinfo=timezone.utc)
|
|
59
|
+
|
|
60
|
+
# Filter candidates within this unit's time window
|
|
61
|
+
matching_neighbors = [
|
|
62
|
+
(row['id'], row['event_date'])
|
|
63
|
+
for row in candidates
|
|
64
|
+
if time_lower <= _normalize_datetime(row['event_date']) <= time_upper
|
|
65
|
+
][:10] # Limit to top 10
|
|
66
|
+
|
|
67
|
+
for recent_id, recent_event_date in matching_neighbors:
|
|
68
|
+
# Calculate temporal proximity weight
|
|
69
|
+
time_diff_hours = abs((unit_event_date_norm - _normalize_datetime(recent_event_date)).total_seconds() / 3600)
|
|
70
|
+
weight = max(0.3, 1.0 - (time_diff_hours / time_window_hours))
|
|
71
|
+
links.append((unit_id, str(recent_id), 'temporal', weight, None))
|
|
72
|
+
|
|
73
|
+
return links
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def compute_temporal_query_bounds(
|
|
77
|
+
new_units: dict,
|
|
78
|
+
time_window_hours: int = 24,
|
|
79
|
+
) -> tuple:
|
|
80
|
+
"""
|
|
81
|
+
Compute the min/max date bounds for querying temporal neighbors.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
new_units: Dict mapping unit_id (str) to event_date (datetime)
|
|
85
|
+
time_window_hours: Time window in hours
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Tuple of (min_date, max_date) with overflow protection
|
|
89
|
+
"""
|
|
90
|
+
if not new_units:
|
|
91
|
+
return None, None
|
|
92
|
+
|
|
93
|
+
# Normalize all dates to be timezone-aware to avoid comparison issues
|
|
94
|
+
all_dates = [_normalize_datetime(d) for d in new_units.values()]
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
min_date = min(all_dates) - timedelta(hours=time_window_hours)
|
|
98
|
+
except OverflowError:
|
|
99
|
+
min_date = datetime.min.replace(tzinfo=timezone.utc)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
max_date = max(all_dates) + timedelta(hours=time_window_hours)
|
|
103
|
+
except OverflowError:
|
|
104
|
+
max_date = datetime.max.replace(tzinfo=timezone.utc)
|
|
105
|
+
|
|
106
|
+
return min_date, max_date
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _log(log_buffer, message, level='info'):
|
|
110
|
+
"""Helper to log to buffer if available, otherwise use logger."""
|
|
111
|
+
if log_buffer is not None:
|
|
112
|
+
log_buffer.append(message)
|
|
113
|
+
else:
|
|
114
|
+
if level == 'info':
|
|
115
|
+
logger.info(message)
|
|
116
|
+
else:
|
|
117
|
+
logger.log(logging.WARNING if level == 'warning' else logging.ERROR, message)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def extract_entities_batch_optimized(
|
|
121
|
+
entity_resolver,
|
|
122
|
+
conn,
|
|
123
|
+
bank_id: str,
|
|
124
|
+
unit_ids: List[str],
|
|
125
|
+
sentences: List[str],
|
|
126
|
+
context: str,
|
|
127
|
+
fact_dates: List,
|
|
128
|
+
llm_entities: List[List[dict]],
|
|
129
|
+
log_buffer: List[str] = None,
|
|
130
|
+
) -> List[tuple]:
|
|
131
|
+
"""
|
|
132
|
+
Process LLM-extracted entities for ALL facts in batch.
|
|
133
|
+
|
|
134
|
+
Uses entities provided by the LLM (no spaCy needed), then resolves
|
|
135
|
+
and links them in bulk.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
entity_resolver: EntityResolver instance for entity resolution
|
|
139
|
+
conn: Database connection
|
|
140
|
+
agent_id: bank IDentifier
|
|
141
|
+
unit_ids: List of unit IDs
|
|
142
|
+
sentences: List of fact sentences
|
|
143
|
+
context: Context string
|
|
144
|
+
fact_dates: List of fact dates
|
|
145
|
+
llm_entities: List of entity lists from LLM extraction
|
|
146
|
+
log_buffer: Optional buffer for logging
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List of tuples for batch insertion: (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
# Step 1: Convert LLM entities to the format expected by entity resolver
|
|
153
|
+
substep_start = time.time()
|
|
154
|
+
all_entities = []
|
|
155
|
+
for entity_list in llm_entities:
|
|
156
|
+
# Convert List[Entity] or List[dict] to List[Dict] format
|
|
157
|
+
formatted_entities = []
|
|
158
|
+
for ent in entity_list:
|
|
159
|
+
# Handle both Entity objects and dicts
|
|
160
|
+
if hasattr(ent, 'text'):
|
|
161
|
+
# Entity objects only have 'text', default type to 'CONCEPT'
|
|
162
|
+
formatted_entities.append({'text': ent.text, 'type': 'CONCEPT'})
|
|
163
|
+
elif isinstance(ent, dict):
|
|
164
|
+
formatted_entities.append({'text': ent.get('text', ''), 'type': ent.get('type', 'CONCEPT')})
|
|
165
|
+
all_entities.append(formatted_entities)
|
|
166
|
+
|
|
167
|
+
total_entities = sum(len(ents) for ents in all_entities)
|
|
168
|
+
_log(log_buffer, f" [6.1] Process LLM entities: {total_entities} entities from {len(sentences)} facts in {time.time() - substep_start:.3f}s")
|
|
169
|
+
|
|
170
|
+
# Step 2: Resolve entities in BATCH (much faster!)
|
|
171
|
+
substep_start = time.time()
|
|
172
|
+
step_6_2_start = time.time()
|
|
173
|
+
|
|
174
|
+
# [6.2.1] Prepare all entities for batch resolution
|
|
175
|
+
substep_6_2_1_start = time.time()
|
|
176
|
+
all_entities_flat = []
|
|
177
|
+
entity_to_unit = [] # Maps flat index to (unit_id, local_index)
|
|
178
|
+
|
|
179
|
+
for unit_id, entities, fact_date in zip(unit_ids, all_entities, fact_dates):
|
|
180
|
+
if not entities:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
for local_idx, entity in enumerate(entities):
|
|
184
|
+
all_entities_flat.append({
|
|
185
|
+
'text': entity['text'],
|
|
186
|
+
'type': entity['type'],
|
|
187
|
+
'nearby_entities': entities,
|
|
188
|
+
})
|
|
189
|
+
entity_to_unit.append((unit_id, local_idx, fact_date))
|
|
190
|
+
_log(log_buffer, f" [6.2.1] Prepare entities: {len(all_entities_flat)} entities in {time.time() - substep_6_2_1_start:.3f}s")
|
|
191
|
+
|
|
192
|
+
# Resolve ALL entities in one batch call
|
|
193
|
+
if all_entities_flat:
|
|
194
|
+
# [6.2.2] Batch resolve entities
|
|
195
|
+
substep_6_2_2_start = time.time()
|
|
196
|
+
# Group by date for batch resolution (round to hour to reduce buckets)
|
|
197
|
+
entities_by_date = {}
|
|
198
|
+
for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
|
|
199
|
+
# Round to hour to group facts from same time period
|
|
200
|
+
date_key = fact_date.replace(minute=0, second=0, microsecond=0)
|
|
201
|
+
if date_key not in entities_by_date:
|
|
202
|
+
entities_by_date[date_key] = []
|
|
203
|
+
entities_by_date[date_key].append((idx, all_entities_flat[idx]))
|
|
204
|
+
|
|
205
|
+
_log(log_buffer, f" [6.2.2] Grouped into {len(entities_by_date)} date buckets, resolving in parallel...")
|
|
206
|
+
|
|
207
|
+
# Resolve all date groups in PARALLEL using asyncio.gather
|
|
208
|
+
resolved_entity_ids = [None] * len(all_entities_flat)
|
|
209
|
+
|
|
210
|
+
# Prepare all resolution tasks
|
|
211
|
+
async def resolve_date_bucket(date_idx, date_key, entities_group):
|
|
212
|
+
date_bucket_start = time.time()
|
|
213
|
+
indices = [idx for idx, _ in entities_group]
|
|
214
|
+
entities_data = [entity_data for _, entity_data in entities_group]
|
|
215
|
+
# Use the first fact's date for this bucket (all should be in same hour)
|
|
216
|
+
fact_date = entity_to_unit[indices[0]][2]
|
|
217
|
+
|
|
218
|
+
# Pass conn=None to let each parallel task acquire its own connection
|
|
219
|
+
batch_resolved = await entity_resolver.resolve_entities_batch(
|
|
220
|
+
bank_id=bank_id,
|
|
221
|
+
entities_data=entities_data,
|
|
222
|
+
context=context,
|
|
223
|
+
unit_event_date=fact_date,
|
|
224
|
+
conn=None # Each task gets its own connection from pool
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if len(entities_by_date) <= 10: # Only log individual buckets if there aren't too many
|
|
228
|
+
_log(log_buffer, f" [6.2.2.{date_idx}] Resolved {len(entities_data)} entities in {time.time() - date_bucket_start:.3f}s")
|
|
229
|
+
|
|
230
|
+
return indices, batch_resolved
|
|
231
|
+
|
|
232
|
+
# Execute all resolution tasks in parallel
|
|
233
|
+
import asyncio
|
|
234
|
+
tasks = [
|
|
235
|
+
resolve_date_bucket(date_idx, date_key, entities_group)
|
|
236
|
+
for date_idx, (date_key, entities_group) in enumerate(entities_by_date.items(), 1)
|
|
237
|
+
]
|
|
238
|
+
results = await asyncio.gather(*tasks)
|
|
239
|
+
|
|
240
|
+
# Map results back to resolved_entity_ids
|
|
241
|
+
for indices, batch_resolved in results:
|
|
242
|
+
for idx, entity_id in zip(indices, batch_resolved):
|
|
243
|
+
resolved_entity_ids[idx] = entity_id
|
|
244
|
+
|
|
245
|
+
_log(log_buffer, f" [6.2.2] Resolve entities: {len(all_entities_flat)} entities across {len(entities_by_date)} buckets in {time.time() - substep_6_2_2_start:.3f}s")
|
|
246
|
+
|
|
247
|
+
# [6.2.3] Create unit-entity links in BATCH
|
|
248
|
+
substep_6_2_3_start = time.time()
|
|
249
|
+
# Map resolved entities back to units and collect all (unit, entity) pairs
|
|
250
|
+
unit_to_entity_ids = {}
|
|
251
|
+
unit_entity_pairs = []
|
|
252
|
+
for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
|
|
253
|
+
if unit_id not in unit_to_entity_ids:
|
|
254
|
+
unit_to_entity_ids[unit_id] = []
|
|
255
|
+
|
|
256
|
+
entity_id = resolved_entity_ids[idx]
|
|
257
|
+
unit_to_entity_ids[unit_id].append(entity_id)
|
|
258
|
+
unit_entity_pairs.append((unit_id, entity_id))
|
|
259
|
+
|
|
260
|
+
# Batch insert all unit-entity links (MUCH faster!)
|
|
261
|
+
await entity_resolver.link_units_to_entities_batch(unit_entity_pairs, conn=conn)
|
|
262
|
+
_log(log_buffer, f" [6.2.3] Create unit-entity links (batched): {len(unit_entity_pairs)} links in {time.time() - substep_6_2_3_start:.3f}s")
|
|
263
|
+
|
|
264
|
+
_log(log_buffer, f" [6.2] Entity resolution (batched): {len(all_entities_flat)} entities resolved in {time.time() - step_6_2_start:.3f}s")
|
|
265
|
+
else:
|
|
266
|
+
unit_to_entity_ids = {}
|
|
267
|
+
_log(log_buffer, f" [6.2] Entity resolution (batched): 0 entities in {time.time() - step_6_2_start:.3f}s")
|
|
268
|
+
|
|
269
|
+
# Step 3: Create entity links between units that share entities
|
|
270
|
+
substep_start = time.time()
|
|
271
|
+
# Collect all unique entity IDs
|
|
272
|
+
all_entity_ids = set()
|
|
273
|
+
for entity_ids in unit_to_entity_ids.values():
|
|
274
|
+
all_entity_ids.update(entity_ids)
|
|
275
|
+
|
|
276
|
+
_log(log_buffer, f" [6.3] Creating entity links for {len(all_entity_ids)} unique entities...")
|
|
277
|
+
|
|
278
|
+
# Find all units that reference these entities (ONE batched query)
|
|
279
|
+
entity_to_units = {}
|
|
280
|
+
if all_entity_ids:
|
|
281
|
+
query_start = time.time()
|
|
282
|
+
import uuid
|
|
283
|
+
entity_id_list = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in all_entity_ids]
|
|
284
|
+
rows = await conn.fetch(
|
|
285
|
+
"""
|
|
286
|
+
SELECT entity_id, unit_id
|
|
287
|
+
FROM unit_entities
|
|
288
|
+
WHERE entity_id = ANY($1::uuid[])
|
|
289
|
+
""",
|
|
290
|
+
entity_id_list
|
|
291
|
+
)
|
|
292
|
+
_log(log_buffer, f" [6.3.1] Query unit_entities: {len(rows)} rows in {time.time() - query_start:.3f}s")
|
|
293
|
+
|
|
294
|
+
# Group by entity_id
|
|
295
|
+
group_start = time.time()
|
|
296
|
+
for row in rows:
|
|
297
|
+
entity_id = row['entity_id']
|
|
298
|
+
if entity_id not in entity_to_units:
|
|
299
|
+
entity_to_units[entity_id] = []
|
|
300
|
+
entity_to_units[entity_id].append(row['unit_id'])
|
|
301
|
+
_log(log_buffer, f" [6.3.2] Group by entity_id: {time.time() - group_start:.3f}s")
|
|
302
|
+
|
|
303
|
+
# Create bidirectional links between units that share entities
|
|
304
|
+
link_gen_start = time.time()
|
|
305
|
+
links = []
|
|
306
|
+
for entity_id, units_with_entity in entity_to_units.items():
|
|
307
|
+
# For each pair of units with this entity, create bidirectional links
|
|
308
|
+
for i, unit_id_1 in enumerate(units_with_entity):
|
|
309
|
+
for unit_id_2 in units_with_entity[i+1:]:
|
|
310
|
+
# Bidirectional links
|
|
311
|
+
links.append((unit_id_1, unit_id_2, 'entity', 1.0, entity_id))
|
|
312
|
+
links.append((unit_id_2, unit_id_1, 'entity', 1.0, entity_id))
|
|
313
|
+
|
|
314
|
+
_log(log_buffer, f" [6.3.3] Generate {len(links)} links: {time.time() - link_gen_start:.3f}s")
|
|
315
|
+
_log(log_buffer, f" [6.3] Entity link creation: {len(links)} links for {len(all_entity_ids)} unique entities in {time.time() - substep_start:.3f}s")
|
|
316
|
+
|
|
317
|
+
return links
|
|
318
|
+
|
|
319
|
+
except Exception as e:
|
|
320
|
+
logger.error(f"Failed to extract entities in batch: {str(e)}")
|
|
321
|
+
import traceback
|
|
322
|
+
traceback.print_exc()
|
|
323
|
+
raise
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
async def create_temporal_links_batch_per_fact(
|
|
327
|
+
conn,
|
|
328
|
+
bank_id: str,
|
|
329
|
+
unit_ids: List[str],
|
|
330
|
+
time_window_hours: int = 24,
|
|
331
|
+
log_buffer: List[str] = None,
|
|
332
|
+
):
|
|
333
|
+
"""
|
|
334
|
+
Create temporal links for multiple units, each with their own event_date.
|
|
335
|
+
|
|
336
|
+
Queries the event_date for each unit from the database and creates temporal
|
|
337
|
+
links based on individual dates (supports per-fact dating).
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
conn: Database connection
|
|
341
|
+
agent_id: bank IDentifier
|
|
342
|
+
unit_ids: List of unit IDs
|
|
343
|
+
time_window_hours: Time window in hours for temporal links
|
|
344
|
+
log_buffer: Optional buffer for logging
|
|
345
|
+
"""
|
|
346
|
+
if not unit_ids:
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
import time as time_mod
|
|
351
|
+
|
|
352
|
+
# Get the event_date for each new unit
|
|
353
|
+
fetch_dates_start = time_mod.time()
|
|
354
|
+
rows = await conn.fetch(
|
|
355
|
+
"""
|
|
356
|
+
SELECT id, event_date
|
|
357
|
+
FROM memory_units
|
|
358
|
+
WHERE id::text = ANY($1)
|
|
359
|
+
""",
|
|
360
|
+
unit_ids
|
|
361
|
+
)
|
|
362
|
+
new_units = {str(row['id']): row['event_date'] for row in rows}
|
|
363
|
+
_log(log_buffer, f" [7.1] Fetch event_dates for {len(unit_ids)} units: {time_mod.time() - fetch_dates_start:.3f}s")
|
|
364
|
+
|
|
365
|
+
# Fetch ALL potential temporal neighbors in ONE query (much faster!)
|
|
366
|
+
# Get time range across all units with overflow protection
|
|
367
|
+
min_date, max_date = compute_temporal_query_bounds(new_units, time_window_hours)
|
|
368
|
+
|
|
369
|
+
fetch_neighbors_start = time_mod.time()
|
|
370
|
+
all_candidates = await conn.fetch(
|
|
371
|
+
"""
|
|
372
|
+
SELECT id, event_date
|
|
373
|
+
FROM memory_units
|
|
374
|
+
WHERE bank_id = $1
|
|
375
|
+
AND event_date BETWEEN $2 AND $3
|
|
376
|
+
AND id::text != ALL($4)
|
|
377
|
+
ORDER BY event_date DESC
|
|
378
|
+
""",
|
|
379
|
+
bank_id,
|
|
380
|
+
min_date,
|
|
381
|
+
max_date,
|
|
382
|
+
unit_ids
|
|
383
|
+
)
|
|
384
|
+
_log(log_buffer, f" [7.2] Fetch {len(all_candidates)} candidate neighbors (1 query): {time_mod.time() - fetch_neighbors_start:.3f}s")
|
|
385
|
+
|
|
386
|
+
# Filter and create links in memory (much faster than N queries)
|
|
387
|
+
link_gen_start = time_mod.time()
|
|
388
|
+
links = compute_temporal_links(new_units, all_candidates, time_window_hours)
|
|
389
|
+
_log(log_buffer, f" [7.3] Generate {len(links)} temporal links: {time_mod.time() - link_gen_start:.3f}s")
|
|
390
|
+
|
|
391
|
+
if links:
|
|
392
|
+
insert_start = time_mod.time()
|
|
393
|
+
await conn.executemany(
|
|
394
|
+
"""
|
|
395
|
+
INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
396
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
397
|
+
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
|
|
398
|
+
""",
|
|
399
|
+
links
|
|
400
|
+
)
|
|
401
|
+
_log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
|
|
402
|
+
|
|
403
|
+
except Exception as e:
|
|
404
|
+
logger.error(f"Failed to create temporal links: {str(e)}")
|
|
405
|
+
import traceback
|
|
406
|
+
traceback.print_exc()
|
|
407
|
+
raise
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
async def create_semantic_links_batch(
|
|
411
|
+
conn,
|
|
412
|
+
bank_id: str,
|
|
413
|
+
unit_ids: List[str],
|
|
414
|
+
embeddings: List[List[float]],
|
|
415
|
+
top_k: int = 5,
|
|
416
|
+
threshold: float = 0.7,
|
|
417
|
+
log_buffer: List[str] = None,
|
|
418
|
+
):
|
|
419
|
+
"""
|
|
420
|
+
Create semantic links for multiple units efficiently.
|
|
421
|
+
|
|
422
|
+
For each unit, finds similar units and creates links.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
conn: Database connection
|
|
426
|
+
agent_id: bank IDentifier
|
|
427
|
+
unit_ids: List of unit IDs
|
|
428
|
+
embeddings: List of embedding vectors
|
|
429
|
+
top_k: Number of top similar units to link
|
|
430
|
+
threshold: Minimum similarity threshold
|
|
431
|
+
log_buffer: Optional buffer for logging
|
|
432
|
+
"""
|
|
433
|
+
if not unit_ids or not embeddings:
|
|
434
|
+
return
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
import time as time_mod
|
|
438
|
+
import numpy as np
|
|
439
|
+
|
|
440
|
+
# Fetch ALL existing units with embeddings in ONE query
|
|
441
|
+
fetch_start = time_mod.time()
|
|
442
|
+
all_existing = await conn.fetch(
|
|
443
|
+
"""
|
|
444
|
+
SELECT id, embedding
|
|
445
|
+
FROM memory_units
|
|
446
|
+
WHERE bank_id = $1
|
|
447
|
+
AND embedding IS NOT NULL
|
|
448
|
+
AND id::text != ALL($2)
|
|
449
|
+
""",
|
|
450
|
+
bank_id,
|
|
451
|
+
unit_ids
|
|
452
|
+
)
|
|
453
|
+
_log(log_buffer, f" [8.1] Fetch {len(all_existing)} existing embeddings (1 query): {time_mod.time() - fetch_start:.3f}s")
|
|
454
|
+
|
|
455
|
+
# Convert to numpy for vectorized similarity computation
|
|
456
|
+
compute_start = time_mod.time()
|
|
457
|
+
all_links = []
|
|
458
|
+
|
|
459
|
+
if all_existing:
|
|
460
|
+
# Convert existing embeddings to numpy array
|
|
461
|
+
existing_ids = [str(row['id']) for row in all_existing]
|
|
462
|
+
# Stack embeddings as 2D array: (num_embeddings, embedding_dim)
|
|
463
|
+
embedding_arrays = []
|
|
464
|
+
for row in all_existing:
|
|
465
|
+
raw_emb = row['embedding']
|
|
466
|
+
# Handle different pgvector formats
|
|
467
|
+
if isinstance(raw_emb, str):
|
|
468
|
+
# Parse string format: "[1.0, 2.0, ...]"
|
|
469
|
+
import json
|
|
470
|
+
emb = np.array(json.loads(raw_emb), dtype=np.float32)
|
|
471
|
+
elif isinstance(raw_emb, (list, tuple)):
|
|
472
|
+
emb = np.array(raw_emb, dtype=np.float32)
|
|
473
|
+
else:
|
|
474
|
+
# Try direct conversion (works for numpy arrays, pgvector objects, etc.)
|
|
475
|
+
emb = np.array(raw_emb, dtype=np.float32)
|
|
476
|
+
|
|
477
|
+
# Ensure it's 1D
|
|
478
|
+
if emb.ndim != 1:
|
|
479
|
+
raise ValueError(f"Expected 1D embedding, got shape {emb.shape}")
|
|
480
|
+
embedding_arrays.append(emb)
|
|
481
|
+
|
|
482
|
+
if not embedding_arrays:
|
|
483
|
+
existing_embeddings = np.array([])
|
|
484
|
+
elif len(embedding_arrays) == 1:
|
|
485
|
+
# Single embedding: reshape to (1, dim)
|
|
486
|
+
existing_embeddings = embedding_arrays[0].reshape(1, -1)
|
|
487
|
+
else:
|
|
488
|
+
# Multiple embeddings: vstack
|
|
489
|
+
existing_embeddings = np.vstack(embedding_arrays)
|
|
490
|
+
|
|
491
|
+
# For each new unit, compute similarities with ALL existing units
|
|
492
|
+
for unit_id, new_embedding in zip(unit_ids, embeddings):
|
|
493
|
+
new_emb_array = np.array(new_embedding)
|
|
494
|
+
|
|
495
|
+
# Compute cosine similarities (dot product for normalized vectors)
|
|
496
|
+
similarities = np.dot(existing_embeddings, new_emb_array)
|
|
497
|
+
|
|
498
|
+
# Find top-k above threshold
|
|
499
|
+
# Get indices of similarities above threshold
|
|
500
|
+
above_threshold = np.where(similarities >= threshold)[0]
|
|
501
|
+
|
|
502
|
+
if len(above_threshold) > 0:
|
|
503
|
+
# Sort by similarity (descending) and take top-k
|
|
504
|
+
sorted_indices = above_threshold[np.argsort(-similarities[above_threshold])][:top_k]
|
|
505
|
+
|
|
506
|
+
for idx in sorted_indices:
|
|
507
|
+
similar_id = existing_ids[idx]
|
|
508
|
+
similarity = float(similarities[idx])
|
|
509
|
+
all_links.append((unit_id, similar_id, 'semantic', similarity, None))
|
|
510
|
+
|
|
511
|
+
_log(log_buffer, f" [8.2] Compute similarities & generate {len(all_links)} semantic links: {time_mod.time() - compute_start:.3f}s")
|
|
512
|
+
|
|
513
|
+
if all_links:
|
|
514
|
+
insert_start = time_mod.time()
|
|
515
|
+
await conn.executemany(
|
|
516
|
+
"""
|
|
517
|
+
INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
518
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
519
|
+
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
|
|
520
|
+
""",
|
|
521
|
+
all_links
|
|
522
|
+
)
|
|
523
|
+
_log(log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s")
|
|
524
|
+
|
|
525
|
+
except Exception as e:
|
|
526
|
+
logger.error(f"Failed to create semantic links: {str(e)}")
|
|
527
|
+
import traceback
|
|
528
|
+
traceback.print_exc()
|
|
529
|
+
raise
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
async def insert_entity_links_batch(conn, links: List[tuple]):
|
|
533
|
+
"""
|
|
534
|
+
Insert all entity links in a single batch.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
conn: Database connection
|
|
538
|
+
links: List of tuples (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
539
|
+
"""
|
|
540
|
+
if not links:
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
await conn.executemany(
|
|
544
|
+
"""
|
|
545
|
+
INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
546
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
547
|
+
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
|
|
548
|
+
""",
|
|
549
|
+
links
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
async def create_causal_links_batch(
|
|
554
|
+
conn,
|
|
555
|
+
unit_ids: List[str],
|
|
556
|
+
causal_relations_per_fact: List[List[dict]],
|
|
557
|
+
) -> int:
|
|
558
|
+
"""
|
|
559
|
+
Create causal links between facts based on LLM-extracted causal relationships.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
conn: Database connection
|
|
563
|
+
unit_ids: List of unit IDs (in same order as causal_relations_per_fact)
|
|
564
|
+
causal_relations_per_fact: List of causal relations for each fact.
|
|
565
|
+
Each element is a list of dicts with:
|
|
566
|
+
- target_fact_index: Index into unit_ids for the target fact
|
|
567
|
+
- relation_type: "causes", "caused_by", "enables", or "prevents"
|
|
568
|
+
- strength: Float in [0.0, 1.0] representing relationship strength
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
Number of causal links created
|
|
572
|
+
|
|
573
|
+
Causal link types:
|
|
574
|
+
- "causes": This fact directly causes the target fact (forward causation)
|
|
575
|
+
- "caused_by": This fact was caused by the target fact (backward causation)
|
|
576
|
+
- "enables": This fact enables/allows the target fact (enablement)
|
|
577
|
+
- "prevents": This fact prevents/blocks the target fact (prevention)
|
|
578
|
+
"""
|
|
579
|
+
if not unit_ids or not causal_relations_per_fact:
|
|
580
|
+
return 0
|
|
581
|
+
|
|
582
|
+
try:
|
|
583
|
+
import time as time_mod
|
|
584
|
+
create_start = time_mod.time()
|
|
585
|
+
|
|
586
|
+
# Build links list
|
|
587
|
+
links = []
|
|
588
|
+
for fact_idx, causal_relations in enumerate(causal_relations_per_fact):
|
|
589
|
+
if not causal_relations:
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
from_unit_id = unit_ids[fact_idx]
|
|
593
|
+
|
|
594
|
+
for relation in causal_relations:
|
|
595
|
+
target_idx = relation['target_fact_index']
|
|
596
|
+
relation_type = relation['relation_type']
|
|
597
|
+
strength = relation.get('strength', 1.0)
|
|
598
|
+
|
|
599
|
+
# Validate relation_type - must match database constraint
|
|
600
|
+
valid_types = {'causes', 'caused_by', 'enables', 'prevents'}
|
|
601
|
+
if relation_type not in valid_types:
|
|
602
|
+
logger.error(
|
|
603
|
+
f"Invalid relation_type '{relation_type}' (type: {type(relation_type).__name__}) "
|
|
604
|
+
f"from fact {fact_idx}. Must be one of: {valid_types}. "
|
|
605
|
+
f"Relation data: {relation}"
|
|
606
|
+
)
|
|
607
|
+
continue
|
|
608
|
+
|
|
609
|
+
# Validate target index
|
|
610
|
+
if target_idx < 0 or target_idx >= len(unit_ids):
|
|
611
|
+
logger.warning(f"Invalid target_fact_index {target_idx} in causal relation from fact {fact_idx}")
|
|
612
|
+
continue
|
|
613
|
+
|
|
614
|
+
to_unit_id = unit_ids[target_idx]
|
|
615
|
+
|
|
616
|
+
# Don't create self-links
|
|
617
|
+
if from_unit_id == to_unit_id:
|
|
618
|
+
continue
|
|
619
|
+
|
|
620
|
+
# Add the causal link
|
|
621
|
+
# link_type is the relation_type (e.g., "causes", "caused_by")
|
|
622
|
+
# weight is the strength of the relationship
|
|
623
|
+
links.append((from_unit_id, to_unit_id, relation_type, strength, None))
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
if links:
|
|
627
|
+
insert_start = time_mod.time()
|
|
628
|
+
try:
|
|
629
|
+
await conn.executemany(
|
|
630
|
+
"""
|
|
631
|
+
INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
632
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
633
|
+
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
|
|
634
|
+
""",
|
|
635
|
+
links
|
|
636
|
+
)
|
|
637
|
+
except Exception as db_error:
|
|
638
|
+
# Log the actual data being inserted for debugging
|
|
639
|
+
logger.error(f"Database insert failed for causal links. Error: {db_error}")
|
|
640
|
+
logger.error(f"Attempted to insert {len(links)} links. First few:")
|
|
641
|
+
for i, link in enumerate(links[:3]):
|
|
642
|
+
logger.error(f" Link {i}: from={link[0]}, to={link[1]}, type='{link[2]}' (repr={repr(link[2])}), weight={link[3]}, entity={link[4]}")
|
|
643
|
+
raise
|
|
644
|
+
|
|
645
|
+
return len(links)
|
|
646
|
+
|
|
647
|
+
except Exception as e:
|
|
648
|
+
logger.error(f"Failed to create causal links: {str(e)}")
|
|
649
|
+
import traceback
|
|
650
|
+
traceback.print_exc()
|
|
651
|
+
raise
|