hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. hindsight_api/__init__.py +10 -9
  2. hindsight_api/alembic/env.py +5 -8
  3. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
  4. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
  5. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
  6. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
  7. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
  8. hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
  9. hindsight_api/api/__init__.py +10 -10
  10. hindsight_api/api/http.py +575 -593
  11. hindsight_api/api/mcp.py +31 -33
  12. hindsight_api/banner.py +13 -6
  13. hindsight_api/config.py +17 -12
  14. hindsight_api/engine/__init__.py +9 -9
  15. hindsight_api/engine/cross_encoder.py +23 -27
  16. hindsight_api/engine/db_utils.py +5 -4
  17. hindsight_api/engine/embeddings.py +22 -21
  18. hindsight_api/engine/entity_resolver.py +81 -75
  19. hindsight_api/engine/llm_wrapper.py +74 -88
  20. hindsight_api/engine/memory_engine.py +663 -673
  21. hindsight_api/engine/query_analyzer.py +100 -97
  22. hindsight_api/engine/response_models.py +105 -106
  23. hindsight_api/engine/retain/__init__.py +9 -16
  24. hindsight_api/engine/retain/bank_utils.py +34 -58
  25. hindsight_api/engine/retain/chunk_storage.py +4 -12
  26. hindsight_api/engine/retain/deduplication.py +9 -28
  27. hindsight_api/engine/retain/embedding_processing.py +4 -11
  28. hindsight_api/engine/retain/embedding_utils.py +3 -4
  29. hindsight_api/engine/retain/entity_processing.py +7 -17
  30. hindsight_api/engine/retain/fact_extraction.py +155 -165
  31. hindsight_api/engine/retain/fact_storage.py +11 -23
  32. hindsight_api/engine/retain/link_creation.py +11 -39
  33. hindsight_api/engine/retain/link_utils.py +166 -95
  34. hindsight_api/engine/retain/observation_regeneration.py +39 -52
  35. hindsight_api/engine/retain/orchestrator.py +72 -62
  36. hindsight_api/engine/retain/types.py +49 -43
  37. hindsight_api/engine/search/__init__.py +15 -1
  38. hindsight_api/engine/search/fusion.py +6 -15
  39. hindsight_api/engine/search/graph_retrieval.py +234 -0
  40. hindsight_api/engine/search/mpfp_retrieval.py +438 -0
  41. hindsight_api/engine/search/observation_utils.py +9 -16
  42. hindsight_api/engine/search/reranking.py +4 -7
  43. hindsight_api/engine/search/retrieval.py +388 -193
  44. hindsight_api/engine/search/scoring.py +5 -7
  45. hindsight_api/engine/search/temporal_extraction.py +8 -11
  46. hindsight_api/engine/search/think_utils.py +115 -39
  47. hindsight_api/engine/search/trace.py +68 -38
  48. hindsight_api/engine/search/tracer.py +49 -35
  49. hindsight_api/engine/search/types.py +22 -16
  50. hindsight_api/engine/task_backend.py +21 -26
  51. hindsight_api/engine/utils.py +25 -10
  52. hindsight_api/main.py +21 -40
  53. hindsight_api/mcp_local.py +190 -0
  54. hindsight_api/metrics.py +44 -30
  55. hindsight_api/migrations.py +10 -8
  56. hindsight_api/models.py +60 -72
  57. hindsight_api/pg0.py +64 -337
  58. hindsight_api/server.py +3 -6
  59. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
  60. hindsight_api-0.1.6.dist-info/RECORD +64 -0
  61. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
  62. hindsight_api-0.1.4.dist-info/RECORD +0 -61
  63. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
hindsight_api/api/mcp.py CHANGED
@@ -4,27 +4,33 @@ import json
4
4
  import logging
5
5
  import os
6
6
  from contextvars import ContextVar
7
- from typing import Optional
8
7
 
9
8
  from fastmcp import FastMCP
9
+
10
10
  from hindsight_api import MemoryEngine
11
11
  from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
12
12
 
13
13
  # Configure logging from HINDSIGHT_API_LOG_LEVEL environment variable
14
14
  _log_level_str = os.environ.get("HINDSIGHT_API_LOG_LEVEL", "info").lower()
15
- _log_level_map = {"critical": logging.CRITICAL, "error": logging.ERROR, "warning": logging.WARNING,
16
- "info": logging.INFO, "debug": logging.DEBUG, "trace": logging.DEBUG}
15
+ _log_level_map = {
16
+ "critical": logging.CRITICAL,
17
+ "error": logging.ERROR,
18
+ "warning": logging.WARNING,
19
+ "info": logging.INFO,
20
+ "debug": logging.DEBUG,
21
+ "trace": logging.DEBUG,
22
+ }
17
23
  logging.basicConfig(
18
24
  level=_log_level_map.get(_log_level_str, logging.INFO),
19
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
25
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
20
26
  )
21
27
  logger = logging.getLogger(__name__)
22
28
 
23
29
  # Context variable to hold the current bank_id from the URL path
24
- _current_bank_id: ContextVar[Optional[str]] = ContextVar("current_bank_id", default=None)
30
+ _current_bank_id: ContextVar[str | None] = ContextVar("current_bank_id", default=None)
25
31
 
26
32
 
27
- def get_current_bank_id() -> Optional[str]:
33
+ def get_current_bank_id() -> str | None:
28
34
  """Get the current bank_id from context (set from URL path)."""
29
35
  return _current_bank_id.get()
30
36
 
@@ -61,10 +67,7 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
61
67
  """
62
68
  try:
63
69
  bank_id = get_current_bank_id()
64
- await memory.put_batch_async(
65
- bank_id=bank_id,
66
- contents=[{"content": content, "context": context}]
67
- )
70
+ await memory.retain_batch_async(bank_id=bank_id, contents=[{"content": content, "context": context}])
68
71
  return "Memory stored successfully"
69
72
  except Exception as e:
70
73
  logger.error(f"Error storing memory: {e}", exc_info=True)
@@ -88,11 +91,9 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
88
91
  try:
89
92
  bank_id = get_current_bank_id()
90
93
  from hindsight_api.engine.memory_engine import Budget
94
+
91
95
  search_result = await memory.recall_async(
92
- bank_id=bank_id,
93
- query=query,
94
- fact_type=list(VALID_RECALL_FACT_TYPES),
95
- budget=Budget.LOW
96
+ bank_id=bank_id, query=query, fact_type=list(VALID_RECALL_FACT_TYPES), budget=Budget.LOW
96
97
  )
97
98
 
98
99
  results = [
@@ -121,11 +122,7 @@ class MCPMiddleware:
121
122
  self.app = app
122
123
  self.memory = memory
123
124
  self.mcp_server = create_mcp_server(memory)
124
- # Use sse_app - http_app requires lifespan management that's complex with middleware
125
- import warnings
126
- with warnings.catch_warnings():
127
- warnings.simplefilter("ignore", DeprecationWarning)
128
- self.mcp_app = self.mcp_server.sse_app()
125
+ self.mcp_app = self.mcp_server.http_app()
129
126
 
130
127
  async def __call__(self, scope, receive, send):
131
128
  if scope["type"] != "http":
@@ -137,7 +134,7 @@ class MCPMiddleware:
137
134
  # Strip any mount prefix (e.g., /mcp) that FastAPI might not have stripped
138
135
  root_path = scope.get("root_path", "")
139
136
  if root_path and path.startswith(root_path):
140
- path = path[len(root_path):] or "/"
137
+ path = path[len(root_path) :] or "/"
141
138
 
142
139
  # Also handle case where mount path wasn't stripped (e.g., /mcp/...)
143
140
  if path.startswith("/mcp/"):
@@ -173,10 +170,7 @@ class MCPMiddleware:
173
170
  body = message.get("body", b"")
174
171
  if body and b"/messages" in body:
175
172
  # Rewrite /messages to /{bank_id}/messages in SSE endpoint event
176
- body = body.replace(
177
- b"data: /messages",
178
- f"data: /{bank_id}/messages".encode()
179
- )
173
+ body = body.replace(b"data: /messages", f"data: /{bank_id}/messages".encode())
180
174
  message = {**message, "body": body}
181
175
  await send(message)
182
176
 
@@ -187,15 +181,19 @@ class MCPMiddleware:
187
181
  async def _send_error(self, send, status: int, message: str):
188
182
  """Send an error response."""
189
183
  body = json.dumps({"error": message}).encode()
190
- await send({
191
- "type": "http.response.start",
192
- "status": status,
193
- "headers": [(b"content-type", b"application/json")],
194
- })
195
- await send({
196
- "type": "http.response.body",
197
- "body": body,
198
- })
184
+ await send(
185
+ {
186
+ "type": "http.response.start",
187
+ "status": status,
188
+ "headers": [(b"content-type", b"application/json")],
189
+ }
190
+ )
191
+ await send(
192
+ {
193
+ "type": "http.response.body",
194
+ "body": body,
195
+ }
196
+ )
199
197
 
200
198
 
201
199
  def create_mcp_app(memory: MemoryEngine):
hindsight_api/banner.py CHANGED
@@ -6,7 +6,7 @@ Shows the logo and tagline with gradient colors.
6
6
 
7
7
  # Gradient colors: #0074d9 -> #009296
8
8
  GRADIENT_START = (0, 116, 217) # #0074d9
9
- GRADIENT_END = (0, 146, 150) # #009296
9
+ GRADIENT_END = (0, 146, 150) # #009296
10
10
 
11
11
  # Pre-generated logo (generated by test-logo.py)
12
12
  LOGO = """\
@@ -31,8 +31,8 @@ def gradient_text(text: str, start: tuple = GRADIENT_START, end: tuple = GRADIEN
31
31
  result = []
32
32
  length = len(text)
33
33
  for i, char in enumerate(text):
34
- if char == ' ':
35
- result.append(' ')
34
+ if char == " ":
35
+ result.append(" ")
36
36
  else:
37
37
  t = i / max(length - 1, 1)
38
38
  r, g, b = _interpolate_color(start, end, t)
@@ -74,9 +74,16 @@ def dim(text: str) -> str:
74
74
  return f"\033[38;2;128;128;128m{text}\033[0m"
75
75
 
76
76
 
77
- def print_startup_info(host: str, port: int, database_url: str, llm_provider: str,
78
- llm_model: str, embeddings_provider: str, reranker_provider: str,
79
- mcp_enabled: bool = False):
77
+ def print_startup_info(
78
+ host: str,
79
+ port: int,
80
+ database_url: str,
81
+ llm_provider: str,
82
+ llm_model: str,
83
+ embeddings_provider: str,
84
+ reranker_provider: str,
85
+ mcp_enabled: bool = False,
86
+ ):
80
87
  """Print styled startup information."""
81
88
  print(color_start("Starting Hindsight API..."))
82
89
  print(f" {dim('URL:')} {color(f'http://{host}:{port}', 0.2)}")
hindsight_api/config.py CHANGED
@@ -3,10 +3,10 @@ Centralized configuration for Hindsight API.
3
3
 
4
4
  All environment variables and their defaults are defined here.
5
5
  """
6
+
7
+ import logging
6
8
  import os
7
9
  from dataclasses import dataclass
8
- from typing import Optional
9
- import logging
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
@@ -29,6 +29,8 @@ ENV_HOST = "HINDSIGHT_API_HOST"
29
29
  ENV_PORT = "HINDSIGHT_API_PORT"
30
30
  ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
31
31
  ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
32
+ ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
33
+ ENV_MCP_LOCAL_BANK_ID = "HINDSIGHT_API_MCP_LOCAL_BANK_ID"
32
34
 
33
35
  # Default values
34
36
  DEFAULT_DATABASE_URL = "pg0"
@@ -45,6 +47,8 @@ DEFAULT_HOST = "0.0.0.0"
45
47
  DEFAULT_PORT = 8888
46
48
  DEFAULT_LOG_LEVEL = "info"
47
49
  DEFAULT_MCP_ENABLED = True
50
+ DEFAULT_GRAPH_RETRIEVER = "bfs" # Options: "bfs", "mpfp"
51
+ DEFAULT_MCP_LOCAL_BANK_ID = "mcp"
48
52
 
49
53
  # Required embedding dimension for database schema
50
54
  EMBEDDING_DIMENSION = 384
@@ -59,19 +63,19 @@ class HindsightConfig:
59
63
 
60
64
  # LLM
61
65
  llm_provider: str
62
- llm_api_key: Optional[str]
66
+ llm_api_key: str | None
63
67
  llm_model: str
64
- llm_base_url: Optional[str]
68
+ llm_base_url: str | None
65
69
 
66
70
  # Embeddings
67
71
  embeddings_provider: str
68
72
  embeddings_local_model: str
69
- embeddings_tei_url: Optional[str]
73
+ embeddings_tei_url: str | None
70
74
 
71
75
  # Reranker
72
76
  reranker_provider: str
73
77
  reranker_local_model: str
74
- reranker_tei_url: Optional[str]
78
+ reranker_tei_url: str | None
75
79
 
76
80
  # Server
77
81
  host: str
@@ -79,34 +83,35 @@ class HindsightConfig:
79
83
  log_level: str
80
84
  mcp_enabled: bool
81
85
 
86
+ # Recall
87
+ graph_retriever: str
88
+
82
89
  @classmethod
83
90
  def from_env(cls) -> "HindsightConfig":
84
91
  """Create configuration from environment variables."""
85
92
  return cls(
86
93
  # Database
87
94
  database_url=os.getenv(ENV_DATABASE_URL, DEFAULT_DATABASE_URL),
88
-
89
95
  # LLM
90
96
  llm_provider=os.getenv(ENV_LLM_PROVIDER, DEFAULT_LLM_PROVIDER),
91
97
  llm_api_key=os.getenv(ENV_LLM_API_KEY),
92
98
  llm_model=os.getenv(ENV_LLM_MODEL, DEFAULT_LLM_MODEL),
93
99
  llm_base_url=os.getenv(ENV_LLM_BASE_URL) or None,
94
-
95
100
  # Embeddings
96
101
  embeddings_provider=os.getenv(ENV_EMBEDDINGS_PROVIDER, DEFAULT_EMBEDDINGS_PROVIDER),
97
102
  embeddings_local_model=os.getenv(ENV_EMBEDDINGS_LOCAL_MODEL, DEFAULT_EMBEDDINGS_LOCAL_MODEL),
98
103
  embeddings_tei_url=os.getenv(ENV_EMBEDDINGS_TEI_URL),
99
-
100
104
  # Reranker
101
105
  reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
102
106
  reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
103
107
  reranker_tei_url=os.getenv(ENV_RERANKER_TEI_URL),
104
-
105
108
  # Server
106
109
  host=os.getenv(ENV_HOST, DEFAULT_HOST),
107
110
  port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
108
111
  log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
109
112
  mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
113
+ # Recall
114
+ graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
110
115
  )
111
116
 
112
117
  def get_llm_base_url(self) -> str:
@@ -137,8 +142,7 @@ class HindsightConfig:
137
142
  def configure_logging(self) -> None:
138
143
  """Configure Python logging based on the log level."""
139
144
  logging.basicConfig(
140
- level=self.get_python_log_level(),
141
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
145
+ level=self.get_python_log_level(), format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
142
146
  )
143
147
 
144
148
  def log_config(self) -> None:
@@ -147,6 +151,7 @@ class HindsightConfig:
147
151
  logger.info(f"LLM: provider={self.llm_provider}, model={self.llm_model}")
148
152
  logger.info(f"Embeddings: provider={self.embeddings_provider}")
149
153
  logger.info(f"Reranker: provider={self.reranker_provider}")
154
+ logger.info(f"Graph retriever: {self.graph_retriever}")
150
155
 
151
156
 
152
157
  def get_config() -> HindsightConfig:
@@ -7,24 +7,24 @@ This package contains all the implementation details of the memory engine:
7
7
  - Supporting modules: embeddings, cross_encoder, entity_resolver, etc.
8
8
  """
9
9
 
10
- from .memory_engine import MemoryEngine
10
+ from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
11
11
  from .db_utils import acquire_with_retry
12
12
  from .embeddings import Embeddings, LocalSTEmbeddings, RemoteTEIEmbeddings
13
- from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
13
+ from .llm_wrapper import LLMConfig
14
+ from .memory_engine import MemoryEngine
15
+ from .response_models import MemoryFact, RecallResult, ReflectResult
14
16
  from .search.trace import (
15
- SearchTrace,
16
- QueryInfo,
17
17
  EntryPoint,
18
- NodeVisit,
19
- WeightComponents,
20
18
  LinkInfo,
19
+ NodeVisit,
21
20
  PruningDecision,
22
- SearchSummary,
21
+ QueryInfo,
23
22
  SearchPhaseMetrics,
23
+ SearchSummary,
24
+ SearchTrace,
25
+ WeightComponents,
24
26
  )
25
27
  from .search.tracer import SearchTracer
26
- from .llm_wrapper import LLMConfig
27
- from .response_models import RecallResult, ReflectResult, MemoryFact
28
28
 
29
29
  __all__ = [
30
30
  "MemoryEngine",
@@ -5,19 +5,19 @@ Provides an interface for reranking with different backends.
5
5
 
6
6
  Configuration via environment variables - see hindsight_api.config for all env var names.
7
7
  """
8
- from abc import ABC, abstractmethod
9
- from typing import List, Tuple, Optional
8
+
10
9
  import logging
11
10
  import os
11
+ from abc import ABC, abstractmethod
12
12
 
13
13
  import httpx
14
14
 
15
15
  from ..config import (
16
- ENV_RERANKER_PROVIDER,
16
+ DEFAULT_RERANKER_LOCAL_MODEL,
17
+ DEFAULT_RERANKER_PROVIDER,
17
18
  ENV_RERANKER_LOCAL_MODEL,
19
+ ENV_RERANKER_PROVIDER,
18
20
  ENV_RERANKER_TEI_URL,
19
- DEFAULT_RERANKER_PROVIDER,
20
- DEFAULT_RERANKER_LOCAL_MODEL,
21
21
  )
22
22
 
23
23
  logger = logging.getLogger(__name__)
@@ -47,7 +47,7 @@ class CrossEncoderModel(ABC):
47
47
  pass
48
48
 
49
49
  @abstractmethod
50
- def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
50
+ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
51
51
  """
52
52
  Score query-document pairs for relevance.
53
53
 
@@ -72,7 +72,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
72
72
  - Trained for passage re-ranking
73
73
  """
74
74
 
75
- def __init__(self, model_name: Optional[str] = None):
75
+ def __init__(self, model_name: str | None = None):
76
76
  """
77
77
  Initialize local SentenceTransformers cross-encoder.
78
78
 
@@ -101,15 +101,10 @@ class LocalSTCrossEncoder(CrossEncoderModel):
101
101
  )
102
102
 
103
103
  logger.info(f"Reranker: initializing local provider with model {self.model_name}")
104
- # Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
105
- # Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
106
- self._model = CrossEncoder(
107
- self.model_name,
108
- model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
109
- )
104
+ self._model = CrossEncoder(self.model_name)
110
105
  logger.info("Reranker: local provider initialized")
111
106
 
112
- def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
107
+ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
113
108
  """
114
109
  Score query-document pairs for relevance.
115
110
 
@@ -122,7 +117,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
122
117
  if self._model is None:
123
118
  raise RuntimeError("Reranker not initialized. Call initialize() first.")
124
119
  scores = self._model.predict(pairs, show_progress_bar=False)
125
- return scores.tolist() if hasattr(scores, 'tolist') else list(scores)
120
+ return scores.tolist() if hasattr(scores, "tolist") else list(scores)
126
121
 
127
122
 
128
123
  class RemoteTEICrossEncoder(CrossEncoderModel):
@@ -158,8 +153,8 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
158
153
  self.batch_size = batch_size
159
154
  self.max_retries = max_retries
160
155
  self.retry_delay = retry_delay
161
- self._client: Optional[httpx.Client] = None
162
- self._model_id: Optional[str] = None
156
+ self._client: httpx.Client | None = None
157
+ self._model_id: str | None = None
163
158
 
164
159
  @property
165
160
  def provider_name(self) -> str:
@@ -168,6 +163,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
168
163
  def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
169
164
  """Make an HTTP request with automatic retries on transient errors."""
170
165
  import time
166
+
171
167
  last_error = None
172
168
  delay = self.retry_delay
173
169
 
@@ -182,14 +178,18 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
182
178
  except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
183
179
  last_error = e
184
180
  if attempt < self.max_retries:
185
- logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
181
+ logger.warning(
182
+ f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
183
+ )
186
184
  time.sleep(delay)
187
185
  delay *= 2 # Exponential backoff
188
186
  except httpx.HTTPStatusError as e:
189
187
  # Retry on 5xx server errors
190
188
  if e.response.status_code >= 500 and attempt < self.max_retries:
191
189
  last_error = e
192
- logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
190
+ logger.warning(
191
+ f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
192
+ )
193
193
  time.sleep(delay)
194
194
  delay *= 2
195
195
  else:
@@ -214,7 +214,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
214
214
  except httpx.HTTPError as e:
215
215
  raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
216
216
 
217
- def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
217
+ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
218
218
  """
219
219
  Score query-document pairs using the remote TEI reranker.
220
220
 
@@ -234,7 +234,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
234
234
 
235
235
  # Process in batches
236
236
  for i in range(0, len(pairs), self.batch_size):
237
- batch = pairs[i:i + self.batch_size]
237
+ batch = pairs[i : i + self.batch_size]
238
238
 
239
239
  # TEI rerank endpoint expects query and texts separately
240
240
  # All pairs in a batch should have the same query for optimal performance
@@ -292,15 +292,11 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
292
292
  if provider == "tei":
293
293
  url = os.environ.get(ENV_RERANKER_TEI_URL)
294
294
  if not url:
295
- raise ValueError(
296
- f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'"
297
- )
295
+ raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
298
296
  return RemoteTEICrossEncoder(base_url=url)
299
297
  elif provider == "local":
300
298
  model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
301
299
  model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
302
300
  return LocalSTCrossEncoder(model_name=model_name)
303
301
  else:
304
- raise ValueError(
305
- f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'"
306
- )
302
+ raise ValueError(f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'")
@@ -1,9 +1,11 @@
1
1
  """
2
2
  Database utility functions for connection management with retry logic.
3
3
  """
4
+
4
5
  import asyncio
5
6
  import logging
6
7
  from contextlib import asynccontextmanager
8
+
7
9
  import asyncpg
8
10
 
9
11
  logger = logging.getLogger(__name__)
@@ -54,16 +56,14 @@ async def retry_with_backoff(
54
56
  except retryable_exceptions as e:
55
57
  last_exception = e
56
58
  if attempt < max_retries:
57
- delay = min(base_delay * (2 ** attempt), max_delay)
59
+ delay = min(base_delay * (2**attempt), max_delay)
58
60
  logger.warning(
59
61
  f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
60
62
  f"Retrying in {delay:.1f}s..."
61
63
  )
62
64
  await asyncio.sleep(delay)
63
65
  else:
64
- logger.error(
65
- f"Database operation failed after {max_retries + 1} attempts: {e}"
66
- )
66
+ logger.error(f"Database operation failed after {max_retries + 1} attempts: {e}")
67
67
  raise last_exception
68
68
 
69
69
 
@@ -83,6 +83,7 @@ async def acquire_with_retry(pool: asyncpg.Pool, max_retries: int = DEFAULT_MAX_
83
83
  Yields:
84
84
  An asyncpg connection
85
85
  """
86
+
86
87
  async def acquire():
87
88
  return await pool.acquire()
88
89
 
@@ -8,20 +8,20 @@ the database schema (pgvector column defined as vector(384)).
8
8
 
9
9
  Configuration via environment variables - see hindsight_api.config for all env var names.
10
10
  """
11
- from abc import ABC, abstractmethod
12
- from typing import List, Optional
11
+
13
12
  import logging
14
13
  import os
14
+ from abc import ABC, abstractmethod
15
15
 
16
16
  import httpx
17
17
 
18
18
  from ..config import (
19
- ENV_EMBEDDINGS_PROVIDER,
20
- ENV_EMBEDDINGS_LOCAL_MODEL,
21
- ENV_EMBEDDINGS_TEI_URL,
22
- DEFAULT_EMBEDDINGS_PROVIDER,
23
19
  DEFAULT_EMBEDDINGS_LOCAL_MODEL,
20
+ DEFAULT_EMBEDDINGS_PROVIDER,
24
21
  EMBEDDING_DIMENSION,
22
+ ENV_EMBEDDINGS_LOCAL_MODEL,
23
+ ENV_EMBEDDINGS_PROVIDER,
24
+ ENV_EMBEDDINGS_TEI_URL,
25
25
  )
26
26
 
27
27
  logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class Embeddings(ABC):
52
52
  pass
53
53
 
54
54
  @abstractmethod
55
- def encode(self, texts: List[str]) -> List[List[float]]:
55
+ def encode(self, texts: list[str]) -> list[list[float]]:
56
56
  """
57
57
  Generate 384-dimensional embeddings for a list of texts.
58
58
 
@@ -75,7 +75,7 @@ class LocalSTEmbeddings(Embeddings):
75
75
  embeddings matching the database schema.
76
76
  """
77
77
 
78
- def __init__(self, model_name: Optional[str] = None):
78
+ def __init__(self, model_name: str | None = None):
79
79
  """
80
80
  Initialize local SentenceTransformers embeddings.
81
81
 
@@ -123,7 +123,7 @@ class LocalSTEmbeddings(Embeddings):
123
123
 
124
124
  logger.info(f"Embeddings: local provider initialized (dim: {model_dim})")
125
125
 
126
- def encode(self, texts: List[str]) -> List[List[float]]:
126
+ def encode(self, texts: list[str]) -> list[list[float]]:
127
127
  """
128
128
  Generate 384-dimensional embeddings for a list of texts.
129
129
 
@@ -172,8 +172,8 @@ class RemoteTEIEmbeddings(Embeddings):
172
172
  self.batch_size = batch_size
173
173
  self.max_retries = max_retries
174
174
  self.retry_delay = retry_delay
175
- self._client: Optional[httpx.Client] = None
176
- self._model_id: Optional[str] = None
175
+ self._client: httpx.Client | None = None
176
+ self._model_id: str | None = None
177
177
 
178
178
  @property
179
179
  def provider_name(self) -> str:
@@ -182,6 +182,7 @@ class RemoteTEIEmbeddings(Embeddings):
182
182
  def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
183
183
  """Make an HTTP request with automatic retries on transient errors."""
184
184
  import time
185
+
185
186
  last_error = None
186
187
  delay = self.retry_delay
187
188
 
@@ -196,14 +197,18 @@ class RemoteTEIEmbeddings(Embeddings):
196
197
  except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
197
198
  last_error = e
198
199
  if attempt < self.max_retries:
199
- logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
200
+ logger.warning(
201
+ f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
202
+ )
200
203
  time.sleep(delay)
201
204
  delay *= 2 # Exponential backoff
202
205
  except httpx.HTTPStatusError as e:
203
206
  # Retry on 5xx server errors
204
207
  if e.response.status_code >= 500 and attempt < self.max_retries:
205
208
  last_error = e
206
- logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
209
+ logger.warning(
210
+ f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
211
+ )
207
212
  time.sleep(delay)
208
213
  delay *= 2
209
214
  else:
@@ -228,7 +233,7 @@ class RemoteTEIEmbeddings(Embeddings):
228
233
  except httpx.HTTPError as e:
229
234
  raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
230
235
 
231
- def encode(self, texts: List[str]) -> List[List[float]]:
236
+ def encode(self, texts: list[str]) -> list[list[float]]:
232
237
  """
233
238
  Generate embeddings using the remote TEI server.
234
239
 
@@ -248,7 +253,7 @@ class RemoteTEIEmbeddings(Embeddings):
248
253
 
249
254
  # Process in batches
250
255
  for i in range(0, len(texts), self.batch_size):
251
- batch = texts[i:i + self.batch_size]
256
+ batch = texts[i : i + self.batch_size]
252
257
 
253
258
  try:
254
259
  response = self._request_with_retry(
@@ -278,15 +283,11 @@ def create_embeddings_from_env() -> Embeddings:
278
283
  if provider == "tei":
279
284
  url = os.environ.get(ENV_EMBEDDINGS_TEI_URL)
280
285
  if not url:
281
- raise ValueError(
282
- f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'"
283
- )
286
+ raise ValueError(f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'")
284
287
  return RemoteTEIEmbeddings(base_url=url)
285
288
  elif provider == "local":
286
289
  model = os.environ.get(ENV_EMBEDDINGS_LOCAL_MODEL)
287
290
  model_name = model or DEFAULT_EMBEDDINGS_LOCAL_MODEL
288
291
  return LocalSTEmbeddings(model_name=model_name)
289
292
  else:
290
- raise ValueError(
291
- f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'"
292
- )
293
+ raise ValueError(f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'")