traia-iatp 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of traia-iatp might be problematic. Click here for more details.
- traia_iatp/README.md +368 -0
- traia_iatp/__init__.py +54 -0
- traia_iatp/cli/__init__.py +5 -0
- traia_iatp/cli/main.py +483 -0
- traia_iatp/client/__init__.py +10 -0
- traia_iatp/client/a2a_client.py +274 -0
- traia_iatp/client/crewai_a2a_tools.py +335 -0
- traia_iatp/client/d402_a2a_client.py +293 -0
- traia_iatp/client/grpc_a2a_tools.py +349 -0
- traia_iatp/client/root_path_a2a_client.py +1 -0
- traia_iatp/contracts/__init__.py +12 -0
- traia_iatp/contracts/iatp_contracts_config.py +263 -0
- traia_iatp/contracts/wallet_creator.py +255 -0
- traia_iatp/core/__init__.py +43 -0
- traia_iatp/core/models.py +172 -0
- traia_iatp/d402/__init__.py +55 -0
- traia_iatp/d402/chains.py +102 -0
- traia_iatp/d402/client.py +150 -0
- traia_iatp/d402/clients/__init__.py +7 -0
- traia_iatp/d402/clients/base.py +218 -0
- traia_iatp/d402/clients/httpx.py +219 -0
- traia_iatp/d402/common.py +114 -0
- traia_iatp/d402/encoding.py +28 -0
- traia_iatp/d402/examples/client_example.py +197 -0
- traia_iatp/d402/examples/server_example.py +171 -0
- traia_iatp/d402/facilitator.py +453 -0
- traia_iatp/d402/fastapi_middleware/__init__.py +6 -0
- traia_iatp/d402/fastapi_middleware/middleware.py +225 -0
- traia_iatp/d402/fastmcp_middleware.py +147 -0
- traia_iatp/d402/mcp_middleware.py +434 -0
- traia_iatp/d402/middleware.py +193 -0
- traia_iatp/d402/models.py +116 -0
- traia_iatp/d402/networks.py +98 -0
- traia_iatp/d402/path.py +43 -0
- traia_iatp/d402/payment_introspection.py +104 -0
- traia_iatp/d402/payment_signing.py +178 -0
- traia_iatp/d402/paywall.py +119 -0
- traia_iatp/d402/starlette_middleware.py +326 -0
- traia_iatp/d402/template.py +1 -0
- traia_iatp/d402/types.py +300 -0
- traia_iatp/mcp/__init__.py +18 -0
- traia_iatp/mcp/client.py +201 -0
- traia_iatp/mcp/d402_mcp_tool_adapter.py +361 -0
- traia_iatp/mcp/mcp_agent_template.py +481 -0
- traia_iatp/mcp/templates/Dockerfile.j2 +80 -0
- traia_iatp/mcp/templates/README.md.j2 +310 -0
- traia_iatp/mcp/templates/cursor-rules.md.j2 +520 -0
- traia_iatp/mcp/templates/deployment_params.json.j2 +20 -0
- traia_iatp/mcp/templates/docker-compose.yml.j2 +32 -0
- traia_iatp/mcp/templates/dockerignore.j2 +47 -0
- traia_iatp/mcp/templates/env.example.j2 +57 -0
- traia_iatp/mcp/templates/gitignore.j2 +77 -0
- traia_iatp/mcp/templates/mcp_health_check.py.j2 +150 -0
- traia_iatp/mcp/templates/pyproject.toml.j2 +32 -0
- traia_iatp/mcp/templates/pyrightconfig.json.j2 +22 -0
- traia_iatp/mcp/templates/run_local_docker.sh.j2 +390 -0
- traia_iatp/mcp/templates/server.py.j2 +175 -0
- traia_iatp/mcp/traia_mcp_adapter.py +543 -0
- traia_iatp/preview_diagrams.html +181 -0
- traia_iatp/registry/__init__.py +26 -0
- traia_iatp/registry/atlas_search_indexes.json +280 -0
- traia_iatp/registry/embeddings.py +298 -0
- traia_iatp/registry/iatp_search_api.py +846 -0
- traia_iatp/registry/mongodb_registry.py +771 -0
- traia_iatp/registry/readmes/ATLAS_SEARCH_INDEXES.md +252 -0
- traia_iatp/registry/readmes/ATLAS_SEARCH_SETUP.md +134 -0
- traia_iatp/registry/readmes/AUTHENTICATION_UPDATE.md +124 -0
- traia_iatp/registry/readmes/EMBEDDINGS_SETUP.md +172 -0
- traia_iatp/registry/readmes/IATP_SEARCH_API_GUIDE.md +257 -0
- traia_iatp/registry/readmes/MONGODB_X509_AUTH.md +208 -0
- traia_iatp/registry/readmes/README.md +251 -0
- traia_iatp/registry/readmes/REFACTORING_SUMMARY.md +191 -0
- traia_iatp/scripts/__init__.py +2 -0
- traia_iatp/scripts/create_wallet.py +244 -0
- traia_iatp/server/__init__.py +15 -0
- traia_iatp/server/a2a_server.py +219 -0
- traia_iatp/server/example_template_usage.py +72 -0
- traia_iatp/server/iatp_server_agent_generator.py +237 -0
- traia_iatp/server/iatp_server_template_generator.py +235 -0
- traia_iatp/server/templates/.dockerignore.j2 +48 -0
- traia_iatp/server/templates/Dockerfile.j2 +49 -0
- traia_iatp/server/templates/README.md +137 -0
- traia_iatp/server/templates/README.md.j2 +425 -0
- traia_iatp/server/templates/__init__.py +1 -0
- traia_iatp/server/templates/__main__.py.j2 +565 -0
- traia_iatp/server/templates/agent.py.j2 +94 -0
- traia_iatp/server/templates/agent_config.json.j2 +22 -0
- traia_iatp/server/templates/agent_executor.py.j2 +279 -0
- traia_iatp/server/templates/docker-compose.yml.j2 +23 -0
- traia_iatp/server/templates/env.example.j2 +84 -0
- traia_iatp/server/templates/gitignore.j2 +78 -0
- traia_iatp/server/templates/grpc_server.py.j2 +218 -0
- traia_iatp/server/templates/pyproject.toml.j2 +78 -0
- traia_iatp/server/templates/run_local_docker.sh.j2 +103 -0
- traia_iatp/server/templates/server.py.j2 +243 -0
- traia_iatp/special_agencies/__init__.py +4 -0
- traia_iatp/special_agencies/registry_search_agency.py +392 -0
- traia_iatp/utils/__init__.py +10 -0
- traia_iatp/utils/docker_utils.py +251 -0
- traia_iatp/utils/general.py +64 -0
- traia_iatp/utils/iatp_utils.py +126 -0
- traia_iatp-0.1.29.dist-info/METADATA +423 -0
- traia_iatp-0.1.29.dist-info/RECORD +107 -0
- traia_iatp-0.1.29.dist-info/WHEEL +5 -0
- traia_iatp-0.1.29.dist-info/entry_points.txt +2 -0
- traia_iatp-0.1.29.dist-info/licenses/LICENSE +21 -0
- traia_iatp-0.1.29.dist-info/top_level.txt +1 -0
traia_iatp/d402/types.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Optional, Union, Dict, Literal, List
|
|
6
|
+
from typing_extensions import (
|
|
7
|
+
TypedDict,
|
|
8
|
+
) # use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
11
|
+
from pydantic.alias_generators import to_camel
|
|
12
|
+
|
|
13
|
+
from .networks import SupportedNetworks
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Add HTTP request structure types
|
|
17
|
+
class HTTPVerbs(str, Enum):
|
|
18
|
+
GET = "GET"
|
|
19
|
+
POST = "POST"
|
|
20
|
+
PUT = "PUT"
|
|
21
|
+
DELETE = "DELETE"
|
|
22
|
+
PATCH = "PATCH"
|
|
23
|
+
OPTIONS = "OPTIONS"
|
|
24
|
+
HEAD = "HEAD"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class HTTPInputSchema(BaseModel):
|
|
28
|
+
"""Schema for HTTP request input, excluding spec and method which are handled by the middleware"""
|
|
29
|
+
|
|
30
|
+
query_params: Optional[Dict[str, str]] = None
|
|
31
|
+
body_type: Optional[
|
|
32
|
+
Literal["json", "form-data", "multipart-form-data", "text", "binary"]
|
|
33
|
+
] = None
|
|
34
|
+
body_fields: Optional[Dict[str, Any]] = None
|
|
35
|
+
header_fields: Optional[Dict[str, Any]] = None
|
|
36
|
+
|
|
37
|
+
model_config = ConfigDict(
|
|
38
|
+
alias_generator=to_camel,
|
|
39
|
+
populate_by_name=True,
|
|
40
|
+
from_attributes=True,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class HTTPRequestStructure(HTTPInputSchema):
|
|
45
|
+
"""Complete HTTP request structure including protocol type and method"""
|
|
46
|
+
|
|
47
|
+
type: Literal["http"]
|
|
48
|
+
method: HTTPVerbs
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# For now we only support HTTP, but could add MCP and OpenAPI later
|
|
52
|
+
RequestStructure = HTTPRequestStructure
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TokenAmount(BaseModel):
|
|
56
|
+
"""Represents an amount of tokens in atomic units with asset information"""
|
|
57
|
+
|
|
58
|
+
amount: str
|
|
59
|
+
asset: TokenAsset
|
|
60
|
+
|
|
61
|
+
@field_validator("amount")
|
|
62
|
+
def validate_amount(cls, v):
|
|
63
|
+
try:
|
|
64
|
+
int(v)
|
|
65
|
+
except ValueError:
|
|
66
|
+
raise ValueError("amount must be an integer encoded as a string")
|
|
67
|
+
return v
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TokenAsset(BaseModel):
|
|
71
|
+
"""Represents token asset information including EIP-712 domain data and network"""
|
|
72
|
+
|
|
73
|
+
address: str
|
|
74
|
+
decimals: int
|
|
75
|
+
eip712: EIP712Domain
|
|
76
|
+
network: Optional[str] = None # Blockchain network (e.g., "sepolia", "base-sepolia")
|
|
77
|
+
|
|
78
|
+
@field_validator("decimals")
|
|
79
|
+
def validate_decimals(cls, v):
|
|
80
|
+
if v < 0 or v > 255:
|
|
81
|
+
raise ValueError("decimals must be between 0 and 255")
|
|
82
|
+
return v
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class EIP712Domain(BaseModel):
|
|
86
|
+
"""EIP-712 domain information for token signing"""
|
|
87
|
+
|
|
88
|
+
name: str
|
|
89
|
+
version: str
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Price can be either Money (USD string) or TokenAmount
|
|
93
|
+
Money = Union[str, int] # e.g., "$0.01", 0.01, "0.001"
|
|
94
|
+
Price = Union[Money, TokenAmount]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class PaymentRequirements(BaseModel):
|
|
98
|
+
scheme: str
|
|
99
|
+
network: SupportedNetworks
|
|
100
|
+
max_amount_required: str
|
|
101
|
+
resource: str
|
|
102
|
+
description: str
|
|
103
|
+
mime_type: str
|
|
104
|
+
output_schema: Optional[Any] = None
|
|
105
|
+
pay_to: str
|
|
106
|
+
max_timeout_seconds: int
|
|
107
|
+
asset: str
|
|
108
|
+
extra: Optional[dict[str, Any]] = None
|
|
109
|
+
|
|
110
|
+
model_config = ConfigDict(
|
|
111
|
+
alias_generator=to_camel,
|
|
112
|
+
populate_by_name=True,
|
|
113
|
+
from_attributes=True,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@field_validator("max_amount_required")
|
|
117
|
+
def validate_max_amount_required(cls, v):
|
|
118
|
+
try:
|
|
119
|
+
int(v)
|
|
120
|
+
except ValueError:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
"max_amount_required must be an integer encoded as a string"
|
|
123
|
+
)
|
|
124
|
+
return v
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# Returned by a server as json alongside a 402 response code
|
|
128
|
+
class d402PaymentRequiredResponse(BaseModel):
|
|
129
|
+
d402_version: int
|
|
130
|
+
accepts: list[PaymentRequirements]
|
|
131
|
+
error: str
|
|
132
|
+
|
|
133
|
+
model_config = ConfigDict(
|
|
134
|
+
alias_generator=to_camel,
|
|
135
|
+
populate_by_name=True,
|
|
136
|
+
from_attributes=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class PullFundsAuthorization(BaseModel):
|
|
141
|
+
"""
|
|
142
|
+
Authorization data for payment header (wire format).
|
|
143
|
+
|
|
144
|
+
This structure is sent in the payment header and includes fields for:
|
|
145
|
+
- EIP-712 signature: wallet, provider, token, amount, deadline, requestPath
|
|
146
|
+
- Transport metadata: valid_after, valid_before (for payment window)
|
|
147
|
+
|
|
148
|
+
Note: Only some fields are signed (see IATPWallet.sol PULL_FUNDS_FOR_SETTLEMENT_TYPEHASH)
|
|
149
|
+
"""
|
|
150
|
+
from_: str = Field(alias="from") # Consumer's IATPWallet address
|
|
151
|
+
to: str # Provider's IATPWallet address
|
|
152
|
+
value: str # Payment amount
|
|
153
|
+
valid_after: str = Field(alias="validAfter") # Not in signature (transport only)
|
|
154
|
+
valid_before: str = Field(alias="validBefore") # Maps to 'deadline' in signature
|
|
155
|
+
request_path: str = Field(alias="requestPath") # API path (signed)
|
|
156
|
+
|
|
157
|
+
model_config = ConfigDict(
|
|
158
|
+
alias_generator=to_camel,
|
|
159
|
+
populate_by_name=True,
|
|
160
|
+
from_attributes=True,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
@field_validator("value")
|
|
164
|
+
def validate_value(cls, v):
|
|
165
|
+
try:
|
|
166
|
+
int(v)
|
|
167
|
+
except ValueError:
|
|
168
|
+
raise ValueError("value must be an integer encoded as a string")
|
|
169
|
+
return v
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class ExactPaymentPayload(BaseModel):
|
|
173
|
+
"""Payment payload with PullFundsForSettlement signature."""
|
|
174
|
+
signature: str
|
|
175
|
+
authorization: PullFundsAuthorization
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class VerifyResponse(BaseModel):
|
|
179
|
+
is_valid: bool = Field(alias="isValid")
|
|
180
|
+
invalid_reason: Optional[str] = Field(None, alias="invalidReason")
|
|
181
|
+
payer: Optional[str]
|
|
182
|
+
payment_uuid: Optional[str] = Field(None, alias="paymentUuid") # Unique payment identifier from facilitator
|
|
183
|
+
facilitator_fee_percent: Optional[int] = Field(250, alias="facilitatorFeePercent") # Fee percent from facilitator (default 2.5% = 250 basis points)
|
|
184
|
+
|
|
185
|
+
model_config = ConfigDict(
|
|
186
|
+
alias_generator=to_camel,
|
|
187
|
+
populate_by_name=True,
|
|
188
|
+
from_attributes=True,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class SettleResponse(BaseModel):
|
|
193
|
+
success: bool
|
|
194
|
+
error_reason: Optional[str] = None
|
|
195
|
+
transaction: Optional[str] = None
|
|
196
|
+
network: Optional[str] = None
|
|
197
|
+
payer: Optional[str] = None
|
|
198
|
+
|
|
199
|
+
model_config = ConfigDict(
|
|
200
|
+
alias_generator=to_camel,
|
|
201
|
+
populate_by_name=True,
|
|
202
|
+
from_attributes=True,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# Union of payloads for each scheme
|
|
207
|
+
SchemePayloads = ExactPaymentPayload
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class PaymentPayload(BaseModel):
|
|
211
|
+
d402_version: int
|
|
212
|
+
scheme: str
|
|
213
|
+
network: str
|
|
214
|
+
payload: SchemePayloads
|
|
215
|
+
|
|
216
|
+
model_config = ConfigDict(
|
|
217
|
+
alias_generator=to_camel,
|
|
218
|
+
populate_by_name=True,
|
|
219
|
+
from_attributes=True,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class D402Headers(BaseModel):
|
|
224
|
+
x_payment: str
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class UnsupportedSchemeException(Exception):
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class PaywallConfig(TypedDict, total=False):
|
|
232
|
+
"""Configuration for paywall UI customization"""
|
|
233
|
+
|
|
234
|
+
cdp_client_key: str
|
|
235
|
+
app_name: str
|
|
236
|
+
app_logo: str
|
|
237
|
+
session_token_endpoint: str
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class DiscoveredResource(BaseModel):
|
|
241
|
+
"""A discovery resource represents a discoverable resource in the D402 ecosystem."""
|
|
242
|
+
|
|
243
|
+
resource: str
|
|
244
|
+
type: str = Field(..., pattern="^http$") # Currently only supports 'http'
|
|
245
|
+
d402_version: int = Field(..., alias="d402Version")
|
|
246
|
+
accepts: List["PaymentRequirements"]
|
|
247
|
+
last_updated: datetime = Field(
|
|
248
|
+
...,
|
|
249
|
+
alias="lastUpdated",
|
|
250
|
+
description="ISO 8601 formatted datetime string with UTC timezone (e.g. 2025-08-09T01:07:04.005Z)",
|
|
251
|
+
)
|
|
252
|
+
metadata: Optional[dict] = None
|
|
253
|
+
|
|
254
|
+
model_config = ConfigDict(
|
|
255
|
+
alias_generator=to_camel,
|
|
256
|
+
populate_by_name=True,
|
|
257
|
+
from_attributes=True,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class ListDiscoveryResourcesRequest(BaseModel):
|
|
262
|
+
"""Request parameters for listing discovery resources."""
|
|
263
|
+
|
|
264
|
+
type: Optional[str] = None
|
|
265
|
+
limit: Optional[int] = None
|
|
266
|
+
offset: Optional[int] = None
|
|
267
|
+
|
|
268
|
+
model_config = ConfigDict(
|
|
269
|
+
alias_generator=to_camel,
|
|
270
|
+
populate_by_name=True,
|
|
271
|
+
from_attributes=True,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class DiscoveryResourcesPagination(BaseModel):
|
|
276
|
+
"""Pagination information for discovery resources responses."""
|
|
277
|
+
|
|
278
|
+
limit: int
|
|
279
|
+
offset: int
|
|
280
|
+
total: int
|
|
281
|
+
|
|
282
|
+
model_config = ConfigDict(
|
|
283
|
+
alias_generator=to_camel,
|
|
284
|
+
populate_by_name=True,
|
|
285
|
+
from_attributes=True,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class ListDiscoveryResourcesResponse(BaseModel):
|
|
290
|
+
"""Response from the discovery resources endpoint."""
|
|
291
|
+
|
|
292
|
+
d402_version: int = Field(..., alias="d402Version")
|
|
293
|
+
items: List[DiscoveredResource]
|
|
294
|
+
pagination: DiscoveryResourcesPagination
|
|
295
|
+
|
|
296
|
+
model_config = ConfigDict(
|
|
297
|
+
alias_generator=to_camel,
|
|
298
|
+
populate_by_name=True,
|
|
299
|
+
from_attributes=True,
|
|
300
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""MCP (Model Context Protocol) integration module."""
|
|
2
|
+
|
|
3
|
+
from .client import MCPClient
|
|
4
|
+
from .mcp_agent_template import MCPServerConfig, MCPAgentBuilder, run_with_mcp_tools, MCPServerInfo
|
|
5
|
+
from .traia_mcp_adapter import TraiaMCPAdapter, create_mcp_adapter
|
|
6
|
+
from .d402_mcp_tool_adapter import D402MCPToolAdapter, create_d402_mcp_adapter
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"MCPClient",
|
|
10
|
+
"MCPServerConfig",
|
|
11
|
+
"MCPAgentBuilder",
|
|
12
|
+
"run_with_mcp_tools",
|
|
13
|
+
"MCPServerInfo",
|
|
14
|
+
"TraiaMCPAdapter",
|
|
15
|
+
"create_mcp_adapter",
|
|
16
|
+
"D402MCPToolAdapter",
|
|
17
|
+
"create_d402_mcp_adapter",
|
|
18
|
+
]
|
traia_iatp/mcp/client.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""MCP client wrapper for connecting to MCP servers with streamable-http support."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, Dict, Optional, List, AsyncIterator
|
|
6
|
+
import httpx
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
from ..core.models import MCPServer, MCPServerType
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MCPClient:
|
|
15
|
+
"""Wrapper for MCP client connections with streamable-http support."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, mcp_server: MCPServer):
|
|
18
|
+
self.mcp_server = mcp_server
|
|
19
|
+
self._available_tools: List[Dict[str, Any]] = []
|
|
20
|
+
self._http_client: Optional[httpx.AsyncClient] = None
|
|
21
|
+
self._connected = False
|
|
22
|
+
|
|
23
|
+
async def connect(self) -> None:
|
|
24
|
+
"""Connect to the MCP server using streamable-http."""
|
|
25
|
+
try:
|
|
26
|
+
if self.mcp_server.server_type != MCPServerType.STREAMABLE_HTTP:
|
|
27
|
+
raise ValueError(f"Only streamable-http is supported, got: {self.mcp_server.server_type}")
|
|
28
|
+
|
|
29
|
+
await self._connect_streamable_http()
|
|
30
|
+
|
|
31
|
+
except Exception as e:
|
|
32
|
+
logger.error(f"Failed to connect to MCP server {self.mcp_server.name}: {e}")
|
|
33
|
+
raise
|
|
34
|
+
|
|
35
|
+
async def _connect_streamable_http(self) -> None:
|
|
36
|
+
"""Connect using streamable-http for real-time updates."""
|
|
37
|
+
url = str(self.mcp_server.url)
|
|
38
|
+
|
|
39
|
+
logger.info(f"Connecting to MCP server {self.mcp_server.name} via streamable-http at {url}")
|
|
40
|
+
|
|
41
|
+
# Initialize HTTP client for persistent connection
|
|
42
|
+
self._http_client = httpx.AsyncClient(
|
|
43
|
+
timeout=httpx.Timeout(30.0, connect=10.0),
|
|
44
|
+
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
45
|
+
http2=True # Enable HTTP/2 for better streaming support
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Test connection and get available tools
|
|
49
|
+
try:
|
|
50
|
+
response = await self._http_client.get(f"{url}/tools")
|
|
51
|
+
response.raise_for_status()
|
|
52
|
+
tools_data = response.json()
|
|
53
|
+
self._available_tools = tools_data.get("tools", [])
|
|
54
|
+
self._connected = True
|
|
55
|
+
logger.info(f"Connected to {self.mcp_server.name}, found {len(self._available_tools)} tools")
|
|
56
|
+
|
|
57
|
+
# Update capabilities in the MCP server model
|
|
58
|
+
self.mcp_server.capabilities = [tool["name"] for tool in self._available_tools]
|
|
59
|
+
except Exception as e:
|
|
60
|
+
await self._http_client.aclose()
|
|
61
|
+
self._http_client = None
|
|
62
|
+
raise RuntimeError(f"Failed to connect to MCP server: {e}")
|
|
63
|
+
|
|
64
|
+
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
|
65
|
+
"""Call a tool on the MCP server."""
|
|
66
|
+
if not self._connected or not self._http_client:
|
|
67
|
+
# Reconnect if needed
|
|
68
|
+
await self.connect()
|
|
69
|
+
|
|
70
|
+
# Find the tool
|
|
71
|
+
tool = next((t for t in self._available_tools if t["name"] == tool_name), None)
|
|
72
|
+
if not tool:
|
|
73
|
+
raise ValueError(f"Tool {tool_name} not found")
|
|
74
|
+
|
|
75
|
+
# Call the tool
|
|
76
|
+
url = str(self.mcp_server.url).rstrip('/') + '/call'
|
|
77
|
+
|
|
78
|
+
response = await self._http_client.post(
|
|
79
|
+
url,
|
|
80
|
+
json={"name": tool_name, "input": arguments}
|
|
81
|
+
)
|
|
82
|
+
response.raise_for_status()
|
|
83
|
+
return response.json()
|
|
84
|
+
|
|
85
|
+
async def call_tool_streaming(self, tool_name: str, arguments: Dict[str, Any]) -> AsyncIterator[Any]:
|
|
86
|
+
"""Call a tool with streaming response support."""
|
|
87
|
+
if not self._connected:
|
|
88
|
+
await self.connect()
|
|
89
|
+
|
|
90
|
+
# Stream the tool call response
|
|
91
|
+
async for chunk in self._stream_tool_call(tool_name, arguments):
|
|
92
|
+
yield chunk
|
|
93
|
+
|
|
94
|
+
async def _stream_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> AsyncIterator[Any]:
|
|
95
|
+
"""Stream a tool call response for streamable-http connections."""
|
|
96
|
+
if not self._http_client:
|
|
97
|
+
raise RuntimeError("HTTP client not initialized")
|
|
98
|
+
|
|
99
|
+
url = str(self.mcp_server.url).rstrip('/') + '/call'
|
|
100
|
+
|
|
101
|
+
# Make streaming request
|
|
102
|
+
async with self._http_client.stream(
|
|
103
|
+
"POST",
|
|
104
|
+
url,
|
|
105
|
+
json={"name": tool_name, "input": arguments},
|
|
106
|
+
headers={"Accept": "text/event-stream"}
|
|
107
|
+
) as response:
|
|
108
|
+
async for line in response.aiter_lines():
|
|
109
|
+
if line.startswith("data: "):
|
|
110
|
+
data = line[6:] # Remove "data: " prefix
|
|
111
|
+
if data:
|
|
112
|
+
try:
|
|
113
|
+
yield json.loads(data)
|
|
114
|
+
except json.JSONDecodeError:
|
|
115
|
+
logger.warning(f"Failed to parse streaming data: {data}")
|
|
116
|
+
|
|
117
|
+
async def list_tools(self) -> List[Dict[str, Any]]:
|
|
118
|
+
"""List available tools."""
|
|
119
|
+
if not self._connected:
|
|
120
|
+
await self.connect()
|
|
121
|
+
|
|
122
|
+
return self._available_tools
|
|
123
|
+
|
|
124
|
+
async def disconnect(self) -> None:
|
|
125
|
+
"""Disconnect from the MCP server."""
|
|
126
|
+
self._connected = False
|
|
127
|
+
self._available_tools = []
|
|
128
|
+
|
|
129
|
+
if self._http_client:
|
|
130
|
+
await self._http_client.aclose()
|
|
131
|
+
self._http_client = None
|
|
132
|
+
|
|
133
|
+
async def health_check(self) -> bool:
|
|
134
|
+
"""Check if the MCP server connection is healthy."""
|
|
135
|
+
try:
|
|
136
|
+
if not self._connected or not self._http_client:
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
# Try to ping the server
|
|
140
|
+
response = await self._http_client.get(f"{self.mcp_server.url}/health")
|
|
141
|
+
return response.status_code == 200
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.warning(f"Health check failed for {self.mcp_server.name}: {e}")
|
|
144
|
+
return False
|
|
145
|
+
|
|
146
|
+
async def __aenter__(self):
|
|
147
|
+
"""Async context manager entry."""
|
|
148
|
+
await self.connect()
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
152
|
+
"""Async context manager exit."""
|
|
153
|
+
await self.disconnect()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class MCPToolWrapper:
|
|
157
|
+
"""Wrapper to expose MCP tools as CrewAI-compatible tools with connection pooling."""
|
|
158
|
+
|
|
159
|
+
# Class-level connection pool
|
|
160
|
+
_connection_pool: Dict[str, MCPClient] = {}
|
|
161
|
+
_pool_lock = asyncio.Lock()
|
|
162
|
+
|
|
163
|
+
def __init__(self, mcp_server: MCPServer, tool_name: str, tool_description: str):
|
|
164
|
+
self.mcp_server = mcp_server
|
|
165
|
+
self.tool_name = tool_name
|
|
166
|
+
self.description = tool_description
|
|
167
|
+
self.name = tool_name
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
async def get_or_create_client(cls, mcp_server: MCPServer) -> MCPClient:
|
|
171
|
+
"""Get or create a client from the connection pool."""
|
|
172
|
+
server_key = f"{mcp_server.name}:{mcp_server.url}"
|
|
173
|
+
|
|
174
|
+
async with cls._pool_lock:
|
|
175
|
+
if server_key not in cls._connection_pool:
|
|
176
|
+
# Create new client
|
|
177
|
+
client = MCPClient(mcp_server)
|
|
178
|
+
await client.connect()
|
|
179
|
+
cls._connection_pool[server_key] = client
|
|
180
|
+
else:
|
|
181
|
+
# Check if existing client is healthy
|
|
182
|
+
client = cls._connection_pool[server_key]
|
|
183
|
+
if not await client.health_check():
|
|
184
|
+
# Reconnect if unhealthy
|
|
185
|
+
await client.disconnect()
|
|
186
|
+
await client.connect()
|
|
187
|
+
|
|
188
|
+
return cls._connection_pool[server_key]
|
|
189
|
+
|
|
190
|
+
async def __call__(self, **kwargs) -> Any:
|
|
191
|
+
"""Execute the MCP tool using pooled connection."""
|
|
192
|
+
client = await self.get_or_create_client(self.mcp_server)
|
|
193
|
+
return await client.call_tool(self.tool_name, kwargs)
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
async def cleanup_pool(cls):
|
|
197
|
+
"""Clean up all connections in the pool."""
|
|
198
|
+
async with cls._pool_lock:
|
|
199
|
+
for client in cls._connection_pool.values():
|
|
200
|
+
await client.disconnect()
|
|
201
|
+
cls._connection_pool.clear()
|