hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +311 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
  6. hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
  7. hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
  8. hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
  9. hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
  10. hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
  11. hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
  12. hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
  13. hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
  14. hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
  15. hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
  16. hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
  17. hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
  18. hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
  19. hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
  20. hindsight_api/api/http.py +1406 -118
  21. hindsight_api/api/mcp.py +11 -196
  22. hindsight_api/config.py +359 -27
  23. hindsight_api/engine/consolidation/__init__.py +5 -0
  24. hindsight_api/engine/consolidation/consolidator.py +859 -0
  25. hindsight_api/engine/consolidation/prompts.py +69 -0
  26. hindsight_api/engine/cross_encoder.py +706 -88
  27. hindsight_api/engine/db_budget.py +284 -0
  28. hindsight_api/engine/db_utils.py +11 -0
  29. hindsight_api/engine/directives/__init__.py +5 -0
  30. hindsight_api/engine/directives/models.py +37 -0
  31. hindsight_api/engine/embeddings.py +553 -29
  32. hindsight_api/engine/entity_resolver.py +8 -5
  33. hindsight_api/engine/interface.py +40 -17
  34. hindsight_api/engine/llm_wrapper.py +744 -68
  35. hindsight_api/engine/memory_engine.py +2505 -1017
  36. hindsight_api/engine/mental_models/__init__.py +14 -0
  37. hindsight_api/engine/mental_models/models.py +53 -0
  38. hindsight_api/engine/query_analyzer.py +4 -3
  39. hindsight_api/engine/reflect/__init__.py +18 -0
  40. hindsight_api/engine/reflect/agent.py +933 -0
  41. hindsight_api/engine/reflect/models.py +109 -0
  42. hindsight_api/engine/reflect/observations.py +186 -0
  43. hindsight_api/engine/reflect/prompts.py +483 -0
  44. hindsight_api/engine/reflect/tools.py +437 -0
  45. hindsight_api/engine/reflect/tools_schema.py +250 -0
  46. hindsight_api/engine/response_models.py +168 -4
  47. hindsight_api/engine/retain/bank_utils.py +79 -201
  48. hindsight_api/engine/retain/fact_extraction.py +424 -195
  49. hindsight_api/engine/retain/fact_storage.py +35 -12
  50. hindsight_api/engine/retain/link_utils.py +29 -24
  51. hindsight_api/engine/retain/orchestrator.py +24 -43
  52. hindsight_api/engine/retain/types.py +11 -2
  53. hindsight_api/engine/search/graph_retrieval.py +43 -14
  54. hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
  55. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  56. hindsight_api/engine/search/reranking.py +2 -2
  57. hindsight_api/engine/search/retrieval.py +848 -201
  58. hindsight_api/engine/search/tags.py +172 -0
  59. hindsight_api/engine/search/think_utils.py +42 -141
  60. hindsight_api/engine/search/trace.py +12 -1
  61. hindsight_api/engine/search/tracer.py +26 -6
  62. hindsight_api/engine/search/types.py +21 -3
  63. hindsight_api/engine/task_backend.py +113 -106
  64. hindsight_api/engine/utils.py +1 -152
  65. hindsight_api/extensions/__init__.py +10 -1
  66. hindsight_api/extensions/builtin/tenant.py +5 -1
  67. hindsight_api/extensions/context.py +10 -1
  68. hindsight_api/extensions/operation_validator.py +81 -4
  69. hindsight_api/extensions/tenant.py +26 -0
  70. hindsight_api/main.py +69 -6
  71. hindsight_api/mcp_local.py +12 -53
  72. hindsight_api/mcp_tools.py +494 -0
  73. hindsight_api/metrics.py +433 -48
  74. hindsight_api/migrations.py +141 -1
  75. hindsight_api/models.py +3 -3
  76. hindsight_api/pg0.py +53 -0
  77. hindsight_api/server.py +39 -2
  78. hindsight_api/worker/__init__.py +11 -0
  79. hindsight_api/worker/main.py +296 -0
  80. hindsight_api/worker/poller.py +486 -0
  81. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
  82. hindsight_api-0.4.0.dist-info/RECORD +112 -0
  83. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
  84. hindsight_api/engine/retain/observation_regeneration.py +0 -254
  85. hindsight_api/engine/search/observation_utils.py +0 -125
  86. hindsight_api/engine/search/scoring.py +0 -159
  87. hindsight_api-0.2.1.dist-info/RECORD +0 -75
  88. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
@@ -41,10 +41,10 @@ async def insert_facts_batch(
41
41
  contexts = []
42
42
  fact_types = []
43
43
  confidence_scores = []
44
- access_counts = []
45
44
  metadata_jsons = []
46
45
  chunk_ids = []
47
46
  document_ids = []
47
+ tags_list = []
48
48
 
49
49
  for fact in facts:
50
50
  fact_texts.append(fact.fact_text)
@@ -60,21 +60,35 @@ async def insert_facts_batch(
60
60
  fact_types.append(fact.fact_type)
61
61
  # confidence_score is only for opinion facts
62
62
  confidence_scores.append(1.0 if fact.fact_type == "opinion" else None)
63
- access_counts.append(0) # Initial access count
64
63
  metadata_jsons.append(json.dumps(fact.metadata))
65
64
  chunk_ids.append(fact.chunk_id)
66
65
  # Use per-fact document_id if available, otherwise fallback to batch-level document_id
67
66
  document_ids.append(fact.document_id if fact.document_id else document_id)
67
+ # Convert tags to JSON string for proper batch insertion (PostgreSQL unnest doesn't handle 2D arrays well)
68
+ tags_list.append(json.dumps(fact.tags if fact.tags else []))
68
69
 
69
70
  # Batch insert all facts
71
+ # Note: tags are passed as JSON strings and converted back to varchar[] via jsonb_array_elements_text + array_agg
70
72
  results = await conn.fetch(
71
73
  f"""
72
- INSERT INTO {fq_table("memory_units")} (bank_id, text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
73
- context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id)
74
- SELECT $1, * FROM unnest(
75
- $2::text[], $3::vector[], $4::timestamptz[], $5::timestamptz[], $6::timestamptz[], $7::timestamptz[],
76
- $8::text[], $9::text[], $10::float[], $11::int[], $12::jsonb[], $13::text[], $14::text[]
74
+ WITH input_data AS (
75
+ SELECT * FROM unnest(
76
+ $2::text[], $3::vector[], $4::timestamptz[], $5::timestamptz[], $6::timestamptz[], $7::timestamptz[],
77
+ $8::text[], $9::text[], $10::float[], $11::jsonb[], $12::text[], $13::text[], $14::jsonb[]
78
+ ) AS t(text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
79
+ context, fact_type, confidence_score, metadata, chunk_id, document_id, tags_json)
77
80
  )
81
+ INSERT INTO {fq_table("memory_units")} (bank_id, text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
82
+ context, fact_type, confidence_score, metadata, chunk_id, document_id, tags)
83
+ SELECT
84
+ $1,
85
+ text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
86
+ context, fact_type, confidence_score, metadata, chunk_id, document_id,
87
+ COALESCE(
88
+ (SELECT array_agg(elem) FROM jsonb_array_elements_text(tags_json) AS elem),
89
+ '{{}}'::varchar[]
90
+ )
91
+ FROM input_data
78
92
  RETURNING id
79
93
  """,
80
94
  bank_id,
@@ -87,10 +101,10 @@ async def insert_facts_batch(
87
101
  contexts,
88
102
  fact_types,
89
103
  confidence_scores,
90
- access_counts,
91
104
  metadata_jsons,
92
105
  chunk_ids,
93
106
  document_ids,
107
+ tags_list,
94
108
  )
95
109
 
96
110
  unit_ids = [str(row["id"]) for row in results]
@@ -109,7 +123,7 @@ async def ensure_bank_exists(conn, bank_id: str) -> None:
109
123
  """
110
124
  await conn.execute(
111
125
  f"""
112
- INSERT INTO {fq_table("banks")} (bank_id, disposition, background)
126
+ INSERT INTO {fq_table("banks")} (bank_id, disposition, mission)
113
127
  VALUES ($1, $2::jsonb, $3)
114
128
  ON CONFLICT (bank_id) DO UPDATE
115
129
  SET updated_at = NOW()
@@ -121,7 +135,13 @@ async def ensure_bank_exists(conn, bank_id: str) -> None:
121
135
 
122
136
 
123
137
  async def handle_document_tracking(
124
- conn, bank_id: str, document_id: str, combined_content: str, is_first_batch: bool, retain_params: dict | None = None
138
+ conn,
139
+ bank_id: str,
140
+ document_id: str,
141
+ combined_content: str,
142
+ is_first_batch: bool,
143
+ retain_params: dict | None = None,
144
+ document_tags: list[str] | None = None,
125
145
  ) -> None:
126
146
  """
127
147
  Handle document tracking in the database.
@@ -133,6 +153,7 @@ async def handle_document_tracking(
133
153
  combined_content: Combined content text from all content items
134
154
  is_first_batch: Whether this is the first batch (for chunked operations)
135
155
  retain_params: Optional parameters passed during retain (context, event_date, etc.)
156
+ document_tags: Optional list of tags to associate with the document
136
157
  """
137
158
  import hashlib
138
159
 
@@ -149,13 +170,14 @@ async def handle_document_tracking(
149
170
  # Insert document (or update if exists from concurrent operations)
150
171
  await conn.execute(
151
172
  f"""
152
- INSERT INTO {fq_table("documents")} (id, bank_id, original_text, content_hash, metadata, retain_params)
153
- VALUES ($1, $2, $3, $4, $5, $6)
173
+ INSERT INTO {fq_table("documents")} (id, bank_id, original_text, content_hash, metadata, retain_params, tags)
174
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
154
175
  ON CONFLICT (id, bank_id) DO UPDATE
155
176
  SET original_text = EXCLUDED.original_text,
156
177
  content_hash = EXCLUDED.content_hash,
157
178
  metadata = EXCLUDED.metadata,
158
179
  retain_params = EXCLUDED.retain_params,
180
+ tags = EXCLUDED.tags,
159
181
  updated_at = NOW()
160
182
  """,
161
183
  document_id,
@@ -164,4 +186,5 @@ async def handle_document_tracking(
164
186
  content_hash,
165
187
  json.dumps({}), # Empty metadata dict
166
188
  json.dumps(retain_params) if retain_params else None,
189
+ document_tags or [],
167
190
  )
@@ -479,14 +479,18 @@ async def create_temporal_links_batch_per_fact(
479
479
 
480
480
  if links:
481
481
  insert_start = time_mod.time()
482
- await conn.executemany(
483
- f"""
484
- INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
485
- VALUES ($1, $2, $3, $4, $5)
486
- ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
487
- """,
488
- links,
489
- )
482
+ # Batch inserts to avoid timeout on large batches
483
+ BATCH_SIZE = 1000
484
+ for batch_start in range(0, len(links), BATCH_SIZE):
485
+ batch = links[batch_start : batch_start + BATCH_SIZE]
486
+ await conn.executemany(
487
+ f"""
488
+ INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
489
+ VALUES ($1, $2, $3, $4, $5)
490
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
491
+ """,
492
+ batch,
493
+ )
490
494
  _log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
491
495
 
492
496
  return len(links)
@@ -644,14 +648,18 @@ async def create_semantic_links_batch(
644
648
 
645
649
  if all_links:
646
650
  insert_start = time_mod.time()
647
- await conn.executemany(
648
- f"""
649
- INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
650
- VALUES ($1, $2, $3, $4, $5)
651
- ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
652
- """,
653
- all_links,
654
- )
651
+ # Batch inserts to avoid timeout on large batches
652
+ BATCH_SIZE = 1000
653
+ for batch_start in range(0, len(all_links), BATCH_SIZE):
654
+ batch = all_links[batch_start : batch_start + BATCH_SIZE]
655
+ await conn.executemany(
656
+ f"""
657
+ INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
658
+ VALUES ($1, $2, $3, $4, $5)
659
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
660
+ """,
661
+ batch,
662
+ )
655
663
  _log(
656
664
  log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s"
657
665
  )
@@ -746,17 +754,14 @@ async def create_causal_links_batch(
746
754
  causal_relations_per_fact: List of causal relations for each fact.
747
755
  Each element is a list of dicts with:
748
756
  - target_fact_index: Index into unit_ids for the target fact
749
- - relation_type: "causes", "caused_by", "enables", or "prevents"
757
+ - relation_type: "caused_by"
750
758
  - strength: Float in [0.0, 1.0] representing relationship strength
751
759
 
752
760
  Returns:
753
761
  Number of causal links created
754
762
 
755
- Causal link types:
756
- - "causes": This fact directly causes the target fact (forward causation)
757
- - "caused_by": This fact was caused by the target fact (backward causation)
758
- - "enables": This fact enables/allows the target fact (enablement)
759
- - "prevents": This fact prevents/blocks the target fact (prevention)
763
+ Causal link type:
764
+ - "caused_by": This fact was caused by the target fact
760
765
  """
761
766
  if not unit_ids or not causal_relations_per_fact:
762
767
  return 0
@@ -779,8 +784,8 @@ async def create_causal_links_batch(
779
784
  relation_type = relation["relation_type"]
780
785
  strength = relation.get("strength", 1.0)
781
786
 
782
- # Validate relation_type - must match database constraint
783
- valid_types = {"causes", "caused_by", "enables", "prevents"}
787
+ # Validate relation_type - only "caused_by" is supported (DB constraint)
788
+ valid_types = {"caused_by"}
784
789
  if relation_type not in valid_types:
785
790
  logger.error(
786
791
  f"Invalid relation_type '{relation_type}' (type: {type(relation_type).__name__}) "
@@ -18,6 +18,7 @@ def utcnow():
18
18
  return datetime.now(UTC)
19
19
 
20
20
 
21
+ from ..response_models import TokenUsage
21
22
  from . import (
22
23
  chunk_storage,
23
24
  deduplication,
@@ -26,9 +27,8 @@ from . import (
26
27
  fact_extraction,
27
28
  fact_storage,
28
29
  link_creation,
29
- observation_regeneration,
30
30
  )
31
- from .types import ExtractedFact, ProcessedFact, RetainContent, RetainContentDict
31
+ from .types import EntityLink, ExtractedFact, ProcessedFact, RetainContent, RetainContentDict
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
@@ -38,7 +38,6 @@ async def retain_batch(
38
38
  embeddings_model,
39
39
  llm_config,
40
40
  entity_resolver,
41
- task_backend,
42
41
  format_date_fn,
43
42
  duplicate_checker_fn,
44
43
  bank_id: str,
@@ -47,7 +46,8 @@ async def retain_batch(
47
46
  is_first_batch: bool = True,
48
47
  fact_type_override: str | None = None,
49
48
  confidence_score: float | None = None,
50
- ) -> list[list[str]]:
49
+ document_tags: list[str] | None = None,
50
+ ) -> tuple[list[list[str]], TokenUsage]:
51
51
  """
52
52
  Process a batch of content through the retain pipeline.
53
53
 
@@ -56,7 +56,6 @@ async def retain_batch(
56
56
  embeddings_model: Embeddings model for generating embeddings
57
57
  llm_config: LLM configuration for fact extraction
58
58
  entity_resolver: Entity resolver for entity processing
59
- task_backend: Task backend for background jobs
60
59
  format_date_fn: Function to format datetime to readable string
61
60
  duplicate_checker_fn: Function to check for duplicate facts
62
61
  bank_id: Bank identifier
@@ -65,9 +64,10 @@ async def retain_batch(
65
64
  is_first_batch: Whether this is the first batch
66
65
  fact_type_override: Override fact type for all facts
67
66
  confidence_score: Confidence score for opinions
67
+ document_tags: Tags applied to all items in this batch
68
68
 
69
69
  Returns:
70
- List of unit ID lists (one list per content item)
70
+ Tuple of (unit ID lists, token usage for fact extraction)
71
71
  """
72
72
  start_time = time.time()
73
73
  total_chars = sum(len(item.get("content", "")) for item in contents_dicts)
@@ -86,12 +86,16 @@ async def retain_batch(
86
86
  # Convert dicts to RetainContent objects
87
87
  contents = []
88
88
  for item in contents_dicts:
89
+ # Merge item-level tags with document-level tags
90
+ item_tags = item.get("tags", []) or []
91
+ merged_tags = list(set(item_tags + (document_tags or [])))
89
92
  content = RetainContent(
90
93
  content=item["content"],
91
94
  context=item.get("context", ""),
92
95
  event_date=item.get("event_date") or utcnow(),
93
96
  metadata=item.get("metadata", {}),
94
97
  entities=item.get("entities", []),
98
+ tags=merged_tags,
95
99
  )
96
100
  contents.append(content)
97
101
 
@@ -99,7 +103,7 @@ async def retain_batch(
99
103
  step_start = time.time()
100
104
  extract_opinions = fact_type_override == "opinion"
101
105
 
102
- extracted_facts, chunks = await fact_extraction.extract_facts_from_contents(
106
+ extracted_facts, chunks, usage = await fact_extraction.extract_facts_from_contents(
103
107
  contents, llm_config, agent_name, extract_opinions
104
108
  )
105
109
  log_buffer.append(
@@ -129,7 +133,7 @@ async def retain_batch(
129
133
  if first_item.get("metadata"):
130
134
  retain_params["metadata"] = first_item["metadata"]
131
135
  await fact_storage.handle_document_tracking(
132
- conn, bank_id, document_id, combined_content, is_first_batch, retain_params
136
+ conn, bank_id, document_id, combined_content, is_first_batch, retain_params, document_tags
133
137
  )
134
138
  else:
135
139
  # Check for per-item document_ids
@@ -157,14 +161,14 @@ async def retain_batch(
157
161
  if first_item.get("metadata"):
158
162
  retain_params["metadata"] = first_item["metadata"]
159
163
  await fact_storage.handle_document_tracking(
160
- conn, bank_id, doc_id, combined_content, is_first_batch, retain_params
164
+ conn, bank_id, doc_id, combined_content, is_first_batch, retain_params, document_tags
161
165
  )
162
166
 
163
167
  total_time = time.time() - start_time
164
168
  logger.info(
165
169
  f"RETAIN_BATCH COMPLETE: 0 facts extracted from {len(contents)} contents in {total_time:.3f}s (document tracked, no facts)"
166
170
  )
167
- return [[] for _ in contents]
171
+ return [[] for _ in contents], usage
168
172
 
169
173
  # Apply fact_type_override if provided
170
174
  if fact_type_override:
@@ -223,7 +227,7 @@ async def retain_batch(
223
227
  retain_params["metadata"] = first_item["metadata"]
224
228
 
225
229
  await fact_storage.handle_document_tracking(
226
- conn, bank_id, document_id, combined_content, is_first_batch, retain_params
230
+ conn, bank_id, document_id, combined_content, is_first_batch, retain_params, document_tags
227
231
  )
228
232
  document_ids_added.append(document_id)
229
233
  doc_id_mapping[None] = document_id # For backwards compatibility
@@ -267,7 +271,13 @@ async def retain_batch(
267
271
  retain_params["metadata"] = first_item["metadata"]
268
272
 
269
273
  await fact_storage.handle_document_tracking(
270
- conn, bank_id, actual_doc_id, combined_content, is_first_batch, retain_params
274
+ conn,
275
+ bank_id,
276
+ actual_doc_id,
277
+ combined_content,
278
+ is_first_batch,
279
+ retain_params,
280
+ document_tags,
271
281
  )
272
282
  document_ids_added.append(actual_doc_id)
273
283
 
@@ -344,7 +354,7 @@ async def retain_batch(
344
354
  non_duplicate_facts = deduplication.filter_duplicates(processed_facts, is_duplicate_flags)
345
355
 
346
356
  if not non_duplicate_facts:
347
- return [[] for _ in contents]
357
+ return [[] for _ in contents], usage
348
358
 
349
359
  # Insert facts (document_id is now stored per-fact)
350
360
  step_start = time.time()
@@ -394,17 +404,9 @@ async def retain_batch(
394
404
  causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, non_duplicate_facts)
395
405
  log_buffer.append(f"[10] Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")
396
406
 
397
- # Regenerate observations INSIDE transaction for atomicity
398
- await observation_regeneration.regenerate_observations_batch(
399
- conn, embeddings_model, llm_config, bank_id, entity_links, log_buffer
400
- )
401
-
402
407
  # Map results back to original content items
403
408
  result_unit_ids = _map_results_to_contents(contents, extracted_facts, is_duplicate_flags, unit_ids)
404
409
 
405
- # Trigger background tasks AFTER transaction commits (opinion reinforcement only)
406
- await _trigger_background_tasks(task_backend, bank_id, unit_ids, non_duplicate_facts)
407
-
408
410
  # Log final summary
409
411
  total_time = time.time() - start_time
410
412
  log_buffer.append(f"{'=' * 60}")
@@ -415,7 +417,7 @@ async def retain_batch(
415
417
 
416
418
  logger.info("\n" + "\n".join(log_buffer) + "\n")
417
419
 
418
- return result_unit_ids
420
+ return result_unit_ids, usage
419
421
 
420
422
 
421
423
  def _map_results_to_contents(
@@ -446,24 +448,3 @@ def _map_results_to_contents(
446
448
  result_unit_ids.append(content_unit_ids)
447
449
 
448
450
  return result_unit_ids
449
-
450
-
451
- async def _trigger_background_tasks(
452
- task_backend,
453
- bank_id: str,
454
- unit_ids: list[str],
455
- facts: list[ProcessedFact],
456
- ) -> None:
457
- """Trigger opinion reinforcement as background task (after transaction commits)."""
458
- # Trigger opinion reinforcement if there are entities
459
- fact_entities = [[e.name for e in fact.entities] for fact in facts]
460
- if any(fact_entities):
461
- await task_backend.submit_task(
462
- {
463
- "type": "reinforce_opinion",
464
- "bank_id": bank_id,
465
- "created_unit_ids": unit_ids,
466
- "unit_texts": [fact.fact_text for fact in facts],
467
- "unit_entities": fact_entities,
468
- }
469
- )
@@ -21,6 +21,7 @@ class RetainContentDict(TypedDict, total=False):
21
21
  metadata: Custom key-value metadata (optional)
22
22
  document_id: Document ID for this content item (optional)
23
23
  entities: User-provided entities to merge with extracted entities (optional)
24
+ tags: Visibility scope tags for this content item (optional)
24
25
  """
25
26
 
26
27
  content: str # Required
@@ -29,6 +30,7 @@ class RetainContentDict(TypedDict, total=False):
29
30
  metadata: dict[str, str]
30
31
  document_id: str
31
32
  entities: list[dict[str, str]] # [{"text": "...", "type": "..."}]
33
+ tags: list[str] # Visibility scope tags
32
34
 
33
35
 
34
36
  def _now_utc() -> datetime:
@@ -49,6 +51,7 @@ class RetainContent:
49
51
  event_date: datetime = field(default_factory=_now_utc)
50
52
  metadata: dict[str, str] = field(default_factory=dict)
51
53
  entities: list[dict[str, str]] = field(default_factory=list) # User-provided entities
54
+ tags: list[str] = field(default_factory=list) # Visibility scope tags
52
55
 
53
56
 
54
57
  @dataclass
@@ -83,10 +86,10 @@ class CausalRelation:
83
86
  """
84
87
  Causal relationship between facts.
85
88
 
86
- Represents how one fact causes, enables, or prevents another.
89
+ Represents how one fact was caused by another.
87
90
  """
88
91
 
89
- relation_type: str # "causes", "enables", "prevents", "caused_by"
92
+ relation_type: str # "caused_by"
90
93
  target_fact_index: int # Index of the target fact in the batch
91
94
  strength: float = 1.0 # Strength of the causal relationship
92
95
 
@@ -113,6 +116,7 @@ class ExtractedFact:
113
116
  context: str = ""
114
117
  mentioned_at: datetime | None = None
115
118
  metadata: dict[str, str] = field(default_factory=dict)
119
+ tags: list[str] = field(default_factory=list) # Visibility scope tags
116
120
 
117
121
 
118
122
  @dataclass
@@ -158,6 +162,9 @@ class ProcessedFact:
158
162
  # Track which content this fact came from (for user entity merging)
159
163
  content_index: int = 0
160
164
 
165
+ # Visibility scope tags
166
+ tags: list[str] = field(default_factory=list)
167
+
161
168
  @property
162
169
  def is_duplicate(self) -> bool:
163
170
  """Check if this fact was marked as a duplicate."""
@@ -201,6 +208,7 @@ class ProcessedFact:
201
208
  causal_relations=extracted_fact.causal_relations,
202
209
  chunk_id=chunk_id,
203
210
  content_index=extracted_fact.content_index,
211
+ tags=extracted_fact.tags,
204
212
  )
205
213
 
206
214
 
@@ -232,6 +240,7 @@ class RetainBatch:
232
240
  document_id: str | None = None
233
241
  fact_type_override: str | None = None
234
242
  confidence_score: float | None = None
243
+ document_tags: list[str] = field(default_factory=list) # Tags applied to all items
235
244
 
236
245
  # Extracted data (populated during processing)
237
246
  extracted_facts: list[ExtractedFact] = field(default_factory=list)
@@ -11,7 +11,8 @@ from abc import ABC, abstractmethod
11
11
 
12
12
  from ..db_utils import acquire_with_retry
13
13
  from ..memory_engine import fq_table
14
- from .types import RetrievalResult
14
+ from .tags import TagsMatch, filter_results_by_tags
15
+ from .types import MPFPTimings, RetrievalResult
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
@@ -42,7 +43,10 @@ class GraphRetriever(ABC):
42
43
  query_text: str | None = None,
43
44
  semantic_seeds: list[RetrievalResult] | None = None,
44
45
  temporal_seeds: list[RetrievalResult] | None = None,
45
- ) -> list[RetrievalResult]:
46
+ adjacency=None, # TypedAdjacency, optional pre-loaded graph
47
+ tags: list[str] | None = None, # Visibility scope tags for filtering
48
+ tags_match: TagsMatch = "any", # How to match tags: 'any' (OR) or 'all' (AND)
49
+ ) -> tuple[list[RetrievalResult], MPFPTimings | None]:
46
50
  """
47
51
  Retrieve relevant facts via graph traversal.
48
52
 
@@ -55,9 +59,11 @@ class GraphRetriever(ABC):
55
59
  query_text: Original query text (optional, for some strategies)
56
60
  semantic_seeds: Pre-computed semantic entry points (from semantic retrieval)
57
61
  temporal_seeds: Pre-computed temporal entry points (from temporal retrieval)
62
+ adjacency: Pre-loaded typed adjacency graph (optional, for MPFP)
63
+ tags: Optional list of tags for visibility filtering (OR matching)
58
64
 
59
65
  Returns:
60
- List of RetrievalResult objects with activation scores set
66
+ Tuple of (List of RetrievalResult with activation scores, optional timing info)
61
67
  """
62
68
  pass
63
69
 
@@ -111,7 +117,10 @@ class BFSGraphRetriever(GraphRetriever):
111
117
  query_text: str | None = None,
112
118
  semantic_seeds: list[RetrievalResult] | None = None,
113
119
  temporal_seeds: list[RetrievalResult] | None = None,
114
- ) -> list[RetrievalResult]:
120
+ adjacency=None, # Not used by BFS
121
+ tags: list[str] | None = None,
122
+ tags_match: TagsMatch = "any",
123
+ ) -> tuple[list[RetrievalResult], MPFPTimings | None]:
115
124
  """
116
125
  Retrieve facts using BFS spreading activation.
117
126
 
@@ -122,11 +131,14 @@ class BFSGraphRetriever(GraphRetriever):
122
131
  4. Return visited nodes up to budget
123
132
 
124
133
  Note: BFS finds its own entry points via embedding search.
125
- The semantic_seeds and temporal_seeds parameters are accepted
134
+ The semantic_seeds, temporal_seeds, and adjacency parameters are accepted
126
135
  for interface compatibility but not used.
127
136
  """
128
137
  async with acquire_with_retry(pool) as conn:
129
- return await self._retrieve_with_conn(conn, query_embedding_str, bank_id, fact_type, budget)
138
+ results = await self._retrieve_with_conn(
139
+ conn, query_embedding_str, bank_id, fact_type, budget, tags=tags, tags_match=tags_match
140
+ )
141
+ return results, None
130
142
 
131
143
  async def _retrieve_with_conn(
132
144
  self,
@@ -135,33 +147,46 @@ class BFSGraphRetriever(GraphRetriever):
135
147
  bank_id: str,
136
148
  fact_type: str,
137
149
  budget: int,
150
+ tags: list[str] | None = None,
151
+ tags_match: TagsMatch = "any",
138
152
  ) -> list[RetrievalResult]:
139
153
  """Internal implementation with connection."""
154
+ from .tags import build_tags_where_clause_simple
155
+
156
+ tags_clause = build_tags_where_clause_simple(tags, 6, match=tags_match)
157
+ params = [query_embedding_str, bank_id, fact_type, self.entry_point_threshold, self.entry_point_limit]
158
+ if tags:
159
+ params.append(tags)
140
160
 
141
161
  # Step 1: Find entry points
142
162
  entry_points = await conn.fetch(
143
163
  f"""
144
164
  SELECT id, text, context, event_date, occurred_start, occurred_end,
145
- mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
165
+ mentioned_at, embedding, fact_type, document_id, chunk_id, tags,
146
166
  1 - (embedding <=> $1::vector) AS similarity
147
167
  FROM {fq_table("memory_units")}
148
168
  WHERE bank_id = $2
149
169
  AND embedding IS NOT NULL
150
170
  AND fact_type = $3
151
171
  AND (1 - (embedding <=> $1::vector)) >= $4
172
+ {tags_clause}
152
173
  ORDER BY embedding <=> $1::vector
153
174
  LIMIT $5
154
175
  """,
155
- query_embedding_str,
156
- bank_id,
157
- fact_type,
158
- self.entry_point_threshold,
159
- self.entry_point_limit,
176
+ *params,
160
177
  )
161
178
 
162
179
  if not entry_points:
180
+ logger.debug(
181
+ f"[BFS] No entry points found for fact_type={fact_type} (tags={tags}, tags_match={tags_match})"
182
+ )
163
183
  return []
164
184
 
185
+ logger.debug(
186
+ f"[BFS] Found {len(entry_points)} entry points for fact_type={fact_type} "
187
+ f"(tags={tags}, tags_match={tags_match})"
188
+ )
189
+
165
190
  # Step 2: BFS spreading activation
166
191
  visited = set()
167
192
  results = []
@@ -191,8 +216,8 @@ class BFSGraphRetriever(GraphRetriever):
191
216
  neighbors = await conn.fetch(
192
217
  f"""
193
218
  SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end,
194
- mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type,
195
- mu.document_id, mu.chunk_id,
219
+ mu.mentioned_at, mu.embedding, mu.fact_type,
220
+ mu.document_id, mu.chunk_id, mu.tags,
196
221
  ml.weight, ml.link_type, ml.from_unit_id
197
222
  FROM {fq_table("memory_links")} ml
198
223
  JOIN {fq_table("memory_units")} mu ON ml.to_unit_id = mu.id
@@ -232,4 +257,8 @@ class BFSGraphRetriever(GraphRetriever):
232
257
  neighbor_result = RetrievalResult.from_db_row(dict(n))
233
258
  queue.append((neighbor_result, new_activation))
234
259
 
260
+ # Apply tags filtering (BFS may traverse into memories that don't match tags criteria)
261
+ if tags:
262
+ results = filter_results_by_tags(results, tags, match=tags_match)
263
+
235
264
  return results