agentstack-sdk 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/__init__.py +6 -0
- agentstack_sdk/a2a/__init__.py +2 -0
- agentstack_sdk/a2a/extensions/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/auth/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +151 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/base.py +11 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/memory.py +38 -0
- agentstack_sdk/a2a/extensions/auth/secrets/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +77 -0
- agentstack_sdk/a2a/extensions/base.py +205 -0
- agentstack_sdk/a2a/extensions/common/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/common/form.py +149 -0
- agentstack_sdk/a2a/extensions/exceptions.py +11 -0
- agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
- agentstack_sdk/a2a/extensions/services/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/services/embedding.py +106 -0
- agentstack_sdk/a2a/extensions/services/form.py +54 -0
- agentstack_sdk/a2a/extensions/services/llm.py +100 -0
- agentstack_sdk/a2a/extensions/services/mcp.py +193 -0
- agentstack_sdk/a2a/extensions/services/platform.py +141 -0
- agentstack_sdk/a2a/extensions/tools/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/tools/call.py +114 -0
- agentstack_sdk/a2a/extensions/tools/exceptions.py +6 -0
- agentstack_sdk/a2a/extensions/ui/__init__.py +10 -0
- agentstack_sdk/a2a/extensions/ui/agent_detail.py +54 -0
- agentstack_sdk/a2a/extensions/ui/canvas.py +71 -0
- agentstack_sdk/a2a/extensions/ui/citation.py +78 -0
- agentstack_sdk/a2a/extensions/ui/error.py +223 -0
- agentstack_sdk/a2a/extensions/ui/form_request.py +52 -0
- agentstack_sdk/a2a/extensions/ui/settings.py +73 -0
- agentstack_sdk/a2a/extensions/ui/trajectory.py +70 -0
- agentstack_sdk/a2a/types.py +104 -0
- agentstack_sdk/platform/__init__.py +12 -0
- agentstack_sdk/platform/client.py +123 -0
- agentstack_sdk/platform/common.py +37 -0
- agentstack_sdk/platform/configuration.py +47 -0
- agentstack_sdk/platform/context.py +291 -0
- agentstack_sdk/platform/file.py +295 -0
- agentstack_sdk/platform/model_provider.py +131 -0
- agentstack_sdk/platform/provider.py +219 -0
- agentstack_sdk/platform/provider_build.py +190 -0
- agentstack_sdk/platform/types.py +45 -0
- agentstack_sdk/platform/user.py +70 -0
- agentstack_sdk/platform/user_feedback.py +42 -0
- agentstack_sdk/platform/variables.py +44 -0
- agentstack_sdk/platform/vector_store.py +217 -0
- agentstack_sdk/py.typed +0 -0
- agentstack_sdk/server/__init__.py +4 -0
- agentstack_sdk/server/agent.py +594 -0
- agentstack_sdk/server/app.py +87 -0
- agentstack_sdk/server/constants.py +9 -0
- agentstack_sdk/server/context.py +68 -0
- agentstack_sdk/server/dependencies.py +117 -0
- agentstack_sdk/server/exceptions.py +3 -0
- agentstack_sdk/server/middleware/__init__.py +3 -0
- agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
- agentstack_sdk/server/server.py +376 -0
- agentstack_sdk/server/store/__init__.py +3 -0
- agentstack_sdk/server/store/context_store.py +35 -0
- agentstack_sdk/server/store/memory_context_store.py +59 -0
- agentstack_sdk/server/store/platform_context_store.py +58 -0
- agentstack_sdk/server/telemetry.py +53 -0
- agentstack_sdk/server/utils.py +26 -0
- agentstack_sdk/types.py +15 -0
- agentstack_sdk/util/__init__.py +4 -0
- agentstack_sdk/util/file.py +260 -0
- agentstack_sdk/util/httpx.py +18 -0
- agentstack_sdk/util/logging.py +63 -0
- agentstack_sdk/util/resource_context.py +44 -0
- agentstack_sdk/util/utils.py +47 -0
- agentstack_sdk-0.5.2rc2.dist-info/METADATA +120 -0
- agentstack_sdk-0.5.2rc2.dist-info/RECORD +76 -0
- agentstack_sdk-0.5.2rc2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import Literal, overload
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
import janus
|
|
9
|
+
from a2a.types import Artifact, Message, MessageSendConfiguration, Task
|
|
10
|
+
from pydantic import BaseModel, PrivateAttr
|
|
11
|
+
|
|
12
|
+
from agentstack_sdk.a2a.types import RunYield, RunYieldResume
|
|
13
|
+
from agentstack_sdk.platform.context import ContextHistoryItem
|
|
14
|
+
from agentstack_sdk.server.store.context_store import ContextStoreInstance
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RunContext(BaseModel, arbitrary_types_allowed=True):
|
|
18
|
+
configuration: MessageSendConfiguration | None = None
|
|
19
|
+
task_id: str
|
|
20
|
+
context_id: str
|
|
21
|
+
current_task: Task | None = None
|
|
22
|
+
related_tasks: list[Task] | None = None
|
|
23
|
+
|
|
24
|
+
_store: ContextStoreInstance | None = PrivateAttr(None)
|
|
25
|
+
_yield_queue: janus.Queue[RunYield] = PrivateAttr(default_factory=janus.Queue)
|
|
26
|
+
_yield_resume_queue: janus.Queue[RunYieldResume] = PrivateAttr(default_factory=janus.Queue)
|
|
27
|
+
|
|
28
|
+
async def store(self, data: Message | Artifact):
|
|
29
|
+
if not self._store:
|
|
30
|
+
raise RuntimeError("Context store is not initialized")
|
|
31
|
+
if isinstance(data, Message):
|
|
32
|
+
data = data.model_copy(deep=True, update={"context_id": self.context_id, "task_id": self.task_id})
|
|
33
|
+
await self._store.store(data)
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
async def load_history(
|
|
37
|
+
self, load_history_items: Literal[False] = False
|
|
38
|
+
) -> AsyncGenerator[Message | Artifact, None]:
|
|
39
|
+
yield ... # type: ignore
|
|
40
|
+
|
|
41
|
+
@overload
|
|
42
|
+
async def load_history(self, load_history_items: Literal[True]) -> AsyncGenerator[ContextHistoryItem, None]:
|
|
43
|
+
yield ... # type: ignore
|
|
44
|
+
|
|
45
|
+
async def load_history(
|
|
46
|
+
self, load_history_items: bool = False
|
|
47
|
+
) -> AsyncGenerator[ContextHistoryItem | Message | Artifact]:
|
|
48
|
+
if not self._store:
|
|
49
|
+
raise RuntimeError("Context store is not initialized")
|
|
50
|
+
async for item in self._store.load_history(load_history_items=load_history_items):
|
|
51
|
+
yield item
|
|
52
|
+
|
|
53
|
+
async def delete_history_from_id(self, from_id: UUID) -> None:
|
|
54
|
+
if not self._store:
|
|
55
|
+
raise RuntimeError("Context store is not initialized")
|
|
56
|
+
await self._store.delete_history_from_id(from_id)
|
|
57
|
+
|
|
58
|
+
def yield_sync(self, value: RunYield) -> RunYieldResume:
|
|
59
|
+
self._yield_queue.sync_q.put(value)
|
|
60
|
+
return self._yield_resume_queue.sync_q.get()
|
|
61
|
+
|
|
62
|
+
async def yield_async(self, value: RunYield) -> RunYieldResume:
|
|
63
|
+
await self._yield_queue.async_q.put(value)
|
|
64
|
+
return await self._yield_resume_queue.async_q.get()
|
|
65
|
+
|
|
66
|
+
def shutdown(self) -> None:
|
|
67
|
+
self._yield_queue.shutdown()
|
|
68
|
+
self._yield_resume_queue.shutdown()
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import inspect
|
|
7
|
+
from collections import Counter
|
|
8
|
+
from collections.abc import AsyncIterator, Callable
|
|
9
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
10
|
+
from inspect import isclass
|
|
11
|
+
from typing import Annotated, Any, TypeAlias, Unpack, get_args, get_origin
|
|
12
|
+
|
|
13
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
14
|
+
from a2a.types import Message
|
|
15
|
+
from typing_extensions import Doc
|
|
16
|
+
|
|
17
|
+
from agentstack_sdk.a2a.extensions import BaseExtensionSpec
|
|
18
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionServer
|
|
19
|
+
from agentstack_sdk.server.context import RunContext
|
|
20
|
+
|
|
21
|
+
Dependency: TypeAlias = (
|
|
22
|
+
Callable[[Message, RunContext, RequestContext, dict[str, "Dependency"]], Any] | BaseExtensionServer[Any, Any]
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Inspired by fastapi.Depends
|
|
27
|
+
class Depends:
|
|
28
|
+
extension: BaseExtensionServer[Any, Any] | None = None
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
dependency: Annotated[
|
|
33
|
+
Dependency,
|
|
34
|
+
Doc(
|
|
35
|
+
"""
|
|
36
|
+
A "dependable" callable (like a function).
|
|
37
|
+
Don't call it directly, Agent Stack SDK will call it for you, just pass the object directly.
|
|
38
|
+
"""
|
|
39
|
+
),
|
|
40
|
+
],
|
|
41
|
+
):
|
|
42
|
+
self._dependency_callable: Dependency = dependency
|
|
43
|
+
if isinstance(dependency, BaseExtensionServer):
|
|
44
|
+
self.extension = dependency
|
|
45
|
+
|
|
46
|
+
def __call__(
|
|
47
|
+
self, message: Message, context: RunContext, request_context: RequestContext, dependencies: dict[str, Any]
|
|
48
|
+
) -> AbstractAsyncContextManager[Dependency]:
|
|
49
|
+
instance = self._dependency_callable(message, context, request_context, dependencies)
|
|
50
|
+
|
|
51
|
+
@asynccontextmanager
|
|
52
|
+
async def lifespan() -> AsyncIterator[Dependency]:
|
|
53
|
+
if self.extension or hasattr(instance, "lifespan"):
|
|
54
|
+
async with instance.lifespan():
|
|
55
|
+
yield instance
|
|
56
|
+
else:
|
|
57
|
+
yield instance
|
|
58
|
+
|
|
59
|
+
return lifespan()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def extract_dependencies(sign: inspect.Signature) -> dict[str, Depends]:
|
|
63
|
+
dependencies = {}
|
|
64
|
+
seen_keys = set()
|
|
65
|
+
|
|
66
|
+
def process_args(name: str, args: tuple[Any, ...]) -> None:
|
|
67
|
+
if len(args) > 1:
|
|
68
|
+
dep_type, spec, *rest = args
|
|
69
|
+
# extension_param: Annotated[some_type, Depends(some_callable)]
|
|
70
|
+
if isinstance(spec, Depends):
|
|
71
|
+
dependencies[name] = spec
|
|
72
|
+
# extension_param: Annotated[BaseExtensionServer, BaseExtensionSpec()]
|
|
73
|
+
elif (
|
|
74
|
+
isclass(dep_type) and issubclass(dep_type, BaseExtensionServer) and isinstance(spec, BaseExtensionSpec)
|
|
75
|
+
):
|
|
76
|
+
dependencies[name] = Depends(dep_type(spec, *rest))
|
|
77
|
+
|
|
78
|
+
for name, param in sign.parameters.items():
|
|
79
|
+
seen_keys.add(name)
|
|
80
|
+
|
|
81
|
+
if get_origin(param.annotation) is Annotated:
|
|
82
|
+
args = get_args(param.annotation)
|
|
83
|
+
process_args(name, args)
|
|
84
|
+
|
|
85
|
+
elif inspect.isclass(param.annotation):
|
|
86
|
+
# message: Message
|
|
87
|
+
if param.annotation == Message:
|
|
88
|
+
dependencies[name] = Depends(lambda message, _run_context, _request_context, _dependencies: message)
|
|
89
|
+
# context: Context
|
|
90
|
+
elif param.annotation == RunContext:
|
|
91
|
+
dependencies[name] = Depends(lambda _message, run_context, _request_context, _dependencies: run_context)
|
|
92
|
+
# extension: BaseExtensionServer = BaseExtensionSpec()
|
|
93
|
+
# TODO: this does not get past linters, should we enable it or somehow fix the typing?
|
|
94
|
+
# elif issubclass(param.annotation, BaseExtensionServer) and isinstance(param.default, BaseExtensionSpec):
|
|
95
|
+
# dependencies[name] = Depends(param.annotation(param.default))
|
|
96
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
|
97
|
+
origin = get_origin(param.annotation)
|
|
98
|
+
if origin is Unpack:
|
|
99
|
+
seen_keys.discard(name)
|
|
100
|
+
(typed_dict,) = get_args(param.annotation)
|
|
101
|
+
for field_name, field_type in typed_dict.__annotations__.items():
|
|
102
|
+
seen_keys.add(field_name)
|
|
103
|
+
if get_origin(field_type) is Annotated:
|
|
104
|
+
args = get_args(field_type)
|
|
105
|
+
process_args(field_name, args)
|
|
106
|
+
|
|
107
|
+
missing_keys = seen_keys.difference(dependencies.keys())
|
|
108
|
+
if missing_keys:
|
|
109
|
+
raise TypeError(f"The agent function contains extra parameters with unknown type annotation: {missing_keys}")
|
|
110
|
+
if reserved_names := {param for param in dependencies if param.startswith("__")}:
|
|
111
|
+
raise TypeError(f"User-defined dependencies cannot start with double underscore: {reserved_names}")
|
|
112
|
+
|
|
113
|
+
extension_deps = Counter(dep.extension.spec.URI for dep in dependencies.values() if dep.extension)
|
|
114
|
+
if duplicate_uris := {k for k, v in extension_deps.items() if v > 1}:
|
|
115
|
+
raise TypeError(f"Duplicate extension URIs found in the agent function: {duplicate_uris}")
|
|
116
|
+
|
|
117
|
+
return dependencies
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from urllib.parse import urljoin
|
|
8
|
+
|
|
9
|
+
from a2a.auth.user import User
|
|
10
|
+
from async_lru import alru_cache
|
|
11
|
+
from authlib.jose import JsonWebKey, JWTClaims, KeySet, jwt
|
|
12
|
+
from authlib.jose.errors import JoseError
|
|
13
|
+
from fastapi import Request
|
|
14
|
+
from fastapi.security import HTTPBearer
|
|
15
|
+
from pydantic import Secret
|
|
16
|
+
from starlette.authentication import (
|
|
17
|
+
AuthCredentials,
|
|
18
|
+
AuthenticationBackend,
|
|
19
|
+
AuthenticationError,
|
|
20
|
+
BaseUser,
|
|
21
|
+
)
|
|
22
|
+
from starlette.requests import HTTPConnection
|
|
23
|
+
from typing_extensions import override
|
|
24
|
+
|
|
25
|
+
from agentstack_sdk.platform import use_platform_client
|
|
26
|
+
from agentstack_sdk.types import JsonValue
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PlatformAuthenticatedUser(User, BaseUser):
|
|
32
|
+
def __init__(self, claims: dict[str, JsonValue], auth_token: str):
|
|
33
|
+
self.claims: dict[str, JsonValue] = claims
|
|
34
|
+
self.auth_token: Secret[str] = Secret(auth_token)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
@override
|
|
38
|
+
def is_authenticated(self) -> bool:
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
@override
|
|
43
|
+
def user_name(self) -> str:
|
|
44
|
+
sub = self.claims.get("sub", None)
|
|
45
|
+
assert sub and isinstance(sub, str)
|
|
46
|
+
return sub
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@override
|
|
50
|
+
def display_name(self) -> str:
|
|
51
|
+
name = self.claims.get("name", None)
|
|
52
|
+
assert name and isinstance(name, str)
|
|
53
|
+
return name
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@override
|
|
57
|
+
def identity(self) -> str:
|
|
58
|
+
return self.user_name
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@alru_cache(ttl=timedelta(minutes=15).seconds)
|
|
62
|
+
async def discover_jwks() -> KeySet:
|
|
63
|
+
try:
|
|
64
|
+
async with use_platform_client() as client:
|
|
65
|
+
response = await client.get("/.well-known/jwks")
|
|
66
|
+
return JsonWebKey.import_key_set(response.raise_for_status().json()) # pyright: ignore[reportAny]
|
|
67
|
+
except Exception as e:
|
|
68
|
+
url = "{platform_url}/.well-known/jwks"
|
|
69
|
+
logger.warning(f"JWKS discovery failed for url {url}: {e}")
|
|
70
|
+
raise RuntimeError(f"JWKS discovery failed for url {url}") from e
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PlatformAuthBackend(AuthenticationBackend):
|
|
74
|
+
def __init__(self, public_url: str | None = None, skip_audience_validation: bool | None = None) -> None:
|
|
75
|
+
self.skip_audience_validation: bool = (
|
|
76
|
+
skip_audience_validation
|
|
77
|
+
if skip_audience_validation is not None
|
|
78
|
+
else os.getenv("PLATFORM_AUTH__SKIP_AUDIENCE_VALIDATION", "false").lower() in ("true", "1")
|
|
79
|
+
)
|
|
80
|
+
self._audience: str | None = public_url or os.getenv("PLATFORM_AUTH__PUBLIC_URL", None)
|
|
81
|
+
if not self.skip_audience_validation and not self._audience:
|
|
82
|
+
logger.warning(
|
|
83
|
+
"Public URL is not provided and audience validation is enabled. Proceeding to check audience from the request target URL. "
|
|
84
|
+
+ "This may not work when requests to agents are proxied. (hint: set PLATFORM_AUTH__PUBLIC_URL env variable)"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.security: HTTPBearer = HTTPBearer(auto_error=False)
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
91
|
+
# We construct a Request object from the scope for compatibility with HTTPBearer and logging
|
|
92
|
+
request = Request(scope=conn.scope)
|
|
93
|
+
|
|
94
|
+
if request.url.path in ["/healthcheck", "/.well-known/agent-card.json"]:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
if not (auth := await self.security(request)):
|
|
98
|
+
raise AuthenticationError("Missing Authorization header")
|
|
99
|
+
|
|
100
|
+
audiences: list[str] = []
|
|
101
|
+
if not self.skip_audience_validation:
|
|
102
|
+
if self._audience:
|
|
103
|
+
audiences = [urljoin(self._audience, path) for path in ["/", "/jsonrpc"]]
|
|
104
|
+
else:
|
|
105
|
+
audiences = [str(request.url.replace(path=path)) for path in ["/", "/jsonrpc"]]
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
# check only hostname urljoin("http://host:port/a/b", "/") -> "http://host:port/"
|
|
109
|
+
jwks = await discover_jwks()
|
|
110
|
+
|
|
111
|
+
# Verify signature
|
|
112
|
+
claims: JWTClaims = jwt.decode(
|
|
113
|
+
auth.credentials,
|
|
114
|
+
jwks,
|
|
115
|
+
claims_options={
|
|
116
|
+
"sub": {"essential": True},
|
|
117
|
+
"exp": {"essential": True},
|
|
118
|
+
# "iss": {"essential": True}, # Issuer validation might be tricky if internal/external URLs differ
|
|
119
|
+
}
|
|
120
|
+
| ({"aud": {"essential": True, "values": audiences}} if not self.skip_audience_validation else {}),
|
|
121
|
+
)
|
|
122
|
+
claims.validate()
|
|
123
|
+
|
|
124
|
+
return AuthCredentials(["authenticated"]), PlatformAuthenticatedUser(claims, auth.credentials)
|
|
125
|
+
|
|
126
|
+
except (ValueError, JoseError) as e:
|
|
127
|
+
logger.warning(f"Authentication failed: {e}")
|
|
128
|
+
raise AuthenticationError("Invalid token") from e
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"Authentication error: {e}")
|
|
131
|
+
raise AuthenticationError(f"Authentication failed: {e}") from e
|