agentstack-sdk 0.5.0rc5__py3-none-any.whl → 0.5.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. agentstack_sdk/a2a/extensions/__init__.py +1 -0
  2. agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +16 -9
  3. agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +12 -6
  4. agentstack_sdk/a2a/extensions/base.py +20 -11
  5. agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
  6. agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
  7. agentstack_sdk/a2a/extensions/services/embedding.py +10 -3
  8. agentstack_sdk/a2a/extensions/services/llm.py +6 -4
  9. agentstack_sdk/a2a/extensions/services/mcp.py +8 -4
  10. agentstack_sdk/a2a/extensions/services/platform.py +34 -16
  11. agentstack_sdk/a2a/extensions/ui/__init__.py +1 -0
  12. agentstack_sdk/a2a/extensions/ui/canvas.py +6 -3
  13. agentstack_sdk/a2a/extensions/ui/error.py +5 -4
  14. agentstack_sdk/a2a/extensions/ui/form_request.py +6 -3
  15. agentstack_sdk/a2a/types.py +2 -11
  16. agentstack_sdk/platform/client.py +10 -8
  17. agentstack_sdk/platform/context.py +15 -1
  18. agentstack_sdk/server/agent.py +19 -9
  19. agentstack_sdk/server/app.py +7 -0
  20. agentstack_sdk/server/dependencies.py +13 -8
  21. agentstack_sdk/server/exceptions.py +3 -0
  22. agentstack_sdk/server/middleware/__init__.py +3 -0
  23. agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
  24. agentstack_sdk/server/server.py +19 -5
  25. agentstack_sdk/types.py +15 -0
  26. {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc3.dist-info}/METADATA +3 -1
  27. {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc3.dist-info}/RECORD +28 -22
  28. {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc3.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
1
  # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  import uuid
4
- from typing import TYPE_CHECKING, Literal, TypeAlias
4
+ from typing import Literal, TypeAlias
5
5
 
6
6
  from a2a.types import (
7
7
  Artifact,
@@ -20,16 +20,7 @@ from a2a.types import (
20
20
  )
21
21
  from pydantic import Field, model_validator
22
22
 
23
- if TYPE_CHECKING:
24
- JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None
25
- JsonDict: TypeAlias = dict[str, JsonValue]
26
- else:
27
- from typing import Union
28
-
29
- from typing_extensions import TypeAliasType
30
-
31
- JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007
32
- JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]")
23
+ from agentstack_sdk.types import JsonDict, JsonValue
33
24
 
34
25
 
35
26
  class Metadata(dict[str, JsonValue]): ...
@@ -5,7 +5,7 @@ import contextlib
5
5
  import os
6
6
  import ssl
7
7
  import typing
8
- from collections.abc import AsyncIterator
8
+ from collections.abc import AsyncIterator, Mapping
9
9
  from types import TracebackType
10
10
 
11
11
  import httpx
@@ -14,6 +14,7 @@ from httpx._client import EventHook
14
14
  from httpx._config import DEFAULT_LIMITS, DEFAULT_MAX_REDIRECTS, Limits
15
15
  from httpx._types import AuthTypes, CertTypes, CookieTypes, HeaderTypes, ProxyTypes, QueryParamTypes, TimeoutTypes
16
16
  from pydantic import Secret
17
+ from typing_extensions import override
17
18
 
18
19
  from agentstack_sdk.util import resource_context
19
20
 
@@ -26,7 +27,7 @@ class PlatformClient(httpx.AsyncClient):
26
27
  def __init__(
27
28
  self,
28
29
  context_id: str | None = None, # Enter context scope
29
- auth_token: str | Secret | None = None,
30
+ auth_token: str | Secret[str] | None = None,
30
31
  *,
31
32
  auth: AuthTypes | None = None,
32
33
  params: QueryParamTypes | None = None,
@@ -37,12 +38,12 @@ class PlatformClient(httpx.AsyncClient):
37
38
  http1: bool = True,
38
39
  http2: bool = False,
39
40
  proxy: ProxyTypes | None = None,
40
- mounts: None | (typing.Mapping[str, AsyncBaseTransport | None]) = None,
41
+ mounts: None | (Mapping[str, AsyncBaseTransport | None]) = None,
41
42
  timeout: TimeoutTypes = DEFAULT_SDK_TIMEOUT,
42
43
  follow_redirects: bool = False,
43
44
  limits: Limits = DEFAULT_LIMITS,
44
45
  max_redirects: int = DEFAULT_MAX_REDIRECTS,
45
- event_hooks: None | (typing.Mapping[str, list[EventHook]]) = None,
46
+ event_hooks: None | (Mapping[str, list[EventHook]]) = None,
46
47
  base_url: URL | str = "",
47
48
  transport: AsyncBaseTransport | None = None,
48
49
  trust_env: bool = True,
@@ -74,16 +75,18 @@ class PlatformClient(httpx.AsyncClient):
74
75
  self.context_id = context_id
75
76
  if auth_token:
76
77
  self.headers["Authorization"] = f"Bearer {auth_token}"
77
- self._ref_count = 0
78
- self._context_manager_lock = asyncio.Lock()
78
+ self._ref_count: int = 0
79
+ self._context_manager_lock: asyncio.Lock = asyncio.Lock()
79
80
 
81
+ @override
80
82
  async def __aenter__(self) -> typing.Self:
81
83
  async with self._context_manager_lock:
82
84
  self._ref_count += 1
83
85
  if self._ref_count == 1:
84
- await super().__aenter__()
86
+ _ = await super().__aenter__()
85
87
  return self
86
88
 
89
+ @override
87
90
  async def __aexit__(
88
91
  self,
89
92
  exc_type: type[BaseException] | None = None,
@@ -94,7 +97,6 @@ class PlatformClient(httpx.AsyncClient):
94
97
  self._ref_count -= 1
95
98
  if self._ref_count == 0:
96
99
  await super().__aexit__(exc_type, exc_value, traceback)
97
- self._resource = None
98
100
 
99
101
 
100
102
  get_platform_client, set_platform_client = resource_context(factory=PlatformClient, default_factory=PlatformClient)
@@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, SerializeAsAny
13
13
 
14
14
  from agentstack_sdk.platform.client import PlatformClient, get_platform_client
15
15
  from agentstack_sdk.platform.common import PaginatedResult
16
+ from agentstack_sdk.platform.provider import Provider
16
17
  from agentstack_sdk.platform.types import Metadata, MetadataPatch
17
18
  from agentstack_sdk.util.utils import filter_dict
18
19
 
@@ -40,7 +41,7 @@ class ContextPermissions(pydantic.BaseModel):
40
41
  class Permissions(ContextPermissions):
41
42
  llm: set[Literal["*"] | str] = set()
42
43
  embeddings: set[Literal["*"] | str] = set()
43
- a2a_proxy: set[Literal["*"]] = set()
44
+ a2a_proxy: set[Literal["*"] | str] = set()
44
45
  model_providers: set[Literal["read", "write", "*"]] = set()
45
46
  variables: SerializeAsAny[set[Literal["read", "write", "*"]]] = set()
46
47
 
@@ -179,6 +180,7 @@ class Context(pydantic.BaseModel):
179
180
  async def generate_token(
180
181
  self: Context | str,
181
182
  *,
183
+ providers: list[str] | list[Provider] | None = None,
182
184
  client: PlatformClient | None = None,
183
185
  grant_global_permissions: Permissions | None = None,
184
186
  grant_context_permissions: ContextPermissions | None = None,
@@ -193,6 +195,18 @@ class Context(pydantic.BaseModel):
193
195
  context_id = self if isinstance(self, str) else self.id
194
196
  grant_global_permissions = grant_global_permissions or Permissions()
195
197
  grant_context_permissions = grant_context_permissions or Permissions()
198
+
199
+ if isinstance(self, Context) and self.metadata and (provider_id := self.metadata.get("provider_id", None)):
200
+ providers = providers or [provider_id]
201
+
202
+ if "*" not in grant_global_permissions.a2a_proxy and not grant_global_permissions.a2a_proxy:
203
+ if not providers:
204
+ raise ValueError(
205
+ "Invalid audience: You must specify providers or use '*' in grant_global_permissions.a2a_proxy."
206
+ )
207
+
208
+ grant_global_permissions.a2a_proxy |= {p.id if isinstance(p, Provider) else p for p in providers}
209
+
196
210
  async with client or get_platform_client() as client:
197
211
  token_response = (
198
212
  (
@@ -46,7 +46,7 @@ from agentstack_sdk.a2a.extensions.ui.error import (
46
46
  from agentstack_sdk.a2a.types import ArtifactChunk, Metadata, RunYield, RunYieldResume
47
47
  from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX
48
48
  from agentstack_sdk.server.context import RunContext
49
- from agentstack_sdk.server.dependencies import Depends, extract_dependencies
49
+ from agentstack_sdk.server.dependencies import Dependency, Depends, extract_dependencies
50
50
  from agentstack_sdk.server.store.context_store import ContextStore
51
51
  from agentstack_sdk.server.utils import cancel_task
52
52
  from agentstack_sdk.util.logging import logger
@@ -234,6 +234,7 @@ class AgentRun:
234
234
  self.last_invocation: datetime = datetime.now()
235
235
  self.resume_queue: asyncio.Queue[RunYieldResume] = asyncio.Queue()
236
236
  self._run_context: RunContext | None = None
237
+ self._request_context: RequestContext | None = None
237
238
  self._task_updater: TaskUpdater | None = None
238
239
  self._context_store: ContextStore = context_store
239
240
  self._lock: asyncio.Lock = asyncio.Lock()
@@ -246,6 +247,12 @@ class AgentRun:
246
247
  raise RuntimeError("Accessing run context for run that has not been started")
247
248
  return self._run_context
248
249
 
250
+ @property
251
+ def request_context(self) -> RequestContext:
252
+ if not self._request_context:
253
+ raise RuntimeError("Accessing request context for run that has not been started")
254
+ return self._request_context
255
+
249
256
  @property
250
257
  def task_updater(self) -> TaskUpdater:
251
258
  if not self._task_updater:
@@ -261,7 +268,6 @@ class AgentRun:
261
268
  self._on_finish()
262
269
 
263
270
  async def start(self, request_context: RequestContext, event_queue: EventQueue):
264
- # These are incorrectly typed in a2a
265
271
  async with self._lock:
266
272
  if self._working or self.done:
267
273
  raise RuntimeError("Attempting to start a run that is already executing or done")
@@ -274,6 +280,7 @@ class AgentRun:
274
280
  current_task=request_context.current_task,
275
281
  related_tasks=request_context.related_tasks,
276
282
  )
283
+ self._request_context = request_context
277
284
  self._task_updater = TaskUpdater(event_queue, task_id, context_id)
278
285
  if not request_context.current_task:
279
286
  await self._task_updater.submit()
@@ -288,11 +295,12 @@ class AgentRun:
288
295
  raise RuntimeError("Attempting to resume a run that is already executing or done")
289
296
  task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
290
297
  assert task_id and context_id and message
298
+ self._request_context = request_context
291
299
  self._task_updater = TaskUpdater(event_queue, task_id, context_id)
292
300
 
293
301
  for dependency in self._agent.dependencies.values():
294
302
  if dependency.extension:
295
- dependency.extension.handle_incoming_message(message, self.run_context)
303
+ dependency.extension.handle_incoming_message(message, self.run_context, request_context)
296
304
 
297
305
  self._working = True
298
306
  await self.resume_queue.put(message)
@@ -311,15 +319,15 @@ class AgentRun:
311
319
  await cancel_task(self._task)
312
320
 
313
321
  @asynccontextmanager
314
- async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Depends]]:
322
+ async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Dependency]]:
315
323
  async with AsyncExitStack() as stack:
316
- dependency_args: dict[str, Depends] = {}
324
+ dependency_args: dict[str, Dependency] = {}
317
325
  initialize_deps_exceptions: list[Exception] = []
318
326
  for pname, depends in self._agent.dependencies.items():
319
327
  # call dependencies with the first message and initialize their lifespan
320
328
  try:
321
329
  dependency_args[pname] = await stack.enter_async_context(
322
- depends(message, self.run_context, dependency_args)
330
+ depends(message, self.run_context, self.request_context, dependency_args)
323
331
  )
324
332
  except Exception as e:
325
333
  initialize_deps_exceptions.append(e)
@@ -524,6 +532,8 @@ class Executor(AgentExecutor):
524
532
  match await tapped_queue.dequeue_event():
525
533
  case TaskStatusUpdateEvent(final=True):
526
534
  break
535
+ case _:
536
+ pass
527
537
 
528
538
  except CancelledError:
529
539
  if agent_run:
@@ -571,14 +581,14 @@ class Executor(AgentExecutor):
571
581
  event = await queue.dequeue_event(no_wait=True)
572
582
  if not isinstance(event, TaskStatusUpdateEvent) or event.status.state != TaskState.canceled:
573
583
  raise RuntimeError(f"Something strange occured during scheduled cancel, event: {event}")
574
- _ = await manager.save_task_event(event)
584
+ await manager.save_task_event(event)
575
585
  break
576
586
  await asyncio.sleep(2)
577
587
  except Exception as ex:
578
588
  logger.error("Error when cleaning up task", exc_info=ex)
579
589
  finally:
580
- _ = self._running_tasks.pop(task_id, None)
581
- _ = self._scheduled_cleanups.pop(task_id, None)
590
+ self._running_tasks.pop(task_id, None)
591
+ self._scheduled_cleanups.pop(task_id, None)
582
592
 
583
593
  self._scheduled_cleanups[task_id] = asyncio.create_task(cleanup_fn())
584
594
  self._scheduled_cleanups[task_id].add_done_callback(lambda _: ...)
@@ -18,6 +18,8 @@ from a2a.server.tasks import (
18
18
  from a2a.types import AgentInterface, TransportProtocol
19
19
  from fastapi import APIRouter, Depends, FastAPI
20
20
  from fastapi.applications import AppType
21
+ from starlette.authentication import AuthenticationBackend
22
+ from starlette.middleware.authentication import AuthenticationMiddleware
21
23
  from starlette.types import Lifespan
22
24
 
23
25
  from agentstack_sdk.server.agent import Agent, Executor
@@ -37,6 +39,7 @@ def create_app(
37
39
  dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues]
38
40
  override_interfaces: bool = True,
39
41
  task_timeout: timedelta = timedelta(minutes=10),
42
+ auth_backend: AuthenticationBackend | None = None,
40
43
  **kwargs,
41
44
  ) -> FastAPI:
42
45
  queue_manager = queue_manager or InMemoryQueueManager()
@@ -75,6 +78,10 @@ def create_app(
75
78
  **kwargs,
76
79
  )
77
80
 
81
+ if auth_backend:
82
+ rest_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
83
+ jsonrpc_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
84
+
78
85
  rest_app.mount("/jsonrpc", jsonrpc_app)
79
86
  rest_app.include_router(APIRouter(lifespan=lifespan))
80
87
  return rest_app
@@ -1,6 +1,8 @@
1
1
  # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import inspect
5
7
  from collections import Counter
6
8
  from collections.abc import AsyncIterator, Callable
@@ -8,6 +10,7 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager
8
10
  from inspect import isclass
9
11
  from typing import Annotated, Any, TypeAlias, Unpack, get_args, get_origin
10
12
 
13
+ from a2a.server.agent_execution.context import RequestContext
11
14
  from a2a.types import Message
12
15
  from typing_extensions import Doc
13
16
 
@@ -15,7 +18,9 @@ from agentstack_sdk.a2a.extensions import BaseExtensionSpec
15
18
  from agentstack_sdk.a2a.extensions.base import BaseExtensionServer
16
19
  from agentstack_sdk.server.context import RunContext
17
20
 
18
- Dependency: TypeAlias = Callable[[Message, RunContext, dict[str, "Dependency"]], Any] | BaseExtensionServer[Any, Any]
21
+ Dependency: TypeAlias = (
22
+ Callable[[Message, RunContext, RequestContext, dict[str, "Dependency"]], Any] | BaseExtensionServer[Any, Any]
23
+ )
19
24
 
20
25
 
21
26
  # Inspired by fastapi.Depends
@@ -34,17 +39,17 @@ class Depends:
34
39
  ),
35
40
  ],
36
41
  ):
37
- self._dependency_callable = dependency
42
+ self._dependency_callable: Dependency = dependency
38
43
  if isinstance(dependency, BaseExtensionServer):
39
44
  self.extension = dependency
40
45
 
41
46
  def __call__(
42
- self, message: Message, context: RunContext, dependencies: dict[str, Any]
43
- ) -> AbstractAsyncContextManager[Any]:
44
- instance = self._dependency_callable(message, context, dependencies)
47
+ self, message: Message, context: RunContext, request_context: RequestContext, dependencies: dict[str, Any]
48
+ ) -> AbstractAsyncContextManager[Dependency]:
49
+ instance = self._dependency_callable(message, context, request_context, dependencies)
45
50
 
46
51
  @asynccontextmanager
47
- async def lifespan() -> AsyncIterator[Any]:
52
+ async def lifespan() -> AsyncIterator[Dependency]:
48
53
  if self.extension or hasattr(instance, "lifespan"):
49
54
  async with instance.lifespan():
50
55
  yield instance
@@ -80,10 +85,10 @@ def extract_dependencies(sign: inspect.Signature) -> dict[str, Depends]:
80
85
  elif inspect.isclass(param.annotation):
81
86
  # message: Message
82
87
  if param.annotation == Message:
83
- dependencies[name] = Depends(lambda message, _context, _dependencies: message)
88
+ dependencies[name] = Depends(lambda message, _run_context, _request_context, _dependencies: message)
84
89
  # context: Context
85
90
  elif param.annotation == RunContext:
86
- dependencies[name] = Depends(lambda _message, context, _dependencies: context)
91
+ dependencies[name] = Depends(lambda _message, run_context, _request_context, _dependencies: run_context)
87
92
  # extension: BaseExtensionServer = BaseExtensionSpec()
88
93
  # TODO: this does not get past linters, should we enable it or somehow fix the typing?
89
94
  # elif issubclass(param.annotation, BaseExtensionServer) and isinstance(param.default, BaseExtensionSpec):
@@ -0,0 +1,3 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
@@ -0,0 +1,3 @@
1
+ # Copyright 2026 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
@@ -0,0 +1,131 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ import os
6
+ from datetime import timedelta
7
+ from urllib.parse import urljoin
8
+
9
+ from a2a.auth.user import User
10
+ from async_lru import alru_cache
11
+ from authlib.jose import JsonWebKey, JWTClaims, KeySet, jwt
12
+ from authlib.jose.errors import JoseError
13
+ from fastapi import Request
14
+ from fastapi.security import HTTPBearer
15
+ from pydantic import Secret
16
+ from starlette.authentication import (
17
+ AuthCredentials,
18
+ AuthenticationBackend,
19
+ AuthenticationError,
20
+ BaseUser,
21
+ )
22
+ from starlette.requests import HTTPConnection
23
+ from typing_extensions import override
24
+
25
+ from agentstack_sdk.platform import use_platform_client
26
+ from agentstack_sdk.types import JsonValue
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class PlatformAuthenticatedUser(User, BaseUser):
32
+ def __init__(self, claims: dict[str, JsonValue], auth_token: str):
33
+ self.claims: dict[str, JsonValue] = claims
34
+ self.auth_token: Secret[str] = Secret(auth_token)
35
+
36
+ @property
37
+ @override
38
+ def is_authenticated(self) -> bool:
39
+ return True
40
+
41
+ @property
42
+ @override
43
+ def user_name(self) -> str:
44
+ sub = self.claims.get("sub", None)
45
+ assert sub and isinstance(sub, str)
46
+ return sub
47
+
48
+ @property
49
+ @override
50
+ def display_name(self) -> str:
51
+ name = self.claims.get("name", None)
52
+ assert name and isinstance(name, str)
53
+ return name
54
+
55
+ @property
56
+ @override
57
+ def identity(self) -> str:
58
+ return self.user_name
59
+
60
+
61
+ @alru_cache(ttl=timedelta(minutes=15).seconds)
62
+ async def discover_jwks() -> KeySet:
63
+ try:
64
+ async with use_platform_client() as client:
65
+ response = await client.get("/.well-known/jwks")
66
+ return JsonWebKey.import_key_set(response.raise_for_status().json()) # pyright: ignore[reportAny]
67
+ except Exception as e:
68
+ url = "{platform_url}/.well-known/jwks"
69
+ logger.warning(f"JWKS discovery failed for url {url}: {e}")
70
+ raise RuntimeError(f"JWKS discovery failed for url {url}") from e
71
+
72
+
73
+ class PlatformAuthBackend(AuthenticationBackend):
74
+ def __init__(self, public_url: str | None = None, skip_audience_validation: bool | None = None) -> None:
75
+ self.skip_audience_validation: bool = (
76
+ skip_audience_validation
77
+ if skip_audience_validation is not None
78
+ else os.getenv("PLATFORM_AUTH__SKIP_AUDIENCE_VALIDATION", "false").lower() in ("true", "1")
79
+ )
80
+ self._audience: str | None = public_url or os.getenv("PLATFORM_AUTH__PUBLIC_URL", None)
81
+ if not self.skip_audience_validation and not self._audience:
82
+ logger.warning(
83
+ "Public URL is not provided and audience validation is enabled. Proceeding to check audience from the request target URL. "
84
+ + "This may not work when requests to agents are proxied. (hint: set PLATFORM_AUTH__PUBLIC_URL env variable)"
85
+ )
86
+
87
+ self.security: HTTPBearer = HTTPBearer(auto_error=False)
88
+
89
+ @override
90
+ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
91
+ # We construct a Request object from the scope for compatibility with HTTPBearer and logging
92
+ request = Request(scope=conn.scope)
93
+
94
+ if request.url.path in ["/healthcheck", "/.well-known/agent-card.json"]:
95
+ return None
96
+
97
+ if not (auth := await self.security(request)):
98
+ raise AuthenticationError("Missing Authorization header")
99
+
100
+ audiences: list[str] = []
101
+ if not self.skip_audience_validation:
102
+ if self._audience:
103
+ audiences = [urljoin(self._audience, path) for path in ["/", "/jsonrpc"]]
104
+ else:
105
+ audiences = [str(request.url.replace(path=path)) for path in ["/", "/jsonrpc"]]
106
+
107
+ try:
108
+ # check only hostname urljoin("http://host:port/a/b", "/") -> "http://host:port/"
109
+ jwks = await discover_jwks()
110
+
111
+ # Verify signature
112
+ claims: JWTClaims = jwt.decode(
113
+ auth.credentials,
114
+ jwks,
115
+ claims_options={
116
+ "sub": {"essential": True},
117
+ "exp": {"essential": True},
118
+ # "iss": {"essential": True}, # Issuer validation might be tricky if internal/external URLs differ
119
+ }
120
+ | ({"aud": {"essential": True, "values": audiences}} if not self.skip_audience_validation else {}),
121
+ )
122
+ claims.validate()
123
+
124
+ return AuthCredentials(["authenticated"]), PlatformAuthenticatedUser(claims, auth.credentials)
125
+
126
+ except (ValueError, JoseError) as e:
127
+ logger.warning(f"Authentication failed: {e}")
128
+ raise AuthenticationError("Invalid token") from e
129
+ except Exception as e:
130
+ logger.error(f"Authentication error: {e}")
131
+ raise AuthenticationError(f"Authentication failed: {e}") from e
@@ -21,16 +21,15 @@ from a2a.server.tasks import PushNotificationConfigStore, PushNotificationSender
21
21
  from a2a.types import AgentExtension
22
22
  from fastapi import FastAPI
23
23
  from fastapi.applications import AppType
24
+ from fastapi.responses import PlainTextResponse
24
25
  from httpx import HTTPError, HTTPStatusError
25
26
  from pydantic import AnyUrl
27
+ from starlette.authentication import AuthenticationBackend, AuthenticationError
28
+ from starlette.middleware.authentication import AuthenticationMiddleware
29
+ from starlette.requests import HTTPConnection
26
30
  from starlette.types import Lifespan
27
31
  from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential
28
32
 
29
- from agentstack_sdk.a2a.extensions import AgentDetail, AgentDetailExtensionSpec
30
- from agentstack_sdk.a2a.extensions.services.platform import (
31
- _PlatformSelfRegistrationExtensionParams,
32
- _PlatformSelfRegistrationExtensionSpec,
33
- )
34
33
  from agentstack_sdk.platform import get_platform_client
35
34
  from agentstack_sdk.platform.client import PlatformClient
36
35
  from agentstack_sdk.platform.provider import Provider
@@ -132,6 +131,7 @@ class Server:
132
131
  factory: bool = False,
133
132
  h11_max_incomplete_event_size: int | None = None,
134
133
  self_registration_client_factory: Callable[[], PlatformClient] | None = None,
134
+ auth_backend: AuthenticationBackend | None = None,
135
135
  ) -> None:
136
136
  if self.server:
137
137
  raise RuntimeError("The server is already running")
@@ -179,6 +179,11 @@ class Server:
179
179
  self._agent.card.url = f"http://{host}:{port}"
180
180
 
181
181
  if self_registration:
182
+ from agentstack_sdk.a2a.extensions.services.platform import (
183
+ _PlatformSelfRegistrationExtensionParams,
184
+ _PlatformSelfRegistrationExtensionSpec,
185
+ )
186
+
182
187
  self._agent.card.capabilities.extensions = [
183
188
  *(self._agent.card.capabilities.extensions or []),
184
189
  *_PlatformSelfRegistrationExtensionSpec(
@@ -198,6 +203,13 @@ class Server:
198
203
  request_context_builder=request_context_builder,
199
204
  )
200
205
 
206
+ if auth_backend:
207
+
208
+ def on_error(connection: HTTPConnection, error: AuthenticationError) -> PlainTextResponse:
209
+ return PlainTextResponse("Unauthorized", status_code=401)
210
+
211
+ app.add_middleware(AuthenticationMiddleware, backend=auth_backend, on_error=on_error)
212
+
201
213
  if configure_logger:
202
214
  configure_logger_func(log_level)
203
215
 
@@ -286,6 +298,8 @@ class Server:
286
298
  await self._load_variables()
287
299
 
288
300
  async def _load_variables(self, first_run: bool = False) -> None:
301
+ from agentstack_sdk.a2a.extensions import AgentDetail, AgentDetailExtensionSpec
302
+
289
303
  assert self.server and self._agent
290
304
  if not self._provider_id:
291
305
  return
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING, TypeAlias
5
+
6
+ if TYPE_CHECKING:
7
+ JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None
8
+ JsonDict: TypeAlias = dict[str, JsonValue]
9
+ else:
10
+ from typing import Union
11
+
12
+ from typing_extensions import TypeAliasType
13
+
14
+ JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007
15
+ JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: agentstack-sdk
3
- Version: 0.5.0rc5
3
+ Version: 0.5.1rc3
4
4
  Summary: Agent Stack SDK
5
5
  Author: IBM Corp.
6
6
  Requires-Dist: a2a-sdk==0.3.21
@@ -19,6 +19,8 @@ Requires-Dist: janus>=2.0.0
19
19
  Requires-Dist: httpx
20
20
  Requires-Dist: mcp>=1.12.3
21
21
  Requires-Dist: fastapi>=0.116.1
22
+ Requires-Dist: authlib>=1.3.0
23
+ Requires-Dist: async-lru>=2.0.4
22
24
  Requires-Python: >=3.11, <3.14
23
25
  Description-Content-Type: text/markdown
24
26