hindsight-api 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +252 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/api/http.py +282 -20
- hindsight_api/api/mcp.py +47 -52
- hindsight_api/config.py +238 -6
- hindsight_api/engine/cross_encoder.py +599 -86
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/embeddings.py +453 -26
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +8 -4
- hindsight_api/engine/llm_wrapper.py +241 -27
- hindsight_api/engine/memory_engine.py +609 -122
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/response_models.py +38 -0
- hindsight_api/engine/retain/fact_extraction.py +388 -192
- hindsight_api/engine/retain/fact_storage.py +34 -8
- hindsight_api/engine/retain/link_utils.py +24 -16
- hindsight_api/engine/retain/orchestrator.py +52 -17
- hindsight_api/engine/retain/types.py +9 -0
- hindsight_api/engine/search/graph_retrieval.py +42 -13
- hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +847 -200
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +1 -1
- hindsight_api/engine/search/trace.py +12 -0
- hindsight_api/engine/search/tracer.py +24 -1
- hindsight_api/engine/search/types.py +21 -0
- hindsight_api/engine/task_backend.py +109 -18
- hindsight_api/engine/utils.py +1 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/main.py +56 -4
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -1
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
- hindsight_api-0.3.0.dist-info/RECORD +82 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
|
@@ -45,6 +45,7 @@ async def insert_facts_batch(
|
|
|
45
45
|
metadata_jsons = []
|
|
46
46
|
chunk_ids = []
|
|
47
47
|
document_ids = []
|
|
48
|
+
tags_list = []
|
|
48
49
|
|
|
49
50
|
for fact in facts:
|
|
50
51
|
fact_texts.append(fact.fact_text)
|
|
@@ -65,16 +66,31 @@ async def insert_facts_batch(
|
|
|
65
66
|
chunk_ids.append(fact.chunk_id)
|
|
66
67
|
# Use per-fact document_id if available, otherwise fallback to batch-level document_id
|
|
67
68
|
document_ids.append(fact.document_id if fact.document_id else document_id)
|
|
69
|
+
# Convert tags to JSON string for proper batch insertion (PostgreSQL unnest doesn't handle 2D arrays well)
|
|
70
|
+
tags_list.append(json.dumps(fact.tags if fact.tags else []))
|
|
68
71
|
|
|
69
72
|
# Batch insert all facts
|
|
73
|
+
# Note: tags are passed as JSON strings and converted back to varchar[] via jsonb_array_elements_text + array_agg
|
|
70
74
|
results = await conn.fetch(
|
|
71
75
|
f"""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
76
|
+
WITH input_data AS (
|
|
77
|
+
SELECT * FROM unnest(
|
|
78
|
+
$2::text[], $3::vector[], $4::timestamptz[], $5::timestamptz[], $6::timestamptz[], $7::timestamptz[],
|
|
79
|
+
$8::text[], $9::text[], $10::float[], $11::int[], $12::jsonb[], $13::text[], $14::text[], $15::jsonb[]
|
|
80
|
+
) AS t(text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
|
|
81
|
+
context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id, tags_json)
|
|
77
82
|
)
|
|
83
|
+
INSERT INTO {fq_table("memory_units")} (bank_id, text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
|
|
84
|
+
context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id, tags)
|
|
85
|
+
SELECT
|
|
86
|
+
$1,
|
|
87
|
+
text, embedding, event_date, occurred_start, occurred_end, mentioned_at,
|
|
88
|
+
context, fact_type, confidence_score, access_count, metadata, chunk_id, document_id,
|
|
89
|
+
COALESCE(
|
|
90
|
+
(SELECT array_agg(elem) FROM jsonb_array_elements_text(tags_json) AS elem),
|
|
91
|
+
'{{}}'::varchar[]
|
|
92
|
+
)
|
|
93
|
+
FROM input_data
|
|
78
94
|
RETURNING id
|
|
79
95
|
""",
|
|
80
96
|
bank_id,
|
|
@@ -91,6 +107,7 @@ async def insert_facts_batch(
|
|
|
91
107
|
metadata_jsons,
|
|
92
108
|
chunk_ids,
|
|
93
109
|
document_ids,
|
|
110
|
+
tags_list,
|
|
94
111
|
)
|
|
95
112
|
|
|
96
113
|
unit_ids = [str(row["id"]) for row in results]
|
|
@@ -121,7 +138,13 @@ async def ensure_bank_exists(conn, bank_id: str) -> None:
|
|
|
121
138
|
|
|
122
139
|
|
|
123
140
|
async def handle_document_tracking(
|
|
124
|
-
conn,
|
|
141
|
+
conn,
|
|
142
|
+
bank_id: str,
|
|
143
|
+
document_id: str,
|
|
144
|
+
combined_content: str,
|
|
145
|
+
is_first_batch: bool,
|
|
146
|
+
retain_params: dict | None = None,
|
|
147
|
+
document_tags: list[str] | None = None,
|
|
125
148
|
) -> None:
|
|
126
149
|
"""
|
|
127
150
|
Handle document tracking in the database.
|
|
@@ -133,6 +156,7 @@ async def handle_document_tracking(
|
|
|
133
156
|
combined_content: Combined content text from all content items
|
|
134
157
|
is_first_batch: Whether this is the first batch (for chunked operations)
|
|
135
158
|
retain_params: Optional parameters passed during retain (context, event_date, etc.)
|
|
159
|
+
document_tags: Optional list of tags to associate with the document
|
|
136
160
|
"""
|
|
137
161
|
import hashlib
|
|
138
162
|
|
|
@@ -149,13 +173,14 @@ async def handle_document_tracking(
|
|
|
149
173
|
# Insert document (or update if exists from concurrent operations)
|
|
150
174
|
await conn.execute(
|
|
151
175
|
f"""
|
|
152
|
-
INSERT INTO {fq_table("documents")} (id, bank_id, original_text, content_hash, metadata, retain_params)
|
|
153
|
-
VALUES ($1, $2, $3, $4, $5, $6)
|
|
176
|
+
INSERT INTO {fq_table("documents")} (id, bank_id, original_text, content_hash, metadata, retain_params, tags)
|
|
177
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
154
178
|
ON CONFLICT (id, bank_id) DO UPDATE
|
|
155
179
|
SET original_text = EXCLUDED.original_text,
|
|
156
180
|
content_hash = EXCLUDED.content_hash,
|
|
157
181
|
metadata = EXCLUDED.metadata,
|
|
158
182
|
retain_params = EXCLUDED.retain_params,
|
|
183
|
+
tags = EXCLUDED.tags,
|
|
159
184
|
updated_at = NOW()
|
|
160
185
|
""",
|
|
161
186
|
document_id,
|
|
@@ -164,4 +189,5 @@ async def handle_document_tracking(
|
|
|
164
189
|
content_hash,
|
|
165
190
|
json.dumps({}), # Empty metadata dict
|
|
166
191
|
json.dumps(retain_params) if retain_params else None,
|
|
192
|
+
document_tags or [],
|
|
167
193
|
)
|
|
@@ -479,14 +479,18 @@ async def create_temporal_links_batch_per_fact(
|
|
|
479
479
|
|
|
480
480
|
if links:
|
|
481
481
|
insert_start = time_mod.time()
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
482
|
+
# Batch inserts to avoid timeout on large batches
|
|
483
|
+
BATCH_SIZE = 1000
|
|
484
|
+
for batch_start in range(0, len(links), BATCH_SIZE):
|
|
485
|
+
batch = links[batch_start : batch_start + BATCH_SIZE]
|
|
486
|
+
await conn.executemany(
|
|
487
|
+
f"""
|
|
488
|
+
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
489
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
490
|
+
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
|
|
491
|
+
""",
|
|
492
|
+
batch,
|
|
493
|
+
)
|
|
490
494
|
_log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
|
|
491
495
|
|
|
492
496
|
return len(links)
|
|
@@ -644,14 +648,18 @@ async def create_semantic_links_batch(
|
|
|
644
648
|
|
|
645
649
|
if all_links:
|
|
646
650
|
insert_start = time_mod.time()
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
651
|
+
# Batch inserts to avoid timeout on large batches
|
|
652
|
+
BATCH_SIZE = 1000
|
|
653
|
+
for batch_start in range(0, len(all_links), BATCH_SIZE):
|
|
654
|
+
batch = all_links[batch_start : batch_start + BATCH_SIZE]
|
|
655
|
+
await conn.executemany(
|
|
656
|
+
f"""
|
|
657
|
+
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
|
|
658
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
659
|
+
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
|
|
660
|
+
""",
|
|
661
|
+
batch,
|
|
662
|
+
)
|
|
655
663
|
_log(
|
|
656
664
|
log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s"
|
|
657
665
|
)
|
|
@@ -9,6 +9,7 @@ import time
|
|
|
9
9
|
import uuid
|
|
10
10
|
from datetime import UTC, datetime
|
|
11
11
|
|
|
12
|
+
from ...config import get_config
|
|
12
13
|
from ..db_utils import acquire_with_retry
|
|
13
14
|
from . import bank_utils
|
|
14
15
|
|
|
@@ -18,6 +19,7 @@ def utcnow():
|
|
|
18
19
|
return datetime.now(UTC)
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
from ..response_models import TokenUsage
|
|
21
23
|
from . import (
|
|
22
24
|
chunk_storage,
|
|
23
25
|
deduplication,
|
|
@@ -47,7 +49,8 @@ async def retain_batch(
|
|
|
47
49
|
is_first_batch: bool = True,
|
|
48
50
|
fact_type_override: str | None = None,
|
|
49
51
|
confidence_score: float | None = None,
|
|
50
|
-
|
|
52
|
+
document_tags: list[str] | None = None,
|
|
53
|
+
) -> tuple[list[list[str]], TokenUsage]:
|
|
51
54
|
"""
|
|
52
55
|
Process a batch of content through the retain pipeline.
|
|
53
56
|
|
|
@@ -65,9 +68,10 @@ async def retain_batch(
|
|
|
65
68
|
is_first_batch: Whether this is the first batch
|
|
66
69
|
fact_type_override: Override fact type for all facts
|
|
67
70
|
confidence_score: Confidence score for opinions
|
|
71
|
+
document_tags: Tags applied to all items in this batch
|
|
68
72
|
|
|
69
73
|
Returns:
|
|
70
|
-
|
|
74
|
+
Tuple of (unit ID lists, token usage for fact extraction)
|
|
71
75
|
"""
|
|
72
76
|
start_time = time.time()
|
|
73
77
|
total_chars = sum(len(item.get("content", "")) for item in contents_dicts)
|
|
@@ -86,12 +90,16 @@ async def retain_batch(
|
|
|
86
90
|
# Convert dicts to RetainContent objects
|
|
87
91
|
contents = []
|
|
88
92
|
for item in contents_dicts:
|
|
93
|
+
# Merge item-level tags with document-level tags
|
|
94
|
+
item_tags = item.get("tags", []) or []
|
|
95
|
+
merged_tags = list(set(item_tags + (document_tags or [])))
|
|
89
96
|
content = RetainContent(
|
|
90
97
|
content=item["content"],
|
|
91
98
|
context=item.get("context", ""),
|
|
92
99
|
event_date=item.get("event_date") or utcnow(),
|
|
93
100
|
metadata=item.get("metadata", {}),
|
|
94
101
|
entities=item.get("entities", []),
|
|
102
|
+
tags=merged_tags,
|
|
95
103
|
)
|
|
96
104
|
contents.append(content)
|
|
97
105
|
|
|
@@ -99,7 +107,7 @@ async def retain_batch(
|
|
|
99
107
|
step_start = time.time()
|
|
100
108
|
extract_opinions = fact_type_override == "opinion"
|
|
101
109
|
|
|
102
|
-
extracted_facts, chunks = await fact_extraction.extract_facts_from_contents(
|
|
110
|
+
extracted_facts, chunks, usage = await fact_extraction.extract_facts_from_contents(
|
|
103
111
|
contents, llm_config, agent_name, extract_opinions
|
|
104
112
|
)
|
|
105
113
|
log_buffer.append(
|
|
@@ -129,7 +137,7 @@ async def retain_batch(
|
|
|
129
137
|
if first_item.get("metadata"):
|
|
130
138
|
retain_params["metadata"] = first_item["metadata"]
|
|
131
139
|
await fact_storage.handle_document_tracking(
|
|
132
|
-
conn, bank_id, document_id, combined_content, is_first_batch, retain_params
|
|
140
|
+
conn, bank_id, document_id, combined_content, is_first_batch, retain_params, document_tags
|
|
133
141
|
)
|
|
134
142
|
else:
|
|
135
143
|
# Check for per-item document_ids
|
|
@@ -157,14 +165,14 @@ async def retain_batch(
|
|
|
157
165
|
if first_item.get("metadata"):
|
|
158
166
|
retain_params["metadata"] = first_item["metadata"]
|
|
159
167
|
await fact_storage.handle_document_tracking(
|
|
160
|
-
conn, bank_id, doc_id, combined_content, is_first_batch, retain_params
|
|
168
|
+
conn, bank_id, doc_id, combined_content, is_first_batch, retain_params, document_tags
|
|
161
169
|
)
|
|
162
170
|
|
|
163
171
|
total_time = time.time() - start_time
|
|
164
172
|
logger.info(
|
|
165
173
|
f"RETAIN_BATCH COMPLETE: 0 facts extracted from {len(contents)} contents in {total_time:.3f}s (document tracked, no facts)"
|
|
166
174
|
)
|
|
167
|
-
return [[] for _ in contents]
|
|
175
|
+
return [[] for _ in contents], usage
|
|
168
176
|
|
|
169
177
|
# Apply fact_type_override if provided
|
|
170
178
|
if fact_type_override:
|
|
@@ -223,7 +231,7 @@ async def retain_batch(
|
|
|
223
231
|
retain_params["metadata"] = first_item["metadata"]
|
|
224
232
|
|
|
225
233
|
await fact_storage.handle_document_tracking(
|
|
226
|
-
conn, bank_id, document_id, combined_content, is_first_batch, retain_params
|
|
234
|
+
conn, bank_id, document_id, combined_content, is_first_batch, retain_params, document_tags
|
|
227
235
|
)
|
|
228
236
|
document_ids_added.append(document_id)
|
|
229
237
|
doc_id_mapping[None] = document_id # For backwards compatibility
|
|
@@ -267,7 +275,13 @@ async def retain_batch(
|
|
|
267
275
|
retain_params["metadata"] = first_item["metadata"]
|
|
268
276
|
|
|
269
277
|
await fact_storage.handle_document_tracking(
|
|
270
|
-
conn,
|
|
278
|
+
conn,
|
|
279
|
+
bank_id,
|
|
280
|
+
actual_doc_id,
|
|
281
|
+
combined_content,
|
|
282
|
+
is_first_batch,
|
|
283
|
+
retain_params,
|
|
284
|
+
document_tags,
|
|
271
285
|
)
|
|
272
286
|
document_ids_added.append(actual_doc_id)
|
|
273
287
|
|
|
@@ -344,7 +358,7 @@ async def retain_batch(
|
|
|
344
358
|
non_duplicate_facts = deduplication.filter_duplicates(processed_facts, is_duplicate_flags)
|
|
345
359
|
|
|
346
360
|
if not non_duplicate_facts:
|
|
347
|
-
return [[] for _ in contents]
|
|
361
|
+
return [[] for _ in contents], usage
|
|
348
362
|
|
|
349
363
|
# Insert facts (document_id is now stored per-fact)
|
|
350
364
|
step_start = time.time()
|
|
@@ -394,16 +408,26 @@ async def retain_batch(
|
|
|
394
408
|
causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, non_duplicate_facts)
|
|
395
409
|
log_buffer.append(f"[10] Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")
|
|
396
410
|
|
|
397
|
-
# Regenerate observations
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
411
|
+
# Regenerate observations - sync (in transaction) or async (background task)
|
|
412
|
+
config = get_config()
|
|
413
|
+
if config.retain_observations_async:
|
|
414
|
+
# Queue for async processing after transaction commits
|
|
415
|
+
entity_ids_for_async = list(set(link.entity_id for link in entity_links)) if entity_links else []
|
|
416
|
+
log_buffer.append(
|
|
417
|
+
f"[11] Observations: queued {len(entity_ids_for_async)} entities for async processing"
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
# Run synchronously inside transaction for atomicity
|
|
421
|
+
await observation_regeneration.regenerate_observations_batch(
|
|
422
|
+
conn, embeddings_model, llm_config, bank_id, entity_links, log_buffer
|
|
423
|
+
)
|
|
424
|
+
entity_ids_for_async = []
|
|
401
425
|
|
|
402
426
|
# Map results back to original content items
|
|
403
427
|
result_unit_ids = _map_results_to_contents(contents, extracted_facts, is_duplicate_flags, unit_ids)
|
|
404
428
|
|
|
405
|
-
# Trigger background tasks AFTER transaction commits
|
|
406
|
-
await _trigger_background_tasks(task_backend, bank_id, unit_ids, non_duplicate_facts)
|
|
429
|
+
# Trigger background tasks AFTER transaction commits
|
|
430
|
+
await _trigger_background_tasks(task_backend, bank_id, unit_ids, non_duplicate_facts, entity_ids_for_async)
|
|
407
431
|
|
|
408
432
|
# Log final summary
|
|
409
433
|
total_time = time.time() - start_time
|
|
@@ -415,7 +439,7 @@ async def retain_batch(
|
|
|
415
439
|
|
|
416
440
|
logger.info("\n" + "\n".join(log_buffer) + "\n")
|
|
417
441
|
|
|
418
|
-
return result_unit_ids
|
|
442
|
+
return result_unit_ids, usage
|
|
419
443
|
|
|
420
444
|
|
|
421
445
|
def _map_results_to_contents(
|
|
@@ -453,8 +477,9 @@ async def _trigger_background_tasks(
|
|
|
453
477
|
bank_id: str,
|
|
454
478
|
unit_ids: list[str],
|
|
455
479
|
facts: list[ProcessedFact],
|
|
480
|
+
entity_ids_for_observations: list[str] | None = None,
|
|
456
481
|
) -> None:
|
|
457
|
-
"""Trigger
|
|
482
|
+
"""Trigger background tasks after transaction commits."""
|
|
458
483
|
# Trigger opinion reinforcement if there are entities
|
|
459
484
|
fact_entities = [[e.name for e in fact.entities] for fact in facts]
|
|
460
485
|
if any(fact_entities):
|
|
@@ -467,3 +492,13 @@ async def _trigger_background_tasks(
|
|
|
467
492
|
"unit_entities": fact_entities,
|
|
468
493
|
}
|
|
469
494
|
)
|
|
495
|
+
|
|
496
|
+
# Trigger observation regeneration if async mode is enabled
|
|
497
|
+
if entity_ids_for_observations:
|
|
498
|
+
await task_backend.submit_task(
|
|
499
|
+
{
|
|
500
|
+
"type": "regenerate_observations",
|
|
501
|
+
"bank_id": bank_id,
|
|
502
|
+
"entity_ids": entity_ids_for_observations,
|
|
503
|
+
}
|
|
504
|
+
)
|
|
@@ -21,6 +21,7 @@ class RetainContentDict(TypedDict, total=False):
|
|
|
21
21
|
metadata: Custom key-value metadata (optional)
|
|
22
22
|
document_id: Document ID for this content item (optional)
|
|
23
23
|
entities: User-provided entities to merge with extracted entities (optional)
|
|
24
|
+
tags: Visibility scope tags for this content item (optional)
|
|
24
25
|
"""
|
|
25
26
|
|
|
26
27
|
content: str # Required
|
|
@@ -29,6 +30,7 @@ class RetainContentDict(TypedDict, total=False):
|
|
|
29
30
|
metadata: dict[str, str]
|
|
30
31
|
document_id: str
|
|
31
32
|
entities: list[dict[str, str]] # [{"text": "...", "type": "..."}]
|
|
33
|
+
tags: list[str] # Visibility scope tags
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
def _now_utc() -> datetime:
|
|
@@ -49,6 +51,7 @@ class RetainContent:
|
|
|
49
51
|
event_date: datetime = field(default_factory=_now_utc)
|
|
50
52
|
metadata: dict[str, str] = field(default_factory=dict)
|
|
51
53
|
entities: list[dict[str, str]] = field(default_factory=list) # User-provided entities
|
|
54
|
+
tags: list[str] = field(default_factory=list) # Visibility scope tags
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
@dataclass
|
|
@@ -113,6 +116,7 @@ class ExtractedFact:
|
|
|
113
116
|
context: str = ""
|
|
114
117
|
mentioned_at: datetime | None = None
|
|
115
118
|
metadata: dict[str, str] = field(default_factory=dict)
|
|
119
|
+
tags: list[str] = field(default_factory=list) # Visibility scope tags
|
|
116
120
|
|
|
117
121
|
|
|
118
122
|
@dataclass
|
|
@@ -158,6 +162,9 @@ class ProcessedFact:
|
|
|
158
162
|
# Track which content this fact came from (for user entity merging)
|
|
159
163
|
content_index: int = 0
|
|
160
164
|
|
|
165
|
+
# Visibility scope tags
|
|
166
|
+
tags: list[str] = field(default_factory=list)
|
|
167
|
+
|
|
161
168
|
@property
|
|
162
169
|
def is_duplicate(self) -> bool:
|
|
163
170
|
"""Check if this fact was marked as a duplicate."""
|
|
@@ -201,6 +208,7 @@ class ProcessedFact:
|
|
|
201
208
|
causal_relations=extracted_fact.causal_relations,
|
|
202
209
|
chunk_id=chunk_id,
|
|
203
210
|
content_index=extracted_fact.content_index,
|
|
211
|
+
tags=extracted_fact.tags,
|
|
204
212
|
)
|
|
205
213
|
|
|
206
214
|
|
|
@@ -232,6 +240,7 @@ class RetainBatch:
|
|
|
232
240
|
document_id: str | None = None
|
|
233
241
|
fact_type_override: str | None = None
|
|
234
242
|
confidence_score: float | None = None
|
|
243
|
+
document_tags: list[str] = field(default_factory=list) # Tags applied to all items
|
|
235
244
|
|
|
236
245
|
# Extracted data (populated during processing)
|
|
237
246
|
extracted_facts: list[ExtractedFact] = field(default_factory=list)
|
|
@@ -11,7 +11,8 @@ from abc import ABC, abstractmethod
|
|
|
11
11
|
|
|
12
12
|
from ..db_utils import acquire_with_retry
|
|
13
13
|
from ..memory_engine import fq_table
|
|
14
|
-
from .
|
|
14
|
+
from .tags import TagsMatch, filter_results_by_tags
|
|
15
|
+
from .types import MPFPTimings, RetrievalResult
|
|
15
16
|
|
|
16
17
|
logger = logging.getLogger(__name__)
|
|
17
18
|
|
|
@@ -42,7 +43,10 @@ class GraphRetriever(ABC):
|
|
|
42
43
|
query_text: str | None = None,
|
|
43
44
|
semantic_seeds: list[RetrievalResult] | None = None,
|
|
44
45
|
temporal_seeds: list[RetrievalResult] | None = None,
|
|
45
|
-
|
|
46
|
+
adjacency=None, # TypedAdjacency, optional pre-loaded graph
|
|
47
|
+
tags: list[str] | None = None, # Visibility scope tags for filtering
|
|
48
|
+
tags_match: TagsMatch = "any", # How to match tags: 'any' (OR) or 'all' (AND)
|
|
49
|
+
) -> tuple[list[RetrievalResult], MPFPTimings | None]:
|
|
46
50
|
"""
|
|
47
51
|
Retrieve relevant facts via graph traversal.
|
|
48
52
|
|
|
@@ -55,9 +59,11 @@ class GraphRetriever(ABC):
|
|
|
55
59
|
query_text: Original query text (optional, for some strategies)
|
|
56
60
|
semantic_seeds: Pre-computed semantic entry points (from semantic retrieval)
|
|
57
61
|
temporal_seeds: Pre-computed temporal entry points (from temporal retrieval)
|
|
62
|
+
adjacency: Pre-loaded typed adjacency graph (optional, for MPFP)
|
|
63
|
+
tags: Optional list of tags for visibility filtering (OR matching)
|
|
58
64
|
|
|
59
65
|
Returns:
|
|
60
|
-
List of RetrievalResult
|
|
66
|
+
Tuple of (List of RetrievalResult with activation scores, optional timing info)
|
|
61
67
|
"""
|
|
62
68
|
pass
|
|
63
69
|
|
|
@@ -111,7 +117,10 @@ class BFSGraphRetriever(GraphRetriever):
|
|
|
111
117
|
query_text: str | None = None,
|
|
112
118
|
semantic_seeds: list[RetrievalResult] | None = None,
|
|
113
119
|
temporal_seeds: list[RetrievalResult] | None = None,
|
|
114
|
-
|
|
120
|
+
adjacency=None, # Not used by BFS
|
|
121
|
+
tags: list[str] | None = None,
|
|
122
|
+
tags_match: TagsMatch = "any",
|
|
123
|
+
) -> tuple[list[RetrievalResult], MPFPTimings | None]:
|
|
115
124
|
"""
|
|
116
125
|
Retrieve facts using BFS spreading activation.
|
|
117
126
|
|
|
@@ -122,11 +131,14 @@ class BFSGraphRetriever(GraphRetriever):
|
|
|
122
131
|
4. Return visited nodes up to budget
|
|
123
132
|
|
|
124
133
|
Note: BFS finds its own entry points via embedding search.
|
|
125
|
-
The semantic_seeds and
|
|
134
|
+
The semantic_seeds, temporal_seeds, and adjacency parameters are accepted
|
|
126
135
|
for interface compatibility but not used.
|
|
127
136
|
"""
|
|
128
137
|
async with acquire_with_retry(pool) as conn:
|
|
129
|
-
|
|
138
|
+
results = await self._retrieve_with_conn(
|
|
139
|
+
conn, query_embedding_str, bank_id, fact_type, budget, tags=tags, tags_match=tags_match
|
|
140
|
+
)
|
|
141
|
+
return results, None
|
|
130
142
|
|
|
131
143
|
async def _retrieve_with_conn(
|
|
132
144
|
self,
|
|
@@ -135,33 +147,46 @@ class BFSGraphRetriever(GraphRetriever):
|
|
|
135
147
|
bank_id: str,
|
|
136
148
|
fact_type: str,
|
|
137
149
|
budget: int,
|
|
150
|
+
tags: list[str] | None = None,
|
|
151
|
+
tags_match: TagsMatch = "any",
|
|
138
152
|
) -> list[RetrievalResult]:
|
|
139
153
|
"""Internal implementation with connection."""
|
|
154
|
+
from .tags import build_tags_where_clause_simple
|
|
155
|
+
|
|
156
|
+
tags_clause = build_tags_where_clause_simple(tags, 6, match=tags_match)
|
|
157
|
+
params = [query_embedding_str, bank_id, fact_type, self.entry_point_threshold, self.entry_point_limit]
|
|
158
|
+
if tags:
|
|
159
|
+
params.append(tags)
|
|
140
160
|
|
|
141
161
|
# Step 1: Find entry points
|
|
142
162
|
entry_points = await conn.fetch(
|
|
143
163
|
f"""
|
|
144
164
|
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
145
|
-
mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
|
|
165
|
+
mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
|
|
146
166
|
1 - (embedding <=> $1::vector) AS similarity
|
|
147
167
|
FROM {fq_table("memory_units")}
|
|
148
168
|
WHERE bank_id = $2
|
|
149
169
|
AND embedding IS NOT NULL
|
|
150
170
|
AND fact_type = $3
|
|
151
171
|
AND (1 - (embedding <=> $1::vector)) >= $4
|
|
172
|
+
{tags_clause}
|
|
152
173
|
ORDER BY embedding <=> $1::vector
|
|
153
174
|
LIMIT $5
|
|
154
175
|
""",
|
|
155
|
-
|
|
156
|
-
bank_id,
|
|
157
|
-
fact_type,
|
|
158
|
-
self.entry_point_threshold,
|
|
159
|
-
self.entry_point_limit,
|
|
176
|
+
*params,
|
|
160
177
|
)
|
|
161
178
|
|
|
162
179
|
if not entry_points:
|
|
180
|
+
logger.debug(
|
|
181
|
+
f"[BFS] No entry points found for fact_type={fact_type} (tags={tags}, tags_match={tags_match})"
|
|
182
|
+
)
|
|
163
183
|
return []
|
|
164
184
|
|
|
185
|
+
logger.debug(
|
|
186
|
+
f"[BFS] Found {len(entry_points)} entry points for fact_type={fact_type} "
|
|
187
|
+
f"(tags={tags}, tags_match={tags_match})"
|
|
188
|
+
)
|
|
189
|
+
|
|
165
190
|
# Step 2: BFS spreading activation
|
|
166
191
|
visited = set()
|
|
167
192
|
results = []
|
|
@@ -192,7 +217,7 @@ class BFSGraphRetriever(GraphRetriever):
|
|
|
192
217
|
f"""
|
|
193
218
|
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end,
|
|
194
219
|
mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type,
|
|
195
|
-
mu.document_id, mu.chunk_id,
|
|
220
|
+
mu.document_id, mu.chunk_id, mu.tags,
|
|
196
221
|
ml.weight, ml.link_type, ml.from_unit_id
|
|
197
222
|
FROM {fq_table("memory_links")} ml
|
|
198
223
|
JOIN {fq_table("memory_units")} mu ON ml.to_unit_id = mu.id
|
|
@@ -232,4 +257,8 @@ class BFSGraphRetriever(GraphRetriever):
|
|
|
232
257
|
neighbor_result = RetrievalResult.from_db_row(dict(n))
|
|
233
258
|
queue.append((neighbor_result, new_activation))
|
|
234
259
|
|
|
260
|
+
# Apply tags filtering (BFS may traverse into memories that don't match tags criteria)
|
|
261
|
+
if tags:
|
|
262
|
+
results = filter_results_by_tags(results, tags, match=tags_match)
|
|
263
|
+
|
|
235
264
|
return results
|