hindsight-api 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. hindsight_api/__init__.py +38 -0
  2. hindsight_api/api/__init__.py +105 -0
  3. hindsight_api/api/http.py +1872 -0
  4. hindsight_api/api/mcp.py +157 -0
  5. hindsight_api/engine/__init__.py +47 -0
  6. hindsight_api/engine/cross_encoder.py +97 -0
  7. hindsight_api/engine/db_utils.py +93 -0
  8. hindsight_api/engine/embeddings.py +113 -0
  9. hindsight_api/engine/entity_resolver.py +575 -0
  10. hindsight_api/engine/llm_wrapper.py +269 -0
  11. hindsight_api/engine/memory_engine.py +3095 -0
  12. hindsight_api/engine/query_analyzer.py +519 -0
  13. hindsight_api/engine/response_models.py +222 -0
  14. hindsight_api/engine/retain/__init__.py +50 -0
  15. hindsight_api/engine/retain/bank_utils.py +423 -0
  16. hindsight_api/engine/retain/chunk_storage.py +82 -0
  17. hindsight_api/engine/retain/deduplication.py +104 -0
  18. hindsight_api/engine/retain/embedding_processing.py +62 -0
  19. hindsight_api/engine/retain/embedding_utils.py +54 -0
  20. hindsight_api/engine/retain/entity_processing.py +90 -0
  21. hindsight_api/engine/retain/fact_extraction.py +1027 -0
  22. hindsight_api/engine/retain/fact_storage.py +176 -0
  23. hindsight_api/engine/retain/link_creation.py +121 -0
  24. hindsight_api/engine/retain/link_utils.py +651 -0
  25. hindsight_api/engine/retain/orchestrator.py +405 -0
  26. hindsight_api/engine/retain/types.py +206 -0
  27. hindsight_api/engine/search/__init__.py +15 -0
  28. hindsight_api/engine/search/fusion.py +122 -0
  29. hindsight_api/engine/search/observation_utils.py +132 -0
  30. hindsight_api/engine/search/reranking.py +103 -0
  31. hindsight_api/engine/search/retrieval.py +503 -0
  32. hindsight_api/engine/search/scoring.py +161 -0
  33. hindsight_api/engine/search/temporal_extraction.py +64 -0
  34. hindsight_api/engine/search/think_utils.py +255 -0
  35. hindsight_api/engine/search/trace.py +215 -0
  36. hindsight_api/engine/search/tracer.py +447 -0
  37. hindsight_api/engine/search/types.py +160 -0
  38. hindsight_api/engine/task_backend.py +223 -0
  39. hindsight_api/engine/utils.py +203 -0
  40. hindsight_api/metrics.py +227 -0
  41. hindsight_api/migrations.py +163 -0
  42. hindsight_api/models.py +309 -0
  43. hindsight_api/pg0.py +425 -0
  44. hindsight_api/web/__init__.py +12 -0
  45. hindsight_api/web/server.py +143 -0
  46. hindsight_api-0.0.13.dist-info/METADATA +41 -0
  47. hindsight_api-0.0.13.dist-info/RECORD +48 -0
  48. hindsight_api-0.0.13.dist-info/WHEEL +4 -0
@@ -0,0 +1,503 @@
1
+ """
2
+ Retrieval module for 4-way parallel search.
3
+
4
+ Implements:
5
+ 1. Semantic retrieval (vector similarity)
6
+ 2. BM25 retrieval (keyword/full-text search)
7
+ 3. Graph retrieval (spreading activation)
8
+ 4. Temporal retrieval (time-aware search with spreading)
9
+ """
10
+
11
+ from typing import List, Dict, Any, Tuple, Optional
12
+ from datetime import datetime
13
+ import asyncio
14
+ from ..db_utils import acquire_with_retry
15
+ from .types import RetrievalResult
16
+
17
+
18
+ async def retrieve_semantic(
19
+ conn,
20
+ query_emb_str: str,
21
+ bank_id: str,
22
+ fact_type: str,
23
+ limit: int
24
+ ) -> List[RetrievalResult]:
25
+ """
26
+ Semantic retrieval via vector similarity.
27
+
28
+ Args:
29
+ conn: Database connection
30
+ query_emb_str: Query embedding as string
31
+ agent_id: bank ID
32
+ fact_type: Fact type to filter
33
+ limit: Maximum results to return
34
+
35
+ Returns:
36
+ List of RetrievalResult objects
37
+ """
38
+ results = await conn.fetch(
39
+ """
40
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
41
+ 1 - (embedding <=> $1::vector) AS similarity
42
+ FROM memory_units
43
+ WHERE bank_id = $2
44
+ AND embedding IS NOT NULL
45
+ AND fact_type = $3
46
+ AND (1 - (embedding <=> $1::vector)) >= 0.3
47
+ ORDER BY embedding <=> $1::vector
48
+ LIMIT $4
49
+ """,
50
+ query_emb_str, bank_id, fact_type, limit
51
+ )
52
+ return [RetrievalResult.from_db_row(dict(r)) for r in results]
53
+
54
+
55
+ async def retrieve_bm25(
56
+ conn,
57
+ query_text: str,
58
+ bank_id: str,
59
+ fact_type: str,
60
+ limit: int
61
+ ) -> List[RetrievalResult]:
62
+ """
63
+ BM25 keyword retrieval via full-text search.
64
+
65
+ Args:
66
+ conn: Database connection
67
+ query_text: Query text
68
+ agent_id: bank ID
69
+ fact_type: Fact type to filter
70
+ limit: Maximum results to return
71
+
72
+ Returns:
73
+ List of RetrievalResult objects
74
+ """
75
+ import re
76
+
77
+ # Sanitize query text: remove special characters that have meaning in tsquery
78
+ # Keep only alphanumeric characters and spaces
79
+ sanitized_text = re.sub(r'[^\w\s]', ' ', query_text.lower())
80
+
81
+ # Split and filter empty strings
82
+ tokens = [token for token in sanitized_text.split() if token]
83
+
84
+ if not tokens:
85
+ # If no valid tokens, return empty results
86
+ return []
87
+
88
+ # Convert query to tsquery using OR for more flexible matching
89
+ # This prevents empty results when some terms are missing
90
+ query_tsquery = " | ".join(tokens)
91
+
92
+ results = await conn.fetch(
93
+ """
94
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
95
+ ts_rank_cd(search_vector, to_tsquery('english', $1)) AS bm25_score
96
+ FROM memory_units
97
+ WHERE bank_id = $2
98
+ AND fact_type = $3
99
+ AND search_vector @@ to_tsquery('english', $1)
100
+ ORDER BY bm25_score DESC
101
+ LIMIT $4
102
+ """,
103
+ query_tsquery, bank_id, fact_type, limit
104
+ )
105
+ return [RetrievalResult.from_db_row(dict(r)) for r in results]
106
+
107
+
108
+ async def retrieve_graph(
109
+ conn,
110
+ query_emb_str: str,
111
+ bank_id: str,
112
+ fact_type: str,
113
+ budget: int
114
+ ) -> List[RetrievalResult]:
115
+ """
116
+ Graph retrieval via spreading activation.
117
+
118
+ Args:
119
+ conn: Database connection
120
+ query_emb_str: Query embedding as string
121
+ agent_id: bank ID
122
+ fact_type: Fact type to filter
123
+ budget: Node budget for graph traversal
124
+
125
+ Returns:
126
+ List of RetrievalResult objects
127
+ """
128
+ # Find entry points
129
+ entry_points = await conn.fetch(
130
+ """
131
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
132
+ 1 - (embedding <=> $1::vector) AS similarity
133
+ FROM memory_units
134
+ WHERE bank_id = $2
135
+ AND embedding IS NOT NULL
136
+ AND fact_type = $3
137
+ AND (1 - (embedding <=> $1::vector)) >= 0.5
138
+ ORDER BY embedding <=> $1::vector
139
+ LIMIT 5
140
+ """,
141
+ query_emb_str, bank_id, fact_type
142
+ )
143
+
144
+ if not entry_points:
145
+ return []
146
+
147
+ # BFS-style spreading activation with batched neighbor fetching
148
+ visited = set()
149
+ results = []
150
+ queue = [(RetrievalResult.from_db_row(dict(r)), r["similarity"]) for r in entry_points]
151
+ budget_remaining = budget
152
+
153
+ # Process nodes in batches to reduce DB roundtrips
154
+ batch_size = 20 # Fetch neighbors for up to 20 nodes at once
155
+
156
+ while queue and budget_remaining > 0:
157
+ # Collect a batch of nodes to process
158
+ batch_nodes = []
159
+ batch_activations = {}
160
+
161
+ while queue and len(batch_nodes) < batch_size and budget_remaining > 0:
162
+ current, activation = queue.pop(0)
163
+ unit_id = current.id
164
+
165
+ if unit_id not in visited:
166
+ visited.add(unit_id)
167
+ budget_remaining -= 1
168
+ results.append(current)
169
+ batch_nodes.append(current.id)
170
+ batch_activations[unit_id] = activation
171
+
172
+ # Batch fetch neighbors for all nodes in this batch
173
+ # Fetch top weighted neighbors (batch_size * 10 = ~200 for good distribution)
174
+ if batch_nodes and budget_remaining > 0:
175
+ max_neighbors = len(batch_nodes) * 10
176
+ neighbors = await conn.fetch(
177
+ """
178
+ SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end, mu.mentioned_at,
179
+ mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
180
+ ml.weight, ml.link_type, ml.from_unit_id
181
+ FROM memory_links ml
182
+ JOIN memory_units mu ON ml.to_unit_id = mu.id
183
+ WHERE ml.from_unit_id = ANY($1::uuid[])
184
+ AND ml.weight >= 0.1
185
+ AND mu.fact_type = $2
186
+ ORDER BY ml.weight DESC
187
+ LIMIT $3
188
+ """,
189
+ batch_nodes, fact_type, max_neighbors
190
+ )
191
+
192
+ for n in neighbors:
193
+ neighbor_id = str(n["id"])
194
+ if neighbor_id not in visited:
195
+ # Get parent activation
196
+ parent_id = str(n["from_unit_id"])
197
+ activation = batch_activations.get(parent_id, 0.5)
198
+
199
+ # Boost activation for causal links (they're high-value relationships)
200
+ link_type = n["link_type"]
201
+ base_weight = n["weight"]
202
+
203
+ # Causal links get 1.5-2.0x boost depending on type
204
+ if link_type in ("causes", "caused_by"):
205
+ # Direct causation - very strong relationship
206
+ causal_boost = 2.0
207
+ elif link_type in ("enables", "prevents"):
208
+ # Conditional causation - strong but not as direct
209
+ causal_boost = 1.5
210
+ else:
211
+ # Temporal, semantic, entity links - standard weight
212
+ causal_boost = 1.0
213
+
214
+ effective_weight = base_weight * causal_boost
215
+ new_activation = activation * effective_weight * 0.8
216
+ if new_activation > 0.1:
217
+ neighbor_result = RetrievalResult.from_db_row(dict(n))
218
+ queue.append((neighbor_result, new_activation))
219
+
220
+ return results
221
+
222
+
223
+ async def retrieve_temporal(
224
+ conn,
225
+ query_emb_str: str,
226
+ bank_id: str,
227
+ fact_type: str,
228
+ start_date: datetime,
229
+ end_date: datetime,
230
+ budget: int,
231
+ semantic_threshold: float = 0.1
232
+ ) -> List[RetrievalResult]:
233
+ """
234
+ Temporal retrieval with spreading activation.
235
+
236
+ Strategy:
237
+ 1. Find entry points (facts in date range with semantic relevance)
238
+ 2. Spread through temporal links to related facts
239
+ 3. Score by temporal proximity + semantic similarity + link weight
240
+
241
+ Args:
242
+ conn: Database connection
243
+ query_emb_str: Query embedding as string
244
+ agent_id: bank ID
245
+ fact_type: Fact type to filter
246
+ start_date: Start of time range
247
+ end_date: End of time range
248
+ budget: Node budget for spreading
249
+ semantic_threshold: Minimum semantic similarity to include
250
+
251
+ Returns:
252
+ List of RetrievalResult objects with temporal scores
253
+ """
254
+ from datetime import timezone
255
+
256
+ # Ensure start_date and end_date are timezone-aware (UTC) to match database datetimes
257
+ if start_date.tzinfo is None:
258
+ start_date = start_date.replace(tzinfo=timezone.utc)
259
+ if end_date.tzinfo is None:
260
+ end_date = end_date.replace(tzinfo=timezone.utc)
261
+
262
+ entry_points = await conn.fetch(
263
+ """
264
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
265
+ 1 - (embedding <=> $1::vector) AS similarity
266
+ FROM memory_units
267
+ WHERE bank_id = $2
268
+ AND fact_type = $3
269
+ AND embedding IS NOT NULL
270
+ AND (
271
+ -- Match if occurred range overlaps with query range
272
+ (occurred_start IS NOT NULL AND occurred_end IS NOT NULL
273
+ AND occurred_start <= $5 AND occurred_end >= $4)
274
+ OR
275
+ -- Match if mentioned_at falls within query range
276
+ (mentioned_at IS NOT NULL AND mentioned_at BETWEEN $4 AND $5)
277
+ OR
278
+ -- Match if any occurred date is set and overlaps (even if only start or end is set)
279
+ (occurred_start IS NOT NULL AND occurred_start BETWEEN $4 AND $5)
280
+ OR
281
+ (occurred_end IS NOT NULL AND occurred_end BETWEEN $4 AND $5)
282
+ )
283
+ AND (1 - (embedding <=> $1::vector)) >= $6
284
+ ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC, (embedding <=> $1::vector) ASC
285
+ LIMIT 10
286
+ """,
287
+ query_emb_str, bank_id, fact_type, start_date, end_date, semantic_threshold
288
+ )
289
+
290
+ if not entry_points:
291
+ return []
292
+
293
+ # Calculate temporal scores for entry points
294
+ total_days = (end_date - start_date).total_seconds() / 86400
295
+ mid_date = start_date + (end_date - start_date) / 2 # Calculate once for all comparisons
296
+ results = []
297
+ visited = set()
298
+
299
+ for ep in entry_points:
300
+ unit_id = str(ep["id"])
301
+ visited.add(unit_id)
302
+
303
+ # Calculate temporal proximity using the most relevant date
304
+ # Priority: occurred_start/end (event time) > mentioned_at (mention time)
305
+ best_date = None
306
+ if ep["occurred_start"] is not None and ep["occurred_end"] is not None:
307
+ # Use midpoint of occurred range
308
+ best_date = ep["occurred_start"] + (ep["occurred_end"] - ep["occurred_start"]) / 2
309
+ elif ep["occurred_start"] is not None:
310
+ best_date = ep["occurred_start"]
311
+ elif ep["occurred_end"] is not None:
312
+ best_date = ep["occurred_end"]
313
+ elif ep["mentioned_at"] is not None:
314
+ best_date = ep["mentioned_at"]
315
+
316
+ # Temporal proximity score (closer to range center = higher score)
317
+ if best_date:
318
+ days_from_mid = abs((best_date - mid_date).total_seconds() / 86400)
319
+ temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
320
+ else:
321
+ temporal_proximity = 0.5 # Fallback if no dates (shouldn't happen due to WHERE clause)
322
+
323
+ # Create RetrievalResult with temporal scores
324
+ ep_result = RetrievalResult.from_db_row(dict(ep))
325
+ ep_result.temporal_score = temporal_proximity
326
+ ep_result.temporal_proximity = temporal_proximity
327
+ results.append(ep_result)
328
+
329
+ # Spread through temporal links
330
+ queue = [(RetrievalResult.from_db_row(dict(ep)), ep["similarity"], 1.0) for ep in entry_points] # (unit, semantic_sim, temporal_score)
331
+ budget_remaining = budget - len(entry_points)
332
+
333
+ while queue and budget_remaining > 0:
334
+ current, semantic_sim, temporal_score = queue.pop(0)
335
+ current_id = current.id
336
+
337
+ # Get neighbors via temporal and causal links
338
+ if budget_remaining > 0:
339
+ neighbors = await conn.fetch(
340
+ """
341
+ SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
342
+ ml.weight, ml.link_type,
343
+ 1 - (mu.embedding <=> $1::vector) AS similarity
344
+ FROM memory_links ml
345
+ JOIN memory_units mu ON ml.to_unit_id = mu.id
346
+ WHERE ml.from_unit_id = $2
347
+ AND ml.link_type IN ('temporal', 'causes', 'caused_by', 'enables', 'prevents')
348
+ AND ml.weight >= 0.1
349
+ AND mu.fact_type = $3
350
+ AND mu.embedding IS NOT NULL
351
+ AND (1 - (mu.embedding <=> $1::vector)) >= $4
352
+ ORDER BY ml.weight DESC
353
+ LIMIT 10
354
+ """,
355
+ query_emb_str, current.id, fact_type, semantic_threshold
356
+ )
357
+
358
+ for n in neighbors:
359
+ neighbor_id = str(n["id"])
360
+ if neighbor_id in visited:
361
+ continue
362
+
363
+ visited.add(neighbor_id)
364
+ budget_remaining -= 1
365
+
366
+ # Calculate temporal score for neighbor using best available date
367
+ neighbor_best_date = None
368
+ if n["occurred_start"] is not None and n["occurred_end"] is not None:
369
+ neighbor_best_date = n["occurred_start"] + (n["occurred_end"] - n["occurred_start"]) / 2
370
+ elif n["occurred_start"] is not None:
371
+ neighbor_best_date = n["occurred_start"]
372
+ elif n["occurred_end"] is not None:
373
+ neighbor_best_date = n["occurred_end"]
374
+ elif n["mentioned_at"] is not None:
375
+ neighbor_best_date = n["mentioned_at"]
376
+
377
+ if neighbor_best_date:
378
+ days_from_mid = abs((neighbor_best_date - mid_date).total_seconds() / 86400)
379
+ neighbor_temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
380
+ else:
381
+ neighbor_temporal_proximity = 0.3 # Lower score if no temporal data
382
+
383
+ # Boost causal links (same as graph retrieval)
384
+ link_type = n["link_type"]
385
+ if link_type in ("causes", "caused_by"):
386
+ causal_boost = 2.0
387
+ elif link_type in ("enables", "prevents"):
388
+ causal_boost = 1.5
389
+ else:
390
+ causal_boost = 1.0
391
+
392
+ # Propagate temporal score through links (decay, with causal boost)
393
+ propagated_temporal = temporal_score * n["weight"] * causal_boost * 0.7
394
+
395
+ # Combined temporal score
396
+ combined_temporal = max(neighbor_temporal_proximity, propagated_temporal)
397
+
398
+ # Create RetrievalResult with temporal scores
399
+ neighbor_result = RetrievalResult.from_db_row(dict(n))
400
+ neighbor_result.temporal_score = combined_temporal
401
+ neighbor_result.temporal_proximity = neighbor_temporal_proximity
402
+ results.append(neighbor_result)
403
+
404
+ # Add to queue for further spreading
405
+ if budget_remaining > 0 and combined_temporal > 0.2:
406
+ queue.append((neighbor_result, n["similarity"], combined_temporal))
407
+
408
+ if budget_remaining <= 0:
409
+ break
410
+
411
+ return results
412
+
413
+
414
+ async def retrieve_parallel(
415
+ pool,
416
+ query_text: str,
417
+ query_embedding_str: str,
418
+ bank_id: str,
419
+ fact_type: str,
420
+ thinking_budget: int,
421
+ question_date: Optional[datetime] = None,
422
+ query_analyzer: Optional["QueryAnalyzer"] = None
423
+ ) -> Tuple[List[RetrievalResult], List[RetrievalResult], List[RetrievalResult], Optional[List[RetrievalResult]], Dict[str, float], Optional[Tuple[datetime, datetime]]]:
424
+ """
425
+ Run 3-way or 4-way parallel retrieval (adds temporal if detected).
426
+
427
+ Args:
428
+ pool: Database connection pool
429
+ query_text: Query text
430
+ query_embedding_str: Query embedding as string
431
+ agent_id: bank ID
432
+ fact_type: Fact type to filter
433
+ thinking_budget: Budget for graph traversal and retrieval limits
434
+ question_date: Optional date when question was asked (for temporal filtering)
435
+ query_analyzer: Query analyzer to use (defaults to TransformerQueryAnalyzer)
436
+
437
+ Returns:
438
+ Tuple of (semantic_results, bm25_results, graph_results, temporal_results, timings, temporal_constraint)
439
+ Each results list contains RetrievalResult objects
440
+ temporal_results is None if no temporal constraint detected
441
+ timings is a dict with per-method latencies in seconds
442
+ temporal_constraint is the (start_date, end_date) tuple if detected, else None
443
+ """
444
+ # Detect temporal constraint
445
+ from .temporal_extraction import extract_temporal_constraint
446
+ import time
447
+
448
+ temporal_constraint = extract_temporal_constraint(
449
+ query_text, reference_date=question_date, analyzer=query_analyzer
450
+ )
451
+
452
+ # Wrapper to track timing for each retrieval method
453
+ async def timed_retrieval(name: str, coro):
454
+ start = time.time()
455
+ result = await coro
456
+ duration = time.time() - start
457
+ return result, name, duration
458
+
459
+ async def run_semantic():
460
+ async with acquire_with_retry(pool) as conn:
461
+ return await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
462
+
463
+ async def run_bm25():
464
+ async with acquire_with_retry(pool) as conn:
465
+ return await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
466
+
467
+ async def run_graph():
468
+ async with acquire_with_retry(pool) as conn:
469
+ return await retrieve_graph(conn, query_embedding_str, bank_id, fact_type, budget=thinking_budget)
470
+
471
+ async def run_temporal(start_date, end_date):
472
+ async with acquire_with_retry(pool) as conn:
473
+ return await retrieve_temporal(
474
+ conn, query_embedding_str, bank_id, fact_type,
475
+ start_date, end_date, budget=thinking_budget, semantic_threshold=0.1
476
+ )
477
+
478
+ # Run retrievals in parallel with timing
479
+ timings = {}
480
+ if temporal_constraint:
481
+ start_date, end_date = temporal_constraint
482
+ results = await asyncio.gather(
483
+ timed_retrieval("semantic", run_semantic()),
484
+ timed_retrieval("bm25", run_bm25()),
485
+ timed_retrieval("graph", run_graph()),
486
+ timed_retrieval("temporal", run_temporal(start_date, end_date))
487
+ )
488
+ semantic_results, _, timings["semantic"] = results[0]
489
+ bm25_results, _, timings["bm25"] = results[1]
490
+ graph_results, _, timings["graph"] = results[2]
491
+ temporal_results, _, timings["temporal"] = results[3]
492
+ else:
493
+ results = await asyncio.gather(
494
+ timed_retrieval("semantic", run_semantic()),
495
+ timed_retrieval("bm25", run_bm25()),
496
+ timed_retrieval("graph", run_graph())
497
+ )
498
+ semantic_results, _, timings["semantic"] = results[0]
499
+ bm25_results, _, timings["bm25"] = results[1]
500
+ graph_results, _, timings["graph"] = results[2]
501
+ temporal_results = None
502
+
503
+ return semantic_results, bm25_results, graph_results, temporal_results, timings, temporal_constraint
@@ -0,0 +1,161 @@
1
+ """
2
+ Scoring functions for memory search and retrieval.
3
+
4
+ Includes recency weighting, frequency weighting, temporal proximity,
5
+ and similarity calculations used in memory activation and ranking.
6
+ """
7
+ from datetime import datetime
8
+ from typing import List
9
+
10
+
11
+ def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
12
+ """
13
+ Calculate cosine similarity between two vectors.
14
+
15
+ Args:
16
+ vec1: First vector
17
+ vec2: Second vector
18
+
19
+ Returns:
20
+ Similarity score between 0 and 1
21
+ """
22
+ if len(vec1) != len(vec2):
23
+ raise ValueError("Vectors must have same dimension")
24
+
25
+ dot_product = sum(a * b for a, b in zip(vec1, vec2))
26
+ magnitude1 = sum(a * a for a in vec1) ** 0.5
27
+ magnitude2 = sum(b * b for b in vec2) ** 0.5
28
+
29
+ if magnitude1 == 0 or magnitude2 == 0:
30
+ return 0.0
31
+
32
+ return dot_product / (magnitude1 * magnitude2)
33
+
34
+
35
+ def calculate_recency_weight(days_since: float, half_life_days: float = 365.0) -> float:
36
+ """
37
+ Calculate recency weight using logarithmic decay.
38
+
39
+ This provides much better differentiation over long time periods compared to
40
+ exponential decay. Uses a log-based decay where the half-life parameter controls
41
+ when memories reach 50% weight.
42
+
43
+ Examples:
44
+ - Today (0 days): 1.0
45
+ - 1 year (365 days): ~0.5 (with default half_life=365)
46
+ - 2 years (730 days): ~0.33
47
+ - 5 years (1825 days): ~0.17
48
+ - 10 years (3650 days): ~0.09
49
+
50
+ This ensures that 2-year-old and 5-year-old memories have meaningfully
51
+ different weights, unlike exponential decay which makes them both ~0.
52
+
53
+ Args:
54
+ days_since: Number of days since the memory was created
55
+ half_life_days: Number of days for weight to reach 0.5 (default: 1 year)
56
+
57
+ Returns:
58
+ Weight between 0 and 1
59
+ """
60
+ import math
61
+ # Logarithmic decay: 1 / (1 + log(1 + days_since/half_life))
62
+ # This decays much slower than exponential, giving better long-term differentiation
63
+ normalized_age = days_since / half_life_days
64
+ return 1.0 / (1.0 + math.log1p(normalized_age))
65
+
66
+
67
+ def calculate_frequency_weight(access_count: int, max_boost: float = 2.0) -> float:
68
+ """
69
+ Calculate frequency weight based on access count.
70
+
71
+ Frequently accessed memories are weighted higher.
72
+ Uses logarithmic scaling to avoid over-weighting.
73
+
74
+ Args:
75
+ access_count: Number of times the memory was accessed
76
+ max_boost: Maximum multiplier for frequently accessed memories
77
+
78
+ Returns:
79
+ Weight between 1.0 and max_boost
80
+ """
81
+ import math
82
+ if access_count <= 0:
83
+ return 1.0
84
+
85
+ # Logarithmic scaling: log(access_count + 1) / log(10)
86
+ # This gives: 0 accesses = 1.0, 9 accesses ~= 1.5, 99 accesses ~= 2.0
87
+ normalized = math.log(access_count + 1) / math.log(10)
88
+ return 1.0 + min(normalized, max_boost - 1.0)
89
+
90
+
91
+ def calculate_temporal_anchor(occurred_start: datetime, occurred_end: datetime) -> datetime:
92
+ """
93
+ Calculate a single temporal anchor point from a temporal range.
94
+
95
+ Used for spreading activation - we need a single representative date
96
+ to calculate temporal proximity between facts. This simplifies the
97
+ range-to-range distance problem.
98
+
99
+ Strategy: Use midpoint of the range for balanced representation.
100
+
101
+ Args:
102
+ occurred_start: Start of temporal range
103
+ occurred_end: End of temporal range
104
+
105
+ Returns:
106
+ Single datetime representing the temporal anchor (midpoint)
107
+
108
+ Examples:
109
+ - Point event (July 14): start=July 14, end=July 14 → anchor=July 14
110
+ - Month range (February): start=Feb 1, end=Feb 28 → anchor=Feb 14
111
+ - Year range (2023): start=Jan 1, end=Dec 31 → anchor=July 1
112
+ """
113
+ # Calculate midpoint
114
+ time_delta = occurred_end - occurred_start
115
+ midpoint = occurred_start + (time_delta / 2)
116
+ return midpoint
117
+
118
+
119
+ def calculate_temporal_proximity(
120
+ anchor_a: datetime,
121
+ anchor_b: datetime,
122
+ half_life_days: float = 30.0
123
+ ) -> float:
124
+ """
125
+ Calculate temporal proximity between two temporal anchors.
126
+
127
+ Used for spreading activation to determine how "close" two facts are
128
+ in time. Uses logarithmic decay so that temporal similarity doesn't
129
+ drop off too quickly.
130
+
131
+ Args:
132
+ anchor_a: Temporal anchor of first fact
133
+ anchor_b: Temporal anchor of second fact
134
+ half_life_days: Number of days for proximity to reach 0.5
135
+ (default: 30 days = 1 month)
136
+
137
+ Returns:
138
+ Proximity score in [0, 1] where:
139
+ - 1.0 = same day
140
+ - 0.5 = ~half_life days apart
141
+ - 0.0 = very distant in time
142
+
143
+ Examples:
144
+ - Same day: 1.0
145
+ - 1 week apart (half_life=30): ~0.7
146
+ - 1 month apart (half_life=30): ~0.5
147
+ - 1 year apart (half_life=30): ~0.2
148
+ """
149
+ import math
150
+
151
+ days_apart = abs((anchor_a - anchor_b).days)
152
+
153
+ if days_apart == 0:
154
+ return 1.0
155
+
156
+ # Logarithmic decay: 1 / (1 + log(1 + days_apart/half_life))
157
+ # Similar to calculate_recency_weight but for proximity between events
158
+ normalized_distance = days_apart / half_life_days
159
+ proximity = 1.0 / (1.0 + math.log1p(normalized_distance))
160
+
161
+ return proximity