agentstack-sdk 0.5.0rc5__py3-none-any.whl → 0.5.1rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/a2a/extensions/__init__.py +1 -0
- agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +16 -9
- agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +12 -6
- agentstack_sdk/a2a/extensions/base.py +20 -11
- agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
- agentstack_sdk/a2a/extensions/services/embedding.py +10 -3
- agentstack_sdk/a2a/extensions/services/llm.py +6 -4
- agentstack_sdk/a2a/extensions/services/mcp.py +8 -4
- agentstack_sdk/a2a/extensions/services/platform.py +34 -16
- agentstack_sdk/a2a/extensions/ui/__init__.py +1 -0
- agentstack_sdk/a2a/extensions/ui/canvas.py +6 -3
- agentstack_sdk/a2a/extensions/ui/error.py +5 -4
- agentstack_sdk/a2a/extensions/ui/form_request.py +6 -3
- agentstack_sdk/a2a/types.py +2 -11
- agentstack_sdk/platform/client.py +10 -8
- agentstack_sdk/platform/context.py +15 -1
- agentstack_sdk/server/agent.py +19 -9
- agentstack_sdk/server/app.py +7 -0
- agentstack_sdk/server/dependencies.py +13 -8
- agentstack_sdk/server/exceptions.py +3 -0
- agentstack_sdk/server/middleware/__init__.py +3 -0
- agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
- agentstack_sdk/server/server.py +19 -5
- agentstack_sdk/types.py +15 -0
- {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc2.dist-info}/METADATA +3 -1
- {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc2.dist-info}/RECORD +28 -22
- {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc2.dist-info}/WHEEL +0 -0
agentstack_sdk/a2a/types.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
import uuid
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Literal, TypeAlias
|
|
5
5
|
|
|
6
6
|
from a2a.types import (
|
|
7
7
|
Artifact,
|
|
@@ -20,16 +20,7 @@ from a2a.types import (
|
|
|
20
20
|
)
|
|
21
21
|
from pydantic import Field, model_validator
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None
|
|
25
|
-
JsonDict: TypeAlias = dict[str, JsonValue]
|
|
26
|
-
else:
|
|
27
|
-
from typing import Union
|
|
28
|
-
|
|
29
|
-
from typing_extensions import TypeAliasType
|
|
30
|
-
|
|
31
|
-
JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007
|
|
32
|
-
JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]")
|
|
23
|
+
from agentstack_sdk.types import JsonDict, JsonValue
|
|
33
24
|
|
|
34
25
|
|
|
35
26
|
class Metadata(dict[str, JsonValue]): ...
|
|
@@ -5,7 +5,7 @@ import contextlib
|
|
|
5
5
|
import os
|
|
6
6
|
import ssl
|
|
7
7
|
import typing
|
|
8
|
-
from collections.abc import AsyncIterator
|
|
8
|
+
from collections.abc import AsyncIterator, Mapping
|
|
9
9
|
from types import TracebackType
|
|
10
10
|
|
|
11
11
|
import httpx
|
|
@@ -14,6 +14,7 @@ from httpx._client import EventHook
|
|
|
14
14
|
from httpx._config import DEFAULT_LIMITS, DEFAULT_MAX_REDIRECTS, Limits
|
|
15
15
|
from httpx._types import AuthTypes, CertTypes, CookieTypes, HeaderTypes, ProxyTypes, QueryParamTypes, TimeoutTypes
|
|
16
16
|
from pydantic import Secret
|
|
17
|
+
from typing_extensions import override
|
|
17
18
|
|
|
18
19
|
from agentstack_sdk.util import resource_context
|
|
19
20
|
|
|
@@ -26,7 +27,7 @@ class PlatformClient(httpx.AsyncClient):
|
|
|
26
27
|
def __init__(
|
|
27
28
|
self,
|
|
28
29
|
context_id: str | None = None, # Enter context scope
|
|
29
|
-
auth_token: str | Secret | None = None,
|
|
30
|
+
auth_token: str | Secret[str] | None = None,
|
|
30
31
|
*,
|
|
31
32
|
auth: AuthTypes | None = None,
|
|
32
33
|
params: QueryParamTypes | None = None,
|
|
@@ -37,12 +38,12 @@ class PlatformClient(httpx.AsyncClient):
|
|
|
37
38
|
http1: bool = True,
|
|
38
39
|
http2: bool = False,
|
|
39
40
|
proxy: ProxyTypes | None = None,
|
|
40
|
-
mounts: None | (
|
|
41
|
+
mounts: None | (Mapping[str, AsyncBaseTransport | None]) = None,
|
|
41
42
|
timeout: TimeoutTypes = DEFAULT_SDK_TIMEOUT,
|
|
42
43
|
follow_redirects: bool = False,
|
|
43
44
|
limits: Limits = DEFAULT_LIMITS,
|
|
44
45
|
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
|
45
|
-
event_hooks: None | (
|
|
46
|
+
event_hooks: None | (Mapping[str, list[EventHook]]) = None,
|
|
46
47
|
base_url: URL | str = "",
|
|
47
48
|
transport: AsyncBaseTransport | None = None,
|
|
48
49
|
trust_env: bool = True,
|
|
@@ -74,16 +75,18 @@ class PlatformClient(httpx.AsyncClient):
|
|
|
74
75
|
self.context_id = context_id
|
|
75
76
|
if auth_token:
|
|
76
77
|
self.headers["Authorization"] = f"Bearer {auth_token}"
|
|
77
|
-
self._ref_count = 0
|
|
78
|
-
self._context_manager_lock = asyncio.Lock()
|
|
78
|
+
self._ref_count: int = 0
|
|
79
|
+
self._context_manager_lock: asyncio.Lock = asyncio.Lock()
|
|
79
80
|
|
|
81
|
+
@override
|
|
80
82
|
async def __aenter__(self) -> typing.Self:
|
|
81
83
|
async with self._context_manager_lock:
|
|
82
84
|
self._ref_count += 1
|
|
83
85
|
if self._ref_count == 1:
|
|
84
|
-
await super().__aenter__()
|
|
86
|
+
_ = await super().__aenter__()
|
|
85
87
|
return self
|
|
86
88
|
|
|
89
|
+
@override
|
|
87
90
|
async def __aexit__(
|
|
88
91
|
self,
|
|
89
92
|
exc_type: type[BaseException] | None = None,
|
|
@@ -94,7 +97,6 @@ class PlatformClient(httpx.AsyncClient):
|
|
|
94
97
|
self._ref_count -= 1
|
|
95
98
|
if self._ref_count == 0:
|
|
96
99
|
await super().__aexit__(exc_type, exc_value, traceback)
|
|
97
|
-
self._resource = None
|
|
98
100
|
|
|
99
101
|
|
|
100
102
|
get_platform_client, set_platform_client = resource_context(factory=PlatformClient, default_factory=PlatformClient)
|
|
@@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, SerializeAsAny
|
|
|
13
13
|
|
|
14
14
|
from agentstack_sdk.platform.client import PlatformClient, get_platform_client
|
|
15
15
|
from agentstack_sdk.platform.common import PaginatedResult
|
|
16
|
+
from agentstack_sdk.platform.provider import Provider
|
|
16
17
|
from agentstack_sdk.platform.types import Metadata, MetadataPatch
|
|
17
18
|
from agentstack_sdk.util.utils import filter_dict
|
|
18
19
|
|
|
@@ -40,7 +41,7 @@ class ContextPermissions(pydantic.BaseModel):
|
|
|
40
41
|
class Permissions(ContextPermissions):
|
|
41
42
|
llm: set[Literal["*"] | str] = set()
|
|
42
43
|
embeddings: set[Literal["*"] | str] = set()
|
|
43
|
-
a2a_proxy: set[Literal["*"]] = set()
|
|
44
|
+
a2a_proxy: set[Literal["*"] | str] = set()
|
|
44
45
|
model_providers: set[Literal["read", "write", "*"]] = set()
|
|
45
46
|
variables: SerializeAsAny[set[Literal["read", "write", "*"]]] = set()
|
|
46
47
|
|
|
@@ -179,6 +180,7 @@ class Context(pydantic.BaseModel):
|
|
|
179
180
|
async def generate_token(
|
|
180
181
|
self: Context | str,
|
|
181
182
|
*,
|
|
183
|
+
providers: list[str] | list[Provider] | None = None,
|
|
182
184
|
client: PlatformClient | None = None,
|
|
183
185
|
grant_global_permissions: Permissions | None = None,
|
|
184
186
|
grant_context_permissions: ContextPermissions | None = None,
|
|
@@ -193,6 +195,18 @@ class Context(pydantic.BaseModel):
|
|
|
193
195
|
context_id = self if isinstance(self, str) else self.id
|
|
194
196
|
grant_global_permissions = grant_global_permissions or Permissions()
|
|
195
197
|
grant_context_permissions = grant_context_permissions or Permissions()
|
|
198
|
+
|
|
199
|
+
if isinstance(self, Context) and self.metadata and (provider_id := self.metadata.get("provider_id", None)):
|
|
200
|
+
providers = providers or [provider_id]
|
|
201
|
+
|
|
202
|
+
if "*" not in grant_global_permissions.a2a_proxy and not grant_global_permissions.a2a_proxy:
|
|
203
|
+
if not providers:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"Invalid audience: You must specify providers or use '*' in grant_global_permissions.a2a_proxy."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
grant_global_permissions.a2a_proxy |= {p.id if isinstance(p, Provider) else p for p in providers}
|
|
209
|
+
|
|
196
210
|
async with client or get_platform_client() as client:
|
|
197
211
|
token_response = (
|
|
198
212
|
(
|
agentstack_sdk/server/agent.py
CHANGED
|
@@ -46,7 +46,7 @@ from agentstack_sdk.a2a.extensions.ui.error import (
|
|
|
46
46
|
from agentstack_sdk.a2a.types import ArtifactChunk, Metadata, RunYield, RunYieldResume
|
|
47
47
|
from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX
|
|
48
48
|
from agentstack_sdk.server.context import RunContext
|
|
49
|
-
from agentstack_sdk.server.dependencies import Depends, extract_dependencies
|
|
49
|
+
from agentstack_sdk.server.dependencies import Dependency, Depends, extract_dependencies
|
|
50
50
|
from agentstack_sdk.server.store.context_store import ContextStore
|
|
51
51
|
from agentstack_sdk.server.utils import cancel_task
|
|
52
52
|
from agentstack_sdk.util.logging import logger
|
|
@@ -234,6 +234,7 @@ class AgentRun:
|
|
|
234
234
|
self.last_invocation: datetime = datetime.now()
|
|
235
235
|
self.resume_queue: asyncio.Queue[RunYieldResume] = asyncio.Queue()
|
|
236
236
|
self._run_context: RunContext | None = None
|
|
237
|
+
self._request_context: RequestContext | None = None
|
|
237
238
|
self._task_updater: TaskUpdater | None = None
|
|
238
239
|
self._context_store: ContextStore = context_store
|
|
239
240
|
self._lock: asyncio.Lock = asyncio.Lock()
|
|
@@ -246,6 +247,12 @@ class AgentRun:
|
|
|
246
247
|
raise RuntimeError("Accessing run context for run that has not been started")
|
|
247
248
|
return self._run_context
|
|
248
249
|
|
|
250
|
+
@property
|
|
251
|
+
def request_context(self) -> RequestContext:
|
|
252
|
+
if not self._request_context:
|
|
253
|
+
raise RuntimeError("Accessing request context for run that has not been started")
|
|
254
|
+
return self._request_context
|
|
255
|
+
|
|
249
256
|
@property
|
|
250
257
|
def task_updater(self) -> TaskUpdater:
|
|
251
258
|
if not self._task_updater:
|
|
@@ -261,7 +268,6 @@ class AgentRun:
|
|
|
261
268
|
self._on_finish()
|
|
262
269
|
|
|
263
270
|
async def start(self, request_context: RequestContext, event_queue: EventQueue):
|
|
264
|
-
# These are incorrectly typed in a2a
|
|
265
271
|
async with self._lock:
|
|
266
272
|
if self._working or self.done:
|
|
267
273
|
raise RuntimeError("Attempting to start a run that is already executing or done")
|
|
@@ -274,6 +280,7 @@ class AgentRun:
|
|
|
274
280
|
current_task=request_context.current_task,
|
|
275
281
|
related_tasks=request_context.related_tasks,
|
|
276
282
|
)
|
|
283
|
+
self._request_context = request_context
|
|
277
284
|
self._task_updater = TaskUpdater(event_queue, task_id, context_id)
|
|
278
285
|
if not request_context.current_task:
|
|
279
286
|
await self._task_updater.submit()
|
|
@@ -288,11 +295,12 @@ class AgentRun:
|
|
|
288
295
|
raise RuntimeError("Attempting to resume a run that is already executing or done")
|
|
289
296
|
task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
|
|
290
297
|
assert task_id and context_id and message
|
|
298
|
+
self._request_context = request_context
|
|
291
299
|
self._task_updater = TaskUpdater(event_queue, task_id, context_id)
|
|
292
300
|
|
|
293
301
|
for dependency in self._agent.dependencies.values():
|
|
294
302
|
if dependency.extension:
|
|
295
|
-
dependency.extension.handle_incoming_message(message, self.run_context)
|
|
303
|
+
dependency.extension.handle_incoming_message(message, self.run_context, request_context)
|
|
296
304
|
|
|
297
305
|
self._working = True
|
|
298
306
|
await self.resume_queue.put(message)
|
|
@@ -311,15 +319,15 @@ class AgentRun:
|
|
|
311
319
|
await cancel_task(self._task)
|
|
312
320
|
|
|
313
321
|
@asynccontextmanager
|
|
314
|
-
async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str,
|
|
322
|
+
async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Dependency]]:
|
|
315
323
|
async with AsyncExitStack() as stack:
|
|
316
|
-
dependency_args: dict[str,
|
|
324
|
+
dependency_args: dict[str, Dependency] = {}
|
|
317
325
|
initialize_deps_exceptions: list[Exception] = []
|
|
318
326
|
for pname, depends in self._agent.dependencies.items():
|
|
319
327
|
# call dependencies with the first message and initialize their lifespan
|
|
320
328
|
try:
|
|
321
329
|
dependency_args[pname] = await stack.enter_async_context(
|
|
322
|
-
depends(message, self.run_context, dependency_args)
|
|
330
|
+
depends(message, self.run_context, self.request_context, dependency_args)
|
|
323
331
|
)
|
|
324
332
|
except Exception as e:
|
|
325
333
|
initialize_deps_exceptions.append(e)
|
|
@@ -524,6 +532,8 @@ class Executor(AgentExecutor):
|
|
|
524
532
|
match await tapped_queue.dequeue_event():
|
|
525
533
|
case TaskStatusUpdateEvent(final=True):
|
|
526
534
|
break
|
|
535
|
+
case _:
|
|
536
|
+
pass
|
|
527
537
|
|
|
528
538
|
except CancelledError:
|
|
529
539
|
if agent_run:
|
|
@@ -571,14 +581,14 @@ class Executor(AgentExecutor):
|
|
|
571
581
|
event = await queue.dequeue_event(no_wait=True)
|
|
572
582
|
if not isinstance(event, TaskStatusUpdateEvent) or event.status.state != TaskState.canceled:
|
|
573
583
|
raise RuntimeError(f"Something strange occured during scheduled cancel, event: {event}")
|
|
574
|
-
|
|
584
|
+
await manager.save_task_event(event)
|
|
575
585
|
break
|
|
576
586
|
await asyncio.sleep(2)
|
|
577
587
|
except Exception as ex:
|
|
578
588
|
logger.error("Error when cleaning up task", exc_info=ex)
|
|
579
589
|
finally:
|
|
580
|
-
|
|
581
|
-
|
|
590
|
+
self._running_tasks.pop(task_id, None)
|
|
591
|
+
self._scheduled_cleanups.pop(task_id, None)
|
|
582
592
|
|
|
583
593
|
self._scheduled_cleanups[task_id] = asyncio.create_task(cleanup_fn())
|
|
584
594
|
self._scheduled_cleanups[task_id].add_done_callback(lambda _: ...)
|
agentstack_sdk/server/app.py
CHANGED
|
@@ -18,6 +18,8 @@ from a2a.server.tasks import (
|
|
|
18
18
|
from a2a.types import AgentInterface, TransportProtocol
|
|
19
19
|
from fastapi import APIRouter, Depends, FastAPI
|
|
20
20
|
from fastapi.applications import AppType
|
|
21
|
+
from starlette.authentication import AuthenticationBackend
|
|
22
|
+
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
21
23
|
from starlette.types import Lifespan
|
|
22
24
|
|
|
23
25
|
from agentstack_sdk.server.agent import Agent, Executor
|
|
@@ -37,6 +39,7 @@ def create_app(
|
|
|
37
39
|
dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues]
|
|
38
40
|
override_interfaces: bool = True,
|
|
39
41
|
task_timeout: timedelta = timedelta(minutes=10),
|
|
42
|
+
auth_backend: AuthenticationBackend | None = None,
|
|
40
43
|
**kwargs,
|
|
41
44
|
) -> FastAPI:
|
|
42
45
|
queue_manager = queue_manager or InMemoryQueueManager()
|
|
@@ -75,6 +78,10 @@ def create_app(
|
|
|
75
78
|
**kwargs,
|
|
76
79
|
)
|
|
77
80
|
|
|
81
|
+
if auth_backend:
|
|
82
|
+
rest_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
|
|
83
|
+
jsonrpc_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
|
|
84
|
+
|
|
78
85
|
rest_app.mount("/jsonrpc", jsonrpc_app)
|
|
79
86
|
rest_app.include_router(APIRouter(lifespan=lifespan))
|
|
80
87
|
return rest_app
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import inspect
|
|
5
7
|
from collections import Counter
|
|
6
8
|
from collections.abc import AsyncIterator, Callable
|
|
@@ -8,6 +10,7 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
|
8
10
|
from inspect import isclass
|
|
9
11
|
from typing import Annotated, Any, TypeAlias, Unpack, get_args, get_origin
|
|
10
12
|
|
|
13
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
11
14
|
from a2a.types import Message
|
|
12
15
|
from typing_extensions import Doc
|
|
13
16
|
|
|
@@ -15,7 +18,9 @@ from agentstack_sdk.a2a.extensions import BaseExtensionSpec
|
|
|
15
18
|
from agentstack_sdk.a2a.extensions.base import BaseExtensionServer
|
|
16
19
|
from agentstack_sdk.server.context import RunContext
|
|
17
20
|
|
|
18
|
-
Dependency: TypeAlias =
|
|
21
|
+
Dependency: TypeAlias = (
|
|
22
|
+
Callable[[Message, RunContext, RequestContext, dict[str, "Dependency"]], Any] | BaseExtensionServer[Any, Any]
|
|
23
|
+
)
|
|
19
24
|
|
|
20
25
|
|
|
21
26
|
# Inspired by fastapi.Depends
|
|
@@ -34,17 +39,17 @@ class Depends:
|
|
|
34
39
|
),
|
|
35
40
|
],
|
|
36
41
|
):
|
|
37
|
-
self._dependency_callable = dependency
|
|
42
|
+
self._dependency_callable: Dependency = dependency
|
|
38
43
|
if isinstance(dependency, BaseExtensionServer):
|
|
39
44
|
self.extension = dependency
|
|
40
45
|
|
|
41
46
|
def __call__(
|
|
42
|
-
self, message: Message, context: RunContext, dependencies: dict[str, Any]
|
|
43
|
-
) -> AbstractAsyncContextManager[
|
|
44
|
-
instance = self._dependency_callable(message, context, dependencies)
|
|
47
|
+
self, message: Message, context: RunContext, request_context: RequestContext, dependencies: dict[str, Any]
|
|
48
|
+
) -> AbstractAsyncContextManager[Dependency]:
|
|
49
|
+
instance = self._dependency_callable(message, context, request_context, dependencies)
|
|
45
50
|
|
|
46
51
|
@asynccontextmanager
|
|
47
|
-
async def lifespan() -> AsyncIterator[
|
|
52
|
+
async def lifespan() -> AsyncIterator[Dependency]:
|
|
48
53
|
if self.extension or hasattr(instance, "lifespan"):
|
|
49
54
|
async with instance.lifespan():
|
|
50
55
|
yield instance
|
|
@@ -80,10 +85,10 @@ def extract_dependencies(sign: inspect.Signature) -> dict[str, Depends]:
|
|
|
80
85
|
elif inspect.isclass(param.annotation):
|
|
81
86
|
# message: Message
|
|
82
87
|
if param.annotation == Message:
|
|
83
|
-
dependencies[name] = Depends(lambda message,
|
|
88
|
+
dependencies[name] = Depends(lambda message, _run_context, _request_context, _dependencies: message)
|
|
84
89
|
# context: Context
|
|
85
90
|
elif param.annotation == RunContext:
|
|
86
|
-
dependencies[name] = Depends(lambda _message,
|
|
91
|
+
dependencies[name] = Depends(lambda _message, run_context, _request_context, _dependencies: run_context)
|
|
87
92
|
# extension: BaseExtensionServer = BaseExtensionSpec()
|
|
88
93
|
# TODO: this does not get past linters, should we enable it or somehow fix the typing?
|
|
89
94
|
# elif issubclass(param.annotation, BaseExtensionServer) and isinstance(param.default, BaseExtensionSpec):
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from urllib.parse import urljoin
|
|
8
|
+
|
|
9
|
+
from a2a.auth.user import User
|
|
10
|
+
from async_lru import alru_cache
|
|
11
|
+
from authlib.jose import JsonWebKey, JWTClaims, KeySet, jwt
|
|
12
|
+
from authlib.jose.errors import JoseError
|
|
13
|
+
from fastapi import Request
|
|
14
|
+
from fastapi.security import HTTPBearer
|
|
15
|
+
from pydantic import Secret
|
|
16
|
+
from starlette.authentication import (
|
|
17
|
+
AuthCredentials,
|
|
18
|
+
AuthenticationBackend,
|
|
19
|
+
AuthenticationError,
|
|
20
|
+
BaseUser,
|
|
21
|
+
)
|
|
22
|
+
from starlette.requests import HTTPConnection
|
|
23
|
+
from typing_extensions import override
|
|
24
|
+
|
|
25
|
+
from agentstack_sdk.platform import use_platform_client
|
|
26
|
+
from agentstack_sdk.types import JsonValue
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PlatformAuthenticatedUser(User, BaseUser):
|
|
32
|
+
def __init__(self, claims: dict[str, JsonValue], auth_token: str):
|
|
33
|
+
self.claims: dict[str, JsonValue] = claims
|
|
34
|
+
self.auth_token: Secret[str] = Secret(auth_token)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
@override
|
|
38
|
+
def is_authenticated(self) -> bool:
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
@override
|
|
43
|
+
def user_name(self) -> str:
|
|
44
|
+
sub = self.claims.get("sub", None)
|
|
45
|
+
assert sub and isinstance(sub, str)
|
|
46
|
+
return sub
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@override
|
|
50
|
+
def display_name(self) -> str:
|
|
51
|
+
name = self.claims.get("name", None)
|
|
52
|
+
assert name and isinstance(name, str)
|
|
53
|
+
return name
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@override
|
|
57
|
+
def identity(self) -> str:
|
|
58
|
+
return self.user_name
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@alru_cache(ttl=timedelta(minutes=15).seconds)
|
|
62
|
+
async def discover_jwks() -> KeySet:
|
|
63
|
+
try:
|
|
64
|
+
async with use_platform_client() as client:
|
|
65
|
+
response = await client.get("/.well-known/jwks")
|
|
66
|
+
return JsonWebKey.import_key_set(response.raise_for_status().json()) # pyright: ignore[reportAny]
|
|
67
|
+
except Exception as e:
|
|
68
|
+
url = "{platform_url}/.well-known/jwks"
|
|
69
|
+
logger.warning(f"JWKS discovery failed for url {url}: {e}")
|
|
70
|
+
raise RuntimeError(f"JWKS discovery failed for url {url}") from e
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PlatformAuthBackend(AuthenticationBackend):
|
|
74
|
+
def __init__(self, public_url: str | None = None, skip_audience_validation: bool | None = None) -> None:
|
|
75
|
+
self.skip_audience_validation: bool = (
|
|
76
|
+
skip_audience_validation
|
|
77
|
+
if skip_audience_validation is not None
|
|
78
|
+
else os.getenv("PLATFORM_AUTH__SKIP_AUDIENCE_VALIDATION", "false").lower() in ("true", "1")
|
|
79
|
+
)
|
|
80
|
+
self._audience: str | None = public_url or os.getenv("PLATFORM_AUTH__PUBLIC_URL", None)
|
|
81
|
+
if not self.skip_audience_validation and not self._audience:
|
|
82
|
+
logger.warning(
|
|
83
|
+
"Public URL is not provided and audience validation is enabled. Proceeding to check audience from the request target URL. "
|
|
84
|
+
+ "This may not work when requests to agents are proxied. (hint: set PLATFORM_AUTH__PUBLIC_URL env variable)"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.security: HTTPBearer = HTTPBearer(auto_error=False)
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
91
|
+
# We construct a Request object from the scope for compatibility with HTTPBearer and logging
|
|
92
|
+
request = Request(scope=conn.scope)
|
|
93
|
+
|
|
94
|
+
if request.url.path in ["/healthcheck", "/.well-known/agent-card.json"]:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
if not (auth := await self.security(request)):
|
|
98
|
+
raise AuthenticationError("Missing Authorization header")
|
|
99
|
+
|
|
100
|
+
audiences: list[str] = []
|
|
101
|
+
if not self.skip_audience_validation:
|
|
102
|
+
if self._audience:
|
|
103
|
+
audiences = [urljoin(self._audience, path) for path in ["/", "/jsonrpc"]]
|
|
104
|
+
else:
|
|
105
|
+
audiences = [str(request.url.replace(path=path)) for path in ["/", "/jsonrpc"]]
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
# check only hostname urljoin("http://host:port/a/b", "/") -> "http://host:port/"
|
|
109
|
+
jwks = await discover_jwks()
|
|
110
|
+
|
|
111
|
+
# Verify signature
|
|
112
|
+
claims: JWTClaims = jwt.decode(
|
|
113
|
+
auth.credentials,
|
|
114
|
+
jwks,
|
|
115
|
+
claims_options={
|
|
116
|
+
"sub": {"essential": True},
|
|
117
|
+
"exp": {"essential": True},
|
|
118
|
+
# "iss": {"essential": True}, # Issuer validation might be tricky if internal/external URLs differ
|
|
119
|
+
}
|
|
120
|
+
| ({"aud": {"essential": True, "values": audiences}} if not self.skip_audience_validation else {}),
|
|
121
|
+
)
|
|
122
|
+
claims.validate()
|
|
123
|
+
|
|
124
|
+
return AuthCredentials(["authenticated"]), PlatformAuthenticatedUser(claims, auth.credentials)
|
|
125
|
+
|
|
126
|
+
except (ValueError, JoseError) as e:
|
|
127
|
+
logger.warning(f"Authentication failed: {e}")
|
|
128
|
+
raise AuthenticationError("Invalid token") from e
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"Authentication error: {e}")
|
|
131
|
+
raise AuthenticationError(f"Authentication failed: {e}") from e
|
agentstack_sdk/server/server.py
CHANGED
|
@@ -21,16 +21,15 @@ from a2a.server.tasks import PushNotificationConfigStore, PushNotificationSender
|
|
|
21
21
|
from a2a.types import AgentExtension
|
|
22
22
|
from fastapi import FastAPI
|
|
23
23
|
from fastapi.applications import AppType
|
|
24
|
+
from fastapi.responses import PlainTextResponse
|
|
24
25
|
from httpx import HTTPError, HTTPStatusError
|
|
25
26
|
from pydantic import AnyUrl
|
|
27
|
+
from starlette.authentication import AuthenticationBackend, AuthenticationError
|
|
28
|
+
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
29
|
+
from starlette.requests import HTTPConnection
|
|
26
30
|
from starlette.types import Lifespan
|
|
27
31
|
from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
28
32
|
|
|
29
|
-
from agentstack_sdk.a2a.extensions import AgentDetail, AgentDetailExtensionSpec
|
|
30
|
-
from agentstack_sdk.a2a.extensions.services.platform import (
|
|
31
|
-
_PlatformSelfRegistrationExtensionParams,
|
|
32
|
-
_PlatformSelfRegistrationExtensionSpec,
|
|
33
|
-
)
|
|
34
33
|
from agentstack_sdk.platform import get_platform_client
|
|
35
34
|
from agentstack_sdk.platform.client import PlatformClient
|
|
36
35
|
from agentstack_sdk.platform.provider import Provider
|
|
@@ -132,6 +131,7 @@ class Server:
|
|
|
132
131
|
factory: bool = False,
|
|
133
132
|
h11_max_incomplete_event_size: int | None = None,
|
|
134
133
|
self_registration_client_factory: Callable[[], PlatformClient] | None = None,
|
|
134
|
+
auth_backend: AuthenticationBackend | None = None,
|
|
135
135
|
) -> None:
|
|
136
136
|
if self.server:
|
|
137
137
|
raise RuntimeError("The server is already running")
|
|
@@ -179,6 +179,11 @@ class Server:
|
|
|
179
179
|
self._agent.card.url = f"http://{host}:{port}"
|
|
180
180
|
|
|
181
181
|
if self_registration:
|
|
182
|
+
from agentstack_sdk.a2a.extensions.services.platform import (
|
|
183
|
+
_PlatformSelfRegistrationExtensionParams,
|
|
184
|
+
_PlatformSelfRegistrationExtensionSpec,
|
|
185
|
+
)
|
|
186
|
+
|
|
182
187
|
self._agent.card.capabilities.extensions = [
|
|
183
188
|
*(self._agent.card.capabilities.extensions or []),
|
|
184
189
|
*_PlatformSelfRegistrationExtensionSpec(
|
|
@@ -198,6 +203,13 @@ class Server:
|
|
|
198
203
|
request_context_builder=request_context_builder,
|
|
199
204
|
)
|
|
200
205
|
|
|
206
|
+
if auth_backend:
|
|
207
|
+
|
|
208
|
+
def on_error(connection: HTTPConnection, error: AuthenticationError) -> PlainTextResponse:
|
|
209
|
+
return PlainTextResponse("Unauthorized", status_code=401)
|
|
210
|
+
|
|
211
|
+
app.add_middleware(AuthenticationMiddleware, backend=auth_backend, on_error=on_error)
|
|
212
|
+
|
|
201
213
|
if configure_logger:
|
|
202
214
|
configure_logger_func(log_level)
|
|
203
215
|
|
|
@@ -286,6 +298,8 @@ class Server:
|
|
|
286
298
|
await self._load_variables()
|
|
287
299
|
|
|
288
300
|
async def _load_variables(self, first_run: bool = False) -> None:
|
|
301
|
+
from agentstack_sdk.a2a.extensions import AgentDetail, AgentDetailExtensionSpec
|
|
302
|
+
|
|
289
303
|
assert self.server and self._agent
|
|
290
304
|
if not self._provider_id:
|
|
291
305
|
return
|
agentstack_sdk/types.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, TypeAlias
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None
|
|
8
|
+
JsonDict: TypeAlias = dict[str, JsonValue]
|
|
9
|
+
else:
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
from typing_extensions import TypeAliasType
|
|
13
|
+
|
|
14
|
+
JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007
|
|
15
|
+
JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: agentstack-sdk
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.1rc2
|
|
4
4
|
Summary: Agent Stack SDK
|
|
5
5
|
Author: IBM Corp.
|
|
6
6
|
Requires-Dist: a2a-sdk==0.3.21
|
|
@@ -19,6 +19,8 @@ Requires-Dist: janus>=2.0.0
|
|
|
19
19
|
Requires-Dist: httpx
|
|
20
20
|
Requires-Dist: mcp>=1.12.3
|
|
21
21
|
Requires-Dist: fastapi>=0.116.1
|
|
22
|
+
Requires-Dist: authlib>=1.3.0
|
|
23
|
+
Requires-Dist: async-lru>=2.0.4
|
|
22
24
|
Requires-Python: >=3.11, <3.14
|
|
23
25
|
Description-Content-Type: text/markdown
|
|
24
26
|
|