hindsight-api 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +252 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/api/http.py +282 -20
  6. hindsight_api/api/mcp.py +47 -52
  7. hindsight_api/config.py +238 -6
  8. hindsight_api/engine/cross_encoder.py +599 -86
  9. hindsight_api/engine/db_budget.py +284 -0
  10. hindsight_api/engine/db_utils.py +11 -0
  11. hindsight_api/engine/embeddings.py +453 -26
  12. hindsight_api/engine/entity_resolver.py +8 -5
  13. hindsight_api/engine/interface.py +8 -4
  14. hindsight_api/engine/llm_wrapper.py +241 -27
  15. hindsight_api/engine/memory_engine.py +609 -122
  16. hindsight_api/engine/query_analyzer.py +4 -3
  17. hindsight_api/engine/response_models.py +38 -0
  18. hindsight_api/engine/retain/fact_extraction.py +388 -192
  19. hindsight_api/engine/retain/fact_storage.py +34 -8
  20. hindsight_api/engine/retain/link_utils.py +24 -16
  21. hindsight_api/engine/retain/orchestrator.py +52 -17
  22. hindsight_api/engine/retain/types.py +9 -0
  23. hindsight_api/engine/search/graph_retrieval.py +42 -13
  24. hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
  25. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  26. hindsight_api/engine/search/reranking.py +2 -2
  27. hindsight_api/engine/search/retrieval.py +847 -200
  28. hindsight_api/engine/search/tags.py +172 -0
  29. hindsight_api/engine/search/think_utils.py +1 -1
  30. hindsight_api/engine/search/trace.py +12 -0
  31. hindsight_api/engine/search/tracer.py +24 -1
  32. hindsight_api/engine/search/types.py +21 -0
  33. hindsight_api/engine/task_backend.py +109 -18
  34. hindsight_api/engine/utils.py +1 -1
  35. hindsight_api/extensions/context.py +10 -1
  36. hindsight_api/main.py +56 -4
  37. hindsight_api/metrics.py +433 -48
  38. hindsight_api/migrations.py +141 -1
  39. hindsight_api/models.py +3 -1
  40. hindsight_api/pg0.py +53 -0
  41. hindsight_api/server.py +39 -2
  42. {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
  43. hindsight_api-0.3.0.dist-info/RECORD +82 -0
  44. {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
  45. hindsight_api-0.2.1.dist-info/RECORD +0 -75
  46. {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
@@ -45,6 +45,7 @@ async def insert_facts_batch(
45
45
  metadata_jsons = []
46
46
  chunk_ids = []
47
47
  document_ids = []
48
+ tags_list = []
48
49
 
49
50
  for fact in facts:
50
51
  fact_texts.append(fact.fact_text)
@@ -65,16 +66,31 @@ async def insert_facts_batch(
65
66
  chunk_ids.append(fact.chunk_id)
66
67
  # Use per-fact document_id if available, otherwise fallback to batch-level document_id
67
68
  document_ids.append(fact.document_id if fact.document_id else document_id)
69
+ # Convert tags to JSON string for proper batch insertion (PostgreSQL unnest doesn't handle 2D arrays well)
70
+ tags_list.append(json.dumps(fact.tags if fact.tags else []))
68
71
 
69
72
  # Batch insert all facts
73
+ # Note: tags are passed as JSON strings and converted back to varchar[] via jsonb_array_elements_text + array_agg
70
74
  results = await conn.fetch(
71
75
  f"""
72
- INSERT INTO {fq_table("memory_units")} (bank_id, text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
73
- context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id)
74
- SELECT $1, * FROM unnest(
75
- $2::text[], $3::vector[], $4::timestamptz[], $5::timestamptz[], $6::timestamptz[], $7::timestamptz[],
76
- $8::text[], $9::text[], $10::float[], $11::int[], $12::jsonb[], $13::text[], $14::text[]
76
+ WITH input_data AS (
77
+ SELECT * FROM unnest(
78
+ $2::text[], $3::vector[], $4::timestamptz[], $5::timestamptz[], $6::timestamptz[], $7::timestamptz[],
79
+ $8::text[], $9::text[], $10::float[], $11::int[], $12::jsonb[], $13::text[], $14::text[], $15::jsonb[]
80
+ ) AS t(text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
81
+ context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id, tags_json)
77
82
  )
83
+ INSERT INTO {fq_table("memory_units")} (bank_id, text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
84
+ context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id, tags)
85
+ SELECT
86
+ $1,
87
+ text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
88
+ context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id,
89
+ COALESCE(
90
+ (SELECT array_agg(elem) FROM jsonb_array_elements_text(tags_json) AS elem),
91
+ '{{}}'::varchar[]
92
+ )
93
+ FROM input_data
78
94
  RETURNING id
79
95
  """,
80
96
  bank_id,
@@ -91,6 +107,7 @@ async def insert_facts_batch(
91
107
  metadata_jsons,
92
108
  chunk_ids,
93
109
  document_ids,
110
+ tags_list,
94
111
  )
95
112
 
96
113
  unit_ids = [str(row["id"]) for row in results]
@@ -121,7 +138,13 @@ async def ensure_bank_exists(conn, bank_id: str) -> None:
121
138
 
122
139
 
123
140
  async def handle_document_tracking(
124
- conn, bank_id: str, document_id: str, combined_content: str, is_first_batch: bool, retain_params: dict | None = None
141
+ conn,
142
+ bank_id: str,
143
+ document_id: str,
144
+ combined_content: str,
145
+ is_first_batch: bool,
146
+ retain_params: dict | None = None,
147
+ document_tags: list[str] | None = None,
125
148
  ) -> None:
126
149
  """
127
150
  Handle document tracking in the database.
@@ -133,6 +156,7 @@ async def handle_document_tracking(
133
156
  combined_content: Combined content text from all content items
134
157
  is_first_batch: Whether this is the first batch (for chunked operations)
135
158
  retain_params: Optional parameters passed during retain (context, event_date, etc.)
159
+ document_tags: Optional list of tags to associate with the document
136
160
  """
137
161
  import hashlib
138
162
 
@@ -149,13 +173,14 @@ async def handle_document_tracking(
149
173
  # Insert document (or update if exists from concurrent operations)
150
174
  await conn.execute(
151
175
  f"""
152
- INSERT INTO {fq_table("documents")} (id, bank_id, original_text, content_hash, metadata, retain_params)
153
- VALUES ($1, $2, $3, $4, $5, $6)
176
+ INSERT INTO {fq_table("documents")} (id, bank_id, original_text, content_hash, metadata, retain_params, tags)
177
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
154
178
  ON CONFLICT (id, bank_id) DO UPDATE
155
179
  SET original_text = EXCLUDED.original_text,
156
180
  content_hash = EXCLUDED.content_hash,
157
181
  metadata = EXCLUDED.metadata,
158
182
  retain_params = EXCLUDED.retain_params,
183
+ tags = EXCLUDED.tags,
159
184
  updated_at = NOW()
160
185
  """,
161
186
  document_id,
@@ -164,4 +189,5 @@ async def handle_document_tracking(
164
189
  content_hash,
165
190
  json.dumps({}), # Empty metadata dict
166
191
  json.dumps(retain_params) if retain_params else None,
192
+ document_tags or [],
167
193
  )
@@ -479,14 +479,18 @@ async def create_temporal_links_batch_per_fact(
479
479
 
480
480
  if links:
481
481
  insert_start = time_mod.time()
482
- await conn.executemany(
483
- f"""
484
- INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
485
- VALUES ($1, $2, $3, $4, $5)
486
- ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
487
- """,
488
- links,
489
- )
482
+ # Batch inserts to avoid timeout on large batches
483
+ BATCH_SIZE = 1000
484
+ for batch_start in range(0, len(links), BATCH_SIZE):
485
+ batch = links[batch_start : batch_start + BATCH_SIZE]
486
+ await conn.executemany(
487
+ f"""
488
+ INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
489
+ VALUES ($1, $2, $3, $4, $5)
490
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
491
+ """,
492
+ batch,
493
+ )
490
494
  _log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
491
495
 
492
496
  return len(links)
@@ -644,14 +648,18 @@ async def create_semantic_links_batch(
644
648
 
645
649
  if all_links:
646
650
  insert_start = time_mod.time()
647
- await conn.executemany(
648
- f"""
649
- INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
650
- VALUES ($1, $2, $3, $4, $5)
651
- ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
652
- """,
653
- all_links,
654
- )
651
+ # Batch inserts to avoid timeout on large batches
652
+ BATCH_SIZE = 1000
653
+ for batch_start in range(0, len(all_links), BATCH_SIZE):
654
+ batch = all_links[batch_start : batch_start + BATCH_SIZE]
655
+ await conn.executemany(
656
+ f"""
657
+ INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
658
+ VALUES ($1, $2, $3, $4, $5)
659
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
660
+ """,
661
+ batch,
662
+ )
655
663
  _log(
656
664
  log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s"
657
665
  )
@@ -9,6 +9,7 @@ import time
9
9
  import uuid
10
10
  from datetime import UTC, datetime
11
11
 
12
+ from ...config import get_config
12
13
  from ..db_utils import acquire_with_retry
13
14
  from . import bank_utils
14
15
 
@@ -18,6 +19,7 @@ def utcnow():
18
19
  return datetime.now(UTC)
19
20
 
20
21
 
22
+ from ..response_models import TokenUsage
21
23
  from . import (
22
24
  chunk_storage,
23
25
  deduplication,
@@ -47,7 +49,8 @@ async def retain_batch(
47
49
  is_first_batch: bool = True,
48
50
  fact_type_override: str | None = None,
49
51
  confidence_score: float | None = None,
50
- ) -> list[list[str]]:
52
+ document_tags: list[str] | None = None,
53
+ ) -> tuple[list[list[str]], TokenUsage]:
51
54
  """
52
55
  Process a batch of content through the retain pipeline.
53
56
 
@@ -65,9 +68,10 @@ async def retain_batch(
65
68
  is_first_batch: Whether this is the first batch
66
69
  fact_type_override: Override fact type for all facts
67
70
  confidence_score: Confidence score for opinions
71
+ document_tags: Tags applied to all items in this batch
68
72
 
69
73
  Returns:
70
- List of unit ID lists (one list per content item)
74
+ Tuple of (unit ID lists, token usage for fact extraction)
71
75
  """
72
76
  start_time = time.time()
73
77
  total_chars = sum(len(item.get("content", "")) for item in contents_dicts)
@@ -86,12 +90,16 @@ async def retain_batch(
86
90
  # Convert dicts to RetainContent objects
87
91
  contents = []
88
92
  for item in contents_dicts:
93
+ # Merge item-level tags with document-level tags
94
+ item_tags = item.get("tags", []) or []
95
+ merged_tags = list(set(item_tags + (document_tags or [])))
89
96
  content = RetainContent(
90
97
  content=item["content"],
91
98
  context=item.get("context", ""),
92
99
  event_date=item.get("event_date") or utcnow(),
93
100
  metadata=item.get("metadata", {}),
94
101
  entities=item.get("entities", []),
102
+ tags=merged_tags,
95
103
  )
96
104
  contents.append(content)
97
105
 
@@ -99,7 +107,7 @@ async def retain_batch(
99
107
  step_start = time.time()
100
108
  extract_opinions = fact_type_override == "opinion"
101
109
 
102
- extracted_facts, chunks = await fact_extraction.extract_facts_from_contents(
110
+ extracted_facts, chunks, usage = await fact_extraction.extract_facts_from_contents(
103
111
  contents, llm_config, agent_name, extract_opinions
104
112
  )
105
113
  log_buffer.append(
@@ -129,7 +137,7 @@ async def retain_batch(
129
137
  if first_item.get("metadata"):
130
138
  retain_params["metadata"] = first_item["metadata"]
131
139
  await fact_storage.handle_document_tracking(
132
- conn, bank_id, document_id, combined_content, is_first_batch, retain_params
140
+ conn, bank_id, document_id, combined_content, is_first_batch, retain_params, document_tags
133
141
  )
134
142
  else:
135
143
  # Check for per-item document_ids
@@ -157,14 +165,14 @@ async def retain_batch(
157
165
  if first_item.get("metadata"):
158
166
  retain_params["metadata"] = first_item["metadata"]
159
167
  await fact_storage.handle_document_tracking(
160
- conn, bank_id, doc_id, combined_content, is_first_batch, retain_params
168
+ conn, bank_id, doc_id, combined_content, is_first_batch, retain_params, document_tags
161
169
  )
162
170
 
163
171
  total_time = time.time() - start_time
164
172
  logger.info(
165
173
  f"RETAIN_BATCH COMPLETE: 0 facts extracted from {len(contents)} contents in {total_time:.3f}s (document tracked, no facts)"
166
174
  )
167
- return [[] for _ in contents]
175
+ return [[] for _ in contents], usage
168
176
 
169
177
  # Apply fact_type_override if provided
170
178
  if fact_type_override:
@@ -223,7 +231,7 @@ async def retain_batch(
223
231
  retain_params["metadata"] = first_item["metadata"]
224
232
 
225
233
  await fact_storage.handle_document_tracking(
226
- conn, bank_id, document_id, combined_content, is_first_batch, retain_params
234
+ conn, bank_id, document_id, combined_content, is_first_batch, retain_params, document_tags
227
235
  )
228
236
  document_ids_added.append(document_id)
229
237
  doc_id_mapping[None] = document_id # For backwards compatibility
@@ -267,7 +275,13 @@ async def retain_batch(
267
275
  retain_params["metadata"] = first_item["metadata"]
268
276
 
269
277
  await fact_storage.handle_document_tracking(
270
- conn, bank_id, actual_doc_id, combined_content, is_first_batch, retain_params
278
+ conn,
279
+ bank_id,
280
+ actual_doc_id,
281
+ combined_content,
282
+ is_first_batch,
283
+ retain_params,
284
+ document_tags,
271
285
  )
272
286
  document_ids_added.append(actual_doc_id)
273
287
 
@@ -344,7 +358,7 @@ async def retain_batch(
344
358
  non_duplicate_facts = deduplication.filter_duplicates(processed_facts, is_duplicate_flags)
345
359
 
346
360
  if not non_duplicate_facts:
347
- return [[] for _ in contents]
361
+ return [[] for _ in contents], usage
348
362
 
349
363
  # Insert facts (document_id is now stored per-fact)
350
364
  step_start = time.time()
@@ -394,16 +408,26 @@ async def retain_batch(
394
408
  causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, non_duplicate_facts)
395
409
  log_buffer.append(f"[10] Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")
396
410
 
397
- # Regenerate observations INSIDE transaction for atomicity
398
- await observation_regeneration.regenerate_observations_batch(
399
- conn, embeddings_model, llm_config, bank_id, entity_links, log_buffer
400
- )
411
+ # Regenerate observations - sync (in transaction) or async (background task)
412
+ config = get_config()
413
+ if config.retain_observations_async:
414
+ # Queue for async processing after transaction commits
415
+ entity_ids_for_async = list(set(link.entity_id for link in entity_links)) if entity_links else []
416
+ log_buffer.append(
417
+ f"[11] Observations: queued {len(entity_ids_for_async)} entities for async processing"
418
+ )
419
+ else:
420
+ # Run synchronously inside transaction for atomicity
421
+ await observation_regeneration.regenerate_observations_batch(
422
+ conn, embeddings_model, llm_config, bank_id, entity_links, log_buffer
423
+ )
424
+ entity_ids_for_async = []
401
425
 
402
426
  # Map results back to original content items
403
427
  result_unit_ids = _map_results_to_contents(contents, extracted_facts, is_duplicate_flags, unit_ids)
404
428
 
405
- # Trigger background tasks AFTER transaction commits (opinion reinforcement only)
406
- await _trigger_background_tasks(task_backend, bank_id, unit_ids, non_duplicate_facts)
429
+ # Trigger background tasks AFTER transaction commits
430
+ await _trigger_background_tasks(task_backend, bank_id, unit_ids, non_duplicate_facts, entity_ids_for_async)
407
431
 
408
432
  # Log final summary
409
433
  total_time = time.time() - start_time
@@ -415,7 +439,7 @@ async def retain_batch(
415
439
 
416
440
  logger.info("\n" + "\n".join(log_buffer) + "\n")
417
441
 
418
- return result_unit_ids
442
+ return result_unit_ids, usage
419
443
 
420
444
 
421
445
  def _map_results_to_contents(
@@ -453,8 +477,9 @@ async def _trigger_background_tasks(
453
477
  bank_id: str,
454
478
  unit_ids: list[str],
455
479
  facts: list[ProcessedFact],
480
+ entity_ids_for_observations: list[str] | None = None,
456
481
  ) -> None:
457
- """Trigger opinion reinforcement as background task (after transaction commits)."""
482
+ """Trigger background tasks after transaction commits."""
458
483
  # Trigger opinion reinforcement if there are entities
459
484
  fact_entities = [[e.name for e in fact.entities] for fact in facts]
460
485
  if any(fact_entities):
@@ -467,3 +492,13 @@ async def _trigger_background_tasks(
467
492
  "unit_entities": fact_entities,
468
493
  }
469
494
  )
495
+
496
+ # Trigger observation regeneration if async mode is enabled
497
+ if entity_ids_for_observations:
498
+ await task_backend.submit_task(
499
+ {
500
+ "type": "regenerate_observations",
501
+ "bank_id": bank_id,
502
+ "entity_ids": entity_ids_for_observations,
503
+ }
504
+ )
@@ -21,6 +21,7 @@ class RetainContentDict(TypedDict, total=False):
21
21
  metadata: Custom key-value metadata (optional)
22
22
  document_id: Document ID for this content item (optional)
23
23
  entities: User-provided entities to merge with extracted entities (optional)
24
+ tags: Visibility scope tags for this content item (optional)
24
25
  """
25
26
 
26
27
  content: str # Required
@@ -29,6 +30,7 @@ class RetainContentDict(TypedDict, total=False):
29
30
  metadata: dict[str, str]
30
31
  document_id: str
31
32
  entities: list[dict[str, str]] # [{"text": "...", "type": "..."}]
33
+ tags: list[str] # Visibility scope tags
32
34
 
33
35
 
34
36
  def _now_utc() -> datetime:
@@ -49,6 +51,7 @@ class RetainContent:
49
51
  event_date: datetime = field(default_factory=_now_utc)
50
52
  metadata: dict[str, str] = field(default_factory=dict)
51
53
  entities: list[dict[str, str]] = field(default_factory=list) # User-provided entities
54
+ tags: list[str] = field(default_factory=list) # Visibility scope tags
52
55
 
53
56
 
54
57
  @dataclass
@@ -113,6 +116,7 @@ class ExtractedFact:
113
116
  context: str = ""
114
117
  mentioned_at: datetime | None = None
115
118
  metadata: dict[str, str] = field(default_factory=dict)
119
+ tags: list[str] = field(default_factory=list) # Visibility scope tags
116
120
 
117
121
 
118
122
  @dataclass
@@ -158,6 +162,9 @@ class ProcessedFact:
158
162
  # Track which content this fact came from (for user entity merging)
159
163
  content_index: int = 0
160
164
 
165
+ # Visibility scope tags
166
+ tags: list[str] = field(default_factory=list)
167
+
161
168
  @property
162
169
  def is_duplicate(self) -> bool:
163
170
  """Check if this fact was marked as a duplicate."""
@@ -201,6 +208,7 @@ class ProcessedFact:
201
208
  causal_relations=extracted_fact.causal_relations,
202
209
  chunk_id=chunk_id,
203
210
  content_index=extracted_fact.content_index,
211
+ tags=extracted_fact.tags,
204
212
  )
205
213
 
206
214
 
@@ -232,6 +240,7 @@ class RetainBatch:
232
240
  document_id: str | None = None
233
241
  fact_type_override: str | None = None
234
242
  confidence_score: float | None = None
243
+ document_tags: list[str] = field(default_factory=list) # Tags applied to all items
235
244
 
236
245
  # Extracted data (populated during processing)
237
246
  extracted_facts: list[ExtractedFact] = field(default_factory=list)
@@ -11,7 +11,8 @@ from abc import ABC, abstractmethod
11
11
 
12
12
  from ..db_utils import acquire_with_retry
13
13
  from ..memory_engine import fq_table
14
- from .types import RetrievalResult
14
+ from .tags import TagsMatch, filter_results_by_tags
15
+ from .types import MPFPTimings, RetrievalResult
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
@@ -42,7 +43,10 @@ class GraphRetriever(ABC):
42
43
  query_text: str | None = None,
43
44
  semantic_seeds: list[RetrievalResult] | None = None,
44
45
  temporal_seeds: list[RetrievalResult] | None = None,
45
- ) -> list[RetrievalResult]:
46
+ adjacency=None, # TypedAdjacency, optional pre-loaded graph
47
+ tags: list[str] | None = None, # Visibility scope tags for filtering
48
+ tags_match: TagsMatch = "any", # How to match tags: 'any' (OR) or 'all' (AND)
49
+ ) -> tuple[list[RetrievalResult], MPFPTimings | None]:
46
50
  """
47
51
  Retrieve relevant facts via graph traversal.
48
52
 
@@ -55,9 +59,11 @@ class GraphRetriever(ABC):
55
59
  query_text: Original query text (optional, for some strategies)
56
60
  semantic_seeds: Pre-computed semantic entry points (from semantic retrieval)
57
61
  temporal_seeds: Pre-computed temporal entry points (from temporal retrieval)
62
+ adjacency: Pre-loaded typed adjacency graph (optional, for MPFP)
63
+ tags: Optional list of tags for visibility filtering (OR matching)
58
64
 
59
65
  Returns:
60
- List of RetrievalResult objects with activation scores set
66
+ Tuple of (List of RetrievalResult with activation scores, optional timing info)
61
67
  """
62
68
  pass
63
69
 
@@ -111,7 +117,10 @@ class BFSGraphRetriever(GraphRetriever):
111
117
  query_text: str | None = None,
112
118
  semantic_seeds: list[RetrievalResult] | None = None,
113
119
  temporal_seeds: list[RetrievalResult] | None = None,
114
- ) -> list[RetrievalResult]:
120
+ adjacency=None, # Not used by BFS
121
+ tags: list[str] | None = None,
122
+ tags_match: TagsMatch = "any",
123
+ ) -> tuple[list[RetrievalResult], MPFPTimings | None]:
115
124
  """
116
125
  Retrieve facts using BFS spreading activation.
117
126
 
@@ -122,11 +131,14 @@ class BFSGraphRetriever(GraphRetriever):
122
131
  4. Return visited nodes up to budget
123
132
 
124
133
  Note: BFS finds its own entry points via embedding search.
125
- The semantic_seeds and temporal_seeds parameters are accepted
134
+ The semantic_seeds, temporal_seeds, and adjacency parameters are accepted
126
135
  for interface compatibility but not used.
127
136
  """
128
137
  async with acquire_with_retry(pool) as conn:
129
- return await self._retrieve_with_conn(conn, query_embedding_str, bank_id, fact_type, budget)
138
+ results = await self._retrieve_with_conn(
139
+ conn, query_embedding_str, bank_id, fact_type, budget, tags=tags, tags_match=tags_match
140
+ )
141
+ return results, None
130
142
 
131
143
  async def _retrieve_with_conn(
132
144
  self,
@@ -135,33 +147,46 @@ class BFSGraphRetriever(GraphRetriever):
135
147
  bank_id: str,
136
148
  fact_type: str,
137
149
  budget: int,
150
+ tags: list[str] | None = None,
151
+ tags_match: TagsMatch = "any",
138
152
  ) -> list[RetrievalResult]:
139
153
  """Internal implementation with connection."""
154
+ from .tags import build_tags_where_clause_simple
155
+
156
+ tags_clause = build_tags_where_clause_simple(tags, 6, match=tags_match)
157
+ params = [query_embedding_str, bank_id, fact_type, self.entry_point_threshold, self.entry_point_limit]
158
+ if tags:
159
+ params.append(tags)
140
160
 
141
161
  # Step 1: Find entry points
142
162
  entry_points = await conn.fetch(
143
163
  f"""
144
164
  SELECT id, text, context, event_date, occurred_start, occurred_end,
145
- mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
165
+ mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
146
166
  1 - (embedding <=> $1::vector) AS similarity
147
167
  FROM {fq_table("memory_units")}
148
168
  WHERE bank_id = $2
149
169
  AND embedding IS NOT NULL
150
170
  AND fact_type = $3
151
171
  AND (1 - (embedding <=> $1::vector)) >= $4
172
+ {tags_clause}
152
173
  ORDER BY embedding <=> $1::vector
153
174
  LIMIT $5
154
175
  """,
155
- query_embedding_str,
156
- bank_id,
157
- fact_type,
158
- self.entry_point_threshold,
159
- self.entry_point_limit,
176
+ *params,
160
177
  )
161
178
 
162
179
  if not entry_points:
180
+ logger.debug(
181
+ f"[BFS] No entry points found for fact_type={fact_type} (tags={tags}, tags_match={tags_match})"
182
+ )
163
183
  return []
164
184
 
185
+ logger.debug(
186
+ f"[BFS] Found {len(entry_points)} entry points for fact_type={fact_type} "
187
+ f"(tags={tags}, tags_match={tags_match})"
188
+ )
189
+
165
190
  # Step 2: BFS spreading activation
166
191
  visited = set()
167
192
  results = []
@@ -192,7 +217,7 @@ class BFSGraphRetriever(GraphRetriever):
192
217
  f"""
193
218
  SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end,
194
219
  mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type,
195
- mu.document_id, mu.chunk_id,
220
+ mu.document_id, mu.chunk_id, mu.tags,
196
221
  ml.weight, ml.link_type, ml.from_unit_id
197
222
  FROM {fq_table("memory_links")} ml
198
223
  JOIN {fq_table("memory_units")} mu ON ml.to_unit_id = mu.id
@@ -232,4 +257,8 @@ class BFSGraphRetriever(GraphRetriever):
232
257
  neighbor_result = RetrievalResult.from_db_row(dict(n))
233
258
  queue.append((neighbor_result, new_activation))
234
259
 
260
+ # Apply tags filtering (BFS may traverse into memories that don't match tags criteria)
261
+ if tags:
262
+ results = filter_results_by_tags(results, tags, match=tags_match)
263
+
235
264
  return results