hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +31 -33
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +17 -12
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +23 -27
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +74 -88
- hindsight_api/engine/memory_engine.py +663 -673
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +234 -0
- hindsight_api/engine/search/mpfp_retrieval.py +438 -0
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +388 -193
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -38
- hindsight_api/engine/search/tracer.py +49 -35
- hindsight_api/engine/search/types.py +22 -16
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +64 -337
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.4.dist-info/RECORD +0 -61
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
hindsight_api/api/mcp.py
CHANGED
|
@@ -4,27 +4,33 @@ import json
|
|
|
4
4
|
import logging
|
|
5
5
|
import os
|
|
6
6
|
from contextvars import ContextVar
|
|
7
|
-
from typing import Optional
|
|
8
7
|
|
|
9
8
|
from fastmcp import FastMCP
|
|
9
|
+
|
|
10
10
|
from hindsight_api import MemoryEngine
|
|
11
11
|
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
|
|
12
12
|
|
|
13
13
|
# Configure logging from HINDSIGHT_API_LOG_LEVEL environment variable
|
|
14
14
|
_log_level_str = os.environ.get("HINDSIGHT_API_LOG_LEVEL", "info").lower()
|
|
15
|
-
_log_level_map = {
|
|
16
|
-
|
|
15
|
+
_log_level_map = {
|
|
16
|
+
"critical": logging.CRITICAL,
|
|
17
|
+
"error": logging.ERROR,
|
|
18
|
+
"warning": logging.WARNING,
|
|
19
|
+
"info": logging.INFO,
|
|
20
|
+
"debug": logging.DEBUG,
|
|
21
|
+
"trace": logging.DEBUG,
|
|
22
|
+
}
|
|
17
23
|
logging.basicConfig(
|
|
18
24
|
level=_log_level_map.get(_log_level_str, logging.INFO),
|
|
19
|
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
25
|
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
20
26
|
)
|
|
21
27
|
logger = logging.getLogger(__name__)
|
|
22
28
|
|
|
23
29
|
# Context variable to hold the current bank_id from the URL path
|
|
24
|
-
_current_bank_id: ContextVar[
|
|
30
|
+
_current_bank_id: ContextVar[str | None] = ContextVar("current_bank_id", default=None)
|
|
25
31
|
|
|
26
32
|
|
|
27
|
-
def get_current_bank_id() ->
|
|
33
|
+
def get_current_bank_id() -> str | None:
|
|
28
34
|
"""Get the current bank_id from context (set from URL path)."""
|
|
29
35
|
return _current_bank_id.get()
|
|
30
36
|
|
|
@@ -61,10 +67,7 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
|
|
|
61
67
|
"""
|
|
62
68
|
try:
|
|
63
69
|
bank_id = get_current_bank_id()
|
|
64
|
-
await memory.
|
|
65
|
-
bank_id=bank_id,
|
|
66
|
-
contents=[{"content": content, "context": context}]
|
|
67
|
-
)
|
|
70
|
+
await memory.retain_batch_async(bank_id=bank_id, contents=[{"content": content, "context": context}])
|
|
68
71
|
return "Memory stored successfully"
|
|
69
72
|
except Exception as e:
|
|
70
73
|
logger.error(f"Error storing memory: {e}", exc_info=True)
|
|
@@ -88,11 +91,9 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
|
|
|
88
91
|
try:
|
|
89
92
|
bank_id = get_current_bank_id()
|
|
90
93
|
from hindsight_api.engine.memory_engine import Budget
|
|
94
|
+
|
|
91
95
|
search_result = await memory.recall_async(
|
|
92
|
-
bank_id=bank_id,
|
|
93
|
-
query=query,
|
|
94
|
-
fact_type=list(VALID_RECALL_FACT_TYPES),
|
|
95
|
-
budget=Budget.LOW
|
|
96
|
+
bank_id=bank_id, query=query, fact_type=list(VALID_RECALL_FACT_TYPES), budget=Budget.LOW
|
|
96
97
|
)
|
|
97
98
|
|
|
98
99
|
results = [
|
|
@@ -121,11 +122,7 @@ class MCPMiddleware:
|
|
|
121
122
|
self.app = app
|
|
122
123
|
self.memory = memory
|
|
123
124
|
self.mcp_server = create_mcp_server(memory)
|
|
124
|
-
|
|
125
|
-
import warnings
|
|
126
|
-
with warnings.catch_warnings():
|
|
127
|
-
warnings.simplefilter("ignore", DeprecationWarning)
|
|
128
|
-
self.mcp_app = self.mcp_server.sse_app()
|
|
125
|
+
self.mcp_app = self.mcp_server.http_app()
|
|
129
126
|
|
|
130
127
|
async def __call__(self, scope, receive, send):
|
|
131
128
|
if scope["type"] != "http":
|
|
@@ -137,7 +134,7 @@ class MCPMiddleware:
|
|
|
137
134
|
# Strip any mount prefix (e.g., /mcp) that FastAPI might not have stripped
|
|
138
135
|
root_path = scope.get("root_path", "")
|
|
139
136
|
if root_path and path.startswith(root_path):
|
|
140
|
-
path = path[len(root_path):] or "/"
|
|
137
|
+
path = path[len(root_path) :] or "/"
|
|
141
138
|
|
|
142
139
|
# Also handle case where mount path wasn't stripped (e.g., /mcp/...)
|
|
143
140
|
if path.startswith("/mcp/"):
|
|
@@ -173,10 +170,7 @@ class MCPMiddleware:
|
|
|
173
170
|
body = message.get("body", b"")
|
|
174
171
|
if body and b"/messages" in body:
|
|
175
172
|
# Rewrite /messages to /{bank_id}/messages in SSE endpoint event
|
|
176
|
-
body = body.replace(
|
|
177
|
-
b"data: /messages",
|
|
178
|
-
f"data: /{bank_id}/messages".encode()
|
|
179
|
-
)
|
|
173
|
+
body = body.replace(b"data: /messages", f"data: /{bank_id}/messages".encode())
|
|
180
174
|
message = {**message, "body": body}
|
|
181
175
|
await send(message)
|
|
182
176
|
|
|
@@ -187,15 +181,19 @@ class MCPMiddleware:
|
|
|
187
181
|
async def _send_error(self, send, status: int, message: str):
|
|
188
182
|
"""Send an error response."""
|
|
189
183
|
body = json.dumps({"error": message}).encode()
|
|
190
|
-
await send(
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
184
|
+
await send(
|
|
185
|
+
{
|
|
186
|
+
"type": "http.response.start",
|
|
187
|
+
"status": status,
|
|
188
|
+
"headers": [(b"content-type", b"application/json")],
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
await send(
|
|
192
|
+
{
|
|
193
|
+
"type": "http.response.body",
|
|
194
|
+
"body": body,
|
|
195
|
+
}
|
|
196
|
+
)
|
|
199
197
|
|
|
200
198
|
|
|
201
199
|
def create_mcp_app(memory: MemoryEngine):
|
hindsight_api/banner.py
CHANGED
|
@@ -6,7 +6,7 @@ Shows the logo and tagline with gradient colors.
|
|
|
6
6
|
|
|
7
7
|
# Gradient colors: #0074d9 -> #009296
|
|
8
8
|
GRADIENT_START = (0, 116, 217) # #0074d9
|
|
9
|
-
GRADIENT_END = (0, 146, 150)
|
|
9
|
+
GRADIENT_END = (0, 146, 150) # #009296
|
|
10
10
|
|
|
11
11
|
# Pre-generated logo (generated by test-logo.py)
|
|
12
12
|
LOGO = """\
|
|
@@ -31,8 +31,8 @@ def gradient_text(text: str, start: tuple = GRADIENT_START, end: tuple = GRADIEN
|
|
|
31
31
|
result = []
|
|
32
32
|
length = len(text)
|
|
33
33
|
for i, char in enumerate(text):
|
|
34
|
-
if char ==
|
|
35
|
-
result.append(
|
|
34
|
+
if char == " ":
|
|
35
|
+
result.append(" ")
|
|
36
36
|
else:
|
|
37
37
|
t = i / max(length - 1, 1)
|
|
38
38
|
r, g, b = _interpolate_color(start, end, t)
|
|
@@ -74,9 +74,16 @@ def dim(text: str) -> str:
|
|
|
74
74
|
return f"\033[38;2;128;128;128m{text}\033[0m"
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def print_startup_info(
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
def print_startup_info(
|
|
78
|
+
host: str,
|
|
79
|
+
port: int,
|
|
80
|
+
database_url: str,
|
|
81
|
+
llm_provider: str,
|
|
82
|
+
llm_model: str,
|
|
83
|
+
embeddings_provider: str,
|
|
84
|
+
reranker_provider: str,
|
|
85
|
+
mcp_enabled: bool = False,
|
|
86
|
+
):
|
|
80
87
|
"""Print styled startup information."""
|
|
81
88
|
print(color_start("Starting Hindsight API..."))
|
|
82
89
|
print(f" {dim('URL:')} {color(f'http://{host}:{port}', 0.2)}")
|
hindsight_api/config.py
CHANGED
|
@@ -3,10 +3,10 @@ Centralized configuration for Hindsight API.
|
|
|
3
3
|
|
|
4
4
|
All environment variables and their defaults are defined here.
|
|
5
5
|
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
6
8
|
import os
|
|
7
9
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Optional
|
|
9
|
-
import logging
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
@@ -29,6 +29,8 @@ ENV_HOST = "HINDSIGHT_API_HOST"
|
|
|
29
29
|
ENV_PORT = "HINDSIGHT_API_PORT"
|
|
30
30
|
ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
|
|
31
31
|
ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
|
|
32
|
+
ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
|
|
33
|
+
ENV_MCP_LOCAL_BANK_ID = "HINDSIGHT_API_MCP_LOCAL_BANK_ID"
|
|
32
34
|
|
|
33
35
|
# Default values
|
|
34
36
|
DEFAULT_DATABASE_URL = "pg0"
|
|
@@ -45,6 +47,8 @@ DEFAULT_HOST = "0.0.0.0"
|
|
|
45
47
|
DEFAULT_PORT = 8888
|
|
46
48
|
DEFAULT_LOG_LEVEL = "info"
|
|
47
49
|
DEFAULT_MCP_ENABLED = True
|
|
50
|
+
DEFAULT_GRAPH_RETRIEVER = "bfs" # Options: "bfs", "mpfp"
|
|
51
|
+
DEFAULT_MCP_LOCAL_BANK_ID = "mcp"
|
|
48
52
|
|
|
49
53
|
# Required embedding dimension for database schema
|
|
50
54
|
EMBEDDING_DIMENSION = 384
|
|
@@ -59,19 +63,19 @@ class HindsightConfig:
|
|
|
59
63
|
|
|
60
64
|
# LLM
|
|
61
65
|
llm_provider: str
|
|
62
|
-
llm_api_key:
|
|
66
|
+
llm_api_key: str | None
|
|
63
67
|
llm_model: str
|
|
64
|
-
llm_base_url:
|
|
68
|
+
llm_base_url: str | None
|
|
65
69
|
|
|
66
70
|
# Embeddings
|
|
67
71
|
embeddings_provider: str
|
|
68
72
|
embeddings_local_model: str
|
|
69
|
-
embeddings_tei_url:
|
|
73
|
+
embeddings_tei_url: str | None
|
|
70
74
|
|
|
71
75
|
# Reranker
|
|
72
76
|
reranker_provider: str
|
|
73
77
|
reranker_local_model: str
|
|
74
|
-
reranker_tei_url:
|
|
78
|
+
reranker_tei_url: str | None
|
|
75
79
|
|
|
76
80
|
# Server
|
|
77
81
|
host: str
|
|
@@ -79,34 +83,35 @@ class HindsightConfig:
|
|
|
79
83
|
log_level: str
|
|
80
84
|
mcp_enabled: bool
|
|
81
85
|
|
|
86
|
+
# Recall
|
|
87
|
+
graph_retriever: str
|
|
88
|
+
|
|
82
89
|
@classmethod
|
|
83
90
|
def from_env(cls) -> "HindsightConfig":
|
|
84
91
|
"""Create configuration from environment variables."""
|
|
85
92
|
return cls(
|
|
86
93
|
# Database
|
|
87
94
|
database_url=os.getenv(ENV_DATABASE_URL, DEFAULT_DATABASE_URL),
|
|
88
|
-
|
|
89
95
|
# LLM
|
|
90
96
|
llm_provider=os.getenv(ENV_LLM_PROVIDER, DEFAULT_LLM_PROVIDER),
|
|
91
97
|
llm_api_key=os.getenv(ENV_LLM_API_KEY),
|
|
92
98
|
llm_model=os.getenv(ENV_LLM_MODEL, DEFAULT_LLM_MODEL),
|
|
93
99
|
llm_base_url=os.getenv(ENV_LLM_BASE_URL) or None,
|
|
94
|
-
|
|
95
100
|
# Embeddings
|
|
96
101
|
embeddings_provider=os.getenv(ENV_EMBEDDINGS_PROVIDER, DEFAULT_EMBEDDINGS_PROVIDER),
|
|
97
102
|
embeddings_local_model=os.getenv(ENV_EMBEDDINGS_LOCAL_MODEL, DEFAULT_EMBEDDINGS_LOCAL_MODEL),
|
|
98
103
|
embeddings_tei_url=os.getenv(ENV_EMBEDDINGS_TEI_URL),
|
|
99
|
-
|
|
100
104
|
# Reranker
|
|
101
105
|
reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
|
|
102
106
|
reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
|
|
103
107
|
reranker_tei_url=os.getenv(ENV_RERANKER_TEI_URL),
|
|
104
|
-
|
|
105
108
|
# Server
|
|
106
109
|
host=os.getenv(ENV_HOST, DEFAULT_HOST),
|
|
107
110
|
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
|
|
108
111
|
log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
|
|
109
112
|
mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
|
|
113
|
+
# Recall
|
|
114
|
+
graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
|
|
110
115
|
)
|
|
111
116
|
|
|
112
117
|
def get_llm_base_url(self) -> str:
|
|
@@ -137,8 +142,7 @@ class HindsightConfig:
|
|
|
137
142
|
def configure_logging(self) -> None:
|
|
138
143
|
"""Configure Python logging based on the log level."""
|
|
139
144
|
logging.basicConfig(
|
|
140
|
-
level=self.get_python_log_level(),
|
|
141
|
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
145
|
+
level=self.get_python_log_level(), format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
142
146
|
)
|
|
143
147
|
|
|
144
148
|
def log_config(self) -> None:
|
|
@@ -147,6 +151,7 @@ class HindsightConfig:
|
|
|
147
151
|
logger.info(f"LLM: provider={self.llm_provider}, model={self.llm_model}")
|
|
148
152
|
logger.info(f"Embeddings: provider={self.embeddings_provider}")
|
|
149
153
|
logger.info(f"Reranker: provider={self.reranker_provider}")
|
|
154
|
+
logger.info(f"Graph retriever: {self.graph_retriever}")
|
|
150
155
|
|
|
151
156
|
|
|
152
157
|
def get_config() -> HindsightConfig:
|
hindsight_api/engine/__init__.py
CHANGED
|
@@ -7,24 +7,24 @@ This package contains all the implementation details of the memory engine:
|
|
|
7
7
|
- Supporting modules: embeddings, cross_encoder, entity_resolver, etc.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from .
|
|
10
|
+
from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
|
|
11
11
|
from .db_utils import acquire_with_retry
|
|
12
12
|
from .embeddings import Embeddings, LocalSTEmbeddings, RemoteTEIEmbeddings
|
|
13
|
-
from .
|
|
13
|
+
from .llm_wrapper import LLMConfig
|
|
14
|
+
from .memory_engine import MemoryEngine
|
|
15
|
+
from .response_models import MemoryFact, RecallResult, ReflectResult
|
|
14
16
|
from .search.trace import (
|
|
15
|
-
SearchTrace,
|
|
16
|
-
QueryInfo,
|
|
17
17
|
EntryPoint,
|
|
18
|
-
NodeVisit,
|
|
19
|
-
WeightComponents,
|
|
20
18
|
LinkInfo,
|
|
19
|
+
NodeVisit,
|
|
21
20
|
PruningDecision,
|
|
22
|
-
|
|
21
|
+
QueryInfo,
|
|
23
22
|
SearchPhaseMetrics,
|
|
23
|
+
SearchSummary,
|
|
24
|
+
SearchTrace,
|
|
25
|
+
WeightComponents,
|
|
24
26
|
)
|
|
25
27
|
from .search.tracer import SearchTracer
|
|
26
|
-
from .llm_wrapper import LLMConfig
|
|
27
|
-
from .response_models import RecallResult, ReflectResult, MemoryFact
|
|
28
28
|
|
|
29
29
|
__all__ = [
|
|
30
30
|
"MemoryEngine",
|
|
@@ -5,19 +5,19 @@ Provides an interface for reranking with different backends.
|
|
|
5
5
|
|
|
6
6
|
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
|
-
from typing import List, Tuple, Optional
|
|
8
|
+
|
|
10
9
|
import logging
|
|
11
10
|
import os
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
12
|
|
|
13
13
|
import httpx
|
|
14
14
|
|
|
15
15
|
from ..config import (
|
|
16
|
-
|
|
16
|
+
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
17
|
+
DEFAULT_RERANKER_PROVIDER,
|
|
17
18
|
ENV_RERANKER_LOCAL_MODEL,
|
|
19
|
+
ENV_RERANKER_PROVIDER,
|
|
18
20
|
ENV_RERANKER_TEI_URL,
|
|
19
|
-
DEFAULT_RERANKER_PROVIDER,
|
|
20
|
-
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
logger = logging.getLogger(__name__)
|
|
@@ -47,7 +47,7 @@ class CrossEncoderModel(ABC):
|
|
|
47
47
|
pass
|
|
48
48
|
|
|
49
49
|
@abstractmethod
|
|
50
|
-
def predict(self, pairs:
|
|
50
|
+
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
51
51
|
"""
|
|
52
52
|
Score query-document pairs for relevance.
|
|
53
53
|
|
|
@@ -72,7 +72,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
72
72
|
- Trained for passage re-ranking
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
|
-
def __init__(self, model_name:
|
|
75
|
+
def __init__(self, model_name: str | None = None):
|
|
76
76
|
"""
|
|
77
77
|
Initialize local SentenceTransformers cross-encoder.
|
|
78
78
|
|
|
@@ -101,15 +101,10 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
101
101
|
)
|
|
102
102
|
|
|
103
103
|
logger.info(f"Reranker: initializing local provider with model {self.model_name}")
|
|
104
|
-
|
|
105
|
-
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
106
|
-
self._model = CrossEncoder(
|
|
107
|
-
self.model_name,
|
|
108
|
-
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
109
|
-
)
|
|
104
|
+
self._model = CrossEncoder(self.model_name)
|
|
110
105
|
logger.info("Reranker: local provider initialized")
|
|
111
106
|
|
|
112
|
-
def predict(self, pairs:
|
|
107
|
+
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
113
108
|
"""
|
|
114
109
|
Score query-document pairs for relevance.
|
|
115
110
|
|
|
@@ -122,7 +117,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
122
117
|
if self._model is None:
|
|
123
118
|
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
124
119
|
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
125
|
-
return scores.tolist() if hasattr(scores,
|
|
120
|
+
return scores.tolist() if hasattr(scores, "tolist") else list(scores)
|
|
126
121
|
|
|
127
122
|
|
|
128
123
|
class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
@@ -158,8 +153,8 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
158
153
|
self.batch_size = batch_size
|
|
159
154
|
self.max_retries = max_retries
|
|
160
155
|
self.retry_delay = retry_delay
|
|
161
|
-
self._client:
|
|
162
|
-
self._model_id:
|
|
156
|
+
self._client: httpx.Client | None = None
|
|
157
|
+
self._model_id: str | None = None
|
|
163
158
|
|
|
164
159
|
@property
|
|
165
160
|
def provider_name(self) -> str:
|
|
@@ -168,6 +163,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
168
163
|
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
169
164
|
"""Make an HTTP request with automatic retries on transient errors."""
|
|
170
165
|
import time
|
|
166
|
+
|
|
171
167
|
last_error = None
|
|
172
168
|
delay = self.retry_delay
|
|
173
169
|
|
|
@@ -182,14 +178,18 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
182
178
|
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
183
179
|
last_error = e
|
|
184
180
|
if attempt < self.max_retries:
|
|
185
|
-
logger.warning(
|
|
181
|
+
logger.warning(
|
|
182
|
+
f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
183
|
+
)
|
|
186
184
|
time.sleep(delay)
|
|
187
185
|
delay *= 2 # Exponential backoff
|
|
188
186
|
except httpx.HTTPStatusError as e:
|
|
189
187
|
# Retry on 5xx server errors
|
|
190
188
|
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
191
189
|
last_error = e
|
|
192
|
-
logger.warning(
|
|
190
|
+
logger.warning(
|
|
191
|
+
f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
192
|
+
)
|
|
193
193
|
time.sleep(delay)
|
|
194
194
|
delay *= 2
|
|
195
195
|
else:
|
|
@@ -214,7 +214,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
214
214
|
except httpx.HTTPError as e:
|
|
215
215
|
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
216
216
|
|
|
217
|
-
def predict(self, pairs:
|
|
217
|
+
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
218
218
|
"""
|
|
219
219
|
Score query-document pairs using the remote TEI reranker.
|
|
220
220
|
|
|
@@ -234,7 +234,7 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
234
234
|
|
|
235
235
|
# Process in batches
|
|
236
236
|
for i in range(0, len(pairs), self.batch_size):
|
|
237
|
-
batch = pairs[i:i + self.batch_size]
|
|
237
|
+
batch = pairs[i : i + self.batch_size]
|
|
238
238
|
|
|
239
239
|
# TEI rerank endpoint expects query and texts separately
|
|
240
240
|
# All pairs in a batch should have the same query for optimal performance
|
|
@@ -292,15 +292,11 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
|
|
|
292
292
|
if provider == "tei":
|
|
293
293
|
url = os.environ.get(ENV_RERANKER_TEI_URL)
|
|
294
294
|
if not url:
|
|
295
|
-
raise ValueError(
|
|
296
|
-
f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'"
|
|
297
|
-
)
|
|
295
|
+
raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
|
|
298
296
|
return RemoteTEICrossEncoder(base_url=url)
|
|
299
297
|
elif provider == "local":
|
|
300
298
|
model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
|
|
301
299
|
model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
|
|
302
300
|
return LocalSTCrossEncoder(model_name=model_name)
|
|
303
301
|
else:
|
|
304
|
-
raise ValueError(
|
|
305
|
-
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'"
|
|
306
|
-
)
|
|
302
|
+
raise ValueError(f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'")
|
hindsight_api/engine/db_utils.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Database utility functions for connection management with retry logic.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import asyncio
|
|
5
6
|
import logging
|
|
6
7
|
from contextlib import asynccontextmanager
|
|
8
|
+
|
|
7
9
|
import asyncpg
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
@@ -54,16 +56,14 @@ async def retry_with_backoff(
|
|
|
54
56
|
except retryable_exceptions as e:
|
|
55
57
|
last_exception = e
|
|
56
58
|
if attempt < max_retries:
|
|
57
|
-
delay = min(base_delay * (2
|
|
59
|
+
delay = min(base_delay * (2**attempt), max_delay)
|
|
58
60
|
logger.warning(
|
|
59
61
|
f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
|
|
60
62
|
f"Retrying in {delay:.1f}s..."
|
|
61
63
|
)
|
|
62
64
|
await asyncio.sleep(delay)
|
|
63
65
|
else:
|
|
64
|
-
logger.error(
|
|
65
|
-
f"Database operation failed after {max_retries + 1} attempts: {e}"
|
|
66
|
-
)
|
|
66
|
+
logger.error(f"Database operation failed after {max_retries + 1} attempts: {e}")
|
|
67
67
|
raise last_exception
|
|
68
68
|
|
|
69
69
|
|
|
@@ -83,6 +83,7 @@ async def acquire_with_retry(pool: asyncpg.Pool, max_retries: int = DEFAULT_MAX_
|
|
|
83
83
|
Yields:
|
|
84
84
|
An asyncpg connection
|
|
85
85
|
"""
|
|
86
|
+
|
|
86
87
|
async def acquire():
|
|
87
88
|
return await pool.acquire()
|
|
88
89
|
|
|
@@ -8,20 +8,20 @@ the database schema (pgvector column defined as vector(384)).
|
|
|
8
8
|
|
|
9
9
|
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
10
10
|
"""
|
|
11
|
-
|
|
12
|
-
from typing import List, Optional
|
|
11
|
+
|
|
13
12
|
import logging
|
|
14
13
|
import os
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
15
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
18
18
|
from ..config import (
|
|
19
|
-
ENV_EMBEDDINGS_PROVIDER,
|
|
20
|
-
ENV_EMBEDDINGS_LOCAL_MODEL,
|
|
21
|
-
ENV_EMBEDDINGS_TEI_URL,
|
|
22
|
-
DEFAULT_EMBEDDINGS_PROVIDER,
|
|
23
19
|
DEFAULT_EMBEDDINGS_LOCAL_MODEL,
|
|
20
|
+
DEFAULT_EMBEDDINGS_PROVIDER,
|
|
24
21
|
EMBEDDING_DIMENSION,
|
|
22
|
+
ENV_EMBEDDINGS_LOCAL_MODEL,
|
|
23
|
+
ENV_EMBEDDINGS_PROVIDER,
|
|
24
|
+
ENV_EMBEDDINGS_TEI_URL,
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
@@ -52,7 +52,7 @@ class Embeddings(ABC):
|
|
|
52
52
|
pass
|
|
53
53
|
|
|
54
54
|
@abstractmethod
|
|
55
|
-
def encode(self, texts:
|
|
55
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
56
56
|
"""
|
|
57
57
|
Generate 384-dimensional embeddings for a list of texts.
|
|
58
58
|
|
|
@@ -75,7 +75,7 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
75
75
|
embeddings matching the database schema.
|
|
76
76
|
"""
|
|
77
77
|
|
|
78
|
-
def __init__(self, model_name:
|
|
78
|
+
def __init__(self, model_name: str | None = None):
|
|
79
79
|
"""
|
|
80
80
|
Initialize local SentenceTransformers embeddings.
|
|
81
81
|
|
|
@@ -123,7 +123,7 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
123
123
|
|
|
124
124
|
logger.info(f"Embeddings: local provider initialized (dim: {model_dim})")
|
|
125
125
|
|
|
126
|
-
def encode(self, texts:
|
|
126
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
127
127
|
"""
|
|
128
128
|
Generate 384-dimensional embeddings for a list of texts.
|
|
129
129
|
|
|
@@ -172,8 +172,8 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
172
172
|
self.batch_size = batch_size
|
|
173
173
|
self.max_retries = max_retries
|
|
174
174
|
self.retry_delay = retry_delay
|
|
175
|
-
self._client:
|
|
176
|
-
self._model_id:
|
|
175
|
+
self._client: httpx.Client | None = None
|
|
176
|
+
self._model_id: str | None = None
|
|
177
177
|
|
|
178
178
|
@property
|
|
179
179
|
def provider_name(self) -> str:
|
|
@@ -182,6 +182,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
182
182
|
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
183
183
|
"""Make an HTTP request with automatic retries on transient errors."""
|
|
184
184
|
import time
|
|
185
|
+
|
|
185
186
|
last_error = None
|
|
186
187
|
delay = self.retry_delay
|
|
187
188
|
|
|
@@ -196,14 +197,18 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
196
197
|
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
197
198
|
last_error = e
|
|
198
199
|
if attempt < self.max_retries:
|
|
199
|
-
logger.warning(
|
|
200
|
+
logger.warning(
|
|
201
|
+
f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
202
|
+
)
|
|
200
203
|
time.sleep(delay)
|
|
201
204
|
delay *= 2 # Exponential backoff
|
|
202
205
|
except httpx.HTTPStatusError as e:
|
|
203
206
|
# Retry on 5xx server errors
|
|
204
207
|
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
205
208
|
last_error = e
|
|
206
|
-
logger.warning(
|
|
209
|
+
logger.warning(
|
|
210
|
+
f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
211
|
+
)
|
|
207
212
|
time.sleep(delay)
|
|
208
213
|
delay *= 2
|
|
209
214
|
else:
|
|
@@ -228,7 +233,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
228
233
|
except httpx.HTTPError as e:
|
|
229
234
|
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
230
235
|
|
|
231
|
-
def encode(self, texts:
|
|
236
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
232
237
|
"""
|
|
233
238
|
Generate embeddings using the remote TEI server.
|
|
234
239
|
|
|
@@ -248,7 +253,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
248
253
|
|
|
249
254
|
# Process in batches
|
|
250
255
|
for i in range(0, len(texts), self.batch_size):
|
|
251
|
-
batch = texts[i:i + self.batch_size]
|
|
256
|
+
batch = texts[i : i + self.batch_size]
|
|
252
257
|
|
|
253
258
|
try:
|
|
254
259
|
response = self._request_with_retry(
|
|
@@ -278,15 +283,11 @@ def create_embeddings_from_env() -> Embeddings:
|
|
|
278
283
|
if provider == "tei":
|
|
279
284
|
url = os.environ.get(ENV_EMBEDDINGS_TEI_URL)
|
|
280
285
|
if not url:
|
|
281
|
-
raise ValueError(
|
|
282
|
-
f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'"
|
|
283
|
-
)
|
|
286
|
+
raise ValueError(f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'")
|
|
284
287
|
return RemoteTEIEmbeddings(base_url=url)
|
|
285
288
|
elif provider == "local":
|
|
286
289
|
model = os.environ.get(ENV_EMBEDDINGS_LOCAL_MODEL)
|
|
287
290
|
model_name = model or DEFAULT_EMBEDDINGS_LOCAL_MODEL
|
|
288
291
|
return LocalSTEmbeddings(model_name=model_name)
|
|
289
292
|
else:
|
|
290
|
-
raise ValueError(
|
|
291
|
-
f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'"
|
|
292
|
-
)
|
|
293
|
+
raise ValueError(f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'")
|