hindsight-api 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +252 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/api/http.py +282 -20
  6. hindsight_api/api/mcp.py +47 -52
  7. hindsight_api/config.py +238 -6
  8. hindsight_api/engine/cross_encoder.py +599 -86
  9. hindsight_api/engine/db_budget.py +284 -0
  10. hindsight_api/engine/db_utils.py +11 -0
  11. hindsight_api/engine/embeddings.py +453 -26
  12. hindsight_api/engine/entity_resolver.py +8 -5
  13. hindsight_api/engine/interface.py +8 -4
  14. hindsight_api/engine/llm_wrapper.py +241 -27
  15. hindsight_api/engine/memory_engine.py +609 -122
  16. hindsight_api/engine/query_analyzer.py +4 -3
  17. hindsight_api/engine/response_models.py +38 -0
  18. hindsight_api/engine/retain/fact_extraction.py +388 -192
  19. hindsight_api/engine/retain/fact_storage.py +34 -8
  20. hindsight_api/engine/retain/link_utils.py +24 -16
  21. hindsight_api/engine/retain/orchestrator.py +52 -17
  22. hindsight_api/engine/retain/types.py +9 -0
  23. hindsight_api/engine/search/graph_retrieval.py +42 -13
  24. hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
  25. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  26. hindsight_api/engine/search/reranking.py +2 -2
  27. hindsight_api/engine/search/retrieval.py +847 -200
  28. hindsight_api/engine/search/tags.py +172 -0
  29. hindsight_api/engine/search/think_utils.py +1 -1
  30. hindsight_api/engine/search/trace.py +12 -0
  31. hindsight_api/engine/search/tracer.py +24 -1
  32. hindsight_api/engine/search/types.py +21 -0
  33. hindsight_api/engine/task_backend.py +109 -18
  34. hindsight_api/engine/utils.py +1 -1
  35. hindsight_api/extensions/context.py +10 -1
  36. hindsight_api/main.py +56 -4
  37. hindsight_api/metrics.py +433 -48
  38. hindsight_api/migrations.py +141 -1
  39. hindsight_api/models.py +3 -1
  40. hindsight_api/pg0.py +53 -0
  41. hindsight_api/server.py +39 -2
  42. {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
  43. hindsight_api-0.3.0.dist-info/RECORD +82 -0
  44. {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
  45. hindsight_api-0.2.1.dist-info/RECORD +0 -75
  46. {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,172 @@
1
+ """
2
+ Tags filtering utilities for retrieval.
3
+
4
+ Provides SQL building functions for filtering memories by tags.
5
+ Supports four matching modes via TagsMatch enum:
6
+ - "any": OR matching, includes untagged memories (default, backward compatible)
7
+ - "all": AND matching, includes untagged memories
8
+ - "any_strict": OR matching, excludes untagged memories
9
+ - "all_strict": AND matching, excludes untagged memories
10
+
11
+ OR matching (any/any_strict): Memory matches if ANY of its tags overlap with request tags
12
+ AND matching (all/all_strict): Memory matches if ALL request tags are present in its tags
13
+ """
14
+
15
+ from typing import Literal
16
+
17
+ TagsMatch = Literal["any", "all", "any_strict", "all_strict"]
18
+
19
+
20
+ def _parse_tags_match(match: TagsMatch) -> tuple[str, bool]:
21
+ """
22
+ Parse TagsMatch into operator and include_untagged flag.
23
+
24
+ Returns:
25
+ Tuple of (operator, include_untagged)
26
+ - operator: "&&" for any/any_strict, "@>" for all/all_strict
27
+ - include_untagged: True for any/all, False for any_strict/all_strict
28
+ """
29
+ if match == "any":
30
+ return "&&", True
31
+ elif match == "all":
32
+ return "@>", True
33
+ elif match == "any_strict":
34
+ return "&&", False
35
+ elif match == "all_strict":
36
+ return "@>", False
37
+ else:
38
+ # Default to "any" behavior
39
+ return "&&", True
40
+
41
+
42
+ def build_tags_where_clause(
43
+ tags: list[str] | None,
44
+ param_offset: int = 1,
45
+ table_alias: str = "",
46
+ match: TagsMatch = "any",
47
+ ) -> tuple[str, list, int]:
48
+ """
49
+ Build a SQL WHERE clause for filtering by tags.
50
+
51
+ Supports four matching modes:
52
+ - "any" (default): OR matching, includes untagged memories
53
+ - "all": AND matching, includes untagged memories
54
+ - "any_strict": OR matching, excludes untagged memories
55
+ - "all_strict": AND matching, excludes untagged memories
56
+
57
+ Args:
58
+ tags: List of tags to filter by. If None or empty, returns empty clause (no filtering).
59
+ param_offset: Starting parameter number for SQL placeholders (default 1).
60
+ table_alias: Optional table alias prefix (e.g., "mu." for "memory_units mu").
61
+ match: Matching mode. Defaults to "any".
62
+
63
+ Returns:
64
+ Tuple of (sql_clause, params, next_param_offset):
65
+ - sql_clause: SQL WHERE clause string
66
+ - params: List of parameter values to bind
67
+ - next_param_offset: Next available parameter number
68
+
69
+ Example:
70
+ >>> clause, params, next_offset = build_tags_where_clause(['user_a'], 3, 'mu.', 'any_strict')
71
+ >>> print(clause) # "AND mu.tags IS NOT NULL AND mu.tags != '{}' AND mu.tags && $3"
72
+ """
73
+ if not tags:
74
+ return "", [], param_offset
75
+
76
+ column = f"{table_alias}tags" if table_alias else "tags"
77
+ operator, include_untagged = _parse_tags_match(match)
78
+
79
+ if include_untagged:
80
+ # Include untagged memories (NULL or empty array) OR matching tags
81
+ clause = f"AND ({column} IS NULL OR {column} = '{{}}' OR {column} {operator} ${param_offset})"
82
+ else:
83
+ # Strict: only memories with matching tags (exclude NULL and empty)
84
+ clause = f"AND {column} IS NOT NULL AND {column} != '{{}}' AND {column} {operator} ${param_offset}"
85
+
86
+ return clause, [tags], param_offset + 1
87
+
88
+
89
+ def build_tags_where_clause_simple(
90
+ tags: list[str] | None,
91
+ param_num: int,
92
+ table_alias: str = "",
93
+ match: TagsMatch = "any",
94
+ ) -> str:
95
+ """
96
+ Build a simple SQL WHERE clause for tags filtering.
97
+
98
+ This is a convenience version that returns just the clause string,
99
+ assuming the caller will add the tags array to their params list.
100
+
101
+ Args:
102
+ tags: List of tags to filter by. If None or empty, returns empty string.
103
+ param_num: Parameter number to use in the clause.
104
+ table_alias: Optional table alias prefix.
105
+ match: Matching mode. Defaults to "any".
106
+
107
+ Returns:
108
+ SQL clause string or empty string.
109
+ """
110
+ if not tags:
111
+ return ""
112
+
113
+ column = f"{table_alias}tags" if table_alias else "tags"
114
+ operator, include_untagged = _parse_tags_match(match)
115
+
116
+ if include_untagged:
117
+ # Include untagged memories (NULL or empty array) OR matching tags
118
+ return f"AND ({column} IS NULL OR {column} = '{{}}' OR {column} {operator} ${param_num})"
119
+ else:
120
+ # Strict: only memories with matching tags (exclude NULL and empty)
121
+ return f"AND {column} IS NOT NULL AND {column} != '{{}}' AND {column} {operator} ${param_num}"
122
+
123
+
124
+ def filter_results_by_tags(
125
+ results: list,
126
+ tags: list[str] | None,
127
+ match: TagsMatch = "any",
128
+ ) -> list:
129
+ """
130
+ Filter retrieval results by tags in Python (for post-processing).
131
+
132
+ Used when SQL filtering isn't possible (e.g., graph traversal results).
133
+
134
+ Args:
135
+ results: List of RetrievalResult objects with a 'tags' attribute.
136
+ tags: List of tags to filter by. If None or empty, returns all results.
137
+ match: Matching mode. Defaults to "any".
138
+
139
+ Returns:
140
+ Filtered list of results.
141
+ """
142
+ if not tags:
143
+ return results
144
+
145
+ _, include_untagged = _parse_tags_match(match)
146
+ is_any_match = match in ("any", "any_strict")
147
+
148
+ tags_set = set(tags)
149
+ filtered = []
150
+
151
+ for result in results:
152
+ result_tags = getattr(result, "tags", None)
153
+
154
+ # Check if untagged
155
+ is_untagged = result_tags is None or len(result_tags) == 0
156
+
157
+ if is_untagged:
158
+ if include_untagged:
159
+ filtered.append(result)
160
+ # else: skip untagged
161
+ else:
162
+ result_tags_set = set(result_tags)
163
+ if is_any_match:
164
+ # Any overlap
165
+ if result_tags_set & tags_set:
166
+ filtered.append(result)
167
+ else:
168
+ # All tags must be present
169
+ if tags_set <= result_tags_set:
170
+ filtered.append(result)
171
+
172
+ return filtered
@@ -172,7 +172,7 @@ def get_system_message(disposition: DispositionTraits) -> str:
172
172
  " ".join(instructions) if instructions else "Balance your disposition traits when interpreting information."
173
173
  )
174
174
 
175
- return f"You are a person with your own thoughts, experiences, opinions, and disposition. When asked a question, respond naturally like a human would - share what you genuinely think based on what you know and have experienced. {disposition_instruction} Be direct, express your views confidently, and use 'I think', 'I believe', 'in my view', etc. Respond in plain text without markdown formatting."
175
+ return f"You are a person with your own thoughts, experiences, opinions, and disposition. When asked a question, respond naturally like a human would - share what you genuinely think based on what you know and have experienced. {disposition_instruction} Be direct, express your views confidently, and use 'I think', 'I believe', 'in my view', etc. Respond in plain text without markdown formatting. IMPORTANT: Detect the language of the question and respond in the SAME language. Do not translate to English if the question is in another language."
176
176
 
177
177
 
178
178
  async def extract_opinions_from_text(llm_config, text: str, query: str) -> list[Opinion]:
@@ -11,6 +11,13 @@ from typing import Any, Literal
11
11
  from pydantic import BaseModel, Field
12
12
 
13
13
 
14
+ class TemporalConstraint(BaseModel):
15
+ """Detected temporal constraint from query analysis."""
16
+
17
+ start: datetime | None = Field(default=None, description="Start of temporal range")
18
+ end: datetime | None = Field(default=None, description="End of temporal range")
19
+
20
+
14
21
  class QueryInfo(BaseModel):
15
22
  """Information about the search query."""
16
23
 
@@ -19,6 +26,11 @@ class QueryInfo(BaseModel):
19
26
  timestamp: datetime = Field(description="When the query was executed")
20
27
  budget: int = Field(description="Maximum nodes to explore")
21
28
  max_tokens: int = Field(description="Maximum tokens to return in results")
29
+ tags: list[str] | None = Field(default=None, description="Tags filter applied to recall")
30
+ tags_match: str | None = Field(default=None, description="Tags matching mode: any, all, any_strict, all_strict")
31
+ temporal_constraint: TemporalConstraint | None = Field(
32
+ default=None, description="Detected temporal range from query"
33
+ )
22
34
 
23
35
 
24
36
  class EntryPoint(BaseModel):
@@ -22,6 +22,7 @@ from .trace import (
22
22
  SearchPhaseMetrics,
23
23
  SearchSummary,
24
24
  SearchTrace,
25
+ TemporalConstraint,
25
26
  WeightComponents,
26
27
  )
27
28
 
@@ -45,7 +46,14 @@ class SearchTracer:
45
46
  json_output = trace.to_json()
46
47
  """
47
48
 
48
- def __init__(self, query: str, budget: int, max_tokens: int):
49
+ def __init__(
50
+ self,
51
+ query: str,
52
+ budget: int,
53
+ max_tokens: int,
54
+ tags: list[str] | None = None,
55
+ tags_match: str | None = None,
56
+ ):
49
57
  """
50
58
  Initialize tracer.
51
59
 
@@ -53,10 +61,14 @@ class SearchTracer:
53
61
  query: Search query text
54
62
  budget: Maximum nodes to explore
55
63
  max_tokens: Maximum tokens to return in results
64
+ tags: Tags filter applied to recall
65
+ tags_match: Tags matching mode (any, all, any_strict, all_strict)
56
66
  """
57
67
  self.query_text = query
58
68
  self.budget = budget
59
69
  self.max_tokens = max_tokens
70
+ self.tags = tags
71
+ self.tags_match = tags_match
60
72
 
61
73
  # Trace data
62
74
  self.query_embedding: list[float] | None = None
@@ -66,6 +78,9 @@ class SearchTracer:
66
78
  self.pruned: list[PruningDecision] = []
67
79
  self.phase_metrics: list[SearchPhaseMetrics] = []
68
80
 
81
+ # Temporal constraint detected from query
82
+ self.temporal_constraint: TemporalConstraint | None = None
83
+
69
84
  # New 4-way retrieval tracking
70
85
  self.retrieval_results: list[RetrievalMethodResults] = []
71
86
  self.rrf_merged: list[RRFMergeResult] = []
@@ -88,6 +103,11 @@ class SearchTracer:
88
103
  """Record the query embedding."""
89
104
  self.query_embedding = embedding
90
105
 
106
+ def record_temporal_constraint(self, start: datetime | None, end: datetime | None):
107
+ """Record the detected temporal constraint from query analysis."""
108
+ if start is not None or end is not None:
109
+ self.temporal_constraint = TemporalConstraint(start=start, end=end)
110
+
91
111
  def add_entry_point(self, node_id: str, text: str, similarity: float, rank: int):
92
112
  """
93
113
  Record an entry point.
@@ -428,6 +448,9 @@ class SearchTracer:
428
448
  timestamp=datetime.now(UTC),
429
449
  budget=self.budget,
430
450
  max_tokens=self.max_tokens,
451
+ tags=self.tags,
452
+ tags_match=self.tags_match,
453
+ temporal_constraint=self.temporal_constraint,
431
454
  )
432
455
 
433
456
  # Create summary
@@ -10,6 +10,24 @@ from datetime import datetime
10
10
  from typing import Any
11
11
 
12
12
 
13
+ @dataclass
14
+ class MPFPTimings:
15
+ """Timing breakdown for a single MPFP retrieval call."""
16
+
17
+ fact_type: str
18
+ edge_count: int = 0 # Total edges loaded
19
+ db_queries: int = 0 # Number of DB queries for edge loading
20
+ edge_load_time: float = 0.0 # Time spent loading edges from DB
21
+ traverse: float = 0.0 # Total traversal time (includes edge loading)
22
+ pattern_count: int = 0 # Number of patterns executed
23
+ fusion: float = 0.0 # Time for RRF fusion
24
+ fetch: float = 0.0 # Time to fetch memory unit details
25
+ seeds_time: float = 0.0 # Time to find semantic seeds (if fallback used)
26
+ result_count: int = 0 # Number of results returned
27
+ # Detailed per-hop timing: list of {hop, exec_time, uncached, load_time, edges_loaded, total_time}
28
+ hop_details: list[dict] = field(default_factory=list)
29
+
30
+
13
31
  @dataclass
14
32
  class RetrievalResult:
15
33
  """
@@ -30,6 +48,7 @@ class RetrievalResult:
30
48
  chunk_id: str | None = None
31
49
  access_count: int = 0
32
50
  embedding: list[float] | None = None
51
+ tags: list[str] | None = None # Visibility scope tags
33
52
 
34
53
  # Retrieval-specific scores (only one will be set depending on retrieval method)
35
54
  similarity: float | None = None # Semantic retrieval
@@ -54,6 +73,7 @@ class RetrievalResult:
54
73
  chunk_id=row.get("chunk_id"),
55
74
  access_count=row.get("access_count", 0),
56
75
  embedding=row.get("embedding"),
76
+ tags=row.get("tags"),
57
77
  similarity=row.get("similarity"),
58
78
  bm25_score=row.get("bm25_score"),
59
79
  activation=row.get("activation"),
@@ -138,6 +158,7 @@ class ScoredResult:
138
158
  "chunk_id": self.retrieval.chunk_id,
139
159
  "access_count": self.retrieval.access_count,
140
160
  "embedding": self.retrieval.embedding,
161
+ "tags": self.retrieval.tags,
141
162
  "semantic_similarity": self.retrieval.similarity,
142
163
  "bm25_score": self.retrieval.bm25_score,
143
164
  }
@@ -121,6 +121,29 @@ class SyncTaskBackend(TaskBackend):
121
121
  logger.debug("SyncTaskBackend shutdown")
122
122
 
123
123
 
124
+ class NoopTaskBackend(TaskBackend):
125
+ """
126
+ No-op task backend that discards all tasks.
127
+
128
+ This is useful for tests where background task execution is not needed
129
+ and would only slow down the test suite.
130
+ """
131
+
132
+ async def initialize(self):
133
+ """No-op."""
134
+ self._initialized = True
135
+ logger.debug("NoopTaskBackend initialized")
136
+
137
+ async def submit_task(self, task_dict: dict[str, Any]):
138
+ """Discard the task (do nothing)."""
139
+ pass
140
+
141
+ async def shutdown(self):
142
+ """No-op."""
143
+ self._initialized = False
144
+ logger.debug("NoopTaskBackend shutdown")
145
+
146
+
124
147
  class AsyncIOQueueBackend(TaskBackend):
125
148
  """
126
149
  Task backend implementation using asyncio queues.
@@ -129,7 +152,7 @@ class AsyncIOQueueBackend(TaskBackend):
129
152
  and a periodic consumer worker.
130
153
  """
131
154
 
132
- def __init__(self, batch_size: int = 100, batch_interval: float = 1.0):
155
+ def __init__(self, batch_size: int = 10, batch_interval: float = 1.0):
133
156
  """
134
157
  Initialize AsyncIO queue backend.
135
158
 
@@ -143,6 +166,8 @@ class AsyncIOQueueBackend(TaskBackend):
143
166
  self._shutdown_event: asyncio.Event | None = None
144
167
  self._batch_size = batch_size
145
168
  self._batch_interval = batch_interval
169
+ self._in_flight_count = 0
170
+ self._in_flight_lock = asyncio.Lock()
146
171
 
147
172
  async def initialize(self):
148
173
  """Initialize the queue and start the worker."""
@@ -166,33 +191,31 @@ class AsyncIOQueueBackend(TaskBackend):
166
191
  await self.initialize()
167
192
 
168
193
  await self._queue.put(task_dict)
169
- task_type = task_dict.get("type", "unknown")
170
- task_id = task_dict.get("id")
171
194
 
172
- async def wait_for_pending_tasks(self, timeout: float = 5.0):
195
+ async def wait_for_pending_tasks(self, timeout: float = 120.0):
173
196
  """
174
- Wait for all pending tasks in the queue to be processed.
197
+ Wait for all pending tasks in the queue and in-flight tasks to complete.
175
198
 
176
199
  This is useful in tests to ensure background tasks complete before assertions.
177
200
 
178
201
  Args:
179
- timeout: Maximum time to wait in seconds
202
+ timeout: Maximum time to wait in seconds (default 120s for long-running tasks)
180
203
  """
181
204
  if not self._initialized or self._queue is None:
182
205
  return
183
206
 
184
- # Wait for queue to be empty and give worker time to process
207
+ # Wait for queue to be empty AND no in-flight tasks
185
208
  start_time = asyncio.get_event_loop().time()
186
209
  while asyncio.get_event_loop().time() - start_time < timeout:
187
- if self._queue.empty():
188
- # Queue is empty, give worker a bit more time to finish any in-flight task
189
- await asyncio.sleep(0.3)
190
- # Check again - if still empty, we're done
191
- if self._queue.empty():
192
- return
193
- else:
194
- # Queue not empty, wait a bit
195
- await asyncio.sleep(0.1)
210
+ async with self._in_flight_lock:
211
+ in_flight = self._in_flight_count
212
+
213
+ if self._queue.empty() and in_flight == 0:
214
+ # Queue is empty and no tasks in flight, we're done
215
+ return
216
+
217
+ # Wait a bit before checking again
218
+ await asyncio.sleep(0.5)
196
219
 
197
220
  async def shutdown(self):
198
221
  """Shutdown the worker and drain the queue."""
@@ -215,6 +238,39 @@ class AsyncIOQueueBackend(TaskBackend):
215
238
  self._initialized = False
216
239
  logger.info("AsyncIOQueueBackend shutdown complete")
217
240
 
241
+ async def _execute_task_with_tracking(self, task_dict: dict[str, Any]):
242
+ """Execute a task and track its in-flight status."""
243
+ async with self._in_flight_lock:
244
+ self._in_flight_count += 1
245
+ try:
246
+ await self._execute_task(task_dict)
247
+ finally:
248
+ async with self._in_flight_lock:
249
+ self._in_flight_count -= 1
250
+
251
+ async def _execute_task_no_tracking(self, task_dict: dict[str, Any]):
252
+ """Execute a task without in-flight tracking (tracking done at batch level)."""
253
+ await self._execute_task(task_dict)
254
+
255
+ def _get_queue_stats(self) -> tuple[int, dict[str, int]]:
256
+ """Get current queue size and bank_id distribution."""
257
+ queue_size = self._queue.qsize() if self._queue else 0
258
+ bank_distribution: dict[str, int] = {}
259
+
260
+ if queue_size > 0 and self._queue:
261
+ # Peek at queue items without removing them
262
+ # Note: This is a snapshot and may not be perfectly accurate due to concurrency
263
+ try:
264
+ # Access internal deque for logging purposes only
265
+ items = list(self._queue._queue) # type: ignore[attr-defined]
266
+ for item in items:
267
+ bank_id = item.get("bank_id", "unknown")
268
+ bank_distribution[bank_id] = bank_distribution.get(bank_id, 0) + 1
269
+ except Exception:
270
+ pass # Queue access failed, return empty distribution
271
+
272
+ return queue_size, bank_distribution
273
+
218
274
  async def _worker(self):
219
275
  """
220
276
  Background worker that processes tasks in batches.
@@ -232,17 +288,52 @@ class AsyncIOQueueBackend(TaskBackend):
232
288
  try:
233
289
  remaining_time = max(0.1, deadline - asyncio.get_event_loop().time())
234
290
  task_dict = await asyncio.wait_for(self._queue.get(), timeout=remaining_time)
291
+ # Track task as in-flight immediately when picked up from queue
292
+ # This prevents wait_for_pending_tasks from returning too early
293
+ async with self._in_flight_lock:
294
+ self._in_flight_count += 1
235
295
  tasks.append(task_dict)
236
296
  except TimeoutError:
237
297
  break
238
298
 
239
299
  # Process batch
240
300
  if tasks:
241
- # Execute tasks concurrently
301
+ # Log batch start with queue stats
302
+ queue_size, bank_distribution = self._get_queue_stats()
303
+
304
+ # Summarize batch by task type and bank
305
+ batch_summary: dict[str, dict[str, int]] = {}
306
+ for task_dict in tasks:
307
+ task_type = task_dict.get("type", "unknown")
308
+ bank_id = task_dict.get("bank_id", "unknown")
309
+ if task_type not in batch_summary:
310
+ batch_summary[task_type] = {}
311
+ batch_summary[task_type][bank_id] = batch_summary[task_type].get(bank_id, 0) + 1
312
+
313
+ # Build log message
314
+ batch_parts = []
315
+ for task_type, banks in sorted(batch_summary.items()):
316
+ bank_str = ", ".join(f"{b}:{c}" for b, c in sorted(banks.items()))
317
+ batch_parts.append(f"{task_type}[{bank_str}]")
318
+ batch_str = ", ".join(batch_parts)
319
+
320
+ if queue_size > 0:
321
+ pending_str = ", ".join(f"{k}:{v}" for k, v in sorted(bank_distribution.items()))
322
+ logger.info(
323
+ f"Processing {len(tasks)} tasks: {batch_str} (pending={queue_size} [{pending_str}])"
324
+ )
325
+ else:
326
+ logger.info(f"Processing {len(tasks)} tasks: {batch_str}")
327
+
328
+ # Execute tasks concurrently (in_flight already tracked when picked up)
242
329
  await asyncio.gather(
243
- *[self._execute_task(task_dict) for task_dict in tasks], return_exceptions=True
330
+ *[self._execute_task_no_tracking(task_dict) for task_dict in tasks], return_exceptions=True
244
331
  )
245
332
 
333
+ # Decrement in_flight count after all tasks complete
334
+ async with self._in_flight_lock:
335
+ self._in_flight_count -= len(tasks)
336
+
246
337
  except asyncio.CancelledError:
247
338
  break
248
339
  except Exception as e:
@@ -49,7 +49,7 @@ async def extract_facts(
49
49
  if not text or not text.strip():
50
50
  return [], []
51
51
 
52
- facts, chunks = await extract_facts_from_text(
52
+ facts, chunks, _ = await extract_facts_from_text(
53
53
  text,
54
54
  event_date,
55
55
  context=context,
@@ -96,7 +96,7 @@ class DefaultExtensionContext(ExtensionContext):
96
96
 
97
97
  async def run_migration(self, schema: str) -> None:
98
98
  """Run migrations for a specific schema."""
99
- from hindsight_api.migrations import run_migrations
99
+ from hindsight_api.migrations import ensure_embedding_dimension, run_migrations
100
100
 
101
101
  # Prefer getting URL from memory engine (handles pg0 case where URL is set after init)
102
102
  db_url = self._database_url
@@ -107,6 +107,15 @@ class DefaultExtensionContext(ExtensionContext):
107
107
 
108
108
  run_migrations(db_url, schema=schema)
109
109
 
110
+ # Ensure embedding column dimension matches the model's dimension
111
+ # This is needed because migrations create columns with default dimension
112
+ if self._memory_engine is not None:
113
+ embeddings = getattr(self._memory_engine, "embeddings", None)
114
+ if embeddings is not None:
115
+ dimension = getattr(embeddings, "dimension", None)
116
+ if dimension is not None:
117
+ ensure_embedding_dimension(db_url, dimension, schema=schema)
118
+
110
119
  def get_memory_engine(self) -> "MemoryEngineInterface":
111
120
  """Get the memory engine interface."""
112
121
  if self._memory_engine is None:
hindsight_api/main.py CHANGED
@@ -23,7 +23,7 @@ import uvicorn
23
23
  from . import MemoryEngine
24
24
  from .api import create_app
25
25
  from .banner import print_banner
26
- from .config import HindsightConfig, get_config
26
+ from .config import DEFAULT_WORKERS, ENV_WORKERS, HindsightConfig, get_config
27
27
  from .daemon import (
28
28
  DEFAULT_DAEMON_PORT,
29
29
  DEFAULT_IDLE_TIMEOUT,
@@ -95,7 +95,12 @@ def main():
95
95
 
96
96
  # Development options
97
97
  parser.add_argument("--reload", action="store_true", help="Enable auto-reload on code changes (development only)")
98
- parser.add_argument("--workers", type=int, default=1, help="Number of worker processes (default: 1)")
98
+ parser.add_argument(
99
+ "--workers",
100
+ type=int,
101
+ default=int(os.getenv(ENV_WORKERS, str(DEFAULT_WORKERS))),
102
+ help=f"Number of worker processes (env: {ENV_WORKERS}, default: {DEFAULT_WORKERS})",
103
+ )
99
104
 
100
105
  # Access log options
101
106
  parser.add_argument("--access-log", action="store_true", help="Enable access log")
@@ -171,21 +176,51 @@ def main():
171
176
  llm_base_url=config.llm_base_url,
172
177
  llm_max_concurrent=config.llm_max_concurrent,
173
178
  llm_timeout=config.llm_timeout,
179
+ retain_llm_provider=config.retain_llm_provider,
180
+ retain_llm_api_key=config.retain_llm_api_key,
181
+ retain_llm_model=config.retain_llm_model,
182
+ retain_llm_base_url=config.retain_llm_base_url,
183
+ reflect_llm_provider=config.reflect_llm_provider,
184
+ reflect_llm_api_key=config.reflect_llm_api_key,
185
+ reflect_llm_model=config.reflect_llm_model,
186
+ reflect_llm_base_url=config.reflect_llm_base_url,
174
187
  embeddings_provider=config.embeddings_provider,
175
188
  embeddings_local_model=config.embeddings_local_model,
176
189
  embeddings_tei_url=config.embeddings_tei_url,
190
+ embeddings_openai_base_url=config.embeddings_openai_base_url,
191
+ embeddings_cohere_base_url=config.embeddings_cohere_base_url,
177
192
  reranker_provider=config.reranker_provider,
178
193
  reranker_local_model=config.reranker_local_model,
179
194
  reranker_tei_url=config.reranker_tei_url,
195
+ reranker_tei_batch_size=config.reranker_tei_batch_size,
196
+ reranker_tei_max_concurrent=config.reranker_tei_max_concurrent,
197
+ reranker_max_candidates=config.reranker_max_candidates,
198
+ reranker_cohere_base_url=config.reranker_cohere_base_url,
180
199
  host=args.host,
181
200
  port=args.port,
182
201
  log_level=args.log_level,
183
202
  mcp_enabled=config.mcp_enabled,
184
203
  graph_retriever=config.graph_retriever,
204
+ mpfp_top_k_neighbors=config.mpfp_top_k_neighbors,
205
+ recall_max_concurrent=config.recall_max_concurrent,
206
+ recall_connection_budget=config.recall_connection_budget,
185
207
  observation_min_facts=config.observation_min_facts,
186
208
  observation_top_entities=config.observation_top_entities,
209
+ retain_max_completion_tokens=config.retain_max_completion_tokens,
210
+ retain_chunk_size=config.retain_chunk_size,
211
+ retain_extract_causal_links=config.retain_extract_causal_links,
212
+ retain_extraction_mode=config.retain_extraction_mode,
213
+ retain_observations_async=config.retain_observations_async,
187
214
  skip_llm_verification=config.skip_llm_verification,
188
215
  lazy_reranker=config.lazy_reranker,
216
+ run_migrations_on_startup=config.run_migrations_on_startup,
217
+ db_pool_min_size=config.db_pool_min_size,
218
+ db_pool_max_size=config.db_pool_max_size,
219
+ db_command_timeout=config.db_command_timeout,
220
+ db_acquire_timeout=config.db_acquire_timeout,
221
+ task_backend=config.task_backend,
222
+ task_backend_memory_batch_size=config.task_backend_memory_batch_size,
223
+ task_backend_memory_batch_interval=config.task_backend_memory_batch_interval,
189
224
  )
190
225
  config.configure_logging()
191
226
  if not args.daemon:
@@ -211,7 +246,11 @@ def main():
211
246
  logging.info(f"Loaded tenant extension: {tenant_extension.__class__.__name__}")
212
247
 
213
248
  # Create MemoryEngine (reads configuration from environment)
214
- _memory = MemoryEngine(operation_validator=operation_validator, tenant_extension=tenant_extension)
249
+ _memory = MemoryEngine(
250
+ operation_validator=operation_validator,
251
+ tenant_extension=tenant_extension,
252
+ run_migrations=config.run_migrations_on_startup,
253
+ )
215
254
 
216
255
  # Set extension context on tenant extension (needed for schema provisioning)
217
256
  if tenant_extension:
@@ -238,14 +277,27 @@ def main():
238
277
  app = idle_middleware
239
278
 
240
279
  # Prepare uvicorn config
280
+ # When using workers or reload, we must use import string so each worker can import the app
281
+ use_import_string = args.workers > 1 or args.reload
282
+ # Check for uvloop availability
283
+ try:
284
+ import uvloop # noqa: F401
285
+
286
+ loop_impl = "uvloop"
287
+ print("uvloop available, will use for event loop")
288
+ except ImportError:
289
+ loop_impl = "asyncio"
290
+ print("uvloop not installed, using default asyncio event loop")
291
+
241
292
  uvicorn_config = {
242
- "app": app,
293
+ "app": "hindsight_api.server:app" if use_import_string else app,
243
294
  "host": args.host,
244
295
  "port": args.port,
245
296
  "log_level": args.log_level,
246
297
  "access_log": args.access_log,
247
298
  "proxy_headers": args.proxy_headers,
248
299
  "ws": "wsproto", # Use wsproto instead of websockets to avoid deprecation warnings
300
+ "loop": loop_impl, # Explicitly set event loop implementation
249
301
  }
250
302
 
251
303
  # Add optional parameters if provided