hindsight-api 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +30 -28
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +9 -13
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +22 -21
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +61 -79
- hindsight_api/engine/memory_engine.py +603 -625
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +5 -5
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +22 -23
- hindsight_api/engine/search/mpfp_retrieval.py +76 -92
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +87 -66
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -39
- hindsight_api/engine/search/tracer.py +44 -35
- hindsight_api/engine/search/types.py +20 -17
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +22 -23
- hindsight_api/server.py +3 -6
- hindsight_api-0.1.7.dist-info/METADATA +178 -0
- hindsight_api-0.1.7.dist-info/RECORD +64 -0
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.5.dist-info/METADATA +0 -42
- hindsight_api-0.1.5.dist-info/RECORD +0 -63
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/WHEEL +0 -0
hindsight_api/api/mcp.py
CHANGED
|
@@ -4,27 +4,33 @@ import json
|
|
|
4
4
|
import logging
|
|
5
5
|
import os
|
|
6
6
|
from contextvars import ContextVar
|
|
7
|
-
from typing import Optional
|
|
8
7
|
|
|
9
8
|
from fastmcp import FastMCP
|
|
9
|
+
|
|
10
10
|
from hindsight_api import MemoryEngine
|
|
11
11
|
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
|
|
12
12
|
|
|
13
13
|
# Configure logging from HINDSIGHT_API_LOG_LEVEL environment variable
|
|
14
14
|
_log_level_str = os.environ.get("HINDSIGHT_API_LOG_LEVEL", "info").lower()
|
|
15
|
-
_log_level_map = {
|
|
16
|
-
|
|
15
|
+
_log_level_map = {
|
|
16
|
+
"critical": logging.CRITICAL,
|
|
17
|
+
"error": logging.ERROR,
|
|
18
|
+
"warning": logging.WARNING,
|
|
19
|
+
"info": logging.INFO,
|
|
20
|
+
"debug": logging.DEBUG,
|
|
21
|
+
"trace": logging.DEBUG,
|
|
22
|
+
}
|
|
17
23
|
logging.basicConfig(
|
|
18
24
|
level=_log_level_map.get(_log_level_str, logging.INFO),
|
|
19
|
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
25
|
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
20
26
|
)
|
|
21
27
|
logger = logging.getLogger(__name__)
|
|
22
28
|
|
|
23
29
|
# Context variable to hold the current bank_id from the URL path
|
|
24
|
-
_current_bank_id: ContextVar[
|
|
30
|
+
_current_bank_id: ContextVar[str | None] = ContextVar("current_bank_id", default=None)
|
|
25
31
|
|
|
26
32
|
|
|
27
|
-
def get_current_bank_id() ->
|
|
33
|
+
def get_current_bank_id() -> str | None:
|
|
28
34
|
"""Get the current bank_id from context (set from URL path)."""
|
|
29
35
|
return _current_bank_id.get()
|
|
30
36
|
|
|
@@ -61,10 +67,7 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
|
|
|
61
67
|
"""
|
|
62
68
|
try:
|
|
63
69
|
bank_id = get_current_bank_id()
|
|
64
|
-
await memory.
|
|
65
|
-
bank_id=bank_id,
|
|
66
|
-
contents=[{"content": content, "context": context}]
|
|
67
|
-
)
|
|
70
|
+
await memory.retain_batch_async(bank_id=bank_id, contents=[{"content": content, "context": context}])
|
|
68
71
|
return "Memory stored successfully"
|
|
69
72
|
except Exception as e:
|
|
70
73
|
logger.error(f"Error storing memory: {e}", exc_info=True)
|
|
@@ -88,11 +91,9 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
|
|
|
88
91
|
try:
|
|
89
92
|
bank_id = get_current_bank_id()
|
|
90
93
|
from hindsight_api.engine.memory_engine import Budget
|
|
94
|
+
|
|
91
95
|
search_result = await memory.recall_async(
|
|
92
|
-
bank_id=bank_id,
|
|
93
|
-
query=query,
|
|
94
|
-
fact_type=list(VALID_RECALL_FACT_TYPES),
|
|
95
|
-
budget=Budget.LOW
|
|
96
|
+
bank_id=bank_id, query=query, fact_type=list(VALID_RECALL_FACT_TYPES), budget=Budget.LOW
|
|
96
97
|
)
|
|
97
98
|
|
|
98
99
|
results = [
|
|
@@ -133,7 +134,7 @@ class MCPMiddleware:
|
|
|
133
134
|
# Strip any mount prefix (e.g., /mcp) that FastAPI might not have stripped
|
|
134
135
|
root_path = scope.get("root_path", "")
|
|
135
136
|
if root_path and path.startswith(root_path):
|
|
136
|
-
path = path[len(root_path):] or "/"
|
|
137
|
+
path = path[len(root_path) :] or "/"
|
|
137
138
|
|
|
138
139
|
# Also handle case where mount path wasn't stripped (e.g., /mcp/...)
|
|
139
140
|
if path.startswith("/mcp/"):
|
|
@@ -169,10 +170,7 @@ class MCPMiddleware:
|
|
|
169
170
|
body = message.get("body", b"")
|
|
170
171
|
if body and b"/messages" in body:
|
|
171
172
|
# Rewrite /messages to /{bank_id}/messages in SSE endpoint event
|
|
172
|
-
body = body.replace(
|
|
173
|
-
b"data: /messages",
|
|
174
|
-
f"data: /{bank_id}/messages".encode()
|
|
175
|
-
)
|
|
173
|
+
body = body.replace(b"data: /messages", f"data: /{bank_id}/messages".encode())
|
|
176
174
|
message = {**message, "body": body}
|
|
177
175
|
await send(message)
|
|
178
176
|
|
|
@@ -183,15 +181,19 @@ class MCPMiddleware:
|
|
|
183
181
|
async def _send_error(self, send, status: int, message: str):
|
|
184
182
|
"""Send an error response."""
|
|
185
183
|
body = json.dumps({"error": message}).encode()
|
|
186
|
-
await send(
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
184
|
+
await send(
|
|
185
|
+
{
|
|
186
|
+
"type": "http.response.start",
|
|
187
|
+
"status": status,
|
|
188
|
+
"headers": [(b"content-type", b"application/json")],
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
await send(
|
|
192
|
+
{
|
|
193
|
+
"type": "http.response.body",
|
|
194
|
+
"body": body,
|
|
195
|
+
}
|
|
196
|
+
)
|
|
195
197
|
|
|
196
198
|
|
|
197
199
|
def create_mcp_app(memory: MemoryEngine):
|
hindsight_api/banner.py
CHANGED
|
@@ -6,7 +6,7 @@ Shows the logo and tagline with gradient colors.
|
|
|
6
6
|
|
|
7
7
|
# Gradient colors: #0074d9 -> #009296
|
|
8
8
|
GRADIENT_START = (0, 116, 217) # #0074d9
|
|
9
|
-
GRADIENT_END = (0, 146, 150)
|
|
9
|
+
GRADIENT_END = (0, 146, 150) # #009296
|
|
10
10
|
|
|
11
11
|
# Pre-generated logo (generated by test-logo.py)
|
|
12
12
|
LOGO = """\
|
|
@@ -31,8 +31,8 @@ def gradient_text(text: str, start: tuple = GRADIENT_START, end: tuple = GRADIEN
|
|
|
31
31
|
result = []
|
|
32
32
|
length = len(text)
|
|
33
33
|
for i, char in enumerate(text):
|
|
34
|
-
if char ==
|
|
35
|
-
result.append(
|
|
34
|
+
if char == " ":
|
|
35
|
+
result.append(" ")
|
|
36
36
|
else:
|
|
37
37
|
t = i / max(length - 1, 1)
|
|
38
38
|
r, g, b = _interpolate_color(start, end, t)
|
|
@@ -74,9 +74,16 @@ def dim(text: str) -> str:
|
|
|
74
74
|
return f"\033[38;2;128;128;128m{text}\033[0m"
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def print_startup_info(
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
def print_startup_info(
|
|
78
|
+
host: str,
|
|
79
|
+
port: int,
|
|
80
|
+
database_url: str,
|
|
81
|
+
llm_provider: str,
|
|
82
|
+
llm_model: str,
|
|
83
|
+
embeddings_provider: str,
|
|
84
|
+
reranker_provider: str,
|
|
85
|
+
mcp_enabled: bool = False,
|
|
86
|
+
):
|
|
80
87
|
"""Print styled startup information."""
|
|
81
88
|
print(color_start("Starting Hindsight API..."))
|
|
82
89
|
print(f" {dim('URL:')} {color(f'http://{host}:{port}', 0.2)}")
|
hindsight_api/config.py
CHANGED
|
@@ -3,10 +3,10 @@ Centralized configuration for Hindsight API.
|
|
|
3
3
|
|
|
4
4
|
All environment variables and their defaults are defined here.
|
|
5
5
|
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
6
8
|
import os
|
|
7
9
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Optional
|
|
9
|
-
import logging
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
@@ -30,6 +30,7 @@ ENV_PORT = "HINDSIGHT_API_PORT"
|
|
|
30
30
|
ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
|
|
31
31
|
ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
|
|
32
32
|
ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
|
|
33
|
+
ENV_MCP_LOCAL_BANK_ID = "HINDSIGHT_API_MCP_LOCAL_BANK_ID"
|
|
33
34
|
|
|
34
35
|
# Default values
|
|
35
36
|
DEFAULT_DATABASE_URL = "pg0"
|
|
@@ -47,6 +48,7 @@ DEFAULT_PORT = 8888
|
|
|
47
48
|
DEFAULT_LOG_LEVEL = "info"
|
|
48
49
|
DEFAULT_MCP_ENABLED = True
|
|
49
50
|
DEFAULT_GRAPH_RETRIEVER = "bfs" # Options: "bfs", "mpfp"
|
|
51
|
+
DEFAULT_MCP_LOCAL_BANK_ID = "mcp"
|
|
50
52
|
|
|
51
53
|
# Required embedding dimension for database schema
|
|
52
54
|
EMBEDDING_DIMENSION = 384
|
|
@@ -61,19 +63,19 @@ class HindsightConfig:
|
|
|
61
63
|
|
|
62
64
|
# LLM
|
|
63
65
|
llm_provider: str
|
|
64
|
-
llm_api_key:
|
|
66
|
+
llm_api_key: str | None
|
|
65
67
|
llm_model: str
|
|
66
|
-
llm_base_url:
|
|
68
|
+
llm_base_url: str | None
|
|
67
69
|
|
|
68
70
|
# Embeddings
|
|
69
71
|
embeddings_provider: str
|
|
70
72
|
embeddings_local_model: str
|
|
71
|
-
embeddings_tei_url:
|
|
73
|
+
embeddings_tei_url: str | None
|
|
72
74
|
|
|
73
75
|
# Reranker
|
|
74
76
|
reranker_provider: str
|
|
75
77
|
reranker_local_model: str
|
|
76
|
-
reranker_tei_url:
|
|
78
|
+
reranker_tei_url: str | None
|
|
77
79
|
|
|
78
80
|
# Server
|
|
79
81
|
host: str
|
|
@@ -90,29 +92,24 @@ class HindsightConfig:
|
|
|
90
92
|
return cls(
|
|
91
93
|
# Database
|
|
92
94
|
database_url=os.getenv(ENV_DATABASE_URL, DEFAULT_DATABASE_URL),
|
|
93
|
-
|
|
94
95
|
# LLM
|
|
95
96
|
llm_provider=os.getenv(ENV_LLM_PROVIDER, DEFAULT_LLM_PROVIDER),
|
|
96
97
|
llm_api_key=os.getenv(ENV_LLM_API_KEY),
|
|
97
98
|
llm_model=os.getenv(ENV_LLM_MODEL, DEFAULT_LLM_MODEL),
|
|
98
99
|
llm_base_url=os.getenv(ENV_LLM_BASE_URL) or None,
|
|
99
|
-
|
|
100
100
|
# Embeddings
|
|
101
101
|
embeddings_provider=os.getenv(ENV_EMBEDDINGS_PROVIDER, DEFAULT_EMBEDDINGS_PROVIDER),
|
|
102
102
|
embeddings_local_model=os.getenv(ENV_EMBEDDINGS_LOCAL_MODEL, DEFAULT_EMBEDDINGS_LOCAL_MODEL),
|
|
103
103
|
embeddings_tei_url=os.getenv(ENV_EMBEDDINGS_TEI_URL),
|
|
104
|
-
|
|
105
104
|
# Reranker
|
|
106
105
|
reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
|
|
107
106
|
reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
|
|
108
107
|
reranker_tei_url=os.getenv(ENV_RERANKER_TEI_URL),
|
|
109
|
-
|
|
110
108
|
# Server
|
|
111
109
|
host=os.getenv(ENV_HOST, DEFAULT_HOST),
|
|
112
110
|
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
|
|
113
111
|
log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
|
|
114
112
|
mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
|
|
115
|
-
|
|
116
113
|
# Recall
|
|
117
114
|
graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
|
|
118
115
|
)
|
|
@@ -145,8 +142,7 @@ class HindsightConfig:
|
|
|
145
142
|
def configure_logging(self) -> None:
|
|
146
143
|
"""Configure Python logging based on the log level."""
|
|
147
144
|
logging.basicConfig(
|
|
148
|
-
level=self.get_python_log_level(),
|
|
149
|
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
145
|
+
level=self.get_python_log_level(), format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
150
146
|
)
|
|
151
147
|
|
|
152
148
|
def log_config(self) -> None:
|
hindsight_api/engine/__init__.py
CHANGED
|
@@ -7,24 +7,24 @@ This package contains all the implementation details of the memory engine:
|
|
|
7
7
|
- Supporting modules: embeddings, cross_encoder, entity_resolver, etc.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from .
|
|
10
|
+
from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
|
|
11
11
|
from .db_utils import acquire_with_retry
|
|
12
12
|
from .embeddings import Embeddings, LocalSTEmbeddings, RemoteTEIEmbeddings
|
|
13
|
-
from .
|
|
13
|
+
from .llm_wrapper import LLMConfig
|
|
14
|
+
from .memory_engine import MemoryEngine
|
|
15
|
+
from .response_models import MemoryFact, RecallResult, ReflectResult
|
|
14
16
|
from .search.trace import (
|
|
15
|
-
SearchTrace,
|
|
16
|
-
QueryInfo,
|
|
17
17
|
EntryPoint,
|
|
18
|
-
NodeVisit,
|
|
19
|
-
WeightComponents,
|
|
20
18
|
LinkInfo,
|
|
19
|
+
NodeVisit,
|
|
21
20
|
PruningDecision,
|
|
22
|
-
|
|
21
|
+
QueryInfo,
|
|
23
22
|
SearchPhaseMetrics,
|
|
23
|
+
SearchSummary,
|
|
24
|
+
SearchTrace,
|
|
25
|
+
WeightComponents,
|
|
24
26
|
)
|
|
25
27
|
from .search.tracer import SearchTracer
|
|
26
|
-
from .llm_wrapper import LLMConfig
|
|
27
|
-
from .response_models import RecallResult, ReflectResult, MemoryFact
|
|
28
28
|
|
|
29
29
|
__all__ = [
|
|
30
30
|
"MemoryEngine",
|
|
@@ -5,19 +5,19 @@ Provides an interface for reranking with different backends.
|
|
|
5
5
|
|
|
6
6
|
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
|
-
from typing import List, Tuple, Optional
|
|
8
|
+
|
|
10
9
|
import logging
|
|
11
10
|
import os
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
12
|
|
|
13
13
|
import httpx
|
|
14
14
|
|
|
15
15
|
from ..config import (
|
|
16
|
-
|
|
16
|
+
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
17
|
+
DEFAULT_RERANKER_PROVIDER,
|
|
17
18
|
ENV_RERANKER_LOCAL_MODEL,
|
|
19
|
+
ENV_RERANKER_PROVIDER,
|
|
18
20
|
ENV_RERANKER_TEI_URL,
|
|
19
|
-
DEFAULT_RERANKER_PROVIDER,
|
|
20
|
-
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
logger = logging.getLogger(__name__)
|
|
@@ -47,7 +47,7 @@ class CrossEncoderModel(ABC):
|
|
|
47
47
|
pass
|
|
48
48
|
|
|
49
49
|
@abstractmethod
|
|
50
|
-
def predict(self, pairs:
|
|
50
|
+
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
51
51
|
"""
|
|
52
52
|
Score query-document pairs for relevance.
|
|
53
53
|
|
|
@@ -72,7 +72,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
72
72
|
- Trained for passage re-ranking
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
|
-
def __init__(self, model_name:
|
|
75
|
+
def __init__(self, model_name: str | None = None):
|
|
76
76
|
"""
|
|
77
77
|
Initialize local SentenceTransformers cross-encoder.
|
|
78
78
|
|
|
@@ -104,7 +104,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
104
104
|
self._model = CrossEncoder(self.model_name)
|
|
105
105
|
logger.info("Reranker: local provider initialized")
|
|
106
106
|
|
|
107
|
-
def predict(self, pairs:
|
|
107
|
+
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
108
108
|
"""
|
|
109
109
|
Score query-document pairs for relevance.
|
|
110
110
|
|
|
@@ -117,7 +117,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
117
117
|
if self._model is None:
|
|
118
118
|
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
119
119
|
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
120
|
-
return scores.tolist() if hasattr(scores,
|
|
120
|
+
return scores.tolist() if hasattr(scores, "tolist") else list(scores)
|
|
121
121
|
|
|
122
122
|
|
|
123
123
|
class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
@@ -153,8 +153,8 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
153
153
|
self.batch_size = batch_size
|
|
154
154
|
self.max_retries = max_retries
|
|
155
155
|
self.retry_delay = retry_delay
|
|
156
|
-
self._client:
|
|
157
|
-
self._model_id:
|
|
156
|
+
self._client: httpx.Client | None = None
|
|
157
|
+
self._model_id: str | None = None
|
|
158
158
|
|
|
159
159
|
@property
|
|
160
160
|
def provider_name(self) -> str:
|
|
@@ -163,6 +163,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
163
163
|
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
164
164
|
"""Make an HTTP request with automatic retries on transient errors."""
|
|
165
165
|
import time
|
|
166
|
+
|
|
166
167
|
last_error = None
|
|
167
168
|
delay = self.retry_delay
|
|
168
169
|
|
|
@@ -177,14 +178,18 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
177
178
|
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
178
179
|
last_error = e
|
|
179
180
|
if attempt < self.max_retries:
|
|
180
|
-
logger.warning(
|
|
181
|
+
logger.warning(
|
|
182
|
+
f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
183
|
+
)
|
|
181
184
|
time.sleep(delay)
|
|
182
185
|
delay *= 2 # Exponential backoff
|
|
183
186
|
except httpx.HTTPStatusError as e:
|
|
184
187
|
# Retry on 5xx server errors
|
|
185
188
|
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
186
189
|
last_error = e
|
|
187
|
-
logger.warning(
|
|
190
|
+
logger.warning(
|
|
191
|
+
f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
192
|
+
)
|
|
188
193
|
time.sleep(delay)
|
|
189
194
|
delay *= 2
|
|
190
195
|
else:
|
|
@@ -209,7 +214,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
209
214
|
except httpx.HTTPError as e:
|
|
210
215
|
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
211
216
|
|
|
212
|
-
def predict(self, pairs:
|
|
217
|
+
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
213
218
|
"""
|
|
214
219
|
Score query-document pairs using the remote TEI reranker.
|
|
215
220
|
|
|
@@ -229,7 +234,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
229
234
|
|
|
230
235
|
# Process in batches
|
|
231
236
|
for i in range(0, len(pairs), self.batch_size):
|
|
232
|
-
batch = pairs[i:i + self.batch_size]
|
|
237
|
+
batch = pairs[i : i + self.batch_size]
|
|
233
238
|
|
|
234
239
|
# TEI rerank endpoint expects query and texts separately
|
|
235
240
|
# All pairs in a batch should have the same query for optimal performance
|
|
@@ -287,15 +292,11 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
|
|
|
287
292
|
if provider == "tei":
|
|
288
293
|
url = os.environ.get(ENV_RERANKER_TEI_URL)
|
|
289
294
|
if not url:
|
|
290
|
-
raise ValueError(
|
|
291
|
-
f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'"
|
|
292
|
-
)
|
|
295
|
+
raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
|
|
293
296
|
return RemoteTEICrossEncoder(base_url=url)
|
|
294
297
|
elif provider == "local":
|
|
295
298
|
model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
|
|
296
299
|
model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
|
|
297
300
|
return LocalSTCrossEncoder(model_name=model_name)
|
|
298
301
|
else:
|
|
299
|
-
raise ValueError(
|
|
300
|
-
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'"
|
|
301
|
-
)
|
|
302
|
+
raise ValueError(f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'")
|
hindsight_api/engine/db_utils.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Database utility functions for connection management with retry logic.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import asyncio
|
|
5
6
|
import logging
|
|
6
7
|
from contextlib import asynccontextmanager
|
|
8
|
+
|
|
7
9
|
import asyncpg
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
@@ -54,16 +56,14 @@ async def retry_with_backoff(
|
|
|
54
56
|
except retryable_exceptions as e:
|
|
55
57
|
last_exception = e
|
|
56
58
|
if attempt < max_retries:
|
|
57
|
-
delay = min(base_delay * (2
|
|
59
|
+
delay = min(base_delay * (2**attempt), max_delay)
|
|
58
60
|
logger.warning(
|
|
59
61
|
f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
|
|
60
62
|
f"Retrying in {delay:.1f}s..."
|
|
61
63
|
)
|
|
62
64
|
await asyncio.sleep(delay)
|
|
63
65
|
else:
|
|
64
|
-
logger.error(
|
|
65
|
-
f"Database operation failed after {max_retries + 1} attempts: {e}"
|
|
66
|
-
)
|
|
66
|
+
logger.error(f"Database operation failed after {max_retries + 1} attempts: {e}")
|
|
67
67
|
raise last_exception
|
|
68
68
|
|
|
69
69
|
|
|
@@ -83,6 +83,7 @@ async def acquire_with_retry(pool: asyncpg.Pool, max_retries: int = DEFAULT_MAX_
|
|
|
83
83
|
Yields:
|
|
84
84
|
An asyncpg connection
|
|
85
85
|
"""
|
|
86
|
+
|
|
86
87
|
async def acquire():
|
|
87
88
|
return await pool.acquire()
|
|
88
89
|
|
|
@@ -8,20 +8,20 @@ the database schema (pgvector column defined as vector(384)).
|
|
|
8
8
|
|
|
9
9
|
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
10
10
|
"""
|
|
11
|
-
|
|
12
|
-
from typing import List, Optional
|
|
11
|
+
|
|
13
12
|
import logging
|
|
14
13
|
import os
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
15
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
18
18
|
from ..config import (
|
|
19
|
-
ENV_EMBEDDINGS_PROVIDER,
|
|
20
|
-
ENV_EMBEDDINGS_LOCAL_MODEL,
|
|
21
|
-
ENV_EMBEDDINGS_TEI_URL,
|
|
22
|
-
DEFAULT_EMBEDDINGS_PROVIDER,
|
|
23
19
|
DEFAULT_EMBEDDINGS_LOCAL_MODEL,
|
|
20
|
+
DEFAULT_EMBEDDINGS_PROVIDER,
|
|
24
21
|
EMBEDDING_DIMENSION,
|
|
22
|
+
ENV_EMBEDDINGS_LOCAL_MODEL,
|
|
23
|
+
ENV_EMBEDDINGS_PROVIDER,
|
|
24
|
+
ENV_EMBEDDINGS_TEI_URL,
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
@@ -52,7 +52,7 @@ class Embeddings(ABC):
|
|
|
52
52
|
pass
|
|
53
53
|
|
|
54
54
|
@abstractmethod
|
|
55
|
-
def encode(self, texts:
|
|
55
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
56
56
|
"""
|
|
57
57
|
Generate 384-dimensional embeddings for a list of texts.
|
|
58
58
|
|
|
@@ -75,7 +75,7 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
75
75
|
embeddings matching the database schema.
|
|
76
76
|
"""
|
|
77
77
|
|
|
78
|
-
def __init__(self, model_name:
|
|
78
|
+
def __init__(self, model_name: str | None = None):
|
|
79
79
|
"""
|
|
80
80
|
Initialize local SentenceTransformers embeddings.
|
|
81
81
|
|
|
@@ -123,7 +123,7 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
123
123
|
|
|
124
124
|
logger.info(f"Embeddings: local provider initialized (dim: {model_dim})")
|
|
125
125
|
|
|
126
|
-
def encode(self, texts:
|
|
126
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
127
127
|
"""
|
|
128
128
|
Generate 384-dimensional embeddings for a list of texts.
|
|
129
129
|
|
|
@@ -172,8 +172,8 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
172
172
|
self.batch_size = batch_size
|
|
173
173
|
self.max_retries = max_retries
|
|
174
174
|
self.retry_delay = retry_delay
|
|
175
|
-
self._client:
|
|
176
|
-
self._model_id:
|
|
175
|
+
self._client: httpx.Client | None = None
|
|
176
|
+
self._model_id: str | None = None
|
|
177
177
|
|
|
178
178
|
@property
|
|
179
179
|
def provider_name(self) -> str:
|
|
@@ -182,6 +182,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
182
182
|
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
183
183
|
"""Make an HTTP request with automatic retries on transient errors."""
|
|
184
184
|
import time
|
|
185
|
+
|
|
185
186
|
last_error = None
|
|
186
187
|
delay = self.retry_delay
|
|
187
188
|
|
|
@@ -196,14 +197,18 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
196
197
|
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
197
198
|
last_error = e
|
|
198
199
|
if attempt < self.max_retries:
|
|
199
|
-
logger.warning(
|
|
200
|
+
logger.warning(
|
|
201
|
+
f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
202
|
+
)
|
|
200
203
|
time.sleep(delay)
|
|
201
204
|
delay *= 2 # Exponential backoff
|
|
202
205
|
except httpx.HTTPStatusError as e:
|
|
203
206
|
# Retry on 5xx server errors
|
|
204
207
|
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
205
208
|
last_error = e
|
|
206
|
-
logger.warning(
|
|
209
|
+
logger.warning(
|
|
210
|
+
f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
211
|
+
)
|
|
207
212
|
time.sleep(delay)
|
|
208
213
|
delay *= 2
|
|
209
214
|
else:
|
|
@@ -228,7 +233,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
228
233
|
except httpx.HTTPError as e:
|
|
229
234
|
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
230
235
|
|
|
231
|
-
def encode(self, texts:
|
|
236
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
232
237
|
"""
|
|
233
238
|
Generate embeddings using the remote TEI server.
|
|
234
239
|
|
|
@@ -248,7 +253,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
248
253
|
|
|
249
254
|
# Process in batches
|
|
250
255
|
for i in range(0, len(texts), self.batch_size):
|
|
251
|
-
batch = texts[i:i + self.batch_size]
|
|
256
|
+
batch = texts[i : i + self.batch_size]
|
|
252
257
|
|
|
253
258
|
try:
|
|
254
259
|
response = self._request_with_retry(
|
|
@@ -278,15 +283,11 @@ def create_embeddings_from_env() -> Embeddings:
|
|
|
278
283
|
if provider == "tei":
|
|
279
284
|
url = os.environ.get(ENV_EMBEDDINGS_TEI_URL)
|
|
280
285
|
if not url:
|
|
281
|
-
raise ValueError(
|
|
282
|
-
f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'"
|
|
283
|
-
)
|
|
286
|
+
raise ValueError(f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'")
|
|
284
287
|
return RemoteTEIEmbeddings(base_url=url)
|
|
285
288
|
elif provider == "local":
|
|
286
289
|
model = os.environ.get(ENV_EMBEDDINGS_LOCAL_MODEL)
|
|
287
290
|
model_name = model or DEFAULT_EMBEDDINGS_LOCAL_MODEL
|
|
288
291
|
return LocalSTEmbeddings(model_name=model_name)
|
|
289
292
|
else:
|
|
290
|
-
raise ValueError(
|
|
291
|
-
f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'"
|
|
292
|
-
)
|
|
293
|
+
raise ValueError(f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'")
|