hindsight-api 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +252 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/api/http.py +282 -20
- hindsight_api/api/mcp.py +47 -52
- hindsight_api/config.py +238 -6
- hindsight_api/engine/cross_encoder.py +599 -86
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/embeddings.py +453 -26
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +8 -4
- hindsight_api/engine/llm_wrapper.py +241 -27
- hindsight_api/engine/memory_engine.py +609 -122
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/response_models.py +38 -0
- hindsight_api/engine/retain/fact_extraction.py +388 -192
- hindsight_api/engine/retain/fact_storage.py +34 -8
- hindsight_api/engine/retain/link_utils.py +24 -16
- hindsight_api/engine/retain/orchestrator.py +52 -17
- hindsight_api/engine/retain/types.py +9 -0
- hindsight_api/engine/search/graph_retrieval.py +42 -13
- hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +847 -200
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +1 -1
- hindsight_api/engine/search/trace.py +12 -0
- hindsight_api/engine/search/tracer.py +24 -1
- hindsight_api/engine/search/types.py +21 -0
- hindsight_api/engine/task_backend.py +109 -18
- hindsight_api/engine/utils.py +1 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/main.py +56 -4
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -1
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
- hindsight_api-0.3.0.dist-info/RECORD +82 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tags filtering utilities for retrieval.
|
|
3
|
+
|
|
4
|
+
Provides SQL building functions for filtering memories by tags.
|
|
5
|
+
Supports four matching modes via TagsMatch enum:
|
|
6
|
+
- "any": OR matching, includes untagged memories (default, backward compatible)
|
|
7
|
+
- "all": AND matching, includes untagged memories
|
|
8
|
+
- "any_strict": OR matching, excludes untagged memories
|
|
9
|
+
- "all_strict": AND matching, excludes untagged memories
|
|
10
|
+
|
|
11
|
+
OR matching (any/any_strict): Memory matches if ANY of its tags overlap with request tags
|
|
12
|
+
AND matching (all/all_strict): Memory matches if ALL request tags are present in its tags
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Literal
|
|
16
|
+
|
|
17
|
+
TagsMatch = Literal["any", "all", "any_strict", "all_strict"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _parse_tags_match(match: TagsMatch) -> tuple[str, bool]:
|
|
21
|
+
"""
|
|
22
|
+
Parse TagsMatch into operator and include_untagged flag.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tuple of (operator, include_untagged)
|
|
26
|
+
- operator: "&&" for any/any_strict, "@>" for all/all_strict
|
|
27
|
+
- include_untagged: True for any/all, False for any_strict/all_strict
|
|
28
|
+
"""
|
|
29
|
+
if match == "any":
|
|
30
|
+
return "&&", True
|
|
31
|
+
elif match == "all":
|
|
32
|
+
return "@>", True
|
|
33
|
+
elif match == "any_strict":
|
|
34
|
+
return "&&", False
|
|
35
|
+
elif match == "all_strict":
|
|
36
|
+
return "@>", False
|
|
37
|
+
else:
|
|
38
|
+
# Default to "any" behavior
|
|
39
|
+
return "&&", True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def build_tags_where_clause(
|
|
43
|
+
tags: list[str] | None,
|
|
44
|
+
param_offset: int = 1,
|
|
45
|
+
table_alias: str = "",
|
|
46
|
+
match: TagsMatch = "any",
|
|
47
|
+
) -> tuple[str, list, int]:
|
|
48
|
+
"""
|
|
49
|
+
Build a SQL WHERE clause for filtering by tags.
|
|
50
|
+
|
|
51
|
+
Supports four matching modes:
|
|
52
|
+
- "any" (default): OR matching, includes untagged memories
|
|
53
|
+
- "all": AND matching, includes untagged memories
|
|
54
|
+
- "any_strict": OR matching, excludes untagged memories
|
|
55
|
+
- "all_strict": AND matching, excludes untagged memories
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
tags: List of tags to filter by. If None or empty, returns empty clause (no filtering).
|
|
59
|
+
param_offset: Starting parameter number for SQL placeholders (default 1).
|
|
60
|
+
table_alias: Optional table alias prefix (e.g., "mu." for "memory_units mu").
|
|
61
|
+
match: Matching mode. Defaults to "any".
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Tuple of (sql_clause, params, next_param_offset):
|
|
65
|
+
- sql_clause: SQL WHERE clause string
|
|
66
|
+
- params: List of parameter values to bind
|
|
67
|
+
- next_param_offset: Next available parameter number
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
>>> clause, params, next_offset = build_tags_where_clause(['user_a'], 3, 'mu.', 'any_strict')
|
|
71
|
+
>>> print(clause) # "AND mu.tags IS NOT NULL AND mu.tags != '{}' AND mu.tags && $3"
|
|
72
|
+
"""
|
|
73
|
+
if not tags:
|
|
74
|
+
return "", [], param_offset
|
|
75
|
+
|
|
76
|
+
column = f"{table_alias}tags" if table_alias else "tags"
|
|
77
|
+
operator, include_untagged = _parse_tags_match(match)
|
|
78
|
+
|
|
79
|
+
if include_untagged:
|
|
80
|
+
# Include untagged memories (NULL or empty array) OR matching tags
|
|
81
|
+
clause = f"AND ({column} IS NULL OR {column} = '{{}}' OR {column} {operator} ${param_offset})"
|
|
82
|
+
else:
|
|
83
|
+
# Strict: only memories with matching tags (exclude NULL and empty)
|
|
84
|
+
clause = f"AND {column} IS NOT NULL AND {column} != '{{}}' AND {column} {operator} ${param_offset}"
|
|
85
|
+
|
|
86
|
+
return clause, [tags], param_offset + 1
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def build_tags_where_clause_simple(
|
|
90
|
+
tags: list[str] | None,
|
|
91
|
+
param_num: int,
|
|
92
|
+
table_alias: str = "",
|
|
93
|
+
match: TagsMatch = "any",
|
|
94
|
+
) -> str:
|
|
95
|
+
"""
|
|
96
|
+
Build a simple SQL WHERE clause for tags filtering.
|
|
97
|
+
|
|
98
|
+
This is a convenience version that returns just the clause string,
|
|
99
|
+
assuming the caller will add the tags array to their params list.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
tags: List of tags to filter by. If None or empty, returns empty string.
|
|
103
|
+
param_num: Parameter number to use in the clause.
|
|
104
|
+
table_alias: Optional table alias prefix.
|
|
105
|
+
match: Matching mode. Defaults to "any".
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
SQL clause string or empty string.
|
|
109
|
+
"""
|
|
110
|
+
if not tags:
|
|
111
|
+
return ""
|
|
112
|
+
|
|
113
|
+
column = f"{table_alias}tags" if table_alias else "tags"
|
|
114
|
+
operator, include_untagged = _parse_tags_match(match)
|
|
115
|
+
|
|
116
|
+
if include_untagged:
|
|
117
|
+
# Include untagged memories (NULL or empty array) OR matching tags
|
|
118
|
+
return f"AND ({column} IS NULL OR {column} = '{{}}' OR {column} {operator} ${param_num})"
|
|
119
|
+
else:
|
|
120
|
+
# Strict: only memories with matching tags (exclude NULL and empty)
|
|
121
|
+
return f"AND {column} IS NOT NULL AND {column} != '{{}}' AND {column} {operator} ${param_num}"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def filter_results_by_tags(
|
|
125
|
+
results: list,
|
|
126
|
+
tags: list[str] | None,
|
|
127
|
+
match: TagsMatch = "any",
|
|
128
|
+
) -> list:
|
|
129
|
+
"""
|
|
130
|
+
Filter retrieval results by tags in Python (for post-processing).
|
|
131
|
+
|
|
132
|
+
Used when SQL filtering isn't possible (e.g., graph traversal results).
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
results: List of RetrievalResult objects with a 'tags' attribute.
|
|
136
|
+
tags: List of tags to filter by. If None or empty, returns all results.
|
|
137
|
+
match: Matching mode. Defaults to "any".
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Filtered list of results.
|
|
141
|
+
"""
|
|
142
|
+
if not tags:
|
|
143
|
+
return results
|
|
144
|
+
|
|
145
|
+
_, include_untagged = _parse_tags_match(match)
|
|
146
|
+
is_any_match = match in ("any", "any_strict")
|
|
147
|
+
|
|
148
|
+
tags_set = set(tags)
|
|
149
|
+
filtered = []
|
|
150
|
+
|
|
151
|
+
for result in results:
|
|
152
|
+
result_tags = getattr(result, "tags", None)
|
|
153
|
+
|
|
154
|
+
# Check if untagged
|
|
155
|
+
is_untagged = result_tags is None or len(result_tags) == 0
|
|
156
|
+
|
|
157
|
+
if is_untagged:
|
|
158
|
+
if include_untagged:
|
|
159
|
+
filtered.append(result)
|
|
160
|
+
# else: skip untagged
|
|
161
|
+
else:
|
|
162
|
+
result_tags_set = set(result_tags)
|
|
163
|
+
if is_any_match:
|
|
164
|
+
# Any overlap
|
|
165
|
+
if result_tags_set & tags_set:
|
|
166
|
+
filtered.append(result)
|
|
167
|
+
else:
|
|
168
|
+
# All tags must be present
|
|
169
|
+
if tags_set <= result_tags_set:
|
|
170
|
+
filtered.append(result)
|
|
171
|
+
|
|
172
|
+
return filtered
|
|
@@ -172,7 +172,7 @@ def get_system_message(disposition: DispositionTraits) -> str:
|
|
|
172
172
|
" ".join(instructions) if instructions else "Balance your disposition traits when interpreting information."
|
|
173
173
|
)
|
|
174
174
|
|
|
175
|
-
return f"You are a person with your own thoughts, experiences, opinions, and disposition. When asked a question, respond naturally like a human would - share what you genuinely think based on what you know and have experienced. {disposition_instruction} Be direct, express your views confidently, and use 'I think', 'I believe', 'in my view', etc. Respond in plain text without markdown formatting."
|
|
175
|
+
return f"You are a person with your own thoughts, experiences, opinions, and disposition. When asked a question, respond naturally like a human would - share what you genuinely think based on what you know and have experienced. {disposition_instruction} Be direct, express your views confidently, and use 'I think', 'I believe', 'in my view', etc. Respond in plain text without markdown formatting. IMPORTANT: Detect the language of the question and respond in the SAME language. Do not translate to English if the question is in another language."
|
|
176
176
|
|
|
177
177
|
|
|
178
178
|
async def extract_opinions_from_text(llm_config, text: str, query: str) -> list[Opinion]:
|
|
@@ -11,6 +11,13 @@ from typing import Any, Literal
|
|
|
11
11
|
from pydantic import BaseModel, Field
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
class TemporalConstraint(BaseModel):
|
|
15
|
+
"""Detected temporal constraint from query analysis."""
|
|
16
|
+
|
|
17
|
+
start: datetime | None = Field(default=None, description="Start of temporal range")
|
|
18
|
+
end: datetime | None = Field(default=None, description="End of temporal range")
|
|
19
|
+
|
|
20
|
+
|
|
14
21
|
class QueryInfo(BaseModel):
|
|
15
22
|
"""Information about the search query."""
|
|
16
23
|
|
|
@@ -19,6 +26,11 @@ class QueryInfo(BaseModel):
|
|
|
19
26
|
timestamp: datetime = Field(description="When the query was executed")
|
|
20
27
|
budget: int = Field(description="Maximum nodes to explore")
|
|
21
28
|
max_tokens: int = Field(description="Maximum tokens to return in results")
|
|
29
|
+
tags: list[str] | None = Field(default=None, description="Tags filter applied to recall")
|
|
30
|
+
tags_match: str | None = Field(default=None, description="Tags matching mode: any, all, any_strict, all_strict")
|
|
31
|
+
temporal_constraint: TemporalConstraint | None = Field(
|
|
32
|
+
default=None, description="Detected temporal range from query"
|
|
33
|
+
)
|
|
22
34
|
|
|
23
35
|
|
|
24
36
|
class EntryPoint(BaseModel):
|
|
@@ -22,6 +22,7 @@ from .trace import (
|
|
|
22
22
|
SearchPhaseMetrics,
|
|
23
23
|
SearchSummary,
|
|
24
24
|
SearchTrace,
|
|
25
|
+
TemporalConstraint,
|
|
25
26
|
WeightComponents,
|
|
26
27
|
)
|
|
27
28
|
|
|
@@ -45,7 +46,14 @@ class SearchTracer:
|
|
|
45
46
|
json_output = trace.to_json()
|
|
46
47
|
"""
|
|
47
48
|
|
|
48
|
-
def __init__(
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
query: str,
|
|
52
|
+
budget: int,
|
|
53
|
+
max_tokens: int,
|
|
54
|
+
tags: list[str] | None = None,
|
|
55
|
+
tags_match: str | None = None,
|
|
56
|
+
):
|
|
49
57
|
"""
|
|
50
58
|
Initialize tracer.
|
|
51
59
|
|
|
@@ -53,10 +61,14 @@ class SearchTracer:
|
|
|
53
61
|
query: Search query text
|
|
54
62
|
budget: Maximum nodes to explore
|
|
55
63
|
max_tokens: Maximum tokens to return in results
|
|
64
|
+
tags: Tags filter applied to recall
|
|
65
|
+
tags_match: Tags matching mode (any, all, any_strict, all_strict)
|
|
56
66
|
"""
|
|
57
67
|
self.query_text = query
|
|
58
68
|
self.budget = budget
|
|
59
69
|
self.max_tokens = max_tokens
|
|
70
|
+
self.tags = tags
|
|
71
|
+
self.tags_match = tags_match
|
|
60
72
|
|
|
61
73
|
# Trace data
|
|
62
74
|
self.query_embedding: list[float] | None = None
|
|
@@ -66,6 +78,9 @@ class SearchTracer:
|
|
|
66
78
|
self.pruned: list[PruningDecision] = []
|
|
67
79
|
self.phase_metrics: list[SearchPhaseMetrics] = []
|
|
68
80
|
|
|
81
|
+
# Temporal constraint detected from query
|
|
82
|
+
self.temporal_constraint: TemporalConstraint | None = None
|
|
83
|
+
|
|
69
84
|
# New 4-way retrieval tracking
|
|
70
85
|
self.retrieval_results: list[RetrievalMethodResults] = []
|
|
71
86
|
self.rrf_merged: list[RRFMergeResult] = []
|
|
@@ -88,6 +103,11 @@ class SearchTracer:
|
|
|
88
103
|
"""Record the query embedding."""
|
|
89
104
|
self.query_embedding = embedding
|
|
90
105
|
|
|
106
|
+
def record_temporal_constraint(self, start: datetime | None, end: datetime | None):
|
|
107
|
+
"""Record the detected temporal constraint from query analysis."""
|
|
108
|
+
if start is not None or end is not None:
|
|
109
|
+
self.temporal_constraint = TemporalConstraint(start=start, end=end)
|
|
110
|
+
|
|
91
111
|
def add_entry_point(self, node_id: str, text: str, similarity: float, rank: int):
|
|
92
112
|
"""
|
|
93
113
|
Record an entry point.
|
|
@@ -428,6 +448,9 @@ class SearchTracer:
|
|
|
428
448
|
timestamp=datetime.now(UTC),
|
|
429
449
|
budget=self.budget,
|
|
430
450
|
max_tokens=self.max_tokens,
|
|
451
|
+
tags=self.tags,
|
|
452
|
+
tags_match=self.tags_match,
|
|
453
|
+
temporal_constraint=self.temporal_constraint,
|
|
431
454
|
)
|
|
432
455
|
|
|
433
456
|
# Create summary
|
|
@@ -10,6 +10,24 @@ from datetime import datetime
|
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
@dataclass
|
|
14
|
+
class MPFPTimings:
|
|
15
|
+
"""Timing breakdown for a single MPFP retrieval call."""
|
|
16
|
+
|
|
17
|
+
fact_type: str
|
|
18
|
+
edge_count: int = 0 # Total edges loaded
|
|
19
|
+
db_queries: int = 0 # Number of DB queries for edge loading
|
|
20
|
+
edge_load_time: float = 0.0 # Time spent loading edges from DB
|
|
21
|
+
traverse: float = 0.0 # Total traversal time (includes edge loading)
|
|
22
|
+
pattern_count: int = 0 # Number of patterns executed
|
|
23
|
+
fusion: float = 0.0 # Time for RRF fusion
|
|
24
|
+
fetch: float = 0.0 # Time to fetch memory unit details
|
|
25
|
+
seeds_time: float = 0.0 # Time to find semantic seeds (if fallback used)
|
|
26
|
+
result_count: int = 0 # Number of results returned
|
|
27
|
+
# Detailed per-hop timing: list of {hop, exec_time, uncached, load_time, edges_loaded, total_time}
|
|
28
|
+
hop_details: list[dict] = field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
|
|
13
31
|
@dataclass
|
|
14
32
|
class RetrievalResult:
|
|
15
33
|
"""
|
|
@@ -30,6 +48,7 @@ class RetrievalResult:
|
|
|
30
48
|
chunk_id: str | None = None
|
|
31
49
|
access_count: int = 0
|
|
32
50
|
embedding: list[float] | None = None
|
|
51
|
+
tags: list[str] | None = None # Visibility scope tags
|
|
33
52
|
|
|
34
53
|
# Retrieval-specific scores (only one will be set depending on retrieval method)
|
|
35
54
|
similarity: float | None = None # Semantic retrieval
|
|
@@ -54,6 +73,7 @@ class RetrievalResult:
|
|
|
54
73
|
chunk_id=row.get("chunk_id"),
|
|
55
74
|
access_count=row.get("access_count", 0),
|
|
56
75
|
embedding=row.get("embedding"),
|
|
76
|
+
tags=row.get("tags"),
|
|
57
77
|
similarity=row.get("similarity"),
|
|
58
78
|
bm25_score=row.get("bm25_score"),
|
|
59
79
|
activation=row.get("activation"),
|
|
@@ -138,6 +158,7 @@ class ScoredResult:
|
|
|
138
158
|
"chunk_id": self.retrieval.chunk_id,
|
|
139
159
|
"access_count": self.retrieval.access_count,
|
|
140
160
|
"embedding": self.retrieval.embedding,
|
|
161
|
+
"tags": self.retrieval.tags,
|
|
141
162
|
"semantic_similarity": self.retrieval.similarity,
|
|
142
163
|
"bm25_score": self.retrieval.bm25_score,
|
|
143
164
|
}
|
|
@@ -121,6 +121,29 @@ class SyncTaskBackend(TaskBackend):
|
|
|
121
121
|
logger.debug("SyncTaskBackend shutdown")
|
|
122
122
|
|
|
123
123
|
|
|
124
|
+
class NoopTaskBackend(TaskBackend):
|
|
125
|
+
"""
|
|
126
|
+
No-op task backend that discards all tasks.
|
|
127
|
+
|
|
128
|
+
This is useful for tests where background task execution is not needed
|
|
129
|
+
and would only slow down the test suite.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
async def initialize(self):
|
|
133
|
+
"""No-op."""
|
|
134
|
+
self._initialized = True
|
|
135
|
+
logger.debug("NoopTaskBackend initialized")
|
|
136
|
+
|
|
137
|
+
async def submit_task(self, task_dict: dict[str, Any]):
|
|
138
|
+
"""Discard the task (do nothing)."""
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
async def shutdown(self):
|
|
142
|
+
"""No-op."""
|
|
143
|
+
self._initialized = False
|
|
144
|
+
logger.debug("NoopTaskBackend shutdown")
|
|
145
|
+
|
|
146
|
+
|
|
124
147
|
class AsyncIOQueueBackend(TaskBackend):
|
|
125
148
|
"""
|
|
126
149
|
Task backend implementation using asyncio queues.
|
|
@@ -129,7 +152,7 @@ class AsyncIOQueueBackend(TaskBackend):
|
|
|
129
152
|
and a periodic consumer worker.
|
|
130
153
|
"""
|
|
131
154
|
|
|
132
|
-
def __init__(self, batch_size: int =
|
|
155
|
+
def __init__(self, batch_size: int = 10, batch_interval: float = 1.0):
|
|
133
156
|
"""
|
|
134
157
|
Initialize AsyncIO queue backend.
|
|
135
158
|
|
|
@@ -143,6 +166,8 @@ class AsyncIOQueueBackend(TaskBackend):
|
|
|
143
166
|
self._shutdown_event: asyncio.Event | None = None
|
|
144
167
|
self._batch_size = batch_size
|
|
145
168
|
self._batch_interval = batch_interval
|
|
169
|
+
self._in_flight_count = 0
|
|
170
|
+
self._in_flight_lock = asyncio.Lock()
|
|
146
171
|
|
|
147
172
|
async def initialize(self):
|
|
148
173
|
"""Initialize the queue and start the worker."""
|
|
@@ -166,33 +191,31 @@ class AsyncIOQueueBackend(TaskBackend):
|
|
|
166
191
|
await self.initialize()
|
|
167
192
|
|
|
168
193
|
await self._queue.put(task_dict)
|
|
169
|
-
task_type = task_dict.get("type", "unknown")
|
|
170
|
-
task_id = task_dict.get("id")
|
|
171
194
|
|
|
172
|
-
async def wait_for_pending_tasks(self, timeout: float =
|
|
195
|
+
async def wait_for_pending_tasks(self, timeout: float = 120.0):
|
|
173
196
|
"""
|
|
174
|
-
Wait for all pending tasks in the queue to
|
|
197
|
+
Wait for all pending tasks in the queue and in-flight tasks to complete.
|
|
175
198
|
|
|
176
199
|
This is useful in tests to ensure background tasks complete before assertions.
|
|
177
200
|
|
|
178
201
|
Args:
|
|
179
|
-
timeout: Maximum time to wait in seconds
|
|
202
|
+
timeout: Maximum time to wait in seconds (default 120s for long-running tasks)
|
|
180
203
|
"""
|
|
181
204
|
if not self._initialized or self._queue is None:
|
|
182
205
|
return
|
|
183
206
|
|
|
184
|
-
# Wait for queue to be empty
|
|
207
|
+
# Wait for queue to be empty AND no in-flight tasks
|
|
185
208
|
start_time = asyncio.get_event_loop().time()
|
|
186
209
|
while asyncio.get_event_loop().time() - start_time < timeout:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
210
|
+
async with self._in_flight_lock:
|
|
211
|
+
in_flight = self._in_flight_count
|
|
212
|
+
|
|
213
|
+
if self._queue.empty() and in_flight == 0:
|
|
214
|
+
# Queue is empty and no tasks in flight, we're done
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
# Wait a bit before checking again
|
|
218
|
+
await asyncio.sleep(0.5)
|
|
196
219
|
|
|
197
220
|
async def shutdown(self):
|
|
198
221
|
"""Shutdown the worker and drain the queue."""
|
|
@@ -215,6 +238,39 @@ class AsyncIOQueueBackend(TaskBackend):
|
|
|
215
238
|
self._initialized = False
|
|
216
239
|
logger.info("AsyncIOQueueBackend shutdown complete")
|
|
217
240
|
|
|
241
|
+
async def _execute_task_with_tracking(self, task_dict: dict[str, Any]):
|
|
242
|
+
"""Execute a task and track its in-flight status."""
|
|
243
|
+
async with self._in_flight_lock:
|
|
244
|
+
self._in_flight_count += 1
|
|
245
|
+
try:
|
|
246
|
+
await self._execute_task(task_dict)
|
|
247
|
+
finally:
|
|
248
|
+
async with self._in_flight_lock:
|
|
249
|
+
self._in_flight_count -= 1
|
|
250
|
+
|
|
251
|
+
async def _execute_task_no_tracking(self, task_dict: dict[str, Any]):
|
|
252
|
+
"""Execute a task without in-flight tracking (tracking done at batch level)."""
|
|
253
|
+
await self._execute_task(task_dict)
|
|
254
|
+
|
|
255
|
+
def _get_queue_stats(self) -> tuple[int, dict[str, int]]:
|
|
256
|
+
"""Get current queue size and bank_id distribution."""
|
|
257
|
+
queue_size = self._queue.qsize() if self._queue else 0
|
|
258
|
+
bank_distribution: dict[str, int] = {}
|
|
259
|
+
|
|
260
|
+
if queue_size > 0 and self._queue:
|
|
261
|
+
# Peek at queue items without removing them
|
|
262
|
+
# Note: This is a snapshot and may not be perfectly accurate due to concurrency
|
|
263
|
+
try:
|
|
264
|
+
# Access internal deque for logging purposes only
|
|
265
|
+
items = list(self._queue._queue) # type: ignore[attr-defined]
|
|
266
|
+
for item in items:
|
|
267
|
+
bank_id = item.get("bank_id", "unknown")
|
|
268
|
+
bank_distribution[bank_id] = bank_distribution.get(bank_id, 0) + 1
|
|
269
|
+
except Exception:
|
|
270
|
+
pass # Queue access failed, return empty distribution
|
|
271
|
+
|
|
272
|
+
return queue_size, bank_distribution
|
|
273
|
+
|
|
218
274
|
async def _worker(self):
|
|
219
275
|
"""
|
|
220
276
|
Background worker that processes tasks in batches.
|
|
@@ -232,17 +288,52 @@ class AsyncIOQueueBackend(TaskBackend):
|
|
|
232
288
|
try:
|
|
233
289
|
remaining_time = max(0.1, deadline - asyncio.get_event_loop().time())
|
|
234
290
|
task_dict = await asyncio.wait_for(self._queue.get(), timeout=remaining_time)
|
|
291
|
+
# Track task as in-flight immediately when picked up from queue
|
|
292
|
+
# This prevents wait_for_pending_tasks from returning too early
|
|
293
|
+
async with self._in_flight_lock:
|
|
294
|
+
self._in_flight_count += 1
|
|
235
295
|
tasks.append(task_dict)
|
|
236
296
|
except TimeoutError:
|
|
237
297
|
break
|
|
238
298
|
|
|
239
299
|
# Process batch
|
|
240
300
|
if tasks:
|
|
241
|
-
#
|
|
301
|
+
# Log batch start with queue stats
|
|
302
|
+
queue_size, bank_distribution = self._get_queue_stats()
|
|
303
|
+
|
|
304
|
+
# Summarize batch by task type and bank
|
|
305
|
+
batch_summary: dict[str, dict[str, int]] = {}
|
|
306
|
+
for task_dict in tasks:
|
|
307
|
+
task_type = task_dict.get("type", "unknown")
|
|
308
|
+
bank_id = task_dict.get("bank_id", "unknown")
|
|
309
|
+
if task_type not in batch_summary:
|
|
310
|
+
batch_summary[task_type] = {}
|
|
311
|
+
batch_summary[task_type][bank_id] = batch_summary[task_type].get(bank_id, 0) + 1
|
|
312
|
+
|
|
313
|
+
# Build log message
|
|
314
|
+
batch_parts = []
|
|
315
|
+
for task_type, banks in sorted(batch_summary.items()):
|
|
316
|
+
bank_str = ", ".join(f"{b}:{c}" for b, c in sorted(banks.items()))
|
|
317
|
+
batch_parts.append(f"{task_type}[{bank_str}]")
|
|
318
|
+
batch_str = ", ".join(batch_parts)
|
|
319
|
+
|
|
320
|
+
if queue_size > 0:
|
|
321
|
+
pending_str = ", ".join(f"{k}:{v}" for k, v in sorted(bank_distribution.items()))
|
|
322
|
+
logger.info(
|
|
323
|
+
f"Processing {len(tasks)} tasks: {batch_str} (pending={queue_size} [{pending_str}])"
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
logger.info(f"Processing {len(tasks)} tasks: {batch_str}")
|
|
327
|
+
|
|
328
|
+
# Execute tasks concurrently (in_flight already tracked when picked up)
|
|
242
329
|
await asyncio.gather(
|
|
243
|
-
*[self.
|
|
330
|
+
*[self._execute_task_no_tracking(task_dict) for task_dict in tasks], return_exceptions=True
|
|
244
331
|
)
|
|
245
332
|
|
|
333
|
+
# Decrement in_flight count after all tasks complete
|
|
334
|
+
async with self._in_flight_lock:
|
|
335
|
+
self._in_flight_count -= len(tasks)
|
|
336
|
+
|
|
246
337
|
except asyncio.CancelledError:
|
|
247
338
|
break
|
|
248
339
|
except Exception as e:
|
hindsight_api/engine/utils.py
CHANGED
|
@@ -96,7 +96,7 @@ class DefaultExtensionContext(ExtensionContext):
|
|
|
96
96
|
|
|
97
97
|
async def run_migration(self, schema: str) -> None:
|
|
98
98
|
"""Run migrations for a specific schema."""
|
|
99
|
-
from hindsight_api.migrations import run_migrations
|
|
99
|
+
from hindsight_api.migrations import ensure_embedding_dimension, run_migrations
|
|
100
100
|
|
|
101
101
|
# Prefer getting URL from memory engine (handles pg0 case where URL is set after init)
|
|
102
102
|
db_url = self._database_url
|
|
@@ -107,6 +107,15 @@ class DefaultExtensionContext(ExtensionContext):
|
|
|
107
107
|
|
|
108
108
|
run_migrations(db_url, schema=schema)
|
|
109
109
|
|
|
110
|
+
# Ensure embedding column dimension matches the model's dimension
|
|
111
|
+
# This is needed because migrations create columns with default dimension
|
|
112
|
+
if self._memory_engine is not None:
|
|
113
|
+
embeddings = getattr(self._memory_engine, "embeddings", None)
|
|
114
|
+
if embeddings is not None:
|
|
115
|
+
dimension = getattr(embeddings, "dimension", None)
|
|
116
|
+
if dimension is not None:
|
|
117
|
+
ensure_embedding_dimension(db_url, dimension, schema=schema)
|
|
118
|
+
|
|
110
119
|
def get_memory_engine(self) -> "MemoryEngineInterface":
|
|
111
120
|
"""Get the memory engine interface."""
|
|
112
121
|
if self._memory_engine is None:
|
hindsight_api/main.py
CHANGED
|
@@ -23,7 +23,7 @@ import uvicorn
|
|
|
23
23
|
from . import MemoryEngine
|
|
24
24
|
from .api import create_app
|
|
25
25
|
from .banner import print_banner
|
|
26
|
-
from .config import HindsightConfig, get_config
|
|
26
|
+
from .config import DEFAULT_WORKERS, ENV_WORKERS, HindsightConfig, get_config
|
|
27
27
|
from .daemon import (
|
|
28
28
|
DEFAULT_DAEMON_PORT,
|
|
29
29
|
DEFAULT_IDLE_TIMEOUT,
|
|
@@ -95,7 +95,12 @@ def main():
|
|
|
95
95
|
|
|
96
96
|
# Development options
|
|
97
97
|
parser.add_argument("--reload", action="store_true", help="Enable auto-reload on code changes (development only)")
|
|
98
|
-
parser.add_argument(
|
|
98
|
+
parser.add_argument(
|
|
99
|
+
"--workers",
|
|
100
|
+
type=int,
|
|
101
|
+
default=int(os.getenv(ENV_WORKERS, str(DEFAULT_WORKERS))),
|
|
102
|
+
help=f"Number of worker processes (env: {ENV_WORKERS}, default: {DEFAULT_WORKERS})",
|
|
103
|
+
)
|
|
99
104
|
|
|
100
105
|
# Access log options
|
|
101
106
|
parser.add_argument("--access-log", action="store_true", help="Enable access log")
|
|
@@ -171,21 +176,51 @@ def main():
|
|
|
171
176
|
llm_base_url=config.llm_base_url,
|
|
172
177
|
llm_max_concurrent=config.llm_max_concurrent,
|
|
173
178
|
llm_timeout=config.llm_timeout,
|
|
179
|
+
retain_llm_provider=config.retain_llm_provider,
|
|
180
|
+
retain_llm_api_key=config.retain_llm_api_key,
|
|
181
|
+
retain_llm_model=config.retain_llm_model,
|
|
182
|
+
retain_llm_base_url=config.retain_llm_base_url,
|
|
183
|
+
reflect_llm_provider=config.reflect_llm_provider,
|
|
184
|
+
reflect_llm_api_key=config.reflect_llm_api_key,
|
|
185
|
+
reflect_llm_model=config.reflect_llm_model,
|
|
186
|
+
reflect_llm_base_url=config.reflect_llm_base_url,
|
|
174
187
|
embeddings_provider=config.embeddings_provider,
|
|
175
188
|
embeddings_local_model=config.embeddings_local_model,
|
|
176
189
|
embeddings_tei_url=config.embeddings_tei_url,
|
|
190
|
+
embeddings_openai_base_url=config.embeddings_openai_base_url,
|
|
191
|
+
embeddings_cohere_base_url=config.embeddings_cohere_base_url,
|
|
177
192
|
reranker_provider=config.reranker_provider,
|
|
178
193
|
reranker_local_model=config.reranker_local_model,
|
|
179
194
|
reranker_tei_url=config.reranker_tei_url,
|
|
195
|
+
reranker_tei_batch_size=config.reranker_tei_batch_size,
|
|
196
|
+
reranker_tei_max_concurrent=config.reranker_tei_max_concurrent,
|
|
197
|
+
reranker_max_candidates=config.reranker_max_candidates,
|
|
198
|
+
reranker_cohere_base_url=config.reranker_cohere_base_url,
|
|
180
199
|
host=args.host,
|
|
181
200
|
port=args.port,
|
|
182
201
|
log_level=args.log_level,
|
|
183
202
|
mcp_enabled=config.mcp_enabled,
|
|
184
203
|
graph_retriever=config.graph_retriever,
|
|
204
|
+
mpfp_top_k_neighbors=config.mpfp_top_k_neighbors,
|
|
205
|
+
recall_max_concurrent=config.recall_max_concurrent,
|
|
206
|
+
recall_connection_budget=config.recall_connection_budget,
|
|
185
207
|
observation_min_facts=config.observation_min_facts,
|
|
186
208
|
observation_top_entities=config.observation_top_entities,
|
|
209
|
+
retain_max_completion_tokens=config.retain_max_completion_tokens,
|
|
210
|
+
retain_chunk_size=config.retain_chunk_size,
|
|
211
|
+
retain_extract_causal_links=config.retain_extract_causal_links,
|
|
212
|
+
retain_extraction_mode=config.retain_extraction_mode,
|
|
213
|
+
retain_observations_async=config.retain_observations_async,
|
|
187
214
|
skip_llm_verification=config.skip_llm_verification,
|
|
188
215
|
lazy_reranker=config.lazy_reranker,
|
|
216
|
+
run_migrations_on_startup=config.run_migrations_on_startup,
|
|
217
|
+
db_pool_min_size=config.db_pool_min_size,
|
|
218
|
+
db_pool_max_size=config.db_pool_max_size,
|
|
219
|
+
db_command_timeout=config.db_command_timeout,
|
|
220
|
+
db_acquire_timeout=config.db_acquire_timeout,
|
|
221
|
+
task_backend=config.task_backend,
|
|
222
|
+
task_backend_memory_batch_size=config.task_backend_memory_batch_size,
|
|
223
|
+
task_backend_memory_batch_interval=config.task_backend_memory_batch_interval,
|
|
189
224
|
)
|
|
190
225
|
config.configure_logging()
|
|
191
226
|
if not args.daemon:
|
|
@@ -211,7 +246,11 @@ def main():
|
|
|
211
246
|
logging.info(f"Loaded tenant extension: {tenant_extension.__class__.__name__}")
|
|
212
247
|
|
|
213
248
|
# Create MemoryEngine (reads configuration from environment)
|
|
214
|
-
_memory = MemoryEngine(
|
|
249
|
+
_memory = MemoryEngine(
|
|
250
|
+
operation_validator=operation_validator,
|
|
251
|
+
tenant_extension=tenant_extension,
|
|
252
|
+
run_migrations=config.run_migrations_on_startup,
|
|
253
|
+
)
|
|
215
254
|
|
|
216
255
|
# Set extension context on tenant extension (needed for schema provisioning)
|
|
217
256
|
if tenant_extension:
|
|
@@ -238,14 +277,27 @@ def main():
|
|
|
238
277
|
app = idle_middleware
|
|
239
278
|
|
|
240
279
|
# Prepare uvicorn config
|
|
280
|
+
# When using workers or reload, we must use import string so each worker can import the app
|
|
281
|
+
use_import_string = args.workers > 1 or args.reload
|
|
282
|
+
# Check for uvloop availability
|
|
283
|
+
try:
|
|
284
|
+
import uvloop # noqa: F401
|
|
285
|
+
|
|
286
|
+
loop_impl = "uvloop"
|
|
287
|
+
print("uvloop available, will use for event loop")
|
|
288
|
+
except ImportError:
|
|
289
|
+
loop_impl = "asyncio"
|
|
290
|
+
print("uvloop not installed, using default asyncio event loop")
|
|
291
|
+
|
|
241
292
|
uvicorn_config = {
|
|
242
|
-
"app": app,
|
|
293
|
+
"app": "hindsight_api.server:app" if use_import_string else app,
|
|
243
294
|
"host": args.host,
|
|
244
295
|
"port": args.port,
|
|
245
296
|
"log_level": args.log_level,
|
|
246
297
|
"access_log": args.access_log,
|
|
247
298
|
"proxy_headers": args.proxy_headers,
|
|
248
299
|
"ws": "wsproto", # Use wsproto instead of websockets to avoid deprecation warnings
|
|
300
|
+
"loop": loop_impl, # Explicitly set event loop implementation
|
|
249
301
|
}
|
|
250
302
|
|
|
251
303
|
# Add optional parameters if provided
|