agentstack-sdk 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. agentstack_sdk/__init__.py +6 -0
  2. agentstack_sdk/a2a/__init__.py +2 -0
  3. agentstack_sdk/a2a/extensions/__init__.py +8 -0
  4. agentstack_sdk/a2a/extensions/auth/__init__.py +5 -0
  5. agentstack_sdk/a2a/extensions/auth/oauth/__init__.py +4 -0
  6. agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +151 -0
  7. agentstack_sdk/a2a/extensions/auth/oauth/storage/__init__.py +5 -0
  8. agentstack_sdk/a2a/extensions/auth/oauth/storage/base.py +11 -0
  9. agentstack_sdk/a2a/extensions/auth/oauth/storage/memory.py +38 -0
  10. agentstack_sdk/a2a/extensions/auth/secrets/__init__.py +4 -0
  11. agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +77 -0
  12. agentstack_sdk/a2a/extensions/base.py +205 -0
  13. agentstack_sdk/a2a/extensions/common/__init__.py +4 -0
  14. agentstack_sdk/a2a/extensions/common/form.py +149 -0
  15. agentstack_sdk/a2a/extensions/exceptions.py +11 -0
  16. agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
  17. agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
  18. agentstack_sdk/a2a/extensions/services/__init__.py +8 -0
  19. agentstack_sdk/a2a/extensions/services/embedding.py +106 -0
  20. agentstack_sdk/a2a/extensions/services/form.py +54 -0
  21. agentstack_sdk/a2a/extensions/services/llm.py +100 -0
  22. agentstack_sdk/a2a/extensions/services/mcp.py +193 -0
  23. agentstack_sdk/a2a/extensions/services/platform.py +141 -0
  24. agentstack_sdk/a2a/extensions/tools/__init__.py +5 -0
  25. agentstack_sdk/a2a/extensions/tools/call.py +114 -0
  26. agentstack_sdk/a2a/extensions/tools/exceptions.py +6 -0
  27. agentstack_sdk/a2a/extensions/ui/__init__.py +10 -0
  28. agentstack_sdk/a2a/extensions/ui/agent_detail.py +54 -0
  29. agentstack_sdk/a2a/extensions/ui/canvas.py +71 -0
  30. agentstack_sdk/a2a/extensions/ui/citation.py +78 -0
  31. agentstack_sdk/a2a/extensions/ui/error.py +223 -0
  32. agentstack_sdk/a2a/extensions/ui/form_request.py +52 -0
  33. agentstack_sdk/a2a/extensions/ui/settings.py +73 -0
  34. agentstack_sdk/a2a/extensions/ui/trajectory.py +70 -0
  35. agentstack_sdk/a2a/types.py +104 -0
  36. agentstack_sdk/platform/__init__.py +12 -0
  37. agentstack_sdk/platform/client.py +123 -0
  38. agentstack_sdk/platform/common.py +37 -0
  39. agentstack_sdk/platform/configuration.py +47 -0
  40. agentstack_sdk/platform/context.py +291 -0
  41. agentstack_sdk/platform/file.py +295 -0
  42. agentstack_sdk/platform/model_provider.py +131 -0
  43. agentstack_sdk/platform/provider.py +219 -0
  44. agentstack_sdk/platform/provider_build.py +190 -0
  45. agentstack_sdk/platform/types.py +45 -0
  46. agentstack_sdk/platform/user.py +70 -0
  47. agentstack_sdk/platform/user_feedback.py +42 -0
  48. agentstack_sdk/platform/variables.py +44 -0
  49. agentstack_sdk/platform/vector_store.py +217 -0
  50. agentstack_sdk/py.typed +0 -0
  51. agentstack_sdk/server/__init__.py +4 -0
  52. agentstack_sdk/server/agent.py +594 -0
  53. agentstack_sdk/server/app.py +87 -0
  54. agentstack_sdk/server/constants.py +9 -0
  55. agentstack_sdk/server/context.py +68 -0
  56. agentstack_sdk/server/dependencies.py +117 -0
  57. agentstack_sdk/server/exceptions.py +3 -0
  58. agentstack_sdk/server/middleware/__init__.py +3 -0
  59. agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
  60. agentstack_sdk/server/server.py +376 -0
  61. agentstack_sdk/server/store/__init__.py +3 -0
  62. agentstack_sdk/server/store/context_store.py +35 -0
  63. agentstack_sdk/server/store/memory_context_store.py +59 -0
  64. agentstack_sdk/server/store/platform_context_store.py +58 -0
  65. agentstack_sdk/server/telemetry.py +53 -0
  66. agentstack_sdk/server/utils.py +26 -0
  67. agentstack_sdk/types.py +15 -0
  68. agentstack_sdk/util/__init__.py +4 -0
  69. agentstack_sdk/util/file.py +260 -0
  70. agentstack_sdk/util/httpx.py +18 -0
  71. agentstack_sdk/util/logging.py +63 -0
  72. agentstack_sdk/util/resource_context.py +44 -0
  73. agentstack_sdk/util/utils.py +47 -0
  74. agentstack_sdk-0.5.2rc2.dist-info/METADATA +120 -0
  75. agentstack_sdk-0.5.2rc2.dist-info/RECORD +76 -0
  76. agentstack_sdk-0.5.2rc2.dist-info/WHEEL +4 -0
@@ -0,0 +1,11 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from agentstack_sdk.a2a.extensions.base import BaseExtensionSpec
5
+
6
+
7
+ class ExtensionError(Exception):
8
+ extension: BaseExtensionSpec
9
+
10
+ def __init__(self, spec: BaseExtensionSpec, message: str):
11
+ super().__init__(f"Exception in extension '{spec.URI}': \n{message}")
@@ -0,0 +1,4 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .approval import *
@@ -0,0 +1,125 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import uuid
7
+ from types import NoneType
8
+ from typing import TYPE_CHECKING, Annotated, Any, Literal
9
+
10
+ import a2a.types
11
+ from mcp import Implementation, Tool
12
+ from pydantic import BaseModel, Discriminator, Field, TypeAdapter
13
+
14
+ from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
15
+ from agentstack_sdk.a2a.types import AgentMessage, InputRequired
16
+
17
+ if TYPE_CHECKING:
18
+ from agentstack_sdk.server.context import RunContext
19
+
20
+
21
+ class ApprovalRejectionError(RuntimeError):
22
+ pass
23
+
24
+
25
+ class GenericApprovalRequest(BaseModel):
26
+ action: Literal["generic"] = "generic"
27
+
28
+ title: str | None = Field(None, description="A human-readable title for the action being approved.")
29
+ description: str | None = Field(None, description="A human-readable description of the action being approved.")
30
+
31
+
32
+ class ToolCallServer(BaseModel):
33
+ name: str = Field(description="The programmatic name of the server.")
34
+ title: str | None = Field(description="A human-readable title for the server.")
35
+ version: str = Field(description="The version of the server.")
36
+
37
+
38
+ class ToolCallApprovalRequest(BaseModel):
39
+ action: Literal["tool-call"] = "tool-call"
40
+
41
+ title: str | None = Field(None, description="A human-readable title for the tool call being approved.")
42
+ description: str | None = Field(None, description="A human-readable description of the tool call being approved.")
43
+ name: str = Field(description="The programmatic name of the tool.")
44
+ input: dict[str, Any] | None = Field(description="The input for the tool.")
45
+ server: ToolCallServer | None = Field(None, description="The server executing the tool.")
46
+
47
+ @staticmethod
48
+ def from_mcp_tool(
49
+ tool: Tool, input: dict[str, Any] | None, server: Implementation | None = None
50
+ ) -> ToolCallApprovalRequest:
51
+ return ToolCallApprovalRequest(
52
+ name=tool.name,
53
+ title=tool.annotations.title if tool.annotations else None,
54
+ description=tool.description,
55
+ input=input,
56
+ server=ToolCallServer(name=server.name, title=server.title, version=server.version) if server else None,
57
+ )
58
+
59
+
60
+ ApprovalRequest = Annotated[GenericApprovalRequest | ToolCallApprovalRequest, Discriminator("action")]
61
+
62
+
63
+ class ApprovalResponse(BaseModel):
64
+ decision: Literal["approve", "reject"]
65
+
66
+ @property
67
+ def approved(self) -> bool:
68
+ return self.decision == "approve"
69
+
70
+ def raise_on_rejection(self) -> None:
71
+ if self.decision == "reject":
72
+ raise ApprovalRejectionError("Approval request has been rejected")
73
+
74
+
75
+ class ApprovalExtensionParams(BaseModel):
76
+ pass
77
+
78
+
79
+ class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]):
80
+ URI: str = "https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1"
81
+
82
+
83
+ class ApprovalExtensionMetadata(BaseModel):
84
+ pass
85
+
86
+
87
+ class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]):
88
+ def create_request_message(self, *, request: ApprovalRequest):
89
+ return AgentMessage(text="Approval requested", metadata={self.spec.URI: request.model_dump(mode="json")})
90
+
91
+ def parse_response(self, *, message: a2a.types.Message):
92
+ if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
93
+ raise ValueError("Approval response data is missing")
94
+ return ApprovalResponse.model_validate(data)
95
+
96
+ async def request_approval(
97
+ self,
98
+ request: ApprovalRequest,
99
+ *,
100
+ context: RunContext,
101
+ ) -> ApprovalResponse:
102
+ message = self.create_request_message(request=request)
103
+ message = await context.yield_async(InputRequired(message=message))
104
+ if not message:
105
+ raise RuntimeError("Yield did not return a message")
106
+ return self.parse_response(message=message)
107
+
108
+
109
+ class ApprovalExtensionClient(BaseExtensionClient[ApprovalExtensionSpec, NoneType]):
110
+ def create_response_message(self, *, response: ApprovalResponse, task_id: str | None):
111
+ return a2a.types.Message(
112
+ message_id=str(uuid.uuid4()),
113
+ role=a2a.types.Role.user,
114
+ parts=[],
115
+ task_id=task_id,
116
+ metadata={self.spec.URI: response.model_dump(mode="json")},
117
+ )
118
+
119
+ def parse_request(self, *, message: a2a.types.Message):
120
+ if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
121
+ raise ValueError("Approval request data is missing")
122
+ return TypeAdapter(ApprovalRequest).validate_python(data)
123
+
124
+ def metadata(self) -> dict[str, Any]:
125
+ return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json")}
@@ -0,0 +1,8 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .embedding import *
5
+ from .form import *
6
+ from .llm import *
7
+ from .mcp import *
8
+ from .platform import *
@@ -0,0 +1,106 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import re
7
+ from types import NoneType
8
+ from typing import TYPE_CHECKING, Any, Self
9
+
10
+ import pydantic
11
+ from a2a.server.agent_execution.context import RequestContext
12
+ from a2a.types import Message as A2AMessage
13
+ from typing_extensions import override
14
+
15
+ from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
16
+
17
+ if TYPE_CHECKING:
18
+ from agentstack_sdk.server.context import RunContext
19
+
20
+
21
+ class EmbeddingFulfillment(pydantic.BaseModel):
22
+ identifier: str | None = None
23
+ """
24
+ Name of the model for identification and optimization purposes. Usually corresponds to LiteLLM identifiers.
25
+ Should be the name of the provider slash name of the model as it appears in the API.
26
+ Examples: openai/text-embedding-3-small, vertex_ai/textembedding-gecko, ollama/nomic-embed-text:latest
27
+ """
28
+
29
+ api_base: str
30
+ """
31
+ Base URL for an OpenAI-compatible API. It should provide at least /v1/chat/completions
32
+ """
33
+
34
+ api_key: str
35
+ """
36
+ API key to attach as a `Authorization: Bearer $api_key` header.
37
+ """
38
+
39
+ api_model: str
40
+ """
41
+ Model name to use with the /v1/chat/completions API.
42
+ """
43
+
44
+
45
+ class EmbeddingDemand(pydantic.BaseModel):
46
+ description: str | None = None
47
+ """
48
+ Short description of how the model will be used, if multiple are requested.
49
+ Intended to be shown in the UI alongside a model picker dropdown.
50
+ """
51
+
52
+ suggested: tuple[str, ...] = ()
53
+ """
54
+ Identifiers of models recommended to be used. Usually corresponds to LiteLLM identifiers.
55
+ Should be the name of the provider slash name of the model as it appears in the API.
56
+ Examples: openai/text-embedding-3-small, vertex_ai/textembedding-gecko, ollama/nomic-embed-text:latest
57
+ """
58
+
59
+
60
+ class EmbeddingServiceExtensionParams(pydantic.BaseModel):
61
+ embedding_demands: dict[str, EmbeddingDemand]
62
+ """Model requests that the agent requires to be provided by the client."""
63
+
64
+
65
+ class EmbeddingServiceExtensionSpec(BaseExtensionSpec[EmbeddingServiceExtensionParams]):
66
+ URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/embedding/v1"
67
+
68
+ @classmethod
69
+ def single_demand(
70
+ cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = ()
71
+ ) -> Self:
72
+ return cls(
73
+ params=EmbeddingServiceExtensionParams(
74
+ embedding_demands={name or "default": EmbeddingDemand(description=description, suggested=suggested)}
75
+ )
76
+ )
77
+
78
+
79
+ class EmbeddingServiceExtensionMetadata(pydantic.BaseModel):
80
+ embedding_fulfillments: dict[str, EmbeddingFulfillment] = {}
81
+ """Provided models corresponding to the model requests."""
82
+
83
+
84
+ class EmbeddingServiceExtensionServer(
85
+ BaseExtensionServer[EmbeddingServiceExtensionSpec, EmbeddingServiceExtensionMetadata]
86
+ ):
87
+ @override
88
+ def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
89
+ from agentstack_sdk.platform import get_platform_client
90
+
91
+ super().handle_incoming_message(message, run_context, request_context)
92
+ if not self.data:
93
+ return
94
+
95
+ for fullfilment in self.data.embedding_fulfillments.values():
96
+ platform_url = str(get_platform_client().base_url)
97
+ fullfilment.api_base = re.sub("{platform_url}", platform_url, fullfilment.api_base)
98
+
99
+
100
+ class EmbeddingServiceExtensionClient(BaseExtensionClient[EmbeddingServiceExtensionSpec, NoneType]):
101
+ def fulfillment_metadata(self, *, embedding_fulfillments: dict[str, EmbeddingFulfillment]) -> dict[str, Any]:
102
+ return {
103
+ self.spec.URI: EmbeddingServiceExtensionMetadata(embedding_fulfillments=embedding_fulfillments).model_dump(
104
+ mode="json"
105
+ )
106
+ }
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Self, TypeVar, cast
8
+
9
+ from pydantic import BaseModel, TypeAdapter
10
+ from typing_extensions import TypedDict
11
+
12
+ from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
13
+ from agentstack_sdk.a2a.extensions.common.form import FormRender, FormResponse
14
+
15
+
16
+ class FormDemands(TypedDict):
17
+ initial_form: FormRender | None
18
+ # TODO: We can put settings here too
19
+
20
+
21
+ class FormServiceExtensionMetadata(BaseModel):
22
+ form_fulfillments: dict[str, FormResponse] = {}
23
+
24
+
25
+ class FormServiceExtensionParams(BaseModel):
26
+ form_demands: FormDemands
27
+
28
+
29
+ class FormServiceExtensionSpec(BaseExtensionSpec[FormServiceExtensionParams]):
30
+ URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/form/v1"
31
+
32
+ @classmethod
33
+ def demand(cls, initial_form: FormRender | None) -> Self:
34
+ return cls(params=FormServiceExtensionParams(form_demands={"initial_form": initial_form}))
35
+
36
+
37
+ T = TypeVar("T")
38
+
39
+
40
+ class FormServiceExtensionServer(BaseExtensionServer[FormServiceExtensionSpec, FormServiceExtensionMetadata]):
41
+ def parse_initial_form(self, *, model: type[T] = FormResponse) -> T | None:
42
+ if self.data is None:
43
+ return None
44
+
45
+ initial_form = self.data.form_fulfillments.get("initial_form")
46
+
47
+ if initial_form is None:
48
+ return None
49
+ if model is FormResponse:
50
+ return cast(T, initial_form)
51
+ return TypeAdapter(model).validate_python(dict(initial_form))
52
+
53
+
54
+ class FormServiceExtensionClient(BaseExtensionClient[FormServiceExtensionSpec, FormRender]): ...
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import re
7
+ from types import NoneType
8
+ from typing import TYPE_CHECKING, Any, Self
9
+
10
+ import pydantic
11
+ from a2a.server.agent_execution.context import RequestContext
12
+ from a2a.types import Message as A2AMessage
13
+ from typing_extensions import override
14
+
15
+ from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
16
+
17
+ if TYPE_CHECKING:
18
+ from agentstack_sdk.server.context import RunContext
19
+
20
+
21
+ class LLMFulfillment(pydantic.BaseModel):
22
+ identifier: str | None = None
23
+ """
24
+ Name of the model for identification and optimization purposes. Usually corresponds to LiteLLM identifiers.
25
+ Should be the name of the provider slash name of the model as it appears in the API.
26
+ Examples: openai/gpt-4o, watsonx/ibm/granite-13b-chat-v2, ollama/mistral-small:22b
27
+ """
28
+
29
+ api_base: str
30
+ """
31
+ Base URL for an OpenAI-compatible API. It should provide at least /v1/chat/completions
32
+ """
33
+
34
+ api_key: str
35
+ """
36
+ API key to attach as a `Authorization: Bearer $api_key` header.
37
+ """
38
+
39
+ api_model: str
40
+ """
41
+ Model name to use with the /v1/chat/completions API.
42
+ """
43
+
44
+
45
+ class LLMDemand(pydantic.BaseModel):
46
+ description: str | None = None
47
+ """
48
+ Short description of how the model will be used, if multiple are requested.
49
+ Intended to be shown in the UI alongside a model picker dropdown.
50
+ """
51
+
52
+ suggested: tuple[str, ...] = ()
53
+ """
54
+ Identifiers of models recommended to be used. Usually corresponds to LiteLLM identifiers.
55
+ Should be the name of the provider slash name of the model as it appears in the API.
56
+ Examples: openai/gpt-4o, watsonx/ibm/granite-13b-chat-v2, ollama/mistral-small:22b
57
+ """
58
+
59
+
60
+ class LLMServiceExtensionParams(pydantic.BaseModel):
61
+ llm_demands: dict[str, LLMDemand]
62
+ """Model requests that the agent requires to be provided by the client."""
63
+
64
+
65
+ class LLMServiceExtensionSpec(BaseExtensionSpec[LLMServiceExtensionParams]):
66
+ URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/llm/v1"
67
+
68
+ @classmethod
69
+ def single_demand(
70
+ cls, name: str | None = None, description: str | None = None, suggested: tuple[str, ...] = ()
71
+ ) -> Self:
72
+ return cls(
73
+ params=LLMServiceExtensionParams(
74
+ llm_demands={name or "default": LLMDemand(description=description, suggested=suggested)}
75
+ )
76
+ )
77
+
78
+
79
+ class LLMServiceExtensionMetadata(pydantic.BaseModel):
80
+ llm_fulfillments: dict[str, LLMFulfillment] = {}
81
+ """Provided models corresponding to the model requests."""
82
+
83
+
84
+ class LLMServiceExtensionServer(BaseExtensionServer[LLMServiceExtensionSpec, LLMServiceExtensionMetadata]):
85
+ @override
86
+ def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
87
+ from agentstack_sdk.platform import get_platform_client
88
+
89
+ super().handle_incoming_message(message, run_context, request_context)
90
+ if not self.data:
91
+ return
92
+
93
+ for fullfilment in self.data.llm_fulfillments.values():
94
+ platform_url = str(get_platform_client().base_url)
95
+ fullfilment.api_base = re.sub("{platform_url}", platform_url, fullfilment.api_base)
96
+
97
+
98
+ class LLMServiceExtensionClient(BaseExtensionClient[LLMServiceExtensionSpec, NoneType]):
99
+ def fulfillment_metadata(self, *, llm_fulfillments: dict[str, LLMFulfillment]) -> dict[str, Any]:
100
+ return {self.spec.URI: LLMServiceExtensionMetadata(llm_fulfillments=llm_fulfillments).model_dump(mode="json")}
@@ -0,0 +1,193 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import re
7
+ from contextlib import asynccontextmanager
8
+ from types import NoneType
9
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, Self
10
+
11
+ import pydantic
12
+ from a2a.server.agent_execution.context import RequestContext
13
+ from a2a.types import Message as A2AMessage
14
+ from mcp.client.stdio import StdioServerParameters, stdio_client
15
+ from mcp.client.streamable_http import streamablehttp_client
16
+ from typing_extensions import override
17
+
18
+ from agentstack_sdk.a2a.extensions.auth.oauth.oauth import OAuthExtensionServer
19
+ from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
20
+ from agentstack_sdk.a2a.extensions.services.platform import PlatformApiExtensionServer
21
+ from agentstack_sdk.platform.client import get_platform_client
22
+ from agentstack_sdk.util.logging import logger
23
+
24
+ if TYPE_CHECKING:
25
+ from agentstack_sdk.server.context import RunContext
26
+
27
+ _TRANSPORT_TYPES = Literal["streamable_http", "stdio"]
28
+
29
+ _DEFAULT_DEMAND_NAME = "default"
30
+ _DEFAULT_ALLOWED_TRANSPORTS: list[_TRANSPORT_TYPES] = ["streamable_http"]
31
+
32
+
33
+ class StdioTransport(pydantic.BaseModel):
34
+ type: Literal["stdio"] = "stdio"
35
+
36
+ command: str
37
+ args: list[str]
38
+ env: dict[str, str] | None = None
39
+
40
+
41
+ class StreamableHTTPTransport(pydantic.BaseModel):
42
+ type: Literal["streamable_http"] = "streamable_http"
43
+
44
+ url: str
45
+ headers: dict[str, str] | None = None
46
+
47
+
48
+ MCPTransport = Annotated[StdioTransport | StreamableHTTPTransport, pydantic.Field(discriminator="type")]
49
+
50
+
51
+ class MCPFulfillment(pydantic.BaseModel):
52
+ transport: MCPTransport
53
+
54
+
55
+ class MCPDemand(pydantic.BaseModel):
56
+ description: str | None = None
57
+ """
58
+ Short description of how the server will be used, what tools should it contain, etc.
59
+ """
60
+
61
+ suggested: tuple[str, ...] = ()
62
+ """
63
+ Identifiers of servers recommended to be used. Usually corresponds to MCP StreamableHTTP URIs.
64
+ """
65
+
66
+ allowed_transports: list[_TRANSPORT_TYPES] = pydantic.Field(default_factory=lambda: _DEFAULT_ALLOWED_TRANSPORTS)
67
+ """
68
+ Transports allowed for the server. Specifying other transports will result in rejection.
69
+ """
70
+
71
+
72
+ class MCPServiceExtensionParams(pydantic.BaseModel):
73
+ mcp_demands: dict[str, MCPDemand]
74
+ """Server requests that the agent requires to be provided by the client."""
75
+
76
+
77
+ class MCPServiceExtensionSpec(BaseExtensionSpec[MCPServiceExtensionParams]):
78
+ URI: str = "https://a2a-extensions.agentstack.beeai.dev/services/mcp/v1"
79
+
80
+ @classmethod
81
+ def single_demand(
82
+ cls,
83
+ name: str = _DEFAULT_DEMAND_NAME,
84
+ description: str | None = None,
85
+ suggested: tuple[str, ...] = (),
86
+ allowed_transports: list[_TRANSPORT_TYPES] | None = None,
87
+ ) -> Self:
88
+ return cls(
89
+ params=MCPServiceExtensionParams(
90
+ mcp_demands={
91
+ name: MCPDemand(
92
+ description=description,
93
+ suggested=suggested,
94
+ allowed_transports=allowed_transports or _DEFAULT_ALLOWED_TRANSPORTS,
95
+ )
96
+ }
97
+ )
98
+ )
99
+
100
+
101
+ class MCPServiceExtensionMetadata(pydantic.BaseModel):
102
+ mcp_fulfillments: dict[str, MCPFulfillment] = {}
103
+ """Provided servers corresponding to the server requests."""
104
+
105
+
106
+ class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCPServiceExtensionMetadata]):
107
+ @override
108
+ def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
109
+ super().handle_incoming_message(message, run_context, request_context)
110
+ if not self.data:
111
+ return
112
+
113
+ platform_url = str(get_platform_client().base_url)
114
+ for fullfilment in self.data.mcp_fulfillments.values():
115
+ if fullfilment.transport.type == "streamable_http":
116
+ try:
117
+ fullfilment.transport.url = re.sub("^{platform_url}", platform_url, str(fullfilment.transport.url))
118
+ except Exception:
119
+ logger.warning("Platform URL substitution failed", exc_info=True)
120
+
121
+ @override
122
+ def parse_client_metadata(self, message: A2AMessage) -> MCPServiceExtensionMetadata | None:
123
+ metadata = super().parse_client_metadata(message)
124
+ if metadata:
125
+ for name, demand in self.spec.params.mcp_demands.items():
126
+ if not (fulfillment := metadata.mcp_fulfillments.get(name)):
127
+ continue
128
+ if fulfillment.transport.type not in demand.allowed_transports:
129
+ raise ValueError(f'Transport "{fulfillment.transport.type}" not allowed for demand "{name}"')
130
+ return metadata
131
+
132
+ def _get_oauth_server(self):
133
+ for dependency in self._dependencies.values():
134
+ if isinstance(dependency, OAuthExtensionServer):
135
+ return dependency
136
+ return None
137
+
138
+ def _get_platform_server(self):
139
+ for dependency in self._dependencies.values():
140
+ if isinstance(dependency, PlatformApiExtensionServer):
141
+ return dependency
142
+ return None
143
+
144
+ @asynccontextmanager
145
+ async def create_client(self, demand: str = _DEFAULT_DEMAND_NAME):
146
+ fulfillment = self.data.mcp_fulfillments.get(demand) if self.data else None
147
+
148
+ if not fulfillment:
149
+ yield None
150
+ return
151
+
152
+ transport = fulfillment.transport
153
+
154
+ if isinstance(transport, StdioTransport):
155
+ async with stdio_client(
156
+ server=StdioServerParameters(command=transport.command, args=transport.args, env=transport.env)
157
+ ) as (
158
+ read,
159
+ write,
160
+ ):
161
+ yield (read, write)
162
+ elif isinstance(transport, StreamableHTTPTransport):
163
+ async with streamablehttp_client(
164
+ url=transport.url,
165
+ headers=transport.headers,
166
+ auth=await self._create_auth(transport),
167
+ ) as (
168
+ read,
169
+ write,
170
+ _,
171
+ ):
172
+ yield (read, write)
173
+ else:
174
+ raise NotImplementedError("Unsupported transport")
175
+
176
+ async def _create_auth(self, transport: StreamableHTTPTransport):
177
+ platform = self._get_platform_server()
178
+ if (
179
+ platform
180
+ and platform.data
181
+ and platform.data.base_url
182
+ and transport.url.startswith(str(platform.data.base_url))
183
+ ):
184
+ return await platform.create_httpx_auth()
185
+ oauth = self._get_oauth_server()
186
+ if oauth:
187
+ return await oauth.create_httpx_auth(resource_url=pydantic.AnyUrl(transport.url))
188
+ return None
189
+
190
+
191
+ class MCPServiceExtensionClient(BaseExtensionClient[MCPServiceExtensionSpec, NoneType]):
192
+ def fulfillment_metadata(self, *, mcp_fulfillments: dict[str, MCPFulfillment]) -> dict[str, Any]:
193
+ return {self.spec.URI: MCPServiceExtensionMetadata(mcp_fulfillments=mcp_fulfillments).model_dump(mode="json")}