agentstack-sdk 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/__init__.py +6 -0
- agentstack_sdk/a2a/__init__.py +2 -0
- agentstack_sdk/a2a/extensions/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/auth/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +151 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/base.py +11 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/memory.py +38 -0
- agentstack_sdk/a2a/extensions/auth/secrets/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +77 -0
- agentstack_sdk/a2a/extensions/base.py +205 -0
- agentstack_sdk/a2a/extensions/common/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/common/form.py +149 -0
- agentstack_sdk/a2a/extensions/exceptions.py +11 -0
- agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
- agentstack_sdk/a2a/extensions/services/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/services/embedding.py +106 -0
- agentstack_sdk/a2a/extensions/services/form.py +54 -0
- agentstack_sdk/a2a/extensions/services/llm.py +100 -0
- agentstack_sdk/a2a/extensions/services/mcp.py +193 -0
- agentstack_sdk/a2a/extensions/services/platform.py +141 -0
- agentstack_sdk/a2a/extensions/tools/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/tools/call.py +114 -0
- agentstack_sdk/a2a/extensions/tools/exceptions.py +6 -0
- agentstack_sdk/a2a/extensions/ui/__init__.py +10 -0
- agentstack_sdk/a2a/extensions/ui/agent_detail.py +54 -0
- agentstack_sdk/a2a/extensions/ui/canvas.py +71 -0
- agentstack_sdk/a2a/extensions/ui/citation.py +78 -0
- agentstack_sdk/a2a/extensions/ui/error.py +223 -0
- agentstack_sdk/a2a/extensions/ui/form_request.py +52 -0
- agentstack_sdk/a2a/extensions/ui/settings.py +73 -0
- agentstack_sdk/a2a/extensions/ui/trajectory.py +70 -0
- agentstack_sdk/a2a/types.py +104 -0
- agentstack_sdk/platform/__init__.py +12 -0
- agentstack_sdk/platform/client.py +123 -0
- agentstack_sdk/platform/common.py +37 -0
- agentstack_sdk/platform/configuration.py +47 -0
- agentstack_sdk/platform/context.py +291 -0
- agentstack_sdk/platform/file.py +295 -0
- agentstack_sdk/platform/model_provider.py +131 -0
- agentstack_sdk/platform/provider.py +219 -0
- agentstack_sdk/platform/provider_build.py +190 -0
- agentstack_sdk/platform/types.py +45 -0
- agentstack_sdk/platform/user.py +70 -0
- agentstack_sdk/platform/user_feedback.py +42 -0
- agentstack_sdk/platform/variables.py +44 -0
- agentstack_sdk/platform/vector_store.py +217 -0
- agentstack_sdk/py.typed +0 -0
- agentstack_sdk/server/__init__.py +4 -0
- agentstack_sdk/server/agent.py +594 -0
- agentstack_sdk/server/app.py +87 -0
- agentstack_sdk/server/constants.py +9 -0
- agentstack_sdk/server/context.py +68 -0
- agentstack_sdk/server/dependencies.py +117 -0
- agentstack_sdk/server/exceptions.py +3 -0
- agentstack_sdk/server/middleware/__init__.py +3 -0
- agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
- agentstack_sdk/server/server.py +376 -0
- agentstack_sdk/server/store/__init__.py +3 -0
- agentstack_sdk/server/store/context_store.py +35 -0
- agentstack_sdk/server/store/memory_context_store.py +59 -0
- agentstack_sdk/server/store/platform_context_store.py +58 -0
- agentstack_sdk/server/telemetry.py +53 -0
- agentstack_sdk/server/utils.py +26 -0
- agentstack_sdk/types.py +15 -0
- agentstack_sdk/util/__init__.py +4 -0
- agentstack_sdk/util/file.py +260 -0
- agentstack_sdk/util/httpx.py +18 -0
- agentstack_sdk/util/logging.py +63 -0
- agentstack_sdk/util/resource_context.py +44 -0
- agentstack_sdk/util/utils.py +47 -0
- agentstack_sdk-0.5.2rc2.dist-info/METADATA +120 -0
- agentstack_sdk-0.5.2rc2.dist-info/RECORD +76 -0
- agentstack_sdk-0.5.2rc2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionSpec
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ExtensionError(Exception):
|
|
8
|
+
extension: BaseExtensionSpec
|
|
9
|
+
|
|
10
|
+
def __init__(self, spec: BaseExtensionSpec, message: str):
|
|
11
|
+
super().__init__(f"Exception in extension '{spec.URI}': \n{message}")
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import uuid
|
|
7
|
+
from types import NoneType
|
|
8
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
9
|
+
|
|
10
|
+
import a2a.types
|
|
11
|
+
from mcp import Implementation, Tool
|
|
12
|
+
from pydantic import BaseModel, Discriminator, Field, TypeAdapter
|
|
13
|
+
|
|
14
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
15
|
+
from agentstack_sdk.a2a.types import AgentMessage, InputRequired
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from agentstack_sdk.server.context import RunContext
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ApprovalRejectionError(RuntimeError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GenericApprovalRequest(BaseModel):
|
|
26
|
+
action: Literal["generic"] = "generic"
|
|
27
|
+
|
|
28
|
+
title: str | None = Field(None, description="A human-readable title for the action being approved.")
|
|
29
|
+
description: str | None = Field(None, description="A human-readable description of the action being approved.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ToolCallServer(BaseModel):
|
|
33
|
+
name: str = Field(description="The programmatic name of the server.")
|
|
34
|
+
title: str | None = Field(description="A human-readable title for the server.")
|
|
35
|
+
version: str = Field(description="The version of the server.")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ToolCallApprovalRequest(BaseModel):
|
|
39
|
+
action: Literal["tool-call"] = "tool-call"
|
|
40
|
+
|
|
41
|
+
title: str | None = Field(None, description="A human-readable title for the tool call being approved.")
|
|
42
|
+
description: str | None = Field(None, description="A human-readable description of the tool call being approved.")
|
|
43
|
+
name: str = Field(description="The programmatic name of the tool.")
|
|
44
|
+
input: dict[str, Any] | None = Field(description="The input for the tool.")
|
|
45
|
+
server: ToolCallServer | None = Field(None, description="The server executing the tool.")
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def from_mcp_tool(
|
|
49
|
+
tool: Tool, input: dict[str, Any] | None, server: Implementation | None = None
|
|
50
|
+
) -> ToolCallApprovalRequest:
|
|
51
|
+
return ToolCallApprovalRequest(
|
|
52
|
+
name=tool.name,
|
|
53
|
+
title=tool.annotations.title if tool.annotations else None,
|
|
54
|
+
description=tool.description,
|
|
55
|
+
input=input,
|
|
56
|
+
server=ToolCallServer(name=server.name, title=server.title, version=server.version) if server else None,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
ApprovalRequest = Annotated[GenericApprovalRequest | ToolCallApprovalRequest, Discriminator("action")]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ApprovalResponse(BaseModel):
|
|
64
|
+
decision: Literal["approve", "reject"]
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def approved(self) -> bool:
|
|
68
|
+
return self.decision == "approve"
|
|
69
|
+
|
|
70
|
+
def raise_on_rejection(self) -> None:
|
|
71
|
+
if self.decision == "reject":
|
|
72
|
+
raise ApprovalRejectionError("Approval request has been rejected")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ApprovalExtensionParams(BaseModel):
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]):
|
|
80
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ApprovalExtensionMetadata(BaseModel):
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]):
|
|
88
|
+
def create_request_message(self, *, request: ApprovalRequest):
|
|
89
|
+
return AgentMessage(text="Approval requested", metadata={self.spec.URI: request.model_dump(mode="json")})
|
|
90
|
+
|
|
91
|
+
def parse_response(self, *, message: a2a.types.Message):
|
|
92
|
+
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
93
|
+
raise ValueError("Approval response data is missing")
|
|
94
|
+
return ApprovalResponse.model_validate(data)
|
|
95
|
+
|
|
96
|
+
async def request_approval(
|
|
97
|
+
self,
|
|
98
|
+
request: ApprovalRequest,
|
|
99
|
+
*,
|
|
100
|
+
context: RunContext,
|
|
101
|
+
) -> ApprovalResponse:
|
|
102
|
+
message = self.create_request_message(request=request)
|
|
103
|
+
message = await context.yield_async(InputRequired(message=message))
|
|
104
|
+
if not message:
|
|
105
|
+
raise RuntimeError("Yield did not return a message")
|
|
106
|
+
return self.parse_response(message=message)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ApprovalExtensionClient(BaseExtensionClient[ApprovalExtensionSpec, NoneType]):
|
|
110
|
+
def create_response_message(self, *, response: ApprovalResponse, task_id: str | None):
|
|
111
|
+
return a2a.types.Message(
|
|
112
|
+
message_id=str(uuid.uuid4()),
|
|
113
|
+
role=a2a.types.Role.user,
|
|
114
|
+
parts=[],
|
|
115
|
+
task_id=task_id,
|
|
116
|
+
metadata={self.spec.URI: response.model_dump(mode="json")},
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def parse_request(self, *, message: a2a.types.Message):
|
|
120
|
+
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
|
|
121
|
+
raise ValueError("Approval request data is missing")
|
|
122
|
+
return TypeAdapter(ApprovalRequest).validate_python(data)
|
|
123
|
+
|
|
124
|
+
def metadata(self) -> dict[str, Any]:
|
|
125
|
+
return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json")}
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import re
|
|
7
|
+
from types import NoneType
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
9
|
+
|
|
10
|
+
import pydantic
|
|
11
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
12
|
+
from a2a.types import Message as A2AMessage
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from agentstack_sdk.server.context import RunContext
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EmbeddingFulfillment(pydantic.BaseModel):
|
|
22
|
+
identifier: str | None = None
|
|
23
|
+
"""
|
|
24
|
+
Name of the model for identification and optimization purposes. Usually corresponds to LiteLLM identifiers.
|
|
25
|
+
Should be the name of the provider slash name of the model as it appears in the API.
|
|
26
|
+
Examples: openai/text-embedding-3-small, vertex_ai/textembedding-gecko, ollama/nomic-embed-text:latest
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
api_base: str
|
|
30
|
+
"""
|
|
31
|
+
Base URL for an OpenAI-compatible API. It should provide at least /v1/chat/completions
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
api_key: str
|
|
35
|
+
"""
|
|
36
|
+
API key to attach as a `Authorization: Bearer $api_key` header.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
api_model: str
|
|
40
|
+
"""
|
|
41
|
+
Model name to use with the /v1/chat/completions API.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EmbeddingDemand(pydantic.BaseModel):
|
|
46
|
+
description: str | None = None
|
|
47
|
+
"""
|
|
48
|
+
Short description of how the model will be used, if multiple are requested.
|
|
49
|
+
Intended to be shown in the UI alongside a model picker dropdown.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
suggested: tuple[str, ...] = ()
|
|
53
|
+
"""
|
|
54
|
+
Identifiers of models recommended to be used. Usually corresponds to LiteLLM identifiers.
|
|
55
|
+
Should be the name of the provider slash name of the model as it appears in the API.
|
|
56
|
+
Examples: openai/text-embedding-3-small, vertex_ai/textembedding-gecko, ollama/nomic-embed-text:latest
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class EmbeddingServiceExtensionParams(pydantic.BaseModel):
|
|
61
|
+
embedding_demands: dict[str, EmbeddingDemand]
|
|
62
|
+
"""Model requests that the agent requires to be provided by the client."""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class EmbeddingServiceExtensionSpec(BaseExtensionSpec[EmbeddingServiceExtensionParams]):
|
|
66
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/embedding/v1"
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def single_demand(
|
|
70
|
+
cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = ()
|
|
71
|
+
) -> Self:
|
|
72
|
+
return cls(
|
|
73
|
+
params=EmbeddingServiceExtensionParams(
|
|
74
|
+
embedding_demands={name or "default": EmbeddingDemand(description=description, suggested=suggested)}
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class EmbeddingServiceExtensionMetadata(pydantic.BaseModel):
|
|
80
|
+
embedding_fulfillments: dict[str, EmbeddingFulfillment] = {}
|
|
81
|
+
"""Provided models corresponding to the model requests."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class EmbeddingServiceExtensionServer(
|
|
85
|
+
BaseExtensionServer[EmbeddingServiceExtensionSpec, EmbeddingServiceExtensionMetadata]
|
|
86
|
+
):
|
|
87
|
+
@override
|
|
88
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
89
|
+
from agentstack_sdk.platform import get_platform_client
|
|
90
|
+
|
|
91
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
92
|
+
if not self.data:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
for fullfilment in self.data.embedding_fulfillments.values():
|
|
96
|
+
platform_url = str(get_platform_client().base_url)
|
|
97
|
+
fullfilment.api_base = re.sub("{platform_url}", platform_url, fullfilment.api_base)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class EmbeddingServiceExtensionClient(BaseExtensionClient[EmbeddingServiceExtensionSpec, NoneType]):
|
|
101
|
+
def fulfillment_metadata(self, *, embedding_fulfillments: dict[str, EmbeddingFulfillment]) -> dict[str, Any]:
|
|
102
|
+
return {
|
|
103
|
+
self.spec.URI: EmbeddingServiceExtensionMetadata(embedding_fulfillments=embedding_fulfillments).model_dump(
|
|
104
|
+
mode="json"
|
|
105
|
+
)
|
|
106
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Self, TypeVar, cast
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, TypeAdapter
|
|
10
|
+
from typing_extensions import TypedDict
|
|
11
|
+
|
|
12
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
13
|
+
from agentstack_sdk.a2a.extensions.common.form import FormRender, FormResponse
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FormDemands(TypedDict):
|
|
17
|
+
initial_form: FormRender | None
|
|
18
|
+
# TODO: We can put settings here too
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FormServiceExtensionMetadata(BaseModel):
|
|
22
|
+
form_fulfillments: dict[str, FormResponse] = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FormServiceExtensionParams(BaseModel):
|
|
26
|
+
form_demands: FormDemands
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FormServiceExtensionSpec(BaseExtensionSpec[FormServiceExtensionParams]):
|
|
30
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/form/v1"
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def demand(cls, initial_form: FormRender | None) -> Self:
|
|
34
|
+
return cls(params=FormServiceExtensionParams(form_demands={"initial_form": initial_form}))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
T = TypeVar("T")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FormServiceExtensionServer(BaseExtensionServer[FormServiceExtensionSpec, FormServiceExtensionMetadata]):
|
|
41
|
+
def parse_initial_form(self, *, model: type[T] = FormResponse) -> T | None:
|
|
42
|
+
if self.data is None:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
initial_form = self.data.form_fulfillments.get("initial_form")
|
|
46
|
+
|
|
47
|
+
if initial_form is None:
|
|
48
|
+
return None
|
|
49
|
+
if model is FormResponse:
|
|
50
|
+
return cast(T, initial_form)
|
|
51
|
+
return TypeAdapter(model).validate_python(dict(initial_form))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class FormServiceExtensionClient(BaseExtensionClient[FormServiceExtensionSpec, FormRender]): ...
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import re
|
|
7
|
+
from types import NoneType
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
9
|
+
|
|
10
|
+
import pydantic
|
|
11
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
12
|
+
from a2a.types import Message as A2AMessage
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from agentstack_sdk.server.context import RunContext
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LLMFulfillment(pydantic.BaseModel):
|
|
22
|
+
identifier: str | None = None
|
|
23
|
+
"""
|
|
24
|
+
Name of the model for identification and optimization purposes. Usually corresponds to LiteLLM identifiers.
|
|
25
|
+
Should be the name of the provider slash name of the model as it appears in the API.
|
|
26
|
+
Examples: openai/gpt-4o, watsonx/ibm/granite-13b-chat-v2, ollama/mistral-small:22b
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
api_base: str
|
|
30
|
+
"""
|
|
31
|
+
Base URL for an OpenAI-compatible API. It should provide at least /v1/chat/completions
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
api_key: str
|
|
35
|
+
"""
|
|
36
|
+
API key to attach as a `Authorization: Bearer $api_key` header.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
api_model: str
|
|
40
|
+
"""
|
|
41
|
+
Model name to use with the /v1/chat/completions API.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LLMDemand(pydantic.BaseModel):
|
|
46
|
+
description: str | None = None
|
|
47
|
+
"""
|
|
48
|
+
Short description of how the model will be used, if multiple are requested.
|
|
49
|
+
Intended to be shown in the UI alongside a model picker dropdown.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
suggested: tuple[str, ...] = ()
|
|
53
|
+
"""
|
|
54
|
+
Identifiers of models recommended to be used. Usually corresponds to LiteLLM identifiers.
|
|
55
|
+
Should be the name of the provider slash name of the model as it appears in the API.
|
|
56
|
+
Examples: openai/gpt-4o, watsonx/ibm/granite-13b-chat-v2, ollama/mistral-small:22b
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LLMServiceExtensionParams(pydantic.BaseModel):
|
|
61
|
+
llm_demands: dict[str, LLMDemand]
|
|
62
|
+
"""Model requests that the agent requires to be provided by the client."""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LLMServiceExtensionSpec(BaseExtensionSpec[LLMServiceExtensionParams]):
|
|
66
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/llm/v1"
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def single_demand(
|
|
70
|
+
cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = ()
|
|
71
|
+
) -> Self:
|
|
72
|
+
return cls(
|
|
73
|
+
params=LLMServiceExtensionParams(
|
|
74
|
+
llm_demands={name or "default": LLMDemand(description=description, suggested=suggested)}
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class LLMServiceExtensionMetadata(pydantic.BaseModel):
|
|
80
|
+
llm_fulfillments: dict[str, LLMFulfillment] = {}
|
|
81
|
+
"""Provided models corresponding to the model requests."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LLMServiceExtensionServer(BaseExtensionServer[LLMServiceExtensionSpec, LLMServiceExtensionMetadata]):
|
|
85
|
+
@override
|
|
86
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
87
|
+
from agentstack_sdk.platform import get_platform_client
|
|
88
|
+
|
|
89
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
90
|
+
if not self.data:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
for fullfilment in self.data.llm_fulfillments.values():
|
|
94
|
+
platform_url = str(get_platform_client().base_url)
|
|
95
|
+
fullfilment.api_base = re.sub("{platform_url}", platform_url, fullfilment.api_base)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LLMServiceExtensionClient(BaseExtensionClient[LLMServiceExtensionSpec, NoneType]):
|
|
99
|
+
def fulfillment_metadata(self, *, llm_fulfillments: dict[str, LLMFulfillment]) -> dict[str, Any]:
|
|
100
|
+
return {self.spec.URI: LLMServiceExtensionMetadata(llm_fulfillments=llm_fulfillments).model_dump(mode="json")}
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import re
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from types import NoneType
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self
|
|
10
|
+
|
|
11
|
+
import pydantic
|
|
12
|
+
from a2a.server.agent_execution.context import RequestContext
|
|
13
|
+
from a2a.types import Message as A2AMessage
|
|
14
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
15
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
16
|
+
from typing_extensions import override
|
|
17
|
+
|
|
18
|
+
from agentstack_sdk.a2a.extensions.auth.oauth.oauth import OAuthExtensionServer
|
|
19
|
+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
|
|
20
|
+
from agentstack_sdk.a2a.extensions.services.platform import PlatformApiExtensionServer
|
|
21
|
+
from agentstack_sdk.platform.client import get_platform_client
|
|
22
|
+
from agentstack_sdk.util.logging import logger
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from agentstack_sdk.server.context import RunContext
|
|
26
|
+
|
|
27
|
+
_TRANSPORT_TYPES = Literal["streamable_http", "stdio"]
|
|
28
|
+
|
|
29
|
+
_DEFAULT_DEMAND_NAME = "default"
|
|
30
|
+
_DEFAULT_ALLOWED_TRANSPORTS: list[_TRANSPORT_TYPES] = ["streamable_http"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class StdioTransport(pydantic.BaseModel):
|
|
34
|
+
type: Literal["stdio"] = "stdio"
|
|
35
|
+
|
|
36
|
+
command: str
|
|
37
|
+
args: list[str]
|
|
38
|
+
env: dict[str, str] | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class StreamableHTTPTransport(pydantic.BaseModel):
|
|
42
|
+
type: Literal["streamable_http"] = "streamable_http"
|
|
43
|
+
|
|
44
|
+
url: str
|
|
45
|
+
headers: dict[str, str] | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
MCPTransport = Annotated[StdioTransport | StreamableHTTPTransport, pydantic.Field(discriminator="type")]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MCPFulfillment(pydantic.BaseModel):
|
|
52
|
+
transport: MCPTransport
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MCPDemand(pydantic.BaseModel):
|
|
56
|
+
description: str | None = None
|
|
57
|
+
"""
|
|
58
|
+
Short description of how the server will be used, what tools should it contain, etc.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
suggested: tuple[str, ...] = ()
|
|
62
|
+
"""
|
|
63
|
+
Identifiers of servers recommended to be used. Usually corresponds to MCP StreamableHTTP URIs.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
allowed_transports: list[_TRANSPORT_TYPES] = pydantic.Field(default_factory=lambda: _DEFAULT_ALLOWED_TRANSPORTS)
|
|
67
|
+
"""
|
|
68
|
+
Transports allowed for the server. Specifying other transports will result in rejection.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MCPServiceExtensionParams(pydantic.BaseModel):
|
|
73
|
+
mcp_demands: dict[str, MCPDemand]
|
|
74
|
+
"""Server requests that the agent requires to be provided by the client."""
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class MCPServiceExtensionSpec(BaseExtensionSpec[MCPServiceExtensionParams]):
|
|
78
|
+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/mcp/v1"
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def single_demand(
|
|
82
|
+
cls,
|
|
83
|
+
name: str = _DEFAULT_DEMAND_NAME,
|
|
84
|
+
description: str | None = None,
|
|
85
|
+
suggested: tuple[str, ...] = (),
|
|
86
|
+
allowed_transports: list[_TRANSPORT_TYPES] | None = None,
|
|
87
|
+
) -> Self:
|
|
88
|
+
return cls(
|
|
89
|
+
params=MCPServiceExtensionParams(
|
|
90
|
+
mcp_demands={
|
|
91
|
+
name: MCPDemand(
|
|
92
|
+
description=description,
|
|
93
|
+
suggested=suggested,
|
|
94
|
+
allowed_transports=allowed_transports or _DEFAULT_ALLOWED_TRANSPORTS,
|
|
95
|
+
)
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class MCPServiceExtensionMetadata(pydantic.BaseModel):
|
|
102
|
+
mcp_fulfillments: dict[str, MCPFulfillment] = {}
|
|
103
|
+
"""Provided servers corresponding to the server requests."""
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCPServiceExtensionMetadata]):
|
|
107
|
+
@override
|
|
108
|
+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
|
|
109
|
+
super().handle_incoming_message(message, run_context, request_context)
|
|
110
|
+
if not self.data:
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
platform_url = str(get_platform_client().base_url)
|
|
114
|
+
for fullfilment in self.data.mcp_fulfillments.values():
|
|
115
|
+
if fullfilment.transport.type == "streamable_http":
|
|
116
|
+
try:
|
|
117
|
+
fullfilment.transport.url = re.sub("^{platform_url}", platform_url, str(fullfilment.transport.url))
|
|
118
|
+
except Exception:
|
|
119
|
+
logger.warning("Platform URL substitution failed", exc_info=True)
|
|
120
|
+
|
|
121
|
+
@override
|
|
122
|
+
def parse_client_metadata(self, message: A2AMessage) -> MCPServiceExtensionMetadata | None:
|
|
123
|
+
metadata = super().parse_client_metadata(message)
|
|
124
|
+
if metadata:
|
|
125
|
+
for name, demand in self.spec.params.mcp_demands.items():
|
|
126
|
+
if not (fulfillment := metadata.mcp_fulfillments.get(name)):
|
|
127
|
+
continue
|
|
128
|
+
if fulfillment.transport.type not in demand.allowed_transports:
|
|
129
|
+
raise ValueError(f'Transport "{fulfillment.transport.type}" not allowed for demand "{name}"')
|
|
130
|
+
return metadata
|
|
131
|
+
|
|
132
|
+
def _get_oauth_server(self):
|
|
133
|
+
for dependency in self._dependencies.values():
|
|
134
|
+
if isinstance(dependency, OAuthExtensionServer):
|
|
135
|
+
return dependency
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
def _get_platform_server(self):
|
|
139
|
+
for dependency in self._dependencies.values():
|
|
140
|
+
if isinstance(dependency, PlatformApiExtensionServer):
|
|
141
|
+
return dependency
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
@asynccontextmanager
|
|
145
|
+
async def create_client(self, demand: str = _DEFAULT_DEMAND_NAME):
|
|
146
|
+
fulfillment = self.data.mcp_fulfillments.get(demand) if self.data else None
|
|
147
|
+
|
|
148
|
+
if not fulfillment:
|
|
149
|
+
yield None
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
transport = fulfillment.transport
|
|
153
|
+
|
|
154
|
+
if isinstance(transport, StdioTransport):
|
|
155
|
+
async with stdio_client(
|
|
156
|
+
server=StdioServerParameters(command=transport.command, args=transport.args, env=transport.env)
|
|
157
|
+
) as (
|
|
158
|
+
read,
|
|
159
|
+
write,
|
|
160
|
+
):
|
|
161
|
+
yield (read, write)
|
|
162
|
+
elif isinstance(transport, StreamableHTTPTransport):
|
|
163
|
+
async with streamablehttp_client(
|
|
164
|
+
url=transport.url,
|
|
165
|
+
headers=transport.headers,
|
|
166
|
+
auth=await self._create_auth(transport),
|
|
167
|
+
) as (
|
|
168
|
+
read,
|
|
169
|
+
write,
|
|
170
|
+
_,
|
|
171
|
+
):
|
|
172
|
+
yield (read, write)
|
|
173
|
+
else:
|
|
174
|
+
raise NotImplementedError("Unsupported transport")
|
|
175
|
+
|
|
176
|
+
async def _create_auth(self, transport: StreamableHTTPTransport):
|
|
177
|
+
platform = self._get_platform_server()
|
|
178
|
+
if (
|
|
179
|
+
platform
|
|
180
|
+
and platform.data
|
|
181
|
+
and platform.data.base_url
|
|
182
|
+
and transport.url.startswith(str(platform.data.base_url))
|
|
183
|
+
):
|
|
184
|
+
return await platform.create_httpx_auth()
|
|
185
|
+
oauth = self._get_oauth_server()
|
|
186
|
+
if oauth:
|
|
187
|
+
return await oauth.create_httpx_auth(resource_url=pydantic.AnyUrl(transport.url))
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class MCPServiceExtensionClient(BaseExtensionClient[MCPServiceExtensionSpec, NoneType]):
|
|
192
|
+
def fulfillment_metadata(self, *, mcp_fulfillments: dict[str, MCPFulfillment]) -> dict[str, Any]:
|
|
193
|
+
return {self.spec.URI: MCPServiceExtensionMetadata(mcp_fulfillments=mcp_fulfillments).model_dump(mode="json")}
|