aixtools 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (50) hide show
  1. aixtools/_version.py +2 -2
  2. aixtools/a2a/app.py +1 -1
  3. aixtools/a2a/google_sdk/__init__.py +0 -0
  4. aixtools/a2a/google_sdk/card.py +27 -0
  5. aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
  6. aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
  7. aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
  8. aixtools/a2a/google_sdk/utils.py +59 -0
  9. aixtools/agents/prompt.py +97 -0
  10. aixtools/context.py +5 -0
  11. aixtools/google/client.py +25 -0
  12. aixtools/logging/logging_config.py +45 -0
  13. aixtools/mcp/client.py +274 -0
  14. aixtools/server/utils.py +3 -3
  15. aixtools/utils/config.py +13 -0
  16. aixtools/utils/files.py +17 -0
  17. aixtools/utils/utils.py +7 -0
  18. aixtools/vault/__init__.py +7 -0
  19. aixtools/vault/vault.py +73 -0
  20. aixtools-0.1.6.dist-info/METADATA +668 -0
  21. {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/RECORD +48 -13
  22. {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/top_level.txt +1 -0
  23. scripts/test.sh +23 -0
  24. tests/__init__.py +0 -0
  25. tests/unit/__init__.py +0 -0
  26. tests/unit/a2a/__init__.py +0 -0
  27. tests/unit/a2a/google_sdk/__init__.py +0 -0
  28. tests/unit/a2a/google_sdk/pydantic_ai_adapter/__init__.py +0 -0
  29. tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_agent_executor.py +188 -0
  30. tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_storage.py +156 -0
  31. tests/unit/a2a/google_sdk/test_card.py +114 -0
  32. tests/unit/a2a/google_sdk/test_remote_agent_connection.py +413 -0
  33. tests/unit/a2a/google_sdk/test_utils.py +208 -0
  34. tests/unit/agents/__init__.py +0 -0
  35. tests/unit/agents/test_prompt.py +363 -0
  36. tests/unit/google/__init__.py +1 -0
  37. tests/unit/google/test_client.py +233 -0
  38. tests/unit/mcp/__init__.py +0 -0
  39. tests/unit/mcp/test_client.py +242 -0
  40. tests/unit/server/__init__.py +0 -0
  41. tests/unit/server/test_path.py +225 -0
  42. tests/unit/server/test_utils.py +362 -0
  43. tests/unit/utils/__init__.py +0 -0
  44. tests/unit/utils/test_files.py +146 -0
  45. tests/unit/vault/__init__.py +0 -0
  46. tests/unit/vault/test_vault.py +114 -0
  47. aixtools/a2a/__init__.py +0 -5
  48. aixtools-0.1.4.dist-info/METADATA +0 -355
  49. {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/WHEEL +0 -0
  50. {aixtools-0.1.4.dist-info → aixtools-0.1.6.dist-info}/entry_points.txt +0 -0
aixtools/mcp/client.py ADDED
@@ -0,0 +1,274 @@
1
+ """MCP server utilities with caching and robust error handling."""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import anyio
7
+ from cachebox import TTLCache
8
+ from mcp import types as mcp_types
9
+ from mcp.shared.exceptions import McpError
10
+ from pydantic_ai import RunContext, exceptions
11
+ from pydantic_ai.mcp import MCPServerStreamableHTTP, ToolResult
12
+ from pydantic_ai.toolsets.abstract import ToolsetTool
13
+
14
+ from aixtools.context import SessionIdTuple
15
+ from aixtools.logging.logging_config import get_logger
16
+
17
+ MCP_TOOL_CACHE_TTL = 300 # 5 minutes
18
+ DEFAULT_MCP_CONNECTION_TIMEOUT = 30
19
+ CACHE_KEY = "TOOL_LIST"
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ def get_mcp_headers(session_id_tuple: SessionIdTuple) -> dict[str, str] | None:
25
+ """
26
+ Generate headers for MCP server requests.
27
+
28
+ This function creates a dictionary of headers to be used in requests to
29
+ the MCP servers. If a `user_id` or `session_id` is provided, they are
30
+ included in the headers.
31
+
32
+ Args:
33
+ session_id_tuple (SessionIdTuple): user_id and session_id tuple
34
+ Returns:
35
+ dict[str, str] | None: A dictionary of headers for MCP server requests,
36
+ or None if neither user_id nor session_id is
37
+ provided. When None is returned, default headers
38
+ from the client or transport will be used.
39
+ """
40
+ headers = None
41
+ user_id, session_id = session_id_tuple
42
+ if session_id or user_id:
43
+ headers = {}
44
+ if session_id:
45
+ headers["session-id"] = session_id
46
+ if user_id:
47
+ headers["user-id"] = user_id
48
+ return headers
49
+
50
+
51
+ def get_configured_mcp_servers(
52
+ session_id_tuple: SessionIdTuple, mcp_urls: list[str], timeout: int = DEFAULT_MCP_CONNECTION_TIMEOUT
53
+ ):
54
+ """
55
+ Retrieve the configured MCP server instances with optional caching.
56
+
57
+ Context values `user_id` and `session_id` are included in the headers for each server request.
58
+
59
+ Each server is wrapped in a try-except block to isolate them from each other.
60
+ If one server fails, it won't affect the others.
61
+
62
+ Args:
63
+ session_id_tuple (SessionIdTuple): A tuple containing (user_id, session_id).
64
+ mcp_urls: (list[str], optional): A list of MCP server URLs to use.
65
+ timeout (int, optional): Timeout in seconds for MCP server connections. Defaults to 30 seconds.
66
+ Returns:
67
+ list[MCPServerStreamableHTTP]: A list of configured MCP server instances. If
68
+ neither user_id nor session_id is provided, the
69
+ server instances will use default headers defined
70
+ by the underlying HTTP implementation.
71
+ """
72
+ headers = get_mcp_headers(session_id_tuple)
73
+
74
+ return [CachedMCPServerStreamableHTTP(url=url, headers=headers, timeout=timeout) for url in mcp_urls]
75
+
76
+
77
+ class CachedMCPServerStreamableHTTP(MCPServerStreamableHTTP):
78
+ """StreamableHTTP MCP server with cachebox-based TTL caching and robust error handling.
79
+
80
+ This class addresses the cancellation propagation issue by:
81
+ 1. Using complete task isolation to prevent CancelledError propagation
82
+ 2. Implementing comprehensive error handling for all MCP operations
83
+ 3. Using fallback mechanisms when servers become unavailable
84
+ 4. Overriding pydantic_ai methods to fix variable scoping bug
85
+ """
86
+
87
+ def __init__(self, **kwargs):
88
+ super().__init__(**kwargs)
89
+ self._tools_cache = TTLCache(maxsize=1, ttl=MCP_TOOL_CACHE_TTL)
90
+ self._tools_list = None
91
+ self._isolation_lock = asyncio.Lock() # Lock for critical operations
92
+
93
+ async def _run_direct_or_isolated(self, func, fallback, timeout: float | None):
94
+ """Run a coroutine in complete isolation to prevent cancellation propagation.
95
+
96
+ Args:
97
+ func: Function that returns a coroutine to run
98
+ fallback: Function that takes an exception and returns a fallback value
99
+ timeout: Timeout in seconds. If None, then direct run is performed
100
+
101
+ Returns:
102
+ The result of the coroutine on success, or fallback value on any exception
103
+ """
104
+ try:
105
+ if timeout is None:
106
+ return await func()
107
+
108
+ task = asyncio.create_task(func())
109
+
110
+ # Use asyncio.wait to prevent cancellation propagation
111
+ done, pending = await asyncio.wait([task], timeout=timeout)
112
+
113
+ if pending:
114
+ # Cancel pending tasks safely
115
+ for t in pending:
116
+ t.cancel()
117
+ try:
118
+ await t
119
+ except (asyncio.CancelledError, Exception): # pylint: disable=broad-except
120
+ pass
121
+ raise TimeoutError(f"Task timed out after {timeout} seconds")
122
+
123
+ # Get result from completed task
124
+ completed_task = done.pop()
125
+ if exc := completed_task.exception():
126
+ raise exc
127
+ return completed_task.result()
128
+
129
+ except exceptions.ModelRetry as exc:
130
+ logger.warning("MCP %s: %s ModelRetry: %s", self.url, func.__name__, exc)
131
+ raise
132
+ except TimeoutError as exc:
133
+ logger.warning("MCP %s: %s timed out: %s", self.url, func.__name__, exc)
134
+ return fallback(exc)
135
+ except asyncio.CancelledError as exc:
136
+ logger.warning("MCP %s: %s was cancelled", self.url, func.__name__)
137
+ return fallback(exc)
138
+ except anyio.ClosedResourceError as exc:
139
+ logger.warning("MCP %s: %s closed resource.", self.url, func.__name__)
140
+ return fallback(exc)
141
+ except Exception as exc: # pylint: disable=broad-except
142
+ if str(exc) == "Attempted to exit cancel scope in a different task than it was entered in":
143
+ logger.warning("MCP %s: %s enter/exit cancel scope task mismatch.", self.url, func.__name__)
144
+ else:
145
+ logger.warning("MCP %s: %s exception %s: %s", self.url, func.__name__, type(exc), exc)
146
+ return fallback(exc)
147
+
148
+ async def __aenter__(self):
149
+ """Enter the context of the cached MCP server with complete cancellation isolation."""
150
+ async with self._isolation_lock:
151
+
152
+ async def direct_init():
153
+ return await super(CachedMCPServerStreamableHTTP, self).__aenter__() # pylint: disable=super-with-arguments
154
+
155
+ def fallback(_exc):
156
+ self._client = None
157
+ return self
158
+
159
+ return await self._run_direct_or_isolated(direct_init, fallback, timeout=None)
160
+
161
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
162
+ """Exit the context of the cached MCP server with complete cancellation isolation."""
163
+ async with self._isolation_lock:
164
+ # If we're being cancelled, just clean up
165
+ if exc_type is asyncio.CancelledError:
166
+ logger.warning("MCP %s: __aexit__ called with cancellation - cleaning up", self.url)
167
+ self._client = None
168
+ return True
169
+
170
+ # If client is already None, skip cleanup
171
+ if not self._client:
172
+ logger.warning("MCP %s: is uninitialized -> skipping cleanup", self.url)
173
+ return True
174
+
175
+ async def direct_cleanup():
176
+ return await super(CachedMCPServerStreamableHTTP, self).__aexit__(exc_type, exc_val, exc_tb) # pylint: disable=super-with-arguments
177
+
178
+ def fallback(_exc):
179
+ self._client = None
180
+ return True # Suppress exceptions to prevent propagation
181
+
182
+ return await self._run_direct_or_isolated(direct_cleanup, fallback, timeout=None)
183
+
184
+ async def list_tools(self) -> list[mcp_types.Tool]:
185
+ """Override to fix variable scoping bug and add caching with cancellation isolation."""
186
+ # If client is not initialized, return empty list
187
+ if not self._client:
188
+ logger.warning("MCP %s: is uninitialized -> no tools", self.url)
189
+ return []
190
+
191
+ # First, check if we have a valid cached result
192
+ if CACHE_KEY in self._tools_cache:
193
+ logger.info("Using cached tools for %s", self.url)
194
+ return self._tools_cache[CACHE_KEY]
195
+
196
+ # Create isolated task to prevent cancellation propagation
197
+ async def isolated_list_tools():
198
+ """Isolated list_tools with variable scoping bug fix."""
199
+ result = None # Initialize to prevent UnboundLocalError
200
+ async with self: # Ensure server is running
201
+ result = await self._client.list_tools()
202
+ if result:
203
+ self._tools_list = result.tools or []
204
+ self._tools_cache[CACHE_KEY] = self._tools_list
205
+ logger.info("MCP %s: list_tools returned %d tools", self.url, len(self._tools_list))
206
+ else:
207
+ logger.warning("MCP %s: list_tools returned no result", self.url)
208
+ return self._tools_list or []
209
+
210
+ def fallback(_exc):
211
+ return self._tools_list or []
212
+
213
+ return await self._run_direct_or_isolated(isolated_list_tools, fallback, timeout=5.0)
214
+
215
+ async def call_tool(
216
+ self,
217
+ name: str,
218
+ tool_args: dict[str, Any],
219
+ ctx: RunContext[Any],
220
+ tool: ToolsetTool[Any],
221
+ ) -> ToolResult:
222
+ """Call tool with complete isolation from cancellation using patched pydantic_ai."""
223
+ logger.info("MCP %s: call_tool '%s' started.", self.url, name)
224
+
225
+ # Early returns for uninitialized servers
226
+ if not self._client:
227
+ logger.warning("MCP %s: is uninitialized -> cannot call tool", self.url)
228
+ return f"There was an error with calling tool '{name}': MCP connection is uninitialized."
229
+
230
+ # Create isolated task to prevent cancellation propagation
231
+ async def isolated_call_tool():
232
+ """Isolated call_tool using patched pydantic_ai methods."""
233
+ return await super(CachedMCPServerStreamableHTTP, self).call_tool(name, tool_args, ctx, tool) # pylint: disable=super-with-arguments
234
+
235
+ def fallback(exc):
236
+ return f"Exception {type(exc)} when calling tool '{name}': {exc}. Consider alternative approaches."
237
+
238
+ result = await self._run_direct_or_isolated(isolated_call_tool, fallback, timeout=3600.0)
239
+ logger.info("MCP %s: call_tool '%s' completed.", self.url, name)
240
+ return result
241
+
242
+ async def direct_call_tool(
243
+ self, name: str, args: dict[str, Any], metadata: dict[str, Any] | None = None
244
+ ) -> ToolResult:
245
+ """Override to fix variable scoping bug in direct_call_tool."""
246
+ result = None # Initialize to prevent UnboundLocalError
247
+ async with self: # Ensure server is running
248
+ try:
249
+ result = await self._client.send_request(
250
+ mcp_types.ClientRequest(
251
+ mcp_types.CallToolRequest(
252
+ method="tools/call",
253
+ params=mcp_types.CallToolRequestParams(
254
+ name=name,
255
+ arguments=args,
256
+ _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
257
+ ),
258
+ )
259
+ ),
260
+ mcp_types.CallToolResult,
261
+ )
262
+ except McpError as e:
263
+ raise exceptions.ModelRetry(e.error.message)
264
+
265
+ if not result:
266
+ raise exceptions.ModelRetry("No result from MCP server")
267
+
268
+ content = [await self._map_tool_result_part(part) for part in result.content]
269
+
270
+ if result.isError:
271
+ text = "\n".join(str(part) for part in content)
272
+ raise exceptions.ModelRetry(text)
273
+
274
+ return content[0] if len(content) == 1 else content
aixtools/server/utils.py CHANGED
@@ -8,7 +8,7 @@ from functools import wraps
8
8
  from fastmcp import Context
9
9
  from fastmcp.server import dependencies
10
10
 
11
- from ..context import session_id_var, user_id_var
11
+ from ..context import DEFAULT_SESSION_ID, DEFAULT_USER_ID, session_id_var, user_id_var
12
12
 
13
13
 
14
14
  def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
@@ -18,9 +18,9 @@ def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
18
18
  Returns: Tuple of (user_id, session_id).
19
19
  """
20
20
  user_id = get_user_id_from_request(ctx)
21
- user_id = user_id or user_id_var.get("default_user")
21
+ user_id = user_id or user_id_var.get(DEFAULT_USER_ID)
22
22
  session_id = get_session_id_from_request(ctx)
23
- session_id = session_id or session_id_var.get("default_session")
23
+ session_id = session_id or session_id_var.get(DEFAULT_SESSION_ID)
24
24
  return user_id, session_id
25
25
 
26
26
 
aixtools/utils/config.py CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
9
9
  from dotenv import dotenv_values, load_dotenv
10
10
 
11
11
  from aixtools.utils.config_util import find_env_file, get_project_root, get_variable_env
12
+ from aixtools.utils.utils import str2bool
12
13
 
13
14
  # Debug mode
14
15
  LOG_LEVEL = logging.DEBUG
@@ -116,3 +117,15 @@ BEDROCK_MODEL_NAME = get_variable_env("BEDROCK_MODEL_NAME", allow_empty=True)
116
117
  # LogFire
117
118
  LOGFIRE_TOKEN = get_variable_env("LOGFIRE_TOKEN", True, "")
118
119
  LOGFIRE_TRACES_ENDPOINT = get_variable_env("LOGFIRE_TRACES_ENDPOINT", True, "")
120
+
121
+ # Google Vertex AI
122
+ GOOGLE_GENAI_USE_VERTEXAI = str2bool(get_variable_env("GOOGLE_GENAI_USE_VERTEXAI", True, True))
123
+ GOOGLE_CLOUD_PROJECT = get_variable_env("GOOGLE_CLOUD_PROJECT", True)
124
+ GOOGLE_CLOUD_LOCATION = get_variable_env("GOOGLE_CLOUD_LOCATION", True)
125
+
126
+ # vault parameters.
127
+ VAULT_ADDRESS = get_variable_env("VAULT_ADDRESS", default="http://localhost:8200")
128
+ VAULT_TOKEN = get_variable_env("VAULT_TOKEN", default="vault-token")
129
+ VAULT_ENV = get_variable_env("ENV", default="dev")
130
+ VAULT_MOUNT_POINT = get_variable_env("VAULT_MOUNT_POINT", default="secret")
131
+ VAULT_PATH_PREFIX = get_variable_env("VAULT_PATH_PREFIX", default="path")
@@ -0,0 +1,17 @@
1
+ """File utilities"""
2
+
3
+
4
+ def is_text_content(data: bytes, mime_type: str) -> bool:
5
+ """Check if content is text based on mime type and content analysis."""
6
+ # Check mime type first
7
+ if mime_type and (
8
+ mime_type.startswith("text/") or mime_type in ["application/json", "application/xml", "application/javascript"]
9
+ ):
10
+ return True
11
+
12
+ # Try to decode as UTF-8 to check if it's text
13
+ try:
14
+ data.decode("utf-8")
15
+ return True
16
+ except UnicodeDecodeError:
17
+ return False
aixtools/utils/utils.py CHANGED
@@ -154,6 +154,13 @@ def timestamp_uuid_tuple() -> tuple[str, str, str]:
154
154
  return (now.strftime("%Y-%m-%d"), now.strftime("%H:%M:%S"), str(uuid.uuid4()))
155
155
 
156
156
 
157
+ def str2bool(v: str | None) -> bool:
158
+ """Convert a string to a boolean value."""
159
+ if not v:
160
+ return False
161
+ return str(v).lower() in ("yes", "true", "on", "1")
162
+
163
+
157
164
  async def async_iter(items):
158
165
  """Asynchronously iterate over items."""
159
166
  for item in items:
@@ -0,0 +1,7 @@
1
+ """
2
+ Provides a Vault client for storing and retrieving user service api keys.
3
+ """
4
+
5
+ from .vault import VaultClient
6
+
7
+ __all__ = ["VaultClient"]
@@ -0,0 +1,73 @@
1
+ # ruff: noqa: PLR0913
2
+ """
3
+ Provides a Vault client for storing and retrieving user service api keys.
4
+ """
5
+
6
+ import logging
7
+ from typing import Optional
8
+
9
+ import hvac
10
+ from hvac.exceptions import InvalidPath
11
+
12
+ from aixtools.utils import config
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class VaultAuthError(Exception):
18
+ """Exception raised for vault authentication errors."""
19
+
20
+
21
+ class VaultClient:
22
+ """Vault client for storing and retrieving user service api keys."""
23
+
24
+ def __init__(self):
25
+ self.client = hvac.Client(url=config.VAULT_ADDRESS, token=config.VAULT_TOKEN)
26
+
27
+ if not self.client.is_authenticated():
28
+ raise VaultAuthError("Vault client authentication failed. Check vault_token.")
29
+
30
+ def store_user_service_api_key(self, *, user_id: str, service_name: str, user_api_key: str):
31
+ """
32
+ Store user's service api key in the Vault at the specified vault mount
33
+ point, where the path is <path_prefix>/<env>/<user_id>/<service_name>.
34
+ """
35
+ secret_path = None
36
+ try:
37
+ secret_path = f"{config.VAULT_PATH_PREFIX}/{config.VAULT_ENV}/{user_id}/{service_name}"
38
+ print("secret_path", secret_path)
39
+ secret_dict = {"user-api-key": user_api_key}
40
+ self.client.secrets.kv.v2.create_or_update_secret(
41
+ secret_path, secret=secret_dict, mount_point=config.VAULT_MOUNT_POINT
42
+ )
43
+
44
+ logger.info("Secret written to path %s", secret_path)
45
+ except Exception as e:
46
+ logger.error("Failed to write secret to path %s: %s", secret_path, str(e))
47
+ raise VaultAuthError(e) from e
48
+
49
+ def read_user_service_api_key(self, *, user_id: str, service_name) -> Optional[str]:
50
+ """
51
+ Read user's service api key in from vault at the specified mount point,
52
+ where the path is <path_prefix>/<env>/<user_id>/<service_name>.
53
+ """
54
+ secret_path = None
55
+
56
+ try:
57
+ secret_path = f"{config.VAULT_PATH_PREFIX}/{config.VAULT_ENV}/{user_id}/{service_name}"
58
+ logger.info("Reading secret from path %s", secret_path)
59
+ response = self.client.secrets.kv.v2.read_secret_version(
60
+ secret_path, mount_point=config.VAULT_MOUNT_POINT, raise_on_deleted_version=True
61
+ )
62
+ secret_data = response["data"]["data"]
63
+ user_api_key = secret_data["user-api-key"]
64
+ logger.info("Secret read from path %s ", secret_path)
65
+ return user_api_key
66
+ except InvalidPath:
67
+ # Secret path does not exist
68
+ logger.warning("Secret path does not exist %s ", secret_path)
69
+ return None
70
+
71
+ except Exception as e:
72
+ logger.error("Failed to read secret from path %s: %s", secret_path, str(e))
73
+ raise VaultAuthError(e) from e