aixtools 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

aixtools/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.5'
32
- __version_tuple__ = version_tuple = (0, 2, 5)
31
+ __version__ = version = '0.2.7'
32
+ __version_tuple__ = version_tuple = (0, 2, 7)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,3 +1,5 @@
1
+ """Utilities for handling A2A SDK agent cards and connections."""
2
+
1
3
  import asyncio
2
4
 
3
5
  import httpx
@@ -8,13 +10,34 @@ from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PAT
8
10
 
9
11
  from aixtools.a2a.google_sdk.remote_agent_connection import RemoteAgentConnection
10
12
  from aixtools.context import DEFAULT_SESSION_ID, DEFAULT_USER_ID, SessionIdTuple
13
+ from aixtools.logging.logging_config import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+ DEFAULT_A2A_TIMEOUT = 60.0
11
18
 
12
19
 
13
20
  class AgentCardLoadFailedError(Exception):
14
21
  pass
15
22
 
16
23
 
24
+ async def get_agent_card(client: httpx.AsyncClient, address: str) -> AgentCard:
25
+ """Retrieve the agent card from the given agent address."""
26
+ for card_path in [AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH]:
27
+ try:
28
+ card_resolver = A2ACardResolver(client, address, card_path)
29
+ card = await card_resolver.get_agent_card()
30
+ card.url = address
31
+ return card
32
+ except Exception as e:
33
+ logger.warning(f"Error retrieving agent card from {address} at path {card_path}: {e}")
34
+
35
+ raise AgentCardLoadFailedError(f"Failed to load agent card from {address}")
36
+
37
+
17
38
  class _AgentCardResolver:
39
+ """Helper class to resolve and manage agent cards and their connections."""
40
+
18
41
  def __init__(self, client: httpx.AsyncClient):
19
42
  self._httpx_client = client
20
43
  self._a2a_client_factory = ClientFactory(ClientConfig(httpx_client=self._httpx_client))
@@ -25,17 +48,13 @@ class _AgentCardResolver:
25
48
  self.clients[card.name] = remote_connection
26
49
 
27
50
  async def retrieve_card(self, address: str):
28
- for card_path in [AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH]:
29
- try:
30
- card_resolver = A2ACardResolver(self._httpx_client, address, card_path)
31
- card = await card_resolver.get_agent_card()
32
- card.url = address
33
- self.register_agent_card(card)
34
- return
35
- except Exception as e:
36
- print(f"Error retrieving agent card from {address} at path {card_path}: {e}")
37
-
38
- raise AgentCardLoadFailedError(f"Failed to load agent card from {address}")
51
+ try:
52
+ card = await get_agent_card(self._httpx_client, address)
53
+ self.register_agent_card(card)
54
+ return
55
+ except Exception as e:
56
+ logger.error(f"Error retrieving agent card from {address}: {e}")
57
+ return
39
58
 
40
59
  async def get_a2a_clients(self, agent_hosts: list[str]) -> dict[str, RemoteAgentConnection]:
41
60
  async with asyncio.TaskGroup() as task_group:
@@ -45,15 +64,19 @@ class _AgentCardResolver:
45
64
  return self.clients
46
65
 
47
66
 
48
- async def get_a2a_clients(ctx: SessionIdTuple, agent_hosts: list[str]) -> dict[str, RemoteAgentConnection]:
67
+ async def get_a2a_clients(
68
+ ctx: SessionIdTuple, agent_hosts: list[str], *, timeout: float = DEFAULT_A2A_TIMEOUT
69
+ ) -> dict[str, RemoteAgentConnection]:
70
+ """Get A2A clients for all agents defined in the configuration."""
49
71
  headers = {
50
72
  "user-id": ctx[0],
51
73
  "session-id": ctx[1],
52
74
  }
53
- httpx_client = httpx.AsyncClient(headers=headers, timeout=60.0)
75
+ httpx_client = httpx.AsyncClient(headers=headers, timeout=timeout, follow_redirects=True)
54
76
  return await _AgentCardResolver(httpx_client).get_a2a_clients(agent_hosts)
55
77
 
56
78
 
57
79
  def get_session_id_tuple(context: RequestContext) -> SessionIdTuple:
80
+ """Get the user_id, session_id tuple from the request context."""
58
81
  headers = context.call_context.state.get("headers", {})
59
82
  return headers.get("user-id", DEFAULT_USER_ID), headers.get("session-id", DEFAULT_SESSION_ID)
aixtools/auth/auth.py CHANGED
@@ -2,21 +2,60 @@
2
2
  Module that manages OAuth2 functions for authentication
3
3
  """
4
4
 
5
+ import enum
5
6
  import logging
6
7
 
7
8
  import jwt
8
- from jwt import ExpiredSignatureError, InvalidAudienceError, InvalidIssuerError, PyJWKClient
9
+ from fastapi import HTTPException
10
+ from jwt import ExpiredSignatureError, InvalidAudienceError, InvalidIssuerError, InvalidSignatureError, PyJWKClient
9
11
 
10
12
  from aixtools.utils import config
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
14
16
 
17
+ class AuthTokenErrorCode(str, enum.Enum):
18
+ """Enum for error codes returned by the AuthTokenError exception."""
19
+
20
+ TOKEN_EXPIRED = "Token expired"
21
+ INVALID_AUDIENCE = "Token not for expected audience"
22
+ INVALID_ISSUER = "Token not for expected issuer"
23
+ INVALID_SIGNATURE = "Token signature error"
24
+ INVALID_TOKEN = "Invalid token"
25
+ JWT_ERROR = "Generic JWT error"
26
+ MISSING_GROUPS_ERROR = "Missing authorized groups"
27
+ INVALID_TOKEN_SCOPE = "Token scope does not match configured scope"
28
+
29
+
15
30
  class AuthTokenError(Exception):
16
31
  """Exception raised for authentication token errors."""
17
32
 
33
+ def __init__(self, error_code: AuthTokenErrorCode, msg: str = None):
34
+ self.error_code = error_code
35
+ error_msg = error_code.value if msg is None else msg
36
+ super().__init__(error_msg)
37
+
38
+ def to_http_exception(self, required_scope: str = None, realm: str = "MCP") -> HTTPException:
39
+ """
40
+ Returns an HTTPException with 401 status for all AuthTokenErrorCode,
41
+ including MCP JSON body and WWW-Authenticate header.
42
+ """
43
+ status_code = 401
44
+ www_error = (
45
+ "insufficient_scope" if self.error_code == AuthTokenErrorCode.INVALID_TOKEN_SCOPE else "invalid_token"
46
+ )
47
+
48
+ header_value = f'Bearer realm="{realm}", error="{www_error}", error_description="{self.error_code.value}"'
49
+ if self.error_code == AuthTokenErrorCode.INVALID_TOKEN_SCOPE and required_scope:
50
+ header_value += f', scope="{required_scope}"'
51
+
52
+ detail = {"error": {"code": self.error_code.name, "message": self.error_code.value}}
53
+ if self.error_code == AuthTokenErrorCode.INVALID_TOKEN_SCOPE and required_scope:
54
+ detail["error"]["required_scope"] = required_scope
55
+
56
+ return HTTPException(status_code=status_code, detail=detail, headers={"WWW-Authenticate": header_value})
57
+
18
58
 
19
- # pylint: disable=too-few-public-methods
20
59
  class AccessTokenVerifier:
21
60
  """
22
61
  Verifies Microsoft SSO JWT token against the configured Tenant ID, Audience, API ID and Issuer URL.
@@ -25,7 +64,12 @@ class AccessTokenVerifier:
25
64
  def __init__(self):
26
65
  tenant_id = config.APP_TENANT_ID
27
66
  self.api_id = config.APP_API_ID
28
- self.issuer_url = config.APP_ISSUER_URL
67
+ self.issuer_url = f"https://sts.windows.net/{tenant_id}/"
68
+
69
+ self.authorized_groups = set(config.APP_AUTHORIZED_GROUPS.split(",")) if config.APP_AUTHORIZED_GROUPS else set()
70
+ if not self.authorized_groups:
71
+ logger.warning("No authorized groups configured")
72
+
29
73
  # Azure AD endpoints
30
74
  jwks_url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
31
75
  self.jwks_client = PyJWKClient(
@@ -48,7 +92,7 @@ class AccessTokenVerifier:
48
92
  """
49
93
  try:
50
94
  signing_key = self.jwks_client.get_signing_key_from_jwt(token)
51
-
95
+ logger.info("Verifying JWT token")
52
96
  claims = jwt.decode(
53
97
  token,
54
98
  signing_key.key,
@@ -58,13 +102,48 @@ class AccessTokenVerifier:
58
102
  # ensure audience verification is carried out
59
103
  options={"verify_aud": True},
60
104
  )
105
+ logger.info("Verified JWT token")
61
106
  return claims
62
107
 
63
108
  except ExpiredSignatureError as e:
64
- raise AuthTokenError("Token expired") from e
109
+ raise AuthTokenError(AuthTokenErrorCode.TOKEN_EXPIRED) from e
65
110
  except InvalidAudienceError as e:
66
- raise AuthTokenError(f"Token not for expected audience: {e}") from e
111
+ raise AuthTokenError(AuthTokenErrorCode.INVALID_AUDIENCE) from e
67
112
  except InvalidIssuerError as e:
68
- raise AuthTokenError(f"Token not for expected issuer: {e}") from e
113
+ raise AuthTokenError(AuthTokenErrorCode.INVALID_ISSUER) from e
114
+ except InvalidSignatureError as e:
115
+ raise AuthTokenError(AuthTokenErrorCode.INVALID_SIGNATURE) from e
69
116
  except jwt.exceptions.PyJWTError as e:
70
- raise AuthTokenError(f"Invalid token: {e}") from e
117
+ raise AuthTokenError(AuthTokenErrorCode.JWT_ERROR) from e
118
+
119
+ def authorize_claims(self, claims: dict, expected_scope: str):
120
+ """
121
+ Authorize claims based on token scope, expected scope and authorized groups
122
+ claims: decoded JWT claims
123
+ expected_scope: expected scope for the token
124
+ Raises AuthTokenError if authorization fails.
125
+ """
126
+ logger.info("Checking JWT token claims")
127
+ if expected_scope:
128
+ token_scopes = claims.get("scp", "").split()
129
+ if expected_scope not in token_scopes:
130
+ logger.error("Expected token scope: %s, got: %s", expected_scope, token_scopes)
131
+ raise AuthTokenError(
132
+ AuthTokenErrorCode.INVALID_TOKEN_SCOPE,
133
+ f"Expected token scope: {expected_scope}, got: {token_scopes}",
134
+ )
135
+
136
+ if not self.authorized_groups:
137
+ logger.info("Authorized JWT token, no authorized groups configured")
138
+ return
139
+
140
+ groups = claims.get("groups", [])
141
+ if self.authorized_groups & set(groups):
142
+ logger.info("Authorized JWT token, against %s", groups)
143
+ return
144
+
145
+ logger.error("Could not find any group in JWT token, matching: %s", self.authorized_groups)
146
+ raise AuthTokenError(
147
+ AuthTokenErrorCode.MISSING_GROUPS_ERROR,
148
+ f"Could not find any group in JWT token, matching: {self.authorized_groups}",
149
+ )
aixtools/utils/config.py CHANGED
@@ -135,5 +135,6 @@ APP_CLIENT_ID = get_variable_env("APP_CLIENT_ID")
135
135
  # used for token audience check
136
136
  APP_API_ID = get_variable_env("APP_API_ID")
137
137
  APP_TENANT_ID = get_variable_env("APP_TENANT_ID")
138
- # used for token issuer check
139
- APP_ISSUER_URL = get_variable_env("APP_ISSUER_URL")
138
+
139
+ # used for token authorization check
140
+ APP_AUTHORIZED_GROUPS = get_variable_env("APP_AUTHORIZED_GROUPS", allow_empty=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aixtools
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: Tools for AI exploration and debugging
5
5
  Requires-Python: >=3.11.2
6
6
  Description-Content-Type: text/markdown
@@ -1,5 +1,5 @@
1
1
  aixtools/__init__.py,sha256=9NGHm7LjsQmsvjTZvw6QFJexSvAU4bCoN_KBk9SCa00,260
2
- aixtools/_version.py,sha256=9wrJ_4Dlc0arUzKiaIqvTY85rMJma3eb1nNlF3uHAxU,704
2
+ aixtools/_version.py,sha256=yXzK2akXKIKUAfJk0WCQothqygqvndys6GBuXxo-wk0,704
3
3
  aixtools/app.py,sha256=JzQ0nrv_bjDQokllIlGHOV0HEb-V8N6k_nGQH-TEsVU,5227
4
4
  aixtools/chainlit.md,sha256=yC37Ly57vjKyiIvK4oUvf4DYxZCwH7iocTlx7bLeGLU,761
5
5
  aixtools/context.py,sha256=I_MD40ZnvRm5WPKAKqBUAdXIf8YaurkYUUHSVVy-QvU,598
@@ -22,7 +22,7 @@ aixtools/a2a/utils.py,sha256=EHr3IyyBJn23ni-JcfAf6i3VpQmPs0g1TSnAZazvY_8,4039
22
22
  aixtools/a2a/google_sdk/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  aixtools/a2a/google_sdk/card.py,sha256=P0L3bKbm28HaRkcIxIvjuSGUKOOc0ymyRAFHKm3a5GQ,996
24
24
  aixtools/a2a/google_sdk/remote_agent_connection.py,sha256=oDCRSN3gfONY1Ibp8BrtysIVqfQQ-lWe5N7lr1ymHxY,2819
25
- aixtools/a2a/google_sdk/utils.py,sha256=hjrNRZywJEUxHHaOttJFQU0FLzteg0Ggtm3qAeXMSVw,2430
25
+ aixtools/a2a/google_sdk/utils.py,sha256=kag7KWNRatpw9bf9ThqBdDlyn0QuBx8-Pja1id1Gk30,3242
26
26
  aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py,sha256=MMbhbEnUL6NwSYnisJrDdHW8zJoSyJ3Pzzkt8jqwNdI,7066
27
27
  aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py,sha256=nGoVL7MPoZJW7iVR71laqpUYP308yFKZIifJtvUgpiU,878
28
28
  aixtools/agents/__init__.py,sha256=MAW196S2_G7uGqv-VNjvlOETRfuV44WlU1leO7SiR0A,282
@@ -31,7 +31,7 @@ aixtools/agents/agent_batch.py,sha256=0Zu9yNCRPAQZPjXQ-dIUAmP1uGTVbxVt7xvnMpoJMj
31
31
  aixtools/agents/print_nodes.py,sha256=wVTngNfqM0As845WTRz6G3Rei_Gr3HuBlvu-G_eXuig,1665
32
32
  aixtools/agents/prompt.py,sha256=p9OYnyJ4-MyGXwHPrQeJBhZ2a3RV2HqhtdUUCrTMsAQ,3361
33
33
  aixtools/auth/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
- aixtools/auth/auth.py,sha256=aKYCKJRjSNrVZmIWN2h2p1zYqkhMLLBXBfk_Qy5NKik,2365
34
+ aixtools/auth/auth.py,sha256=4vNcljdpFzvIBXqjK7i3xAy-CC8Hc7ii11t-11z6ofY,6041
35
35
  aixtools/compliance/__init__.py,sha256=vnw0zEdySIJWvDAJ8DCRRaWmY_agEOz1qlpAdhmtiuo,191
36
36
  aixtools/compliance/private_data.py,sha256=OOM9mIp3_w0fNgj3VAEWBl7-jrPc19_Ls1pC5dfF5UY,5323
37
37
  aixtools/db/__init__.py,sha256=b8vRhme3egV-aUZbAntnOaDkSXB8UT0Xy5oqQhU_z0Q,399
@@ -79,7 +79,7 @@ aixtools/tools/doctor/mcp_tool_doctor.py,sha256=sX2q5GfNkmUYxnXrqMpeGIwGfeL1LpYJ
79
79
  aixtools/tools/doctor/tool_doctor.py,sha256=EY1pshjLGLD0j6cc1ZFtbc0G19I5IbOZwHFDqypE49Q,2661
80
80
  aixtools/tools/doctor/tool_recommendation.py,sha256=LYyVOSXdAorWiY4P-ucSA1vLlV5BTEfX4GzBXNE_X0M,1569
81
81
  aixtools/utils/__init__.py,sha256=xT6almZBQYMfj4h7Hq9QXDHyVXbOOTxqLsmJsxYYnSw,757
82
- aixtools/utils/config.py,sha256=t32731F53Cv1YYoX95wksoreE0Zn0B8UKyEiKWne4ec,5147
82
+ aixtools/utils/config.py,sha256=fCATakEfeASo6eFMh5Me_idwMQgiITS3LMYc2Z1bGzU,5187
83
83
  aixtools/utils/config_util.py,sha256=3Ya4Qqhj1RJ1qtTTykQ6iayf5uxlpigPXgEJlTi1wn4,2229
84
84
  aixtools/utils/enum_with_description.py,sha256=zjSzWxG74eR4x7dpmb74pLTYCWNSMvauHd7_9LpDYIw,1088
85
85
  aixtools/utils/files.py,sha256=8JnxwHJRJcjWCdFpjzWmo0po2fRg8esj4H7sOxElYXU,517
@@ -89,8 +89,8 @@ aixtools/utils/chainlit/cl_agent_show.py,sha256=vaRuowp4BRvhxEr5hw0zHEJ7iaSF_5bo
89
89
  aixtools/utils/chainlit/cl_utils.py,sha256=fxaxdkcZg6uHdM8uztxdPowg3a2f7VR7B26VPY4t-3c,5738
90
90
  aixtools/vault/__init__.py,sha256=fsr_NuX3GZ9WZ7dGfe0gp_5-z3URxAfwVRXw7Xyc0dU,141
91
91
  aixtools/vault/vault.py,sha256=9dZLWdZQk9qN_Q9Djkofw9LUKnJqnrX5H0fGusVLBhA,6037
92
- aixtools-0.2.5.dist-info/METADATA,sha256=BHPUgnHXs7ET3BvwAkPxYRkkXnxLdptFwbYNDkoBMbw,27229
93
- aixtools-0.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
94
- aixtools-0.2.5.dist-info/entry_points.txt,sha256=q8412TG4T0S8K0SKeWp2vkVPIDYQs0jNoHqcQ7qxOiA,155
95
- aixtools-0.2.5.dist-info/top_level.txt,sha256=wBn-rw9bCtxrR4AYEYgjilNCUVmKY0LWby9Zan2PRJM,9
96
- aixtools-0.2.5.dist-info/RECORD,,
92
+ aixtools-0.2.7.dist-info/METADATA,sha256=V3-mduY2Z8lURHq7Vhe4KVyieG1snJAHeSMmvVL2k5k,27229
93
+ aixtools-0.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
94
+ aixtools-0.2.7.dist-info/entry_points.txt,sha256=q8412TG4T0S8K0SKeWp2vkVPIDYQs0jNoHqcQ7qxOiA,155
95
+ aixtools-0.2.7.dist-info/top_level.txt,sha256=wBn-rw9bCtxrR4AYEYgjilNCUVmKY0LWby9Zan2PRJM,9
96
+ aixtools-0.2.7.dist-info/RECORD,,