agentstack-sdk 0.5.0rc5__py3-none-any.whl → 0.5.1rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/a2a/extensions/__init__.py +1 -0
- agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +16 -9
- agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +12 -6
- agentstack_sdk/a2a/extensions/base.py +20 -11
- agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
- agentstack_sdk/a2a/extensions/services/embedding.py +10 -3
- agentstack_sdk/a2a/extensions/services/llm.py +6 -4
- agentstack_sdk/a2a/extensions/services/mcp.py +8 -4
- agentstack_sdk/a2a/extensions/services/platform.py +34 -16
- agentstack_sdk/a2a/extensions/ui/__init__.py +1 -0
- agentstack_sdk/a2a/extensions/ui/canvas.py +6 -3
- agentstack_sdk/a2a/extensions/ui/error.py +5 -4
- agentstack_sdk/a2a/extensions/ui/form_request.py +6 -3
- agentstack_sdk/a2a/types.py +2 -11
- agentstack_sdk/platform/client.py +10 -8
- agentstack_sdk/platform/context.py +15 -1
- agentstack_sdk/server/agent.py +19 -9
- agentstack_sdk/server/app.py +7 -0
- agentstack_sdk/server/dependencies.py +13 -8
- agentstack_sdk/server/exceptions.py +3 -0
- agentstack_sdk/server/middleware/__init__.py +3 -0
- agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
- agentstack_sdk/server/server.py +19 -5
- agentstack_sdk/types.py +15 -0
- {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc2.dist-info}/METADATA +3 -1
- {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc2.dist-info}/RECORD +28 -22
- {agentstack_sdk-0.5.0rc5.dist-info → agentstack_sdk-0.5.1rc2.dist-info}/WHEEL +0 -0
|
@@ -8,10 +8,13 @@ from types import NoneType
|
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Self
|
|
9
9
|
from urllib.parse import parse_qs
|
|
10
10
|
|
|
11
|
-
import a2a.types
|
|
12
11
|
import pydantic
|
|
12
|
+
from a2a.server.agent_execution import RequestContext
|
|
13
|
+
from a2a.types import Message as A2AMessage
|
|
14
|
+
from a2a.types import Role, TextPart
|
|
13
15
|
from mcp.client.auth import OAuthClientProvider
|
|
14
16
|
from mcp.shared.auth import OAuthClientMetadata
|
|
17
|
+
from typing_extensions import override
|
|
15
18
|
|
|
16
19
|
from agentstack_sdk.a2a.extensions.auth.oauth.storage import MemoryTokenStorageFactory, TokenStorageFactory
|
|
17
20
|
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
@@ -58,13 +61,17 @@ class OAuthExtensionMetadata(pydantic.BaseModel):
|
|
|
58
61
|
|
|
59
62
|
|
|
60
63
|
class OAuthExtensionServer(BaseExtensionServer[OAuthExtensionSpec, OAuthExtensionMetadata]):
|
|
64
|
+
context: RunContext
|
|
65
|
+
token_storage_factory: TokenStorageFactory
|
|
66
|
+
|
|
61
67
|
def __init__(self, spec: OAuthExtensionSpec, token_storage_factory: TokenStorageFactory | None = None) -> None:
|
|
62
68
|
super().__init__(spec)
|
|
63
69
|
self.token_storage_factory = token_storage_factory or MemoryTokenStorageFactory()
|
|
64
70
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
71
|
+
@override
|
|
72
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
73
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
74
|
+
self.context = run_context
|
|
68
75
|
|
|
69
76
|
def _get_fulfillment_for_resource(self, resource_url: pydantic.AnyUrl):
|
|
70
77
|
if not self.data:
|
|
@@ -117,7 +124,7 @@ class OAuthExtensionServer(BaseExtensionServer[OAuthExtensionSpec, OAuthExtensio
|
|
|
117
124
|
data = AuthRequest(authorization_endpoint_url=authorization_endpoint_url)
|
|
118
125
|
return AgentMessage(text="Authorization required", metadata={self.spec.URI: data.model_dump(mode="json")})
|
|
119
126
|
|
|
120
|
-
def parse_auth_response(self, *, message:
|
|
127
|
+
def parse_auth_response(self, *, message: A2AMessage):
|
|
121
128
|
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
122
129
|
raise RuntimeError("Invalid auth response")
|
|
123
130
|
return AuthResponse.model_validate(data)
|
|
@@ -127,7 +134,7 @@ class OAuthExtensionClient(BaseExtensionClient[OAuthExtensionSpec, NoneType]):
|
|
|
127
134
|
def fulfillment_metadata(self, *, oauth_fulfillments: dict[str, Any]) -> dict[str, Any]:
|
|
128
135
|
return {self.spec.URI: OAuthExtensionMetadata(oauth_fulfillments=oauth_fulfillments).model_dump(mode="json")}
|
|
129
136
|
|
|
130
|
-
def parse_auth_request(self, *, message:
|
|
137
|
+
def parse_auth_request(self, *, message: A2AMessage):
|
|
131
138
|
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
132
139
|
raise ValueError("Invalid auth request")
|
|
133
140
|
return AuthRequest.model_validate(data)
|
|
@@ -135,10 +142,10 @@ class OAuthExtensionClient(BaseExtensionClient[OAuthExtensionSpec, NoneType]):
|
|
|
135
142
|
def create_auth_response(self, *, task_id: str, redirect_uri: pydantic.AnyUrl):
|
|
136
143
|
data = AuthResponse(redirect_uri=redirect_uri)
|
|
137
144
|
|
|
138
|
-
return
|
|
145
|
+
return A2AMessage(
|
|
139
146
|
message_id=str(uuid.uuid4()),
|
|
140
|
-
role=
|
|
141
|
-
parts=[
|
|
147
|
+
role=Role.user,
|
|
148
|
+
parts=[TextPart(text="Authorization completed")], # type: ignore
|
|
142
149
|
task_id=task_id,
|
|
143
150
|
metadata={self.spec.URI: data.model_dump(mode="json")},
|
|
144
151
|
)
|
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
import
|
|
5
|
-
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING, Self
|
|
6
7
|
|
|
7
8
|
import pydantic
|
|
9
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
8
10
|
from a2a.types import Message as A2AMessage
|
|
11
|
+
from typing_extensions import override
|
|
9
12
|
|
|
10
13
|
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
11
14
|
from agentstack_sdk.a2a.types import AgentMessage, AuthRequired
|
|
@@ -35,7 +38,7 @@ class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | Non
|
|
|
35
38
|
URI: str = "https://a2a-extensions.agentstack.beeai.dev/auth/secrets/v1"
|
|
36
39
|
|
|
37
40
|
@classmethod
|
|
38
|
-
def single_demand(cls, name: str, key: str | None = None, description: str | None = None) ->
|
|
41
|
+
def single_demand(cls, name: str, key: str | None = None, description: str | None = None) -> Self:
|
|
39
42
|
return cls(
|
|
40
43
|
params=SecretsServiceExtensionParams(
|
|
41
44
|
secret_demands={key or "default": SecretDemand(description=description, name=name)}
|
|
@@ -44,9 +47,12 @@ class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | Non
|
|
|
44
47
|
|
|
45
48
|
|
|
46
49
|
class SecretsExtensionServer(BaseExtensionServer[SecretsExtensionSpec, SecretsServiceExtensionMetadata]):
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
+
context: RunContext
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
54
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
55
|
+
self.context = run_context
|
|
50
56
|
|
|
51
57
|
def parse_secret_response(self, message: A2AMessage) -> SecretsServiceExtensionMetadata:
|
|
52
58
|
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
@@ -9,8 +9,11 @@ from collections.abc import AsyncIterator
|
|
|
9
9
|
from contextlib import asynccontextmanager
|
|
10
10
|
from types import NoneType
|
|
11
11
|
|
|
12
|
-
import a2a.types
|
|
13
12
|
import pydantic
|
|
13
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
14
|
+
from a2a.types import AgentCard, AgentExtension
|
|
15
|
+
from a2a.types import Message as A2AMessage
|
|
16
|
+
from typing_extensions import override
|
|
14
17
|
|
|
15
18
|
ParamsT = typing.TypeVar("ParamsT")
|
|
16
19
|
MetadataFromClientT = typing.TypeVar("MetadataFromClientT")
|
|
@@ -19,6 +22,7 @@ MetadataFromServerT = typing.TypeVar("MetadataFromServerT")
|
|
|
19
22
|
|
|
20
23
|
if typing.TYPE_CHECKING:
|
|
21
24
|
from agentstack_sdk.server.context import RunContext
|
|
25
|
+
from agentstack_sdk.server.dependencies import Dependency
|
|
22
26
|
|
|
23
27
|
|
|
24
28
|
def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]:
|
|
@@ -68,7 +72,7 @@ class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT]):
|
|
|
68
72
|
self.params = params
|
|
69
73
|
|
|
70
74
|
@classmethod
|
|
71
|
-
def from_agent_card(cls, agent:
|
|
75
|
+
def from_agent_card(cls, agent: AgentCard) -> typing.Self | None:
|
|
72
76
|
"""
|
|
73
77
|
Client should construct an extension instance using this classmethod.
|
|
74
78
|
"""
|
|
@@ -81,14 +85,14 @@ class BaseExtensionSpec(abc.ABC, typing.Generic[ParamsT]):
|
|
|
81
85
|
except StopIteration:
|
|
82
86
|
return None
|
|
83
87
|
|
|
84
|
-
def to_agent_card_extensions(self, *, required: bool = False) -> list[
|
|
88
|
+
def to_agent_card_extensions(self, *, required: bool = False) -> list[AgentExtension]:
|
|
85
89
|
"""
|
|
86
90
|
Agent should use this method to obtain extension definitions to advertise on the agent card.
|
|
87
91
|
This returns a list, as it's possible to support multiple A2A extensions within a single class.
|
|
88
92
|
(Usually, that would be different versions of the extension spec.)
|
|
89
93
|
"""
|
|
90
94
|
return [
|
|
91
|
-
|
|
95
|
+
AgentExtension(
|
|
92
96
|
uri=self.URI,
|
|
93
97
|
description=self.DESCRIPTION,
|
|
94
98
|
params=typing.cast(
|
|
@@ -105,7 +109,8 @@ class NoParamsBaseExtensionSpec(BaseExtensionSpec[NoneType]):
|
|
|
105
109
|
super().__init__(None)
|
|
106
110
|
|
|
107
111
|
@classmethod
|
|
108
|
-
|
|
112
|
+
@override
|
|
113
|
+
def from_agent_card(cls, agent: AgentCard) -> typing.Self | None:
|
|
109
114
|
if any(e.uri == cls.URI for e in agent.capabilities.extensions or []):
|
|
110
115
|
return cls()
|
|
111
116
|
return None
|
|
@@ -125,7 +130,7 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl
|
|
|
125
130
|
cls.MetadataFromClient = _get_generic_args(cls, BaseExtensionServer)[1]
|
|
126
131
|
|
|
127
132
|
_metadata_from_client: MetadataFromClientT | None = None
|
|
128
|
-
_dependencies: dict
|
|
133
|
+
_dependencies: dict[str, Dependency] = {} # noqa: RUF012
|
|
129
134
|
|
|
130
135
|
@property
|
|
131
136
|
def data(self):
|
|
@@ -139,7 +144,7 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl
|
|
|
139
144
|
self._args = args
|
|
140
145
|
self._kwargs = kwargs
|
|
141
146
|
|
|
142
|
-
def parse_client_metadata(self, message:
|
|
147
|
+
def parse_client_metadata(self, message: A2AMessage) -> MetadataFromClientT | None:
|
|
143
148
|
"""
|
|
144
149
|
Server should use this method to retrieve extension-associated metadata from a message.
|
|
145
150
|
"""
|
|
@@ -149,7 +154,7 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl
|
|
|
149
154
|
else pydantic.TypeAdapter(self.MetadataFromClient).validate_python(message.metadata[self.spec.URI])
|
|
150
155
|
)
|
|
151
156
|
|
|
152
|
-
def handle_incoming_message(self, message:
|
|
157
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
153
158
|
if self._metadata_from_client is None:
|
|
154
159
|
self._metadata_from_client = self.parse_client_metadata(message)
|
|
155
160
|
|
|
@@ -158,12 +163,16 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl
|
|
|
158
163
|
return type(self)(self.spec, *self._args, **self._kwargs)
|
|
159
164
|
|
|
160
165
|
def __call__(
|
|
161
|
-
self,
|
|
166
|
+
self,
|
|
167
|
+
message: A2AMessage,
|
|
168
|
+
run_context: RunContext,
|
|
169
|
+
request_context: RequestContext,
|
|
170
|
+
dependencies: dict[str, Dependency],
|
|
162
171
|
) -> typing.Self:
|
|
163
172
|
"""Works as a dependency constructor - create a private instance for the request"""
|
|
164
173
|
instance = self._fork()
|
|
165
174
|
instance._dependencies = dependencies
|
|
166
|
-
instance.handle_incoming_message(message,
|
|
175
|
+
instance.handle_incoming_message(message, run_context, request_context)
|
|
167
176
|
return instance
|
|
168
177
|
|
|
169
178
|
@asynccontextmanager
|
|
@@ -185,7 +194,7 @@ class BaseExtensionClient(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromSe
|
|
|
185
194
|
def __init__(self, spec: ExtensionSpecT) -> None:
|
|
186
195
|
self.spec = spec
|
|
187
196
|
|
|
188
|
-
def parse_server_metadata(self, message:
|
|
197
|
+
def parse_server_metadata(self, message: A2AMessage) -> MetadataFromServerT | None:
|
|
189
198
|
"""
|
|
190
199
|
Client should use this method to retrieve extension-associated metadata from a message.
|
|
191
200
|
"""
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import uuid
|
|
7
|
+
from types import NoneType
|
|
8
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
9
|
+
|
|
10
|
+
import a2a.types
|
|
11
|
+
from mcp import Implementation, Tool
|
|
12
|
+
from pydantic import BaseModel, Discriminator, Field, TypeAdapter
|
|
13
|
+
|
|
14
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
15
|
+
from agentstack_sdk.a2a.types import AgentMessage, InputRequired
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from agentstack_sdk.server.context import RunContext
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ApprovalRejectionError(RuntimeError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GenericApprovalRequest(BaseModel):
|
|
26
|
+
action: Literal["generic"] = "generic"
|
|
27
|
+
|
|
28
|
+
title: str | None = Field(None, description="A human-readable title for the action being approved.")
|
|
29
|
+
description: str | None = Field(None, description="A human-readable description of the action being approved.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ToolCallServer(BaseModel):
|
|
33
|
+
name: str = Field(description="The programmatic name of the server.")
|
|
34
|
+
title: str | None = Field(description="A human-readable title for the server.")
|
|
35
|
+
version: str = Field(description="The version of the server.")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ToolCallApprovalRequest(BaseModel):
|
|
39
|
+
action: Literal["tool-call"] = "tool-call"
|
|
40
|
+
|
|
41
|
+
title: str | None = Field(None, description="A human-readable title for the tool call being approved.")
|
|
42
|
+
description: str | None = Field(None, description="A human-readable description of the tool call being approved.")
|
|
43
|
+
name: str = Field(description="The programmatic name of the tool.")
|
|
44
|
+
input: dict[str, Any] | None = Field(description="The input for the tool.")
|
|
45
|
+
server: ToolCallServer | None = Field(None, description="The server executing the tool.")
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def from_mcp_tool(
|
|
49
|
+
tool: Tool, input: dict[str, Any] | None, server: Implementation | None = None
|
|
50
|
+
) -> ToolCallApprovalRequest:
|
|
51
|
+
return ToolCallApprovalRequest(
|
|
52
|
+
name=tool.name,
|
|
53
|
+
title=tool.annotations.title if tool.annotations else None,
|
|
54
|
+
description=tool.description,
|
|
55
|
+
input=input,
|
|
56
|
+
server=ToolCallServer(name=server.name, title=server.title, version=server.version) if server else None,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
ApprovalRequest = Annotated[GenericApprovalRequest | ToolCallApprovalRequest, Discriminator("action")]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ApprovalResponse(BaseModel):
|
|
64
|
+
decision: Literal["approve", "reject"]
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def approved(self) -> bool:
|
|
68
|
+
return self.decision == "approve"
|
|
69
|
+
|
|
70
|
+
def raise_on_rejection(self) -> None:
|
|
71
|
+
if self.decision == "reject":
|
|
72
|
+
raise ApprovalRejectionError("Approval request has been rejected")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ApprovalExtensionParams(BaseModel):
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]):
|
|
80
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ApprovalExtensionMetadata(BaseModel):
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]):
|
|
88
|
+
def create_request_message(self, *, request: ApprovalRequest):
|
|
89
|
+
return AgentMessage(text="Approval requested", metadata={self.spec.URI: request.model_dump(mode="json")})
|
|
90
|
+
|
|
91
|
+
def parse_response(self, *, message: a2a.types.Message):
|
|
92
|
+
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
93
|
+
raise ValueError("Approval response data is missing")
|
|
94
|
+
return ApprovalResponse.model_validate(data)
|
|
95
|
+
|
|
96
|
+
async def request_approval(
|
|
97
|
+
self,
|
|
98
|
+
request: ApprovalRequest,
|
|
99
|
+
*,
|
|
100
|
+
context: RunContext,
|
|
101
|
+
) -> ApprovalResponse:
|
|
102
|
+
message = self.create_request_message(request=request)
|
|
103
|
+
message = await context.yield_async(InputRequired(message=message))
|
|
104
|
+
if not message:
|
|
105
|
+
raise RuntimeError("Yield did not return a message")
|
|
106
|
+
return self.parse_response(message=message)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ApprovalExtensionClient(BaseExtensionClient[ApprovalExtensionSpec, NoneType]):
|
|
110
|
+
def create_response_message(self, *, response: ApprovalResponse, task_id: str | None):
|
|
111
|
+
return a2a.types.Message(
|
|
112
|
+
message_id=str(uuid.uuid4()),
|
|
113
|
+
role=a2a.types.Role.user,
|
|
114
|
+
parts=[],
|
|
115
|
+
task_id=task_id,
|
|
116
|
+
metadata={self.spec.URI: response.model_dump(mode="json")},
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def parse_request(self, *, message: a2a.types.Message):
|
|
120
|
+
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
121
|
+
raise ValueError("Approval request data is missing")
|
|
122
|
+
return TypeAdapter(ApprovalRequest).validate_python(data)
|
|
123
|
+
|
|
124
|
+
def metadata(self) -> dict[str, Any]:
|
|
125
|
+
return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json")}
|
|
@@ -5,12 +5,18 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import re
|
|
7
7
|
from types import NoneType
|
|
8
|
-
from typing import Any, Self
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
9
9
|
|
|
10
10
|
import pydantic
|
|
11
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
12
|
+
from a2a.types import Message as A2AMessage
|
|
13
|
+
from typing_extensions import override
|
|
11
14
|
|
|
12
15
|
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
13
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from agentstack_sdk.server.context import RunContext
|
|
19
|
+
|
|
14
20
|
|
|
15
21
|
class EmbeddingFulfillment(pydantic.BaseModel):
|
|
16
22
|
identifier: str | None = None
|
|
@@ -78,10 +84,11 @@ class EmbeddingServiceExtensionMetadata(pydantic.BaseModel):
|
|
|
78
84
|
class EmbeddingServiceExtensionServer(
|
|
79
85
|
BaseExtensionServer[EmbeddingServiceExtensionSpec, EmbeddingServiceExtensionMetadata]
|
|
80
86
|
):
|
|
81
|
-
|
|
87
|
+
@override
|
|
88
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
82
89
|
from agentstack_sdk.platform import get_platform_client
|
|
83
90
|
|
|
84
|
-
super().handle_incoming_message(message,
|
|
91
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
85
92
|
if not self.data:
|
|
86
93
|
return
|
|
87
94
|
|
|
@@ -8,12 +8,13 @@ from types import NoneType
|
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Self
|
|
9
9
|
|
|
10
10
|
import pydantic
|
|
11
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
12
|
+
from a2a.types import Message as A2AMessage
|
|
13
|
+
from typing_extensions import override
|
|
11
14
|
|
|
12
15
|
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
13
16
|
|
|
14
17
|
if TYPE_CHECKING:
|
|
15
|
-
from a2a.types import Message
|
|
16
|
-
|
|
17
18
|
from agentstack_sdk.server.context import RunContext
|
|
18
19
|
|
|
19
20
|
|
|
@@ -81,10 +82,11 @@ class LLMServiceExtensionMetadata(pydantic.BaseModel):
|
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
class LLMServiceExtensionServer(BaseExtensionServer[LLMServiceExtensionSpec, LLMServiceExtensionMetadata]):
|
|
84
|
-
|
|
85
|
+
@override
|
|
86
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
85
87
|
from agentstack_sdk.platform import get_platform_client
|
|
86
88
|
|
|
87
|
-
super().handle_incoming_message(message,
|
|
89
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
88
90
|
if not self.data:
|
|
89
91
|
return
|
|
90
92
|
|
|
@@ -8,10 +8,12 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from types import NoneType
|
|
9
9
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self
|
|
10
10
|
|
|
11
|
-
import a2a.types
|
|
12
11
|
import pydantic
|
|
12
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
13
|
+
from a2a.types import Message as A2AMessage
|
|
13
14
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
14
15
|
from mcp.client.streamable_http import streamablehttp_client
|
|
16
|
+
from typing_extensions import override
|
|
15
17
|
|
|
16
18
|
from agentstack_sdk.a2a.extensions.auth.oauth.oauth import OAuthExtensionServer
|
|
17
19
|
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
@@ -102,8 +104,9 @@ class MCPServiceExtensionMetadata(pydantic.BaseModel):
|
|
|
102
104
|
|
|
103
105
|
|
|
104
106
|
class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCPServiceExtensionMetadata]):
|
|
105
|
-
|
|
106
|
-
|
|
107
|
+
@override
|
|
108
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
109
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
107
110
|
if not self.data:
|
|
108
111
|
return
|
|
109
112
|
|
|
@@ -115,7 +118,8 @@ class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCP
|
|
|
115
118
|
except Exception:
|
|
116
119
|
logger.warning("Platform URL substitution failed", exc_info=True)
|
|
117
120
|
|
|
118
|
-
|
|
121
|
+
@override
|
|
122
|
+
def parse_client_metadata(self, message: A2AMessage) -> MCPServiceExtensionMetadata | None:
|
|
119
123
|
metadata = super().parse_client_metadata(message)
|
|
120
124
|
if metadata:
|
|
121
125
|
for name, demand in self.spec.params.mcp_demands.items():
|
|
@@ -9,9 +9,12 @@ from contextlib import asynccontextmanager
|
|
|
9
9
|
from types import NoneType
|
|
10
10
|
from typing import TYPE_CHECKING
|
|
11
11
|
|
|
12
|
-
import a2a.types
|
|
13
12
|
import pydantic
|
|
13
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
14
|
+
from a2a.types import Message as A2AMessage
|
|
15
|
+
from fastapi.security.utils import get_authorization_scheme_param
|
|
14
16
|
from pydantic.networks import HttpUrl
|
|
17
|
+
from typing_extensions import override
|
|
15
18
|
|
|
16
19
|
from agentstack_sdk.a2a.extensions.base import (
|
|
17
20
|
BaseExtensionClient,
|
|
@@ -21,6 +24,7 @@ from agentstack_sdk.a2a.extensions.base import (
|
|
|
21
24
|
from agentstack_sdk.a2a.extensions.exceptions import ExtensionError
|
|
22
25
|
from agentstack_sdk.platform import use_platform_client
|
|
23
26
|
from agentstack_sdk.platform.client import PlatformClient
|
|
27
|
+
from agentstack_sdk.server.middleware.platform_auth_backend import PlatformAuthenticatedUser
|
|
24
28
|
from agentstack_sdk.util.httpx import BearerAuth
|
|
25
29
|
|
|
26
30
|
if TYPE_CHECKING:
|
|
@@ -29,7 +33,7 @@ if TYPE_CHECKING:
|
|
|
29
33
|
|
|
30
34
|
class PlatformApiExtensionMetadata(pydantic.BaseModel):
|
|
31
35
|
base_url: HttpUrl | None = None
|
|
32
|
-
auth_token: pydantic.Secret[str]
|
|
36
|
+
auth_token: pydantic.Secret[str] | None = None
|
|
33
37
|
expires_at: pydantic.AwareDatetime | None = None
|
|
34
38
|
|
|
35
39
|
|
|
@@ -53,13 +57,8 @@ class PlatformApiExtensionSpec(BaseExtensionSpec[PlatformApiExtensionParams]):
|
|
|
53
57
|
class PlatformApiExtensionServer(BaseExtensionServer[PlatformApiExtensionSpec, PlatformApiExtensionMetadata]):
|
|
54
58
|
context_id: str | None = None
|
|
55
59
|
|
|
56
|
-
def parse_client_metadata(self, message: a2a.types.Message) -> PlatformApiExtensionMetadata | None:
|
|
57
|
-
self.context_id = message.context_id
|
|
58
|
-
# we assume that the context id is the same ID as the platform context id
|
|
59
|
-
# if different IDs are passed, api requests to platform using this token will fail
|
|
60
|
-
return super().parse_client_metadata(message)
|
|
61
|
-
|
|
62
60
|
@asynccontextmanager
|
|
61
|
+
@override
|
|
63
62
|
async def lifespan(self) -> AsyncIterator[None]:
|
|
64
63
|
"""Called when entering the agent context after the first message was parsed (__call__ was already called)"""
|
|
65
64
|
if self.data and self.spec.params.auto_use:
|
|
@@ -68,25 +67,44 @@ class PlatformApiExtensionServer(BaseExtensionServer[PlatformApiExtensionSpec, P
|
|
|
68
67
|
else:
|
|
69
68
|
yield
|
|
70
69
|
|
|
71
|
-
def
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
70
|
+
def _get_header_token(self, request_context: RequestContext) -> pydantic.Secret[str] | None:
|
|
71
|
+
header_token = None
|
|
72
|
+
call_context = request_context.call_context
|
|
73
|
+
assert call_context
|
|
74
|
+
if isinstance(call_context.user, PlatformAuthenticatedUser):
|
|
75
|
+
header_token = call_context.user.auth_token.get_secret_value()
|
|
76
|
+
elif auth_header := call_context.state.get("headers", {}).get("authorization", None):
|
|
77
|
+
_scheme, header_token = get_authorization_scheme_param(auth_header)
|
|
78
|
+
return pydantic.Secret(header_token) if header_token else None
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
82
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
83
|
+
# we assume that request context id is the same ID as the platform context id
|
|
84
|
+
# if different IDs are passed, api requests to platform using this token will fail
|
|
85
|
+
self.context_id = request_context.context_id
|
|
86
|
+
|
|
87
|
+
self._metadata_from_client = self._metadata_from_client or PlatformApiExtensionMetadata()
|
|
88
|
+
data = self._metadata_from_client
|
|
89
|
+
data.base_url = data.base_url or HttpUrl(os.getenv("PLATFORM_URL", "http://127.0.0.1:8333"))
|
|
90
|
+
data.auth_token = data.auth_token or self._get_header_token(request_context)
|
|
91
|
+
|
|
92
|
+
if not data.auth_token:
|
|
93
|
+
raise ExtensionError(self.spec, "Platform extension metadata was not provided")
|
|
75
94
|
|
|
76
95
|
@asynccontextmanager
|
|
77
96
|
async def use_client(self) -> AsyncIterator[PlatformClient]:
|
|
78
|
-
if not self.data:
|
|
97
|
+
if not self.data or not self.data.auth_token:
|
|
79
98
|
raise ExtensionError(self.spec, "Platform extension metadata was not provided")
|
|
80
|
-
auth_token = self.data.auth_token.get_secret_value()
|
|
81
99
|
async with use_platform_client(
|
|
82
100
|
context_id=self.context_id,
|
|
83
101
|
base_url=str(self.data.base_url),
|
|
84
|
-
auth_token=auth_token,
|
|
102
|
+
auth_token=self.data.auth_token.get_secret_value(),
|
|
85
103
|
) as client:
|
|
86
104
|
yield client
|
|
87
105
|
|
|
88
106
|
async def create_httpx_auth(self) -> BearerAuth:
|
|
89
|
-
if not self.data:
|
|
107
|
+
if not self.data or not self.data.auth_token:
|
|
90
108
|
raise ExtensionError(self.spec, "Platform extension metadata was not provided")
|
|
91
109
|
return BearerAuth(token=self.data.auth_token.get_secret_value())
|
|
92
110
|
|
|
@@ -6,8 +6,10 @@ from __future__ import annotations
|
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
8
|
import pydantic
|
|
9
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
9
10
|
from a2a.types import Artifact, TextPart
|
|
10
11
|
from a2a.types import Message as A2AMessage
|
|
12
|
+
from typing_extensions import override
|
|
11
13
|
|
|
12
14
|
if TYPE_CHECKING:
|
|
13
15
|
from agentstack_sdk.server.context import RunContext
|
|
@@ -37,12 +39,13 @@ class CanvasExtensionSpec(NoParamsBaseExtensionSpec):
|
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
class CanvasExtensionServer(BaseExtensionServer[CanvasExtensionSpec, CanvasEditRequestMetadata]):
|
|
40
|
-
|
|
42
|
+
@override
|
|
43
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
41
44
|
if message.metadata and self.spec.URI in message.metadata and message.parts:
|
|
42
45
|
message.parts = [part for part in message.parts if not isinstance(part.root, TextPart)]
|
|
43
46
|
|
|
44
|
-
super().handle_incoming_message(message,
|
|
45
|
-
self.context =
|
|
47
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
48
|
+
self.context = run_context
|
|
46
49
|
|
|
47
50
|
async def parse_canvas_edit_request(self, *, message: A2AMessage) -> CanvasEditRequest | None:
|
|
48
51
|
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
@@ -18,7 +18,8 @@ from agentstack_sdk.a2a.extensions.base import (
|
|
|
18
18
|
BaseExtensionServer,
|
|
19
19
|
BaseExtensionSpec,
|
|
20
20
|
)
|
|
21
|
-
from agentstack_sdk.a2a.types import AgentMessage,
|
|
21
|
+
from agentstack_sdk.a2a.types import AgentMessage, Metadata
|
|
22
|
+
from agentstack_sdk.types import JsonValue
|
|
22
23
|
from agentstack_sdk.util import resource_context
|
|
23
24
|
|
|
24
25
|
logger = logging.getLogger(__name__)
|
|
@@ -68,7 +69,7 @@ class ErrorMetadata(pydantic.BaseModel):
|
|
|
68
69
|
|
|
69
70
|
error: Error | ErrorGroup
|
|
70
71
|
stack_trace: str | None = None
|
|
71
|
-
context:
|
|
72
|
+
context: JsonValue | None = None
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class ErrorExtensionParams(pydantic.BaseModel):
|
|
@@ -133,7 +134,7 @@ class ErrorExtensionServer(BaseExtensionServer[ErrorExtensionSpec, NoneType]):
|
|
|
133
134
|
yield
|
|
134
135
|
|
|
135
136
|
@property
|
|
136
|
-
def context(self) ->
|
|
137
|
+
def context(self) -> JsonValue:
|
|
137
138
|
"""Get the current request's error context."""
|
|
138
139
|
try:
|
|
139
140
|
return get_error_extension_context().context
|
|
@@ -214,7 +215,7 @@ DEFAULT_ERROR_EXTENSION: Final = ErrorExtensionServer(ErrorExtensionSpec(ErrorEx
|
|
|
214
215
|
|
|
215
216
|
class ErrorContext(pydantic.BaseModel, arbitrary_types_allowed=True):
|
|
216
217
|
server: ErrorExtensionServer = pydantic.Field(default=DEFAULT_ERROR_EXTENSION)
|
|
217
|
-
context:
|
|
218
|
+
context: JsonValue = pydantic.Field(default_factory=dict)
|
|
218
219
|
|
|
219
220
|
|
|
220
221
|
get_error_extension_context, use_error_extension_context = resource_context(
|
|
@@ -5,8 +5,10 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from typing import TYPE_CHECKING, TypeVar, cast
|
|
7
7
|
|
|
8
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
8
9
|
from a2a.types import Message as A2AMessage
|
|
9
10
|
from pydantic import TypeAdapter
|
|
11
|
+
from typing_extensions import override
|
|
10
12
|
|
|
11
13
|
from agentstack_sdk.a2a.extensions.base import (
|
|
12
14
|
BaseExtensionClient,
|
|
@@ -27,9 +29,10 @@ class FormRequestExtensionSpec(NoParamsBaseExtensionSpec):
|
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
class FormRequestExtensionServer(BaseExtensionServer[FormRequestExtensionSpec, FormResponse]):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
@override
|
|
33
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
34
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
35
|
+
self.context = run_context
|
|
33
36
|
|
|
34
37
|
async def request_form(self, *, form: FormRender, model: type[T] = FormResponse) -> T | None:
|
|
35
38
|
message = await self.context.yield_async(
|