hindsight-api 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. hindsight_api/__init__.py +10 -9
  2. hindsight_api/alembic/env.py +5 -8
  3. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
  4. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
  5. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
  6. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
  7. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
  8. hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
  9. hindsight_api/api/__init__.py +10 -10
  10. hindsight_api/api/http.py +575 -593
  11. hindsight_api/api/mcp.py +30 -28
  12. hindsight_api/banner.py +13 -6
  13. hindsight_api/config.py +9 -13
  14. hindsight_api/engine/__init__.py +9 -9
  15. hindsight_api/engine/cross_encoder.py +22 -21
  16. hindsight_api/engine/db_utils.py +5 -4
  17. hindsight_api/engine/embeddings.py +22 -21
  18. hindsight_api/engine/entity_resolver.py +81 -75
  19. hindsight_api/engine/llm_wrapper.py +61 -79
  20. hindsight_api/engine/memory_engine.py +603 -625
  21. hindsight_api/engine/query_analyzer.py +100 -97
  22. hindsight_api/engine/response_models.py +105 -106
  23. hindsight_api/engine/retain/__init__.py +9 -16
  24. hindsight_api/engine/retain/bank_utils.py +34 -58
  25. hindsight_api/engine/retain/chunk_storage.py +4 -12
  26. hindsight_api/engine/retain/deduplication.py +9 -28
  27. hindsight_api/engine/retain/embedding_processing.py +4 -11
  28. hindsight_api/engine/retain/embedding_utils.py +3 -4
  29. hindsight_api/engine/retain/entity_processing.py +7 -17
  30. hindsight_api/engine/retain/fact_extraction.py +155 -165
  31. hindsight_api/engine/retain/fact_storage.py +11 -23
  32. hindsight_api/engine/retain/link_creation.py +11 -39
  33. hindsight_api/engine/retain/link_utils.py +166 -95
  34. hindsight_api/engine/retain/observation_regeneration.py +39 -52
  35. hindsight_api/engine/retain/orchestrator.py +72 -62
  36. hindsight_api/engine/retain/types.py +49 -43
  37. hindsight_api/engine/search/__init__.py +5 -5
  38. hindsight_api/engine/search/fusion.py +6 -15
  39. hindsight_api/engine/search/graph_retrieval.py +22 -23
  40. hindsight_api/engine/search/mpfp_retrieval.py +76 -92
  41. hindsight_api/engine/search/observation_utils.py +9 -16
  42. hindsight_api/engine/search/reranking.py +4 -7
  43. hindsight_api/engine/search/retrieval.py +87 -66
  44. hindsight_api/engine/search/scoring.py +5 -7
  45. hindsight_api/engine/search/temporal_extraction.py +8 -11
  46. hindsight_api/engine/search/think_utils.py +115 -39
  47. hindsight_api/engine/search/trace.py +68 -39
  48. hindsight_api/engine/search/tracer.py +44 -35
  49. hindsight_api/engine/search/types.py +20 -17
  50. hindsight_api/engine/task_backend.py +21 -26
  51. hindsight_api/engine/utils.py +25 -10
  52. hindsight_api/main.py +21 -40
  53. hindsight_api/mcp_local.py +190 -0
  54. hindsight_api/metrics.py +44 -30
  55. hindsight_api/migrations.py +10 -8
  56. hindsight_api/models.py +60 -72
  57. hindsight_api/pg0.py +22 -23
  58. hindsight_api/server.py +3 -6
  59. hindsight_api-0.1.7.dist-info/METADATA +178 -0
  60. hindsight_api-0.1.7.dist-info/RECORD +64 -0
  61. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/entry_points.txt +1 -0
  62. hindsight_api-0.1.5.dist-info/METADATA +0 -42
  63. hindsight_api-0.1.5.dist-info/RECORD +0 -63
  64. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/WHEEL +0 -0
hindsight_api/api/mcp.py CHANGED
@@ -4,27 +4,33 @@ import json
4
4
  import logging
5
5
  import os
6
6
  from contextvars import ContextVar
7
- from typing import Optional
8
7
 
9
8
  from fastmcp import FastMCP
9
+
10
10
  from hindsight_api import MemoryEngine
11
11
  from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
12
12
 
13
13
  # Configure logging from HINDSIGHT_API_LOG_LEVEL environment variable
14
14
  _log_level_str = os.environ.get("HINDSIGHT_API_LOG_LEVEL", "info").lower()
15
- _log_level_map = {"critical": logging.CRITICAL, "error": logging.ERROR, "warning": logging.WARNING,
16
- "info": logging.INFO, "debug": logging.DEBUG, "trace": logging.DEBUG}
15
+ _log_level_map = {
16
+ "critical": logging.CRITICAL,
17
+ "error": logging.ERROR,
18
+ "warning": logging.WARNING,
19
+ "info": logging.INFO,
20
+ "debug": logging.DEBUG,
21
+ "trace": logging.DEBUG,
22
+ }
17
23
  logging.basicConfig(
18
24
  level=_log_level_map.get(_log_level_str, logging.INFO),
19
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
25
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
20
26
  )
21
27
  logger = logging.getLogger(__name__)
22
28
 
23
29
  # Context variable to hold the current bank_id from the URL path
24
- _current_bank_id: ContextVar[Optional[str]] = ContextVar("current_bank_id", default=None)
30
+ _current_bank_id: ContextVar[str | None] = ContextVar("current_bank_id", default=None)
25
31
 
26
32
 
27
- def get_current_bank_id() -> Optional[str]:
33
+ def get_current_bank_id() -> str | None:
28
34
  """Get the current bank_id from context (set from URL path)."""
29
35
  return _current_bank_id.get()
30
36
 
@@ -61,10 +67,7 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
61
67
  """
62
68
  try:
63
69
  bank_id = get_current_bank_id()
64
- await memory.put_batch_async(
65
- bank_id=bank_id,
66
- contents=[{"content": content, "context": context}]
67
- )
70
+ await memory.retain_batch_async(bank_id=bank_id, contents=[{"content": content, "context": context}])
68
71
  return "Memory stored successfully"
69
72
  except Exception as e:
70
73
  logger.error(f"Error storing memory: {e}", exc_info=True)
@@ -88,11 +91,9 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
88
91
  try:
89
92
  bank_id = get_current_bank_id()
90
93
  from hindsight_api.engine.memory_engine import Budget
94
+
91
95
  search_result = await memory.recall_async(
92
- bank_id=bank_id,
93
- query=query,
94
- fact_type=list(VALID_RECALL_FACT_TYPES),
95
- budget=Budget.LOW
96
+ bank_id=bank_id, query=query, fact_type=list(VALID_RECALL_FACT_TYPES), budget=Budget.LOW
96
97
  )
97
98
 
98
99
  results = [
@@ -133,7 +134,7 @@ class MCPMiddleware:
133
134
  # Strip any mount prefix (e.g., /mcp) that FastAPI might not have stripped
134
135
  root_path = scope.get("root_path", "")
135
136
  if root_path and path.startswith(root_path):
136
- path = path[len(root_path):] or "/"
137
+ path = path[len(root_path) :] or "/"
137
138
 
138
139
  # Also handle case where mount path wasn't stripped (e.g., /mcp/...)
139
140
  if path.startswith("/mcp/"):
@@ -169,10 +170,7 @@ class MCPMiddleware:
169
170
  body = message.get("body", b"")
170
171
  if body and b"/messages" in body:
171
172
  # Rewrite /messages to /{bank_id}/messages in SSE endpoint event
172
- body = body.replace(
173
- b"data: /messages",
174
- f"data: /{bank_id}/messages".encode()
175
- )
173
+ body = body.replace(b"data: /messages", f"data: /{bank_id}/messages".encode())
176
174
  message = {**message, "body": body}
177
175
  await send(message)
178
176
 
@@ -183,15 +181,19 @@ class MCPMiddleware:
183
181
  async def _send_error(self, send, status: int, message: str):
184
182
  """Send an error response."""
185
183
  body = json.dumps({"error": message}).encode()
186
- await send({
187
- "type": "http.response.start",
188
- "status": status,
189
- "headers": [(b"content-type", b"application/json")],
190
- })
191
- await send({
192
- "type": "http.response.body",
193
- "body": body,
194
- })
184
+ await send(
185
+ {
186
+ "type": "http.response.start",
187
+ "status": status,
188
+ "headers": [(b"content-type", b"application/json")],
189
+ }
190
+ )
191
+ await send(
192
+ {
193
+ "type": "http.response.body",
194
+ "body": body,
195
+ }
196
+ )
195
197
 
196
198
 
197
199
  def create_mcp_app(memory: MemoryEngine):
hindsight_api/banner.py CHANGED
@@ -6,7 +6,7 @@ Shows the logo and tagline with gradient colors.
6
6
 
7
7
  # Gradient colors: #0074d9 -> #009296
8
8
  GRADIENT_START = (0, 116, 217) # #0074d9
9
- GRADIENT_END = (0, 146, 150) # #009296
9
+ GRADIENT_END = (0, 146, 150) # #009296
10
10
 
11
11
  # Pre-generated logo (generated by test-logo.py)
12
12
  LOGO = """\
@@ -31,8 +31,8 @@ def gradient_text(text: str, start: tuple = GRADIENT_START, end: tuple = GRADIEN
31
31
  result = []
32
32
  length = len(text)
33
33
  for i, char in enumerate(text):
34
- if char == ' ':
35
- result.append(' ')
34
+ if char == " ":
35
+ result.append(" ")
36
36
  else:
37
37
  t = i / max(length - 1, 1)
38
38
  r, g, b = _interpolate_color(start, end, t)
@@ -74,9 +74,16 @@ def dim(text: str) -> str:
74
74
  return f"\033[38;2;128;128;128m{text}\033[0m"
75
75
 
76
76
 
77
- def print_startup_info(host: str, port: int, database_url: str, llm_provider: str,
78
- llm_model: str, embeddings_provider: str, reranker_provider: str,
79
- mcp_enabled: bool = False):
77
+ def print_startup_info(
78
+ host: str,
79
+ port: int,
80
+ database_url: str,
81
+ llm_provider: str,
82
+ llm_model: str,
83
+ embeddings_provider: str,
84
+ reranker_provider: str,
85
+ mcp_enabled: bool = False,
86
+ ):
80
87
  """Print styled startup information."""
81
88
  print(color_start("Starting Hindsight API..."))
82
89
  print(f" {dim('URL:')} {color(f'http://{host}:{port}', 0.2)}")
hindsight_api/config.py CHANGED
@@ -3,10 +3,10 @@ Centralized configuration for Hindsight API.
3
3
 
4
4
  All environment variables and their defaults are defined here.
5
5
  """
6
+
7
+ import logging
6
8
  import os
7
9
  from dataclasses import dataclass
8
- from typing import Optional
9
- import logging
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
@@ -30,6 +30,7 @@ ENV_PORT = "HINDSIGHT_API_PORT"
30
30
  ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
31
31
  ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
32
32
  ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
33
+ ENV_MCP_LOCAL_BANK_ID = "HINDSIGHT_API_MCP_LOCAL_BANK_ID"
33
34
 
34
35
  # Default values
35
36
  DEFAULT_DATABASE_URL = "pg0"
@@ -47,6 +48,7 @@ DEFAULT_PORT = 8888
47
48
  DEFAULT_LOG_LEVEL = "info"
48
49
  DEFAULT_MCP_ENABLED = True
49
50
  DEFAULT_GRAPH_RETRIEVER = "bfs" # Options: "bfs", "mpfp"
51
+ DEFAULT_MCP_LOCAL_BANK_ID = "mcp"
50
52
 
51
53
  # Required embedding dimension for database schema
52
54
  EMBEDDING_DIMENSION = 384
@@ -61,19 +63,19 @@ class HindsightConfig:
61
63
 
62
64
  # LLM
63
65
  llm_provider: str
64
- llm_api_key: Optional[str]
66
+ llm_api_key: str | None
65
67
  llm_model: str
66
- llm_base_url: Optional[str]
68
+ llm_base_url: str | None
67
69
 
68
70
  # Embeddings
69
71
  embeddings_provider: str
70
72
  embeddings_local_model: str
71
- embeddings_tei_url: Optional[str]
73
+ embeddings_tei_url: str | None
72
74
 
73
75
  # Reranker
74
76
  reranker_provider: str
75
77
  reranker_local_model: str
76
- reranker_tei_url: Optional[str]
78
+ reranker_tei_url: str | None
77
79
 
78
80
  # Server
79
81
  host: str
@@ -90,29 +92,24 @@ class HindsightConfig:
90
92
  return cls(
91
93
  # Database
92
94
  database_url=os.getenv(ENV_DATABASE_URL, DEFAULT_DATABASE_URL),
93
-
94
95
  # LLM
95
96
  llm_provider=os.getenv(ENV_LLM_PROVIDER, DEFAULT_LLM_PROVIDER),
96
97
  llm_api_key=os.getenv(ENV_LLM_API_KEY),
97
98
  llm_model=os.getenv(ENV_LLM_MODEL, DEFAULT_LLM_MODEL),
98
99
  llm_base_url=os.getenv(ENV_LLM_BASE_URL) or None,
99
-
100
100
  # Embeddings
101
101
  embeddings_provider=os.getenv(ENV_EMBEDDINGS_PROVIDER, DEFAULT_EMBEDDINGS_PROVIDER),
102
102
  embeddings_local_model=os.getenv(ENV_EMBEDDINGS_LOCAL_MODEL, DEFAULT_EMBEDDINGS_LOCAL_MODEL),
103
103
  embeddings_tei_url=os.getenv(ENV_EMBEDDINGS_TEI_URL),
104
-
105
104
  # Reranker
106
105
  reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
107
106
  reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
108
107
  reranker_tei_url=os.getenv(ENV_RERANKER_TEI_URL),
109
-
110
108
  # Server
111
109
  host=os.getenv(ENV_HOST, DEFAULT_HOST),
112
110
  port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
113
111
  log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
114
112
  mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
115
-
116
113
  # Recall
117
114
  graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
118
115
  )
@@ -145,8 +142,7 @@ class HindsightConfig:
145
142
  def configure_logging(self) -> None:
146
143
  """Configure Python logging based on the log level."""
147
144
  logging.basicConfig(
148
- level=self.get_python_log_level(),
149
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
145
+ level=self.get_python_log_level(), format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
150
146
  )
151
147
 
152
148
  def log_config(self) -> None:
@@ -7,24 +7,24 @@ This package contains all the implementation details of the memory engine:
7
7
  - Supporting modules: embeddings, cross_encoder, entity_resolver, etc.
8
8
  """
9
9
 
10
- from .memory_engine import MemoryEngine
10
+ from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
11
11
  from .db_utils import acquire_with_retry
12
12
  from .embeddings import Embeddings, LocalSTEmbeddings, RemoteTEIEmbeddings
13
- from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
13
+ from .llm_wrapper import LLMConfig
14
+ from .memory_engine import MemoryEngine
15
+ from .response_models import MemoryFact, RecallResult, ReflectResult
14
16
  from .search.trace import (
15
- SearchTrace,
16
- QueryInfo,
17
17
  EntryPoint,
18
- NodeVisit,
19
- WeightComponents,
20
18
  LinkInfo,
19
+ NodeVisit,
21
20
  PruningDecision,
22
- SearchSummary,
21
+ QueryInfo,
23
22
  SearchPhaseMetrics,
23
+ SearchSummary,
24
+ SearchTrace,
25
+ WeightComponents,
24
26
  )
25
27
  from .search.tracer import SearchTracer
26
- from .llm_wrapper import LLMConfig
27
- from .response_models import RecallResult, ReflectResult, MemoryFact
28
28
 
29
29
  __all__ = [
30
30
  "MemoryEngine",
@@ -5,19 +5,19 @@ Provides an interface for reranking with different backends.
5
5
 
6
6
  Configuration via environment variables - see hindsight_api.config for all env var names.
7
7
  """
8
- from abc import ABC, abstractmethod
9
- from typing import List, Tuple, Optional
8
+
10
9
  import logging
11
10
  import os
11
+ from abc import ABC, abstractmethod
12
12
 
13
13
  import httpx
14
14
 
15
15
  from ..config import (
16
- ENV_RERANKER_PROVIDER,
16
+ DEFAULT_RERANKER_LOCAL_MODEL,
17
+ DEFAULT_RERANKER_PROVIDER,
17
18
  ENV_RERANKER_LOCAL_MODEL,
19
+ ENV_RERANKER_PROVIDER,
18
20
  ENV_RERANKER_TEI_URL,
19
- DEFAULT_RERANKER_PROVIDER,
20
- DEFAULT_RERANKER_LOCAL_MODEL,
21
21
  )
22
22
 
23
23
  logger = logging.getLogger(__name__)
@@ -47,7 +47,7 @@ class CrossEncoderModel(ABC):
47
47
  pass
48
48
 
49
49
  @abstractmethod
50
- def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
50
+ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
51
51
  """
52
52
  Score query-document pairs for relevance.
53
53
 
@@ -72,7 +72,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
72
72
  - Trained for passage re-ranking
73
73
  """
74
74
 
75
- def __init__(self, model_name: Optional[str] = None):
75
+ def __init__(self, model_name: str | None = None):
76
76
  """
77
77
  Initialize local SentenceTransformers cross-encoder.
78
78
 
@@ -104,7 +104,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
104
104
  self._model = CrossEncoder(self.model_name)
105
105
  logger.info("Reranker: local provider initialized")
106
106
 
107
- def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
107
+ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
108
108
  """
109
109
  Score query-document pairs for relevance.
110
110
 
@@ -117,7 +117,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
117
117
  if self._model is None:
118
118
  raise RuntimeError("Reranker not initialized. Call initialize() first.")
119
119
  scores = self._model.predict(pairs, show_progress_bar=False)
120
- return scores.tolist() if hasattr(scores, 'tolist') else list(scores)
120
+ return scores.tolist() if hasattr(scores, "tolist") else list(scores)
121
121
 
122
122
 
123
123
  class RemoteTEICrossEncoder(CrossEncoderModel):
@@ -153,8 +153,8 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
153
153
  self.batch_size = batch_size
154
154
  self.max_retries = max_retries
155
155
  self.retry_delay = retry_delay
156
- self._client: Optional[httpx.Client] = None
157
- self._model_id: Optional[str] = None
156
+ self._client: httpx.Client | None = None
157
+ self._model_id: str | None = None
158
158
 
159
159
  @property
160
160
  def provider_name(self) -> str:
@@ -163,6 +163,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
163
163
  def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
164
164
  """Make an HTTP request with automatic retries on transient errors."""
165
165
  import time
166
+
166
167
  last_error = None
167
168
  delay = self.retry_delay
168
169
 
@@ -177,14 +178,18 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
177
178
  except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
178
179
  last_error = e
179
180
  if attempt < self.max_retries:
180
- logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
181
+ logger.warning(
182
+ f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
183
+ )
181
184
  time.sleep(delay)
182
185
  delay *= 2 # Exponential backoff
183
186
  except httpx.HTTPStatusError as e:
184
187
  # Retry on 5xx server errors
185
188
  if e.response.status_code >= 500 and attempt < self.max_retries:
186
189
  last_error = e
187
- logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
190
+ logger.warning(
191
+ f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
192
+ )
188
193
  time.sleep(delay)
189
194
  delay *= 2
190
195
  else:
@@ -209,7 +214,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
209
214
  except httpx.HTTPError as e:
210
215
  raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
211
216
 
212
- def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
217
+ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
213
218
  """
214
219
  Score query-document pairs using the remote TEI reranker.
215
220
 
@@ -229,7 +234,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
229
234
 
230
235
  # Process in batches
231
236
  for i in range(0, len(pairs), self.batch_size):
232
- batch = pairs[i:i + self.batch_size]
237
+ batch = pairs[i : i + self.batch_size]
233
238
 
234
239
  # TEI rerank endpoint expects query and texts separately
235
240
  # All pairs in a batch should have the same query for optimal performance
@@ -287,15 +292,11 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
287
292
  if provider == "tei":
288
293
  url = os.environ.get(ENV_RERANKER_TEI_URL)
289
294
  if not url:
290
- raise ValueError(
291
- f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'"
292
- )
295
+ raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
293
296
  return RemoteTEICrossEncoder(base_url=url)
294
297
  elif provider == "local":
295
298
  model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
296
299
  model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
297
300
  return LocalSTCrossEncoder(model_name=model_name)
298
301
  else:
299
- raise ValueError(
300
- f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'"
301
- )
302
+ raise ValueError(f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'")
@@ -1,9 +1,11 @@
1
1
  """
2
2
  Database utility functions for connection management with retry logic.
3
3
  """
4
+
4
5
  import asyncio
5
6
  import logging
6
7
  from contextlib import asynccontextmanager
8
+
7
9
  import asyncpg
8
10
 
9
11
  logger = logging.getLogger(__name__)
@@ -54,16 +56,14 @@ async def retry_with_backoff(
54
56
  except retryable_exceptions as e:
55
57
  last_exception = e
56
58
  if attempt < max_retries:
57
- delay = min(base_delay * (2 ** attempt), max_delay)
59
+ delay = min(base_delay * (2**attempt), max_delay)
58
60
  logger.warning(
59
61
  f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
60
62
  f"Retrying in {delay:.1f}s..."
61
63
  )
62
64
  await asyncio.sleep(delay)
63
65
  else:
64
- logger.error(
65
- f"Database operation failed after {max_retries + 1} attempts: {e}"
66
- )
66
+ logger.error(f"Database operation failed after {max_retries + 1} attempts: {e}")
67
67
  raise last_exception
68
68
 
69
69
 
@@ -83,6 +83,7 @@ async def acquire_with_retry(pool: asyncpg.Pool, max_retries: int = DEFAULT_MAX_
83
83
  Yields:
84
84
  An asyncpg connection
85
85
  """
86
+
86
87
  async def acquire():
87
88
  return await pool.acquire()
88
89
 
@@ -8,20 +8,20 @@ the database schema (pgvector column defined as vector(384)).
8
8
 
9
9
  Configuration via environment variables - see hindsight_api.config for all env var names.
10
10
  """
11
- from abc import ABC, abstractmethod
12
- from typing import List, Optional
11
+
13
12
  import logging
14
13
  import os
14
+ from abc import ABC, abstractmethod
15
15
 
16
16
  import httpx
17
17
 
18
18
  from ..config import (
19
- ENV_EMBEDDINGS_PROVIDER,
20
- ENV_EMBEDDINGS_LOCAL_MODEL,
21
- ENV_EMBEDDINGS_TEI_URL,
22
- DEFAULT_EMBEDDINGS_PROVIDER,
23
19
  DEFAULT_EMBEDDINGS_LOCAL_MODEL,
20
+ DEFAULT_EMBEDDINGS_PROVIDER,
24
21
  EMBEDDING_DIMENSION,
22
+ ENV_EMBEDDINGS_LOCAL_MODEL,
23
+ ENV_EMBEDDINGS_PROVIDER,
24
+ ENV_EMBEDDINGS_TEI_URL,
25
25
  )
26
26
 
27
27
  logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class Embeddings(ABC):
52
52
  pass
53
53
 
54
54
  @abstractmethod
55
- def encode(self, texts: List[str]) -> List[List[float]]:
55
+ def encode(self, texts: list[str]) -> list[list[float]]:
56
56
  """
57
57
  Generate 384-dimensional embeddings for a list of texts.
58
58
 
@@ -75,7 +75,7 @@ class LocalSTEmbeddings(Embeddings):
75
75
  embeddings matching the database schema.
76
76
  """
77
77
 
78
- def __init__(self, model_name: Optional[str] = None):
78
+ def __init__(self, model_name: str | None = None):
79
79
  """
80
80
  Initialize local SentenceTransformers embeddings.
81
81
 
@@ -123,7 +123,7 @@ class LocalSTEmbeddings(Embeddings):
123
123
 
124
124
  logger.info(f"Embeddings: local provider initialized (dim: {model_dim})")
125
125
 
126
- def encode(self, texts: List[str]) -> List[List[float]]:
126
+ def encode(self, texts: list[str]) -> list[list[float]]:
127
127
  """
128
128
  Generate 384-dimensional embeddings for a list of texts.
129
129
 
@@ -172,8 +172,8 @@ class RemoteTEIEmbeddings(Embeddings):
172
172
  self.batch_size = batch_size
173
173
  self.max_retries = max_retries
174
174
  self.retry_delay = retry_delay
175
- self._client: Optional[httpx.Client] = None
176
- self._model_id: Optional[str] = None
175
+ self._client: httpx.Client | None = None
176
+ self._model_id: str | None = None
177
177
 
178
178
  @property
179
179
  def provider_name(self) -> str:
@@ -182,6 +182,7 @@ class RemoteTEIEmbeddings(Embeddings):
182
182
  def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
183
183
  """Make an HTTP request with automatic retries on transient errors."""
184
184
  import time
185
+
185
186
  last_error = None
186
187
  delay = self.retry_delay
187
188
 
@@ -196,14 +197,18 @@ class RemoteTEIEmbeddings(Embeddings):
196
197
  except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
197
198
  last_error = e
198
199
  if attempt < self.max_retries:
199
- logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
200
+ logger.warning(
201
+ f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
202
+ )
200
203
  time.sleep(delay)
201
204
  delay *= 2 # Exponential backoff
202
205
  except httpx.HTTPStatusError as e:
203
206
  # Retry on 5xx server errors
204
207
  if e.response.status_code >= 500 and attempt < self.max_retries:
205
208
  last_error = e
206
- logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
209
+ logger.warning(
210
+ f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
211
+ )
207
212
  time.sleep(delay)
208
213
  delay *= 2
209
214
  else:
@@ -228,7 +233,7 @@ class RemoteTEIEmbeddings(Embeddings):
228
233
  except httpx.HTTPError as e:
229
234
  raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
230
235
 
231
- def encode(self, texts: List[str]) -> List[List[float]]:
236
+ def encode(self, texts: list[str]) -> list[list[float]]:
232
237
  """
233
238
  Generate embeddings using the remote TEI server.
234
239
 
@@ -248,7 +253,7 @@ class RemoteTEIEmbeddings(Embeddings):
248
253
 
249
254
  # Process in batches
250
255
  for i in range(0, len(texts), self.batch_size):
251
- batch = texts[i:i + self.batch_size]
256
+ batch = texts[i : i + self.batch_size]
252
257
 
253
258
  try:
254
259
  response = self._request_with_retry(
@@ -278,15 +283,11 @@ def create_embeddings_from_env() -> Embeddings:
278
283
  if provider == "tei":
279
284
  url = os.environ.get(ENV_EMBEDDINGS_TEI_URL)
280
285
  if not url:
281
- raise ValueError(
282
- f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'"
283
- )
286
+ raise ValueError(f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'")
284
287
  return RemoteTEIEmbeddings(base_url=url)
285
288
  elif provider == "local":
286
289
  model = os.environ.get(ENV_EMBEDDINGS_LOCAL_MODEL)
287
290
  model_name = model or DEFAULT_EMBEDDINGS_LOCAL_MODEL
288
291
  return LocalSTEmbeddings(model_name=model_name)
289
292
  else:
290
- raise ValueError(
291
- f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'"
292
- )
293
+ raise ValueError(f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'")