agentstack-sdk 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/__init__.py +6 -0
- agentstack_sdk/a2a/__init__.py +2 -0
- agentstack_sdk/a2a/extensions/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/auth/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +151 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/base.py +11 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/memory.py +38 -0
- agentstack_sdk/a2a/extensions/auth/secrets/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +77 -0
- agentstack_sdk/a2a/extensions/base.py +205 -0
- agentstack_sdk/a2a/extensions/common/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/common/form.py +149 -0
- agentstack_sdk/a2a/extensions/exceptions.py +11 -0
- agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
- agentstack_sdk/a2a/extensions/services/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/services/embedding.py +106 -0
- agentstack_sdk/a2a/extensions/services/form.py +54 -0
- agentstack_sdk/a2a/extensions/services/llm.py +100 -0
- agentstack_sdk/a2a/extensions/services/mcp.py +193 -0
- agentstack_sdk/a2a/extensions/services/platform.py +141 -0
- agentstack_sdk/a2a/extensions/tools/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/tools/call.py +114 -0
- agentstack_sdk/a2a/extensions/tools/exceptions.py +6 -0
- agentstack_sdk/a2a/extensions/ui/__init__.py +10 -0
- agentstack_sdk/a2a/extensions/ui/agent_detail.py +54 -0
- agentstack_sdk/a2a/extensions/ui/canvas.py +71 -0
- agentstack_sdk/a2a/extensions/ui/citation.py +78 -0
- agentstack_sdk/a2a/extensions/ui/error.py +223 -0
- agentstack_sdk/a2a/extensions/ui/form_request.py +52 -0
- agentstack_sdk/a2a/extensions/ui/settings.py +73 -0
- agentstack_sdk/a2a/extensions/ui/trajectory.py +70 -0
- agentstack_sdk/a2a/types.py +104 -0
- agentstack_sdk/platform/__init__.py +12 -0
- agentstack_sdk/platform/client.py +123 -0
- agentstack_sdk/platform/common.py +37 -0
- agentstack_sdk/platform/configuration.py +47 -0
- agentstack_sdk/platform/context.py +291 -0
- agentstack_sdk/platform/file.py +295 -0
- agentstack_sdk/platform/model_provider.py +131 -0
- agentstack_sdk/platform/provider.py +219 -0
- agentstack_sdk/platform/provider_build.py +190 -0
- agentstack_sdk/platform/types.py +45 -0
- agentstack_sdk/platform/user.py +70 -0
- agentstack_sdk/platform/user_feedback.py +42 -0
- agentstack_sdk/platform/variables.py +44 -0
- agentstack_sdk/platform/vector_store.py +217 -0
- agentstack_sdk/py.typed +0 -0
- agentstack_sdk/server/__init__.py +4 -0
- agentstack_sdk/server/agent.py +594 -0
- agentstack_sdk/server/app.py +87 -0
- agentstack_sdk/server/constants.py +9 -0
- agentstack_sdk/server/context.py +68 -0
- agentstack_sdk/server/dependencies.py +117 -0
- agentstack_sdk/server/exceptions.py +3 -0
- agentstack_sdk/server/middleware/__init__.py +3 -0
- agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
- agentstack_sdk/server/server.py +376 -0
- agentstack_sdk/server/store/__init__.py +3 -0
- agentstack_sdk/server/store/context_store.py +35 -0
- agentstack_sdk/server/store/memory_context_store.py +59 -0
- agentstack_sdk/server/store/platform_context_store.py +58 -0
- agentstack_sdk/server/telemetry.py +53 -0
- agentstack_sdk/server/utils.py +26 -0
- agentstack_sdk/types.py +15 -0
- agentstack_sdk/util/__init__.py +4 -0
- agentstack_sdk/util/file.py +260 -0
- agentstack_sdk/util/httpx.py +18 -0
- agentstack_sdk/util/logging.py +63 -0
- agentstack_sdk/util/resource_context.py +44 -0
- agentstack_sdk/util/utils.py +47 -0
- agentstack_sdk-0.5.2rc2.dist-info/METADATA +120 -0
- agentstack_sdk-0.5.2rc2.dist-info/RECORD +76 -0
- agentstack_sdk-0.5.2rc2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import uuid
|
|
7
|
+
from types import NoneType
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
9
|
+
from urllib.parse import parse_qs
|
|
10
|
+
|
|
11
|
+
import pydantic
|
|
12
|
+
from a2a.server.agent_execution import RequestContext
|
|
13
|
+
from a2a.types import Message as A2AMessage
|
|
14
|
+
from a2a.types import Role, TextPart
|
|
15
|
+
from mcp.client.auth import OAuthClientProvider
|
|
16
|
+
from mcp.shared.auth import OAuthClientMetadata
|
|
17
|
+
from typing_extensions import override
|
|
18
|
+
|
|
19
|
+
from agentstack_sdk.a2a.extensions.auth.oauth.storage import MemoryTokenStorageFactory, TokenStorageFactory
|
|
20
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
21
|
+
from agentstack_sdk.a2a.types import AgentMessage, AuthRequired, RunYieldResume
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from agentstack_sdk.server.context import RunContext
|
|
25
|
+
|
|
26
|
+
_DEFAULT_DEMAND_NAME = "default"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AuthRequest(pydantic.BaseModel):
|
|
30
|
+
authorization_endpoint_url: pydantic.AnyUrl
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AuthResponse(pydantic.BaseModel):
|
|
34
|
+
redirect_uri: pydantic.AnyUrl
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OAuthFulfillment(pydantic.BaseModel):
|
|
38
|
+
redirect_uri: pydantic.AnyUrl
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class OAuthDemand(pydantic.BaseModel):
|
|
42
|
+
redirect_uri: bool = True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OAuthExtensionParams(pydantic.BaseModel):
|
|
46
|
+
oauth_demands: dict[str, OAuthDemand]
|
|
47
|
+
"""Server requests that the agent requires to be provided by the client."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class OAuthExtensionSpec(BaseExtensionSpec[OAuthExtensionParams]):
|
|
51
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/auth/oauth/v1"
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def single_demand(cls, name: str = _DEFAULT_DEMAND_NAME) -> Self:
|
|
55
|
+
return cls(params=OAuthExtensionParams(oauth_demands={name: OAuthDemand()}))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class OAuthExtensionMetadata(pydantic.BaseModel):
|
|
59
|
+
oauth_fulfillments: dict[str, OAuthFulfillment] = {}
|
|
60
|
+
"""Provided servers corresponding to the server requests."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class OAuthExtensionServer(BaseExtensionServer[OAuthExtensionSpec, OAuthExtensionMetadata]):
|
|
64
|
+
context: RunContext
|
|
65
|
+
token_storage_factory: TokenStorageFactory
|
|
66
|
+
|
|
67
|
+
def __init__(self, spec: OAuthExtensionSpec, token_storage_factory: TokenStorageFactory | None = None) -> None:
|
|
68
|
+
super().__init__(spec)
|
|
69
|
+
self.token_storage_factory = token_storage_factory or MemoryTokenStorageFactory()
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
73
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
74
|
+
self.context = run_context
|
|
75
|
+
|
|
76
|
+
def _get_fulfillment_for_resource(self, resource_url: pydantic.AnyUrl):
|
|
77
|
+
if not self.data:
|
|
78
|
+
raise RuntimeError("No fulfillments found")
|
|
79
|
+
|
|
80
|
+
fulfillment = self.data.oauth_fulfillments.get(str(resource_url)) or self.data.oauth_fulfillments.get(
|
|
81
|
+
_DEFAULT_DEMAND_NAME
|
|
82
|
+
)
|
|
83
|
+
if fulfillment:
|
|
84
|
+
return fulfillment
|
|
85
|
+
|
|
86
|
+
raise RuntimeError("Fulfillment not found")
|
|
87
|
+
|
|
88
|
+
async def create_httpx_auth(self, *, resource_url: pydantic.AnyUrl):
|
|
89
|
+
fulfillment = self._get_fulfillment_for_resource(resource_url=resource_url)
|
|
90
|
+
|
|
91
|
+
resume: RunYieldResume = None
|
|
92
|
+
|
|
93
|
+
async def handle_redirect(auth_url: str) -> None:
|
|
94
|
+
nonlocal resume
|
|
95
|
+
if resume:
|
|
96
|
+
raise RuntimeError("Another redirect is already pending")
|
|
97
|
+
message = self.create_auth_request(authorization_endpoint_url=pydantic.AnyUrl(auth_url))
|
|
98
|
+
resume = await self.context.yield_async(AuthRequired(message=message))
|
|
99
|
+
|
|
100
|
+
async def handle_callback() -> tuple[str, str | None]:
|
|
101
|
+
nonlocal resume
|
|
102
|
+
try:
|
|
103
|
+
if not resume:
|
|
104
|
+
raise ValueError("Missing resume data")
|
|
105
|
+
response = self.parse_auth_response(message=resume)
|
|
106
|
+
params = parse_qs(response.redirect_uri.query)
|
|
107
|
+
return params["code"][0], params.get("state", [None])[0]
|
|
108
|
+
finally:
|
|
109
|
+
resume = None
|
|
110
|
+
|
|
111
|
+
# A2A Client is responsible for catching the redirect and forwarding it over the A2A connection
|
|
112
|
+
oauth_auth = OAuthClientProvider(
|
|
113
|
+
server_url=str(resource_url),
|
|
114
|
+
client_metadata=OAuthClientMetadata(
|
|
115
|
+
redirect_uris=[fulfillment.redirect_uri],
|
|
116
|
+
),
|
|
117
|
+
storage=await self.token_storage_factory.create_storage(),
|
|
118
|
+
redirect_handler=handle_redirect,
|
|
119
|
+
callback_handler=handle_callback,
|
|
120
|
+
)
|
|
121
|
+
return oauth_auth
|
|
122
|
+
|
|
123
|
+
def create_auth_request(self, *, authorization_endpoint_url: pydantic.AnyUrl):
|
|
124
|
+
data = AuthRequest(authorization_endpoint_url=authorization_endpoint_url)
|
|
125
|
+
return AgentMessage(text="Authorization required", metadata={self.spec.URI: data.model_dump(mode="json")})
|
|
126
|
+
|
|
127
|
+
def parse_auth_response(self, *, message: A2AMessage):
|
|
128
|
+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
129
|
+
raise RuntimeError("Invalid auth response")
|
|
130
|
+
return AuthResponse.model_validate(data)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class OAuthExtensionClient(BaseExtensionClient[OAuthExtensionSpec, NoneType]):
|
|
134
|
+
def fulfillment_metadata(self, *, oauth_fulfillments: dict[str, Any]) -> dict[str, Any]:
|
|
135
|
+
return {self.spec.URI: OAuthExtensionMetadata(oauth_fulfillments=oauth_fulfillments).model_dump(mode="json")}
|
|
136
|
+
|
|
137
|
+
def parse_auth_request(self, *, message: A2AMessage):
|
|
138
|
+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
139
|
+
raise ValueError("Invalid auth request")
|
|
140
|
+
return AuthRequest.model_validate(data)
|
|
141
|
+
|
|
142
|
+
def create_auth_response(self, *, task_id: str, redirect_uri: pydantic.AnyUrl):
|
|
143
|
+
data = AuthResponse(redirect_uri=redirect_uri)
|
|
144
|
+
|
|
145
|
+
return A2AMessage(
|
|
146
|
+
message_id=str(uuid.uuid4()),
|
|
147
|
+
role=Role.user,
|
|
148
|
+
parts=[TextPart(text="Authorization completed")], # type: ignore
|
|
149
|
+
task_id=task_id,
|
|
150
|
+
metadata={self.spec.URI: data.model_dump(mode="json")},
|
|
151
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import abc
|
|
5
|
+
|
|
6
|
+
from mcp.client.auth import TokenStorage
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TokenStorageFactory(abc.ABC):
|
|
10
|
+
@abc.abstractmethod
|
|
11
|
+
async def create_storage(self) -> TokenStorage: ...
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from mcp.client.auth import TokenStorage
|
|
6
|
+
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
7
|
+
|
|
8
|
+
from .base import TokenStorageFactory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MemoryTokenStorage(TokenStorage):
|
|
12
|
+
def __init__(self):
|
|
13
|
+
self.tokens: OAuthToken | None = None
|
|
14
|
+
self.client_info: OAuthClientInformationFull | None = None
|
|
15
|
+
|
|
16
|
+
async def get_tokens(self) -> OAuthToken | None:
|
|
17
|
+
return self.tokens
|
|
18
|
+
|
|
19
|
+
async def set_tokens(self, tokens: OAuthToken) -> None:
|
|
20
|
+
self.tokens = tokens
|
|
21
|
+
|
|
22
|
+
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
|
23
|
+
return self.client_info
|
|
24
|
+
|
|
25
|
+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
|
26
|
+
self.client_info = client_info
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemoryTokenStorageFactory(TokenStorageFactory):
|
|
30
|
+
def __init__(self, *, client_info: OAuthClientInformationFull | None = None):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self._client_info = client_info
|
|
33
|
+
|
|
34
|
+
async def create_storage(self) -> TokenStorage:
|
|
35
|
+
storage = MemoryTokenStorage()
|
|
36
|
+
if self._client_info:
|
|
37
|
+
await storage.set_client_info(self._client_info)
|
|
38
|
+
return storage
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING, Self
|
|
7
|
+
|
|
8
|
+
import pydantic
|
|
9
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
10
|
+
from a2a.types import Message as A2AMessage
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
14
|
+
from agentstack_sdk.a2a.types import AgentMessage, AuthRequired
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from agentstack_sdk.server.context import RunContext
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SecretDemand(pydantic.BaseModel):
|
|
21
|
+
name: str
|
|
22
|
+
description: str | None = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SecretFulfillment(pydantic.BaseModel):
|
|
26
|
+
secret: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SecretsServiceExtensionParams(pydantic.BaseModel):
|
|
30
|
+
secret_demands: dict[str, SecretDemand]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SecretsServiceExtensionMetadata(pydantic.BaseModel):
|
|
34
|
+
secret_fulfillments: dict[str, SecretFulfillment] = {}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | None]):
|
|
38
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/auth/secrets/v1"
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def single_demand(cls, name: str, key: str | None = None, description: str | None = None) -> Self:
|
|
42
|
+
return cls(
|
|
43
|
+
params=SecretsServiceExtensionParams(
|
|
44
|
+
secret_demands={key or "default": SecretDemand(description=description, name=name)}
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SecretsExtensionServer(BaseExtensionServer[SecretsExtensionSpec, SecretsServiceExtensionMetadata]):
|
|
50
|
+
context: RunContext
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
54
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
55
|
+
self.context = run_context
|
|
56
|
+
|
|
57
|
+
def parse_secret_response(self, message: A2AMessage) -> SecretsServiceExtensionMetadata:
|
|
58
|
+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
59
|
+
raise ValueError("Secrets has not been provided in response.")
|
|
60
|
+
|
|
61
|
+
return SecretsServiceExtensionMetadata.model_validate(data)
|
|
62
|
+
|
|
63
|
+
async def request_secrets(self, params: SecretsServiceExtensionParams) -> SecretsServiceExtensionMetadata:
|
|
64
|
+
resume = await self.context.yield_async(
|
|
65
|
+
AuthRequired(
|
|
66
|
+
message=AgentMessage(
|
|
67
|
+
metadata={self.spec.URI: params.model_dump(mode="json")},
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
if isinstance(resume, A2AMessage):
|
|
72
|
+
return self.parse_secret_response(message=resume)
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError("Secrets has not been provided in response.")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SecretsExtensionClient(BaseExtensionClient[SecretsExtensionSpec, SecretsServiceExtensionParams]): ...
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import abc
|
|
7
|
+
import typing
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
from types import NoneType
|
|
11
|
+
|
|
12
|
+
import pydantic
|
|
13
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
14
|
+
from a2a.types import AgentCard, AgentExtension
|
|
15
|
+
from a2a.types import Message as A2AMessage
|
|
16
|
+
from typing_extensions import override
|
|
17
|
+
|
|
18
|
+
ParamsT = typing.TypeVar("ParamsT")
|
|
19
|
+
MetadataFromClientT = typing.TypeVar("MetadataFromClientT")
|
|
20
|
+
MetadataFromServerT = typing.TypeVar("MetadataFromServerT")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
if typing.TYPE_CHECKING:
|
|
24
|
+
from agentstack_sdk.server.context import RunContext
|
|
25
|
+
from agentstack_sdk.server.dependencies import Dependency
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]:
|
|
29
|
+
for base in getattr(cls, "__orig_bases__", ()):
|
|
30
|
+
if typing.get_origin(base) is base_class and (args := typing.get_args(base)):
|
|
31
|
+
return args
|
|
32
|
+
raise TypeError(f"Missing Params type for {cls.__name__}")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT]):
|
|
36
|
+
"""
|
|
37
|
+
Base class for an A2A extension handler.
|
|
38
|
+
|
|
39
|
+
The base implementations assume a single URI. More complex extension
|
|
40
|
+
handlers (e.g. serving multiple versions of an extension spec) may override
|
|
41
|
+
the appropriate methods.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
URI: str
|
|
45
|
+
"""
|
|
46
|
+
URI of the extension spec, or the preferred one if there are multiple supported.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
DESCRIPTION: str | None = None
|
|
50
|
+
"""
|
|
51
|
+
Description to be attached with the extension spec.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
Params: type[ParamsT]
|
|
55
|
+
"""
|
|
56
|
+
Type of the extension params, attached to the agent card.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init_subclass__(cls, **kwargs):
|
|
60
|
+
super().__init_subclass__(**kwargs)
|
|
61
|
+
cls.Params = _get_generic_args(cls, BaseExtensionSpec)[0]
|
|
62
|
+
|
|
63
|
+
params: ParamsT
|
|
64
|
+
"""
|
|
65
|
+
Params from the agent card.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, params: ParamsT) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Agent should construct an extension instance using the constructor.
|
|
71
|
+
"""
|
|
72
|
+
self.params = params
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def from_agent_card(cls, agent: AgentCard) -> typing.Self | None:
|
|
76
|
+
"""
|
|
77
|
+
Client should construct an extension instance using this classmethod.
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
return cls(
|
|
81
|
+
params=pydantic.TypeAdapter(cls.Params).validate_python(
|
|
82
|
+
next(x for x in agent.capabilities.extensions or [] if x.uri == cls.URI).params
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
except StopIteration:
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
def to_agent_card_extensions(self, *, required: bool = False) -> list[AgentExtension]:
|
|
89
|
+
"""
|
|
90
|
+
Agent should use this method to obtain extension definitions to advertise on the agent card.
|
|
91
|
+
This returns a list, as it's possible to support multiple A2A extensions within a single class.
|
|
92
|
+
(Usually, that would be different versions of the extension spec.)
|
|
93
|
+
"""
|
|
94
|
+
return [
|
|
95
|
+
AgentExtension(
|
|
96
|
+
uri=self.URI,
|
|
97
|
+
description=self.DESCRIPTION,
|
|
98
|
+
params=typing.cast(
|
|
99
|
+
dict[str, typing.Any] | None,
|
|
100
|
+
pydantic.TypeAdapter(self.Params).dump_python(self.params, mode="json"),
|
|
101
|
+
),
|
|
102
|
+
required=required,
|
|
103
|
+
)
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class NoParamsBaseExtensionSpec(BaseExtensionSpec[NoneType]):
|
|
108
|
+
def __init__(self):
|
|
109
|
+
super().__init__(None)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
@override
|
|
113
|
+
def from_agent_card(cls, agent: AgentCard) -> typing.Self | None:
|
|
114
|
+
if any(e.uri == cls.URI for e in agent.capabilities.extensions or []):
|
|
115
|
+
return cls()
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec[typing.Any])
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromClientT]):
|
|
123
|
+
MetadataFromClient: type[MetadataFromClientT]
|
|
124
|
+
"""
|
|
125
|
+
Type of the extension metadata, attached to messages.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init_subclass__(cls, **kwargs):
|
|
129
|
+
super().__init_subclass__(**kwargs)
|
|
130
|
+
cls.MetadataFromClient = _get_generic_args(cls, BaseExtensionServer)[1]
|
|
131
|
+
|
|
132
|
+
_metadata_from_client: MetadataFromClientT | None = None
|
|
133
|
+
_dependencies: dict[str, Dependency] = {} # noqa: RUF012
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def data(self):
|
|
137
|
+
return self._metadata_from_client
|
|
138
|
+
|
|
139
|
+
def __bool__(self):
|
|
140
|
+
return bool(self.data)
|
|
141
|
+
|
|
142
|
+
def __init__(self, spec: ExtensionSpecT, *args, **kwargs) -> None:
|
|
143
|
+
self.spec = spec
|
|
144
|
+
self._args = args
|
|
145
|
+
self._kwargs = kwargs
|
|
146
|
+
|
|
147
|
+
def parse_client_metadata(self, message: A2AMessage) -> MetadataFromClientT | None:
|
|
148
|
+
"""
|
|
149
|
+
Server should use this method to retrieve extension-associated metadata from a message.
|
|
150
|
+
"""
|
|
151
|
+
return (
|
|
152
|
+
None
|
|
153
|
+
if not message.metadata or self.spec.URI not in message.metadata
|
|
154
|
+
else pydantic.TypeAdapter(self.MetadataFromClient).validate_python(message.metadata[self.spec.URI])
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
158
|
+
if self._metadata_from_client is None:
|
|
159
|
+
self._metadata_from_client = self.parse_client_metadata(message)
|
|
160
|
+
|
|
161
|
+
def _fork(self) -> typing.Self:
|
|
162
|
+
"""Creates a clone of this instance with the same arguments as the original"""
|
|
163
|
+
return type(self)(self.spec, *self._args, **self._kwargs)
|
|
164
|
+
|
|
165
|
+
def __call__(
|
|
166
|
+
self,
|
|
167
|
+
message: A2AMessage,
|
|
168
|
+
run_context: RunContext,
|
|
169
|
+
request_context: RequestContext,
|
|
170
|
+
dependencies: dict[str, Dependency],
|
|
171
|
+
) -> typing.Self:
|
|
172
|
+
"""Works as a dependency constructor - create a private instance for the request"""
|
|
173
|
+
instance = self._fork()
|
|
174
|
+
instance._dependencies = dependencies
|
|
175
|
+
instance.handle_incoming_message(message, run_context, request_context)
|
|
176
|
+
return instance
|
|
177
|
+
|
|
178
|
+
@asynccontextmanager
|
|
179
|
+
async def lifespan(self) -> AsyncIterator[None]:
|
|
180
|
+
"""Called when entering the agent context after the first message was parsed (__call__ was already called)"""
|
|
181
|
+
yield
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class BaseExtensionClient(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromServerT]):
|
|
185
|
+
MetadataFromServer: type[MetadataFromServerT]
|
|
186
|
+
"""
|
|
187
|
+
Type of the extension metadata, attached to messages.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init_subclass__(cls, **kwargs):
|
|
191
|
+
super().__init_subclass__(**kwargs)
|
|
192
|
+
cls.MetadataFromServer = _get_generic_args(cls, BaseExtensionClient)[1]
|
|
193
|
+
|
|
194
|
+
def __init__(self, spec: ExtensionSpecT) -> None:
|
|
195
|
+
self.spec = spec
|
|
196
|
+
|
|
197
|
+
def parse_server_metadata(self, message: A2AMessage) -> MetadataFromServerT | None:
|
|
198
|
+
"""
|
|
199
|
+
Client should use this method to retrieve extension-associated metadata from a message.
|
|
200
|
+
"""
|
|
201
|
+
return (
|
|
202
|
+
None
|
|
203
|
+
if not message.metadata or self.spec.URI not in message.metadata
|
|
204
|
+
else pydantic.TypeAdapter(self.MetadataFromServer).validate_python(message.metadata[self.spec.URI])
|
|
205
|
+
)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseField(BaseModel):
|
|
10
|
+
id: str
|
|
11
|
+
label: str
|
|
12
|
+
required: bool = False
|
|
13
|
+
col_span: int | None = Field(default=None, ge=1, le=4)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TextField(BaseField):
|
|
17
|
+
type: Literal["text"] = "text"
|
|
18
|
+
placeholder: str | None = None
|
|
19
|
+
default_value: str | None = None
|
|
20
|
+
auto_resize: bool | None = True
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DateField(BaseField):
|
|
24
|
+
type: Literal["date"] = "date"
|
|
25
|
+
placeholder: str | None = None
|
|
26
|
+
default_value: str | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FileItem(BaseModel):
|
|
30
|
+
uri: str
|
|
31
|
+
name: str | None = None
|
|
32
|
+
mime_type: str | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FileField(BaseField):
|
|
36
|
+
type: Literal["file"] = "file"
|
|
37
|
+
accept: list[str]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class OptionItem(BaseModel):
|
|
41
|
+
id: str
|
|
42
|
+
label: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SingleSelectField(BaseField):
|
|
46
|
+
type: Literal["singleselect"] = "singleselect"
|
|
47
|
+
options: list[OptionItem]
|
|
48
|
+
default_value: str | None = None
|
|
49
|
+
|
|
50
|
+
@model_validator(mode="after")
|
|
51
|
+
def default_value_validator(self):
|
|
52
|
+
if self.default_value:
|
|
53
|
+
valid_values = {opt.id for opt in self.options}
|
|
54
|
+
if self.default_value not in valid_values:
|
|
55
|
+
raise ValueError(f"Invalid default_value: {self.default_value}. Must be one of {valid_values}")
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class MultiSelectField(BaseField):
|
|
60
|
+
type: Literal["multiselect"] = "multiselect"
|
|
61
|
+
options: list[OptionItem]
|
|
62
|
+
default_value: list[str] | None = None
|
|
63
|
+
|
|
64
|
+
@model_validator(mode="after")
|
|
65
|
+
def default_values_validator(self):
|
|
66
|
+
if self.default_value:
|
|
67
|
+
valid_values = {opt.id for opt in self.options}
|
|
68
|
+
invalid_values = [v for v in self.default_value if v not in valid_values]
|
|
69
|
+
if invalid_values:
|
|
70
|
+
raise ValueError(f"Invalid default_value(s): {invalid_values}. Must be one of {valid_values}")
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class CheckboxField(BaseField):
|
|
75
|
+
type: Literal["checkbox"] = "checkbox"
|
|
76
|
+
content: str
|
|
77
|
+
default_value: bool = False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
FormField = TextField | DateField | FileField | SingleSelectField | MultiSelectField | CheckboxField
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class FormRender(BaseModel):
|
|
84
|
+
title: str | None = None
|
|
85
|
+
description: str | None = None
|
|
86
|
+
columns: int | None = Field(default=None, ge=1, le=4)
|
|
87
|
+
submit_label: str | None = None
|
|
88
|
+
fields: list[FormField]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class TextFieldValue(BaseModel):
|
|
92
|
+
type: Literal["text"] = "text"
|
|
93
|
+
value: str | None = None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DateFieldValue(BaseModel):
|
|
97
|
+
type: Literal["date"] = "date"
|
|
98
|
+
value: str | None = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class FileInfo(BaseModel):
|
|
102
|
+
uri: str
|
|
103
|
+
name: str | None = None
|
|
104
|
+
mime_type: str | None = None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class FileFieldValue(BaseModel):
|
|
108
|
+
type: Literal["file"] = "file"
|
|
109
|
+
value: list[FileInfo] | None = None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class SingleSelectFieldValue(BaseModel):
|
|
113
|
+
type: Literal["singleselect"] = "singleselect"
|
|
114
|
+
value: str | None = None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class MultiSelectFieldValue(BaseModel):
|
|
118
|
+
type: Literal["multiselect"] = "multiselect"
|
|
119
|
+
value: list[str] | None = None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class CheckboxFieldValue(BaseModel):
|
|
123
|
+
type: Literal["checkbox"] = "checkbox"
|
|
124
|
+
value: bool | None = None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
FormFieldValue = (
|
|
128
|
+
TextFieldValue
|
|
129
|
+
| DateFieldValue
|
|
130
|
+
| FileFieldValue
|
|
131
|
+
| SingleSelectFieldValue
|
|
132
|
+
| MultiSelectFieldValue
|
|
133
|
+
| CheckboxFieldValue
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class FormResponse(BaseModel):
|
|
138
|
+
values: dict[str, FormFieldValue]
|
|
139
|
+
|
|
140
|
+
def __iter__(self):
|
|
141
|
+
for key, value in self.values.items():
|
|
142
|
+
match value:
|
|
143
|
+
case FileFieldValue():
|
|
144
|
+
yield (
|
|
145
|
+
key,
|
|
146
|
+
[file.model_dump() for file in value.value] if value.value else None,
|
|
147
|
+
)
|
|
148
|
+
case _:
|
|
149
|
+
yield key, value.value
|