aixtools 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/_version.py +2 -2
- aixtools/a2a/app.py +1 -1
- aixtools/a2a/google_sdk/__init__.py +0 -0
- aixtools/a2a/google_sdk/card.py +27 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
- aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
- aixtools/a2a/google_sdk/utils.py +59 -0
- aixtools/agents/prompt.py +97 -0
- aixtools/context.py +5 -0
- aixtools/google/client.py +25 -0
- aixtools/logging/logging_config.py +45 -0
- aixtools/mcp/client.py +274 -0
- aixtools/server/utils.py +3 -3
- aixtools/utils/config.py +13 -0
- aixtools/utils/files.py +17 -0
- aixtools/utils/utils.py +7 -0
- aixtools/vault/__init__.py +7 -0
- aixtools/vault/vault.py +73 -0
- aixtools-0.1.6.dist-info/METADATA +668 -0
- {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/RECORD +48 -13
- {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/top_level.txt +1 -0
- scripts/test.sh +23 -0
- tests/__init__.py +0 -0
- tests/unit/__init__.py +0 -0
- tests/unit/a2a/__init__.py +0 -0
- tests/unit/a2a/google_sdk/__init__.py +0 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/__init__.py +0 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_agent_executor.py +188 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_storage.py +156 -0
- tests/unit/a2a/google_sdk/test_card.py +114 -0
- tests/unit/a2a/google_sdk/test_remote_agent_connection.py +413 -0
- tests/unit/a2a/google_sdk/test_utils.py +208 -0
- tests/unit/agents/__init__.py +0 -0
- tests/unit/agents/test_prompt.py +363 -0
- tests/unit/google/__init__.py +1 -0
- tests/unit/google/test_client.py +233 -0
- tests/unit/mcp/__init__.py +0 -0
- tests/unit/mcp/test_client.py +242 -0
- tests/unit/server/__init__.py +0 -0
- tests/unit/server/test_path.py +225 -0
- tests/unit/server/test_utils.py +362 -0
- tests/unit/utils/__init__.py +0 -0
- tests/unit/utils/test_files.py +146 -0
- tests/unit/vault/__init__.py +0 -0
- tests/unit/vault/test_vault.py +114 -0
- aixtools/a2a/__init__.py +0 -5
- aixtools-0.1.4.dist-info/METADATA +0 -355
- {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/WHEEL +0 -0
- {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/entry_points.txt +0 -0
aixtools/mcp/client.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""MCP server utilities with caching and robust error handling."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import anyio
|
|
7
|
+
from cachebox import TTLCache
|
|
8
|
+
from mcp import types as mcp_types
|
|
9
|
+
from mcp.shared.exceptions import McpError
|
|
10
|
+
from pydantic_ai import RunContext, exceptions
|
|
11
|
+
from pydantic_ai.mcp import MCPServerStreamableHTTP, ToolResult
|
|
12
|
+
from pydantic_ai.toolsets.abstract import ToolsetTool
|
|
13
|
+
|
|
14
|
+
from aixtools.context import SessionIdTuple
|
|
15
|
+
from aixtools.logging.logging_config import get_logger
|
|
16
|
+
|
|
17
|
+
MCP_TOOL_CACHE_TTL = 300 # 5 minutes
|
|
18
|
+
DEFAULT_MCP_CONNECTION_TIMEOUT = 30
|
|
19
|
+
CACHE_KEY = "TOOL_LIST"
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_mcp_headers(session_id_tuple: SessionIdTuple) -> dict[str, str] | None:
|
|
25
|
+
"""
|
|
26
|
+
Generate headers for MCP server requests.
|
|
27
|
+
|
|
28
|
+
This function creates a dictionary of headers to be used in requests to
|
|
29
|
+
the MCP servers. If a `user_id` or `session_id` is provided, they are
|
|
30
|
+
included in the headers.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
session_id_tuple (SessionIdTuple): user_id and session_id tuple
|
|
34
|
+
Returns:
|
|
35
|
+
dict[str, str] | None: A dictionary of headers for MCP server requests,
|
|
36
|
+
or None if neither user_id nor session_id is
|
|
37
|
+
provided. When None is returned, default headers
|
|
38
|
+
from the client or transport will be used.
|
|
39
|
+
"""
|
|
40
|
+
headers = None
|
|
41
|
+
user_id, session_id = session_id_tuple
|
|
42
|
+
if session_id or user_id:
|
|
43
|
+
headers = {}
|
|
44
|
+
if session_id:
|
|
45
|
+
headers["session-id"] = session_id
|
|
46
|
+
if user_id:
|
|
47
|
+
headers["user-id"] = user_id
|
|
48
|
+
return headers
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_configured_mcp_servers(
|
|
52
|
+
session_id_tuple: SessionIdTuple, mcp_urls: list[str], timeout: int = DEFAULT_MCP_CONNECTION_TIMEOUT
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Retrieve the configured MCP server instances with optional caching.
|
|
56
|
+
|
|
57
|
+
Context values `user_id` and `session_id` are included in the headers for each server request.
|
|
58
|
+
|
|
59
|
+
Each server is wrapped in a try-except block to isolate them from each other.
|
|
60
|
+
If one server fails, it won't affect the others.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
session_id_tuple (SessionIdTuple): A tuple containing (user_id, session_id).
|
|
64
|
+
mcp_urls: (list[str], optional): A list of MCP server URLs to use.
|
|
65
|
+
timeout (int, optional): Timeout in seconds for MCP server connections. Defaults to 30 seconds.
|
|
66
|
+
Returns:
|
|
67
|
+
list[MCPServerStreamableHTTP]: A list of configured MCP server instances. If
|
|
68
|
+
neither user_id nor session_id is provided, the
|
|
69
|
+
server instances will use default headers defined
|
|
70
|
+
by the underlying HTTP implementation.
|
|
71
|
+
"""
|
|
72
|
+
headers = get_mcp_headers(session_id_tuple)
|
|
73
|
+
|
|
74
|
+
return [CachedMCPServerStreamableHTTP(url=url, headers=headers, timeout=timeout) for url in mcp_urls]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class CachedMCPServerStreamableHTTP(MCPServerStreamableHTTP):
|
|
78
|
+
"""StreamableHTTP MCP server with cachebox-based TTL caching and robust error handling.
|
|
79
|
+
|
|
80
|
+
This class addresses the cancellation propagation issue by:
|
|
81
|
+
1. Using complete task isolation to prevent CancelledError propagation
|
|
82
|
+
2. Implementing comprehensive error handling for all MCP operations
|
|
83
|
+
3. Using fallback mechanisms when servers become unavailable
|
|
84
|
+
4. Overriding pydantic_ai methods to fix variable scoping bug
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self, **kwargs):
|
|
88
|
+
super().__init__(**kwargs)
|
|
89
|
+
self._tools_cache = TTLCache(maxsize=1, ttl=MCP_TOOL_CACHE_TTL)
|
|
90
|
+
self._tools_list = None
|
|
91
|
+
self._isolation_lock = asyncio.Lock() # Lock for critical operations
|
|
92
|
+
|
|
93
|
+
async def _run_direct_or_isolated(self, func, fallback, timeout: float | None):
|
|
94
|
+
"""Run a coroutine in complete isolation to prevent cancellation propagation.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
func: Function that returns a coroutine to run
|
|
98
|
+
fallback: Function that takes an exception and returns a fallback value
|
|
99
|
+
timeout: Timeout in seconds. If None, then direct run is performed
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The result of the coroutine on success, or fallback value on any exception
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
if timeout is None:
|
|
106
|
+
return await func()
|
|
107
|
+
|
|
108
|
+
task = asyncio.create_task(func())
|
|
109
|
+
|
|
110
|
+
# Use asyncio.wait to prevent cancellation propagation
|
|
111
|
+
done, pending = await asyncio.wait([task], timeout=timeout)
|
|
112
|
+
|
|
113
|
+
if pending:
|
|
114
|
+
# Cancel pending tasks safely
|
|
115
|
+
for t in pending:
|
|
116
|
+
t.cancel()
|
|
117
|
+
try:
|
|
118
|
+
await t
|
|
119
|
+
except (asyncio.CancelledError, Exception): # pylint: disable=broad-except
|
|
120
|
+
pass
|
|
121
|
+
raise TimeoutError(f"Task timed out after {timeout} seconds")
|
|
122
|
+
|
|
123
|
+
# Get result from completed task
|
|
124
|
+
completed_task = done.pop()
|
|
125
|
+
if exc := completed_task.exception():
|
|
126
|
+
raise exc
|
|
127
|
+
return completed_task.result()
|
|
128
|
+
|
|
129
|
+
except exceptions.ModelRetry as exc:
|
|
130
|
+
logger.warning("MCP %s: %s ModelRetry: %s", self.url, func.__name__, exc)
|
|
131
|
+
raise
|
|
132
|
+
except TimeoutError as exc:
|
|
133
|
+
logger.warning("MCP %s: %s timed out: %s", self.url, func.__name__, exc)
|
|
134
|
+
return fallback(exc)
|
|
135
|
+
except asyncio.CancelledError as exc:
|
|
136
|
+
logger.warning("MCP %s: %s was cancelled", self.url, func.__name__)
|
|
137
|
+
return fallback(exc)
|
|
138
|
+
except anyio.ClosedResourceError as exc:
|
|
139
|
+
logger.warning("MCP %s: %s closed resource.", self.url, func.__name__)
|
|
140
|
+
return fallback(exc)
|
|
141
|
+
except Exception as exc: # pylint: disable=broad-except
|
|
142
|
+
if str(exc) == "Attempted to exit cancel scope in a different task than it was entered in":
|
|
143
|
+
logger.warning("MCP %s: %s enter/exit cancel scope task mismatch.", self.url, func.__name__)
|
|
144
|
+
else:
|
|
145
|
+
logger.warning("MCP %s: %s exception %s: %s", self.url, func.__name__, type(exc), exc)
|
|
146
|
+
return fallback(exc)
|
|
147
|
+
|
|
148
|
+
async def __aenter__(self):
|
|
149
|
+
"""Enter the context of the cached MCP server with complete cancellation isolation."""
|
|
150
|
+
async with self._isolation_lock:
|
|
151
|
+
|
|
152
|
+
async def direct_init():
|
|
153
|
+
return await super(CachedMCPServerStreamableHTTP, self).__aenter__() # pylint: disable=super-with-arguments
|
|
154
|
+
|
|
155
|
+
def fallback(_exc):
|
|
156
|
+
self._client = None
|
|
157
|
+
return self
|
|
158
|
+
|
|
159
|
+
return await self._run_direct_or_isolated(direct_init, fallback, timeout=None)
|
|
160
|
+
|
|
161
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
162
|
+
"""Exit the context of the cached MCP server with complete cancellation isolation."""
|
|
163
|
+
async with self._isolation_lock:
|
|
164
|
+
# If we're being cancelled, just clean up
|
|
165
|
+
if exc_type is asyncio.CancelledError:
|
|
166
|
+
logger.warning("MCP %s: __aexit__ called with cancellation - cleaning up", self.url)
|
|
167
|
+
self._client = None
|
|
168
|
+
return True
|
|
169
|
+
|
|
170
|
+
# If client is already None, skip cleanup
|
|
171
|
+
if not self._client:
|
|
172
|
+
logger.warning("MCP %s: is uninitialized -> skipping cleanup", self.url)
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
async def direct_cleanup():
|
|
176
|
+
return await super(CachedMCPServerStreamableHTTP, self).__aexit__(exc_type, exc_val, exc_tb) # pylint: disable=super-with-arguments
|
|
177
|
+
|
|
178
|
+
def fallback(_exc):
|
|
179
|
+
self._client = None
|
|
180
|
+
return True # Suppress exceptions to prevent propagation
|
|
181
|
+
|
|
182
|
+
return await self._run_direct_or_isolated(direct_cleanup, fallback, timeout=None)
|
|
183
|
+
|
|
184
|
+
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
185
|
+
"""Override to fix variable scoping bug and add caching with cancellation isolation."""
|
|
186
|
+
# If client is not initialized, return empty list
|
|
187
|
+
if not self._client:
|
|
188
|
+
logger.warning("MCP %s: is uninitialized -> no tools", self.url)
|
|
189
|
+
return []
|
|
190
|
+
|
|
191
|
+
# First, check if we have a valid cached result
|
|
192
|
+
if CACHE_KEY in self._tools_cache:
|
|
193
|
+
logger.info("Using cached tools for %s", self.url)
|
|
194
|
+
return self._tools_cache[CACHE_KEY]
|
|
195
|
+
|
|
196
|
+
# Create isolated task to prevent cancellation propagation
|
|
197
|
+
async def isolated_list_tools():
|
|
198
|
+
"""Isolated list_tools with variable scoping bug fix."""
|
|
199
|
+
result = None # Initialize to prevent UnboundLocalError
|
|
200
|
+
async with self: # Ensure server is running
|
|
201
|
+
result = await self._client.list_tools()
|
|
202
|
+
if result:
|
|
203
|
+
self._tools_list = result.tools or []
|
|
204
|
+
self._tools_cache[CACHE_KEY] = self._tools_list
|
|
205
|
+
logger.info("MCP %s: list_tools returned %d tools", self.url, len(self._tools_list))
|
|
206
|
+
else:
|
|
207
|
+
logger.warning("MCP %s: list_tools returned no result", self.url)
|
|
208
|
+
return self._tools_list or []
|
|
209
|
+
|
|
210
|
+
def fallback(_exc):
|
|
211
|
+
return self._tools_list or []
|
|
212
|
+
|
|
213
|
+
return await self._run_direct_or_isolated(isolated_list_tools, fallback, timeout=5.0)
|
|
214
|
+
|
|
215
|
+
async def call_tool(
|
|
216
|
+
self,
|
|
217
|
+
name: str,
|
|
218
|
+
tool_args: dict[str, Any],
|
|
219
|
+
ctx: RunContext[Any],
|
|
220
|
+
tool: ToolsetTool[Any],
|
|
221
|
+
) -> ToolResult:
|
|
222
|
+
"""Call tool with complete isolation from cancellation using patched pydantic_ai."""
|
|
223
|
+
logger.info("MCP %s: call_tool '%s' started.", self.url, name)
|
|
224
|
+
|
|
225
|
+
# Early returns for uninitialized servers
|
|
226
|
+
if not self._client:
|
|
227
|
+
logger.warning("MCP %s: is uninitialized -> cannot call tool", self.url)
|
|
228
|
+
return f"There was an error with calling tool '{name}': MCP connection is uninitialized."
|
|
229
|
+
|
|
230
|
+
# Create isolated task to prevent cancellation propagation
|
|
231
|
+
async def isolated_call_tool():
|
|
232
|
+
"""Isolated call_tool using patched pydantic_ai methods."""
|
|
233
|
+
return await super(CachedMCPServerStreamableHTTP, self).call_tool(name, tool_args, ctx, tool) # pylint: disable=super-with-arguments
|
|
234
|
+
|
|
235
|
+
def fallback(exc):
|
|
236
|
+
return f"Exception {type(exc)} when calling tool '{name}': {exc}. Consider alternative approaches."
|
|
237
|
+
|
|
238
|
+
result = await self._run_direct_or_isolated(isolated_call_tool, fallback, timeout=3600.0)
|
|
239
|
+
logger.info("MCP %s: call_tool '%s' completed.", self.url, name)
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
async def direct_call_tool(
|
|
243
|
+
self, name: str, args: dict[str, Any], metadata: dict[str, Any] | None = None
|
|
244
|
+
) -> ToolResult:
|
|
245
|
+
"""Override to fix variable scoping bug in direct_call_tool."""
|
|
246
|
+
result = None # Initialize to prevent UnboundLocalError
|
|
247
|
+
async with self: # Ensure server is running
|
|
248
|
+
try:
|
|
249
|
+
result = await self._client.send_request(
|
|
250
|
+
mcp_types.ClientRequest(
|
|
251
|
+
mcp_types.CallToolRequest(
|
|
252
|
+
method="tools/call",
|
|
253
|
+
params=mcp_types.CallToolRequestParams(
|
|
254
|
+
name=name,
|
|
255
|
+
arguments=args,
|
|
256
|
+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
|
|
257
|
+
),
|
|
258
|
+
)
|
|
259
|
+
),
|
|
260
|
+
mcp_types.CallToolResult,
|
|
261
|
+
)
|
|
262
|
+
except McpError as e:
|
|
263
|
+
raise exceptions.ModelRetry(e.error.message)
|
|
264
|
+
|
|
265
|
+
if not result:
|
|
266
|
+
raise exceptions.ModelRetry("No result from MCP server")
|
|
267
|
+
|
|
268
|
+
content = [await self._map_tool_result_part(part) for part in result.content]
|
|
269
|
+
|
|
270
|
+
if result.isError:
|
|
271
|
+
text = "\n".join(str(part) for part in content)
|
|
272
|
+
raise exceptions.ModelRetry(text)
|
|
273
|
+
|
|
274
|
+
return content[0] if len(content) == 1 else content
|
aixtools/server/utils.py
CHANGED
|
@@ -8,7 +8,7 @@ from functools import wraps
|
|
|
8
8
|
from fastmcp import Context
|
|
9
9
|
from fastmcp.server import dependencies
|
|
10
10
|
|
|
11
|
-
from ..context import session_id_var, user_id_var
|
|
11
|
+
from ..context import DEFAULT_SESSION_ID, DEFAULT_USER_ID, session_id_var, user_id_var
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
|
|
@@ -18,9 +18,9 @@ def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
|
|
|
18
18
|
Returns: Tuple of (user_id, session_id).
|
|
19
19
|
"""
|
|
20
20
|
user_id = get_user_id_from_request(ctx)
|
|
21
|
-
user_id = user_id or user_id_var.get(
|
|
21
|
+
user_id = user_id or user_id_var.get(DEFAULT_USER_ID)
|
|
22
22
|
session_id = get_session_id_from_request(ctx)
|
|
23
|
-
session_id = session_id or session_id_var.get(
|
|
23
|
+
session_id = session_id or session_id_var.get(DEFAULT_SESSION_ID)
|
|
24
24
|
return user_id, session_id
|
|
25
25
|
|
|
26
26
|
|
aixtools/utils/config.py
CHANGED
|
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
|
9
9
|
from dotenv import dotenv_values, load_dotenv
|
|
10
10
|
|
|
11
11
|
from aixtools.utils.config_util import find_env_file, get_project_root, get_variable_env
|
|
12
|
+
from aixtools.utils.utils import str2bool
|
|
12
13
|
|
|
13
14
|
# Debug mode
|
|
14
15
|
LOG_LEVEL = logging.DEBUG
|
|
@@ -116,3 +117,15 @@ BEDROCK_MODEL_NAME = get_variable_env("BEDROCK_MODEL_NAME", allow_empty=True)
|
|
|
116
117
|
# LogFire
|
|
117
118
|
LOGFIRE_TOKEN = get_variable_env("LOGFIRE_TOKEN", True, "")
|
|
118
119
|
LOGFIRE_TRACES_ENDPOINT = get_variable_env("LOGFIRE_TRACES_ENDPOINT", True, "")
|
|
120
|
+
|
|
121
|
+
# Google Vertex AI
|
|
122
|
+
GOOGLE_GENAI_USE_VERTEXAI = str2bool(get_variable_env("GOOGLE_GENAI_USE_VERTEXAI", True, True))
|
|
123
|
+
GOOGLE_CLOUD_PROJECT = get_variable_env("GOOGLE_CLOUD_PROJECT", True)
|
|
124
|
+
GOOGLE_CLOUD_LOCATION = get_variable_env("GOOGLE_CLOUD_LOCATION", True)
|
|
125
|
+
|
|
126
|
+
# vault parameters.
|
|
127
|
+
VAULT_ADDRESS = get_variable_env("VAULT_ADDRESS", default="http://localhost:8200")
|
|
128
|
+
VAULT_TOKEN = get_variable_env("VAULT_TOKEN", default="vault-token")
|
|
129
|
+
VAULT_ENV = get_variable_env("ENV", default="dev")
|
|
130
|
+
VAULT_MOUNT_POINT = get_variable_env("VAULT_MOUNT_POINT", default="secret")
|
|
131
|
+
VAULT_PATH_PREFIX = get_variable_env("VAULT_PATH_PREFIX", default="path")
|
aixtools/utils/files.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""File utilities"""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def is_text_content(data: bytes, mime_type: str) -> bool:
|
|
5
|
+
"""Check if content is text based on mime type and content analysis."""
|
|
6
|
+
# Check mime type first
|
|
7
|
+
if mime_type and (
|
|
8
|
+
mime_type.startswith("text/") or mime_type in ["application/json", "application/xml", "application/javascript"]
|
|
9
|
+
):
|
|
10
|
+
return True
|
|
11
|
+
|
|
12
|
+
# Try to decode as UTF-8 to check if it's text
|
|
13
|
+
try:
|
|
14
|
+
data.decode("utf-8")
|
|
15
|
+
return True
|
|
16
|
+
except UnicodeDecodeError:
|
|
17
|
+
return False
|
aixtools/utils/utils.py
CHANGED
|
@@ -154,6 +154,13 @@ def timestamp_uuid_tuple() -> tuple[str, str, str]:
|
|
|
154
154
|
return (now.strftime("%Y-%m-%d"), now.strftime("%H:%M:%S"), str(uuid.uuid4()))
|
|
155
155
|
|
|
156
156
|
|
|
157
|
+
def str2bool(v: str | None) -> bool:
|
|
158
|
+
"""Convert a string to a boolean value."""
|
|
159
|
+
if not v:
|
|
160
|
+
return False
|
|
161
|
+
return str(v).lower() in ("yes", "true", "on", "1")
|
|
162
|
+
|
|
163
|
+
|
|
157
164
|
async def async_iter(items):
|
|
158
165
|
"""Asynchronously iterate over items."""
|
|
159
166
|
for item in items:
|
aixtools/vault/vault.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# ruff: noqa: PLR0913
|
|
2
|
+
"""
|
|
3
|
+
Provides a Vault client for storing and retrieving user service api keys.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import hvac
|
|
10
|
+
from hvac.exceptions import InvalidPath
|
|
11
|
+
|
|
12
|
+
from aixtools.utils import config
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class VaultAuthError(Exception):
|
|
18
|
+
"""Exception raised for vault authentication errors."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VaultClient:
|
|
22
|
+
"""Vault client for storing and retrieving user service api keys."""
|
|
23
|
+
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self.client = hvac.Client(url=config.VAULT_ADDRESS, token=config.VAULT_TOKEN)
|
|
26
|
+
|
|
27
|
+
if not self.client.is_authenticated():
|
|
28
|
+
raise VaultAuthError("Vault client authentication failed. Check vault_token.")
|
|
29
|
+
|
|
30
|
+
def store_user_service_api_key(self, *, user_id: str, service_name: str, user_api_key: str):
|
|
31
|
+
"""
|
|
32
|
+
Store user's service api key in the Vault at the specified vault mount
|
|
33
|
+
point, where the path is <path_prefix>/<env>/<user_id>/<service_name>.
|
|
34
|
+
"""
|
|
35
|
+
secret_path = None
|
|
36
|
+
try:
|
|
37
|
+
secret_path = f"{config.VAULT_PATH_PREFIX}/{config.VAULT_ENV}/{user_id}/{service_name}"
|
|
38
|
+
print("secret_path", secret_path)
|
|
39
|
+
secret_dict = {"user-api-key": user_api_key}
|
|
40
|
+
self.client.secrets.kv.v2.create_or_update_secret(
|
|
41
|
+
secret_path, secret=secret_dict, mount_point=config.VAULT_MOUNT_POINT
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
logger.info("Secret written to path %s", secret_path)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
logger.error("Failed to write secret to path %s: %s", secret_path, str(e))
|
|
47
|
+
raise VaultAuthError(e) from e
|
|
48
|
+
|
|
49
|
+
def read_user_service_api_key(self, *, user_id: str, service_name) -> Optional[str]:
|
|
50
|
+
"""
|
|
51
|
+
Read user's service api key in from vault at the specified mount point,
|
|
52
|
+
where the path is <path_prefix>/<env>/<user_id>/<service_name>.
|
|
53
|
+
"""
|
|
54
|
+
secret_path = None
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
secret_path = f"{config.VAULT_PATH_PREFIX}/{config.VAULT_ENV}/{user_id}/{service_name}"
|
|
58
|
+
logger.info("Reading secret from path %s", secret_path)
|
|
59
|
+
response = self.client.secrets.kv.v2.read_secret_version(
|
|
60
|
+
secret_path, mount_point=config.VAULT_MOUNT_POINT, raise_on_deleted_version=True
|
|
61
|
+
)
|
|
62
|
+
secret_data = response["data"]["data"]
|
|
63
|
+
user_api_key = secret_data["user-api-key"]
|
|
64
|
+
logger.info("Secret read from path %s ", secret_path)
|
|
65
|
+
return user_api_key
|
|
66
|
+
except InvalidPath:
|
|
67
|
+
# Secret path does not exist
|
|
68
|
+
logger.warning("Secret path does not exist %s ", secret_path)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error("Failed to read secret from path %s: %s", secret_path, str(e))
|
|
73
|
+
raise VaultAuthError(e) from e
|