traia-iatp 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of traia-iatp might be problematic. Click here for more details.

Files changed (107) hide show
  1. traia_iatp/README.md +368 -0
  2. traia_iatp/__init__.py +54 -0
  3. traia_iatp/cli/__init__.py +5 -0
  4. traia_iatp/cli/main.py +483 -0
  5. traia_iatp/client/__init__.py +10 -0
  6. traia_iatp/client/a2a_client.py +274 -0
  7. traia_iatp/client/crewai_a2a_tools.py +335 -0
  8. traia_iatp/client/d402_a2a_client.py +293 -0
  9. traia_iatp/client/grpc_a2a_tools.py +349 -0
  10. traia_iatp/client/root_path_a2a_client.py +1 -0
  11. traia_iatp/contracts/__init__.py +12 -0
  12. traia_iatp/contracts/iatp_contracts_config.py +263 -0
  13. traia_iatp/contracts/wallet_creator.py +255 -0
  14. traia_iatp/core/__init__.py +43 -0
  15. traia_iatp/core/models.py +172 -0
  16. traia_iatp/d402/__init__.py +55 -0
  17. traia_iatp/d402/chains.py +102 -0
  18. traia_iatp/d402/client.py +150 -0
  19. traia_iatp/d402/clients/__init__.py +7 -0
  20. traia_iatp/d402/clients/base.py +218 -0
  21. traia_iatp/d402/clients/httpx.py +219 -0
  22. traia_iatp/d402/common.py +114 -0
  23. traia_iatp/d402/encoding.py +28 -0
  24. traia_iatp/d402/examples/client_example.py +197 -0
  25. traia_iatp/d402/examples/server_example.py +171 -0
  26. traia_iatp/d402/facilitator.py +453 -0
  27. traia_iatp/d402/fastapi_middleware/__init__.py +6 -0
  28. traia_iatp/d402/fastapi_middleware/middleware.py +225 -0
  29. traia_iatp/d402/fastmcp_middleware.py +147 -0
  30. traia_iatp/d402/mcp_middleware.py +434 -0
  31. traia_iatp/d402/middleware.py +193 -0
  32. traia_iatp/d402/models.py +116 -0
  33. traia_iatp/d402/networks.py +98 -0
  34. traia_iatp/d402/path.py +43 -0
  35. traia_iatp/d402/payment_introspection.py +104 -0
  36. traia_iatp/d402/payment_signing.py +178 -0
  37. traia_iatp/d402/paywall.py +119 -0
  38. traia_iatp/d402/starlette_middleware.py +326 -0
  39. traia_iatp/d402/template.py +1 -0
  40. traia_iatp/d402/types.py +300 -0
  41. traia_iatp/mcp/__init__.py +18 -0
  42. traia_iatp/mcp/client.py +201 -0
  43. traia_iatp/mcp/d402_mcp_tool_adapter.py +361 -0
  44. traia_iatp/mcp/mcp_agent_template.py +481 -0
  45. traia_iatp/mcp/templates/Dockerfile.j2 +80 -0
  46. traia_iatp/mcp/templates/README.md.j2 +310 -0
  47. traia_iatp/mcp/templates/cursor-rules.md.j2 +520 -0
  48. traia_iatp/mcp/templates/deployment_params.json.j2 +20 -0
  49. traia_iatp/mcp/templates/docker-compose.yml.j2 +32 -0
  50. traia_iatp/mcp/templates/dockerignore.j2 +47 -0
  51. traia_iatp/mcp/templates/env.example.j2 +57 -0
  52. traia_iatp/mcp/templates/gitignore.j2 +77 -0
  53. traia_iatp/mcp/templates/mcp_health_check.py.j2 +150 -0
  54. traia_iatp/mcp/templates/pyproject.toml.j2 +32 -0
  55. traia_iatp/mcp/templates/pyrightconfig.json.j2 +22 -0
  56. traia_iatp/mcp/templates/run_local_docker.sh.j2 +390 -0
  57. traia_iatp/mcp/templates/server.py.j2 +175 -0
  58. traia_iatp/mcp/traia_mcp_adapter.py +543 -0
  59. traia_iatp/preview_diagrams.html +181 -0
  60. traia_iatp/registry/__init__.py +26 -0
  61. traia_iatp/registry/atlas_search_indexes.json +280 -0
  62. traia_iatp/registry/embeddings.py +298 -0
  63. traia_iatp/registry/iatp_search_api.py +846 -0
  64. traia_iatp/registry/mongodb_registry.py +771 -0
  65. traia_iatp/registry/readmes/ATLAS_SEARCH_INDEXES.md +252 -0
  66. traia_iatp/registry/readmes/ATLAS_SEARCH_SETUP.md +134 -0
  67. traia_iatp/registry/readmes/AUTHENTICATION_UPDATE.md +124 -0
  68. traia_iatp/registry/readmes/EMBEDDINGS_SETUP.md +172 -0
  69. traia_iatp/registry/readmes/IATP_SEARCH_API_GUIDE.md +257 -0
  70. traia_iatp/registry/readmes/MONGODB_X509_AUTH.md +208 -0
  71. traia_iatp/registry/readmes/README.md +251 -0
  72. traia_iatp/registry/readmes/REFACTORING_SUMMARY.md +191 -0
  73. traia_iatp/scripts/__init__.py +2 -0
  74. traia_iatp/scripts/create_wallet.py +244 -0
  75. traia_iatp/server/__init__.py +15 -0
  76. traia_iatp/server/a2a_server.py +219 -0
  77. traia_iatp/server/example_template_usage.py +72 -0
  78. traia_iatp/server/iatp_server_agent_generator.py +237 -0
  79. traia_iatp/server/iatp_server_template_generator.py +235 -0
  80. traia_iatp/server/templates/.dockerignore.j2 +48 -0
  81. traia_iatp/server/templates/Dockerfile.j2 +49 -0
  82. traia_iatp/server/templates/README.md +137 -0
  83. traia_iatp/server/templates/README.md.j2 +425 -0
  84. traia_iatp/server/templates/__init__.py +1 -0
  85. traia_iatp/server/templates/__main__.py.j2 +565 -0
  86. traia_iatp/server/templates/agent.py.j2 +94 -0
  87. traia_iatp/server/templates/agent_config.json.j2 +22 -0
  88. traia_iatp/server/templates/agent_executor.py.j2 +279 -0
  89. traia_iatp/server/templates/docker-compose.yml.j2 +23 -0
  90. traia_iatp/server/templates/env.example.j2 +84 -0
  91. traia_iatp/server/templates/gitignore.j2 +78 -0
  92. traia_iatp/server/templates/grpc_server.py.j2 +218 -0
  93. traia_iatp/server/templates/pyproject.toml.j2 +78 -0
  94. traia_iatp/server/templates/run_local_docker.sh.j2 +103 -0
  95. traia_iatp/server/templates/server.py.j2 +243 -0
  96. traia_iatp/special_agencies/__init__.py +4 -0
  97. traia_iatp/special_agencies/registry_search_agency.py +392 -0
  98. traia_iatp/utils/__init__.py +10 -0
  99. traia_iatp/utils/docker_utils.py +251 -0
  100. traia_iatp/utils/general.py +64 -0
  101. traia_iatp/utils/iatp_utils.py +126 -0
  102. traia_iatp-0.1.29.dist-info/METADATA +423 -0
  103. traia_iatp-0.1.29.dist-info/RECORD +107 -0
  104. traia_iatp-0.1.29.dist-info/WHEEL +5 -0
  105. traia_iatp-0.1.29.dist-info/entry_points.txt +2 -0
  106. traia_iatp-0.1.29.dist-info/licenses/LICENSE +21 -0
  107. traia_iatp-0.1.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,300 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import Any, Optional, Union, Dict, Literal, List
6
+ from typing_extensions import (
7
+ TypedDict,
8
+ ) # use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12
9
+
10
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
11
+ from pydantic.alias_generators import to_camel
12
+
13
+ from .networks import SupportedNetworks
14
+
15
+
16
+ # Add HTTP request structure types
17
+ class HTTPVerbs(str, Enum):
18
+ GET = "GET"
19
+ POST = "POST"
20
+ PUT = "PUT"
21
+ DELETE = "DELETE"
22
+ PATCH = "PATCH"
23
+ OPTIONS = "OPTIONS"
24
+ HEAD = "HEAD"
25
+
26
+
27
+ class HTTPInputSchema(BaseModel):
28
+ """Schema for HTTP request input, excluding spec and method which are handled by the middleware"""
29
+
30
+ query_params: Optional[Dict[str, str]] = None
31
+ body_type: Optional[
32
+ Literal["json", "form-data", "multipart-form-data", "text", "binary"]
33
+ ] = None
34
+ body_fields: Optional[Dict[str, Any]] = None
35
+ header_fields: Optional[Dict[str, Any]] = None
36
+
37
+ model_config = ConfigDict(
38
+ alias_generator=to_camel,
39
+ populate_by_name=True,
40
+ from_attributes=True,
41
+ )
42
+
43
+
44
+ class HTTPRequestStructure(HTTPInputSchema):
45
+ """Complete HTTP request structure including protocol type and method"""
46
+
47
+ type: Literal["http"]
48
+ method: HTTPVerbs
49
+
50
+
51
+ # For now we only support HTTP, but could add MCP and OpenAPI later
52
+ RequestStructure = HTTPRequestStructure
53
+
54
+
55
+ class TokenAmount(BaseModel):
56
+ """Represents an amount of tokens in atomic units with asset information"""
57
+
58
+ amount: str
59
+ asset: TokenAsset
60
+
61
+ @field_validator("amount")
62
+ def validate_amount(cls, v):
63
+ try:
64
+ int(v)
65
+ except ValueError:
66
+ raise ValueError("amount must be an integer encoded as a string")
67
+ return v
68
+
69
+
70
+ class TokenAsset(BaseModel):
71
+ """Represents token asset information including EIP-712 domain data and network"""
72
+
73
+ address: str
74
+ decimals: int
75
+ eip712: EIP712Domain
76
+ network: Optional[str] = None # Blockchain network (e.g., "sepolia", "base-sepolia")
77
+
78
+ @field_validator("decimals")
79
+ def validate_decimals(cls, v):
80
+ if v < 0 or v > 255:
81
+ raise ValueError("decimals must be between 0 and 255")
82
+ return v
83
+
84
+
85
+ class EIP712Domain(BaseModel):
86
+ """EIP-712 domain information for token signing"""
87
+
88
+ name: str
89
+ version: str
90
+
91
+
92
+ # Price can be either Money (USD string) or TokenAmount
93
+ Money = Union[str, int] # e.g., "$0.01", 0.01, "0.001"
94
+ Price = Union[Money, TokenAmount]
95
+
96
+
97
+ class PaymentRequirements(BaseModel):
98
+ scheme: str
99
+ network: SupportedNetworks
100
+ max_amount_required: str
101
+ resource: str
102
+ description: str
103
+ mime_type: str
104
+ output_schema: Optional[Any] = None
105
+ pay_to: str
106
+ max_timeout_seconds: int
107
+ asset: str
108
+ extra: Optional[dict[str, Any]] = None
109
+
110
+ model_config = ConfigDict(
111
+ alias_generator=to_camel,
112
+ populate_by_name=True,
113
+ from_attributes=True,
114
+ )
115
+
116
+ @field_validator("max_amount_required")
117
+ def validate_max_amount_required(cls, v):
118
+ try:
119
+ int(v)
120
+ except ValueError:
121
+ raise ValueError(
122
+ "max_amount_required must be an integer encoded as a string"
123
+ )
124
+ return v
125
+
126
+
127
+ # Returned by a server as json alongside a 402 response code
128
+ class d402PaymentRequiredResponse(BaseModel):
129
+ d402_version: int
130
+ accepts: list[PaymentRequirements]
131
+ error: str
132
+
133
+ model_config = ConfigDict(
134
+ alias_generator=to_camel,
135
+ populate_by_name=True,
136
+ from_attributes=True,
137
+ )
138
+
139
+
140
+ class PullFundsAuthorization(BaseModel):
141
+ """
142
+ Authorization data for payment header (wire format).
143
+
144
+ This structure is sent in the payment header and includes fields for:
145
+ - EIP-712 signature: wallet, provider, token, amount, deadline, requestPath
146
+ - Transport metadata: valid_after, valid_before (for payment window)
147
+
148
+ Note: Only some fields are signed (see IATPWallet.sol PULL_FUNDS_FOR_SETTLEMENT_TYPEHASH)
149
+ """
150
+ from_: str = Field(alias="from") # Consumer's IATPWallet address
151
+ to: str # Provider's IATPWallet address
152
+ value: str # Payment amount
153
+ valid_after: str = Field(alias="validAfter") # Not in signature (transport only)
154
+ valid_before: str = Field(alias="validBefore") # Maps to 'deadline' in signature
155
+ request_path: str = Field(alias="requestPath") # API path (signed)
156
+
157
+ model_config = ConfigDict(
158
+ alias_generator=to_camel,
159
+ populate_by_name=True,
160
+ from_attributes=True,
161
+ )
162
+
163
+ @field_validator("value")
164
+ def validate_value(cls, v):
165
+ try:
166
+ int(v)
167
+ except ValueError:
168
+ raise ValueError("value must be an integer encoded as a string")
169
+ return v
170
+
171
+
172
+ class ExactPaymentPayload(BaseModel):
173
+ """Payment payload with PullFundsForSettlement signature."""
174
+ signature: str
175
+ authorization: PullFundsAuthorization
176
+
177
+
178
+ class VerifyResponse(BaseModel):
179
+ is_valid: bool = Field(alias="isValid")
180
+ invalid_reason: Optional[str] = Field(None, alias="invalidReason")
181
+ payer: Optional[str]
182
+ payment_uuid: Optional[str] = Field(None, alias="paymentUuid") # Unique payment identifier from facilitator
183
+ facilitator_fee_percent: Optional[int] = Field(250, alias="facilitatorFeePercent") # Fee percent from facilitator (default 2.5% = 250 basis points)
184
+
185
+ model_config = ConfigDict(
186
+ alias_generator=to_camel,
187
+ populate_by_name=True,
188
+ from_attributes=True,
189
+ )
190
+
191
+
192
+ class SettleResponse(BaseModel):
193
+ success: bool
194
+ error_reason: Optional[str] = None
195
+ transaction: Optional[str] = None
196
+ network: Optional[str] = None
197
+ payer: Optional[str] = None
198
+
199
+ model_config = ConfigDict(
200
+ alias_generator=to_camel,
201
+ populate_by_name=True,
202
+ from_attributes=True,
203
+ )
204
+
205
+
206
+ # Union of payloads for each scheme
207
+ SchemePayloads = ExactPaymentPayload
208
+
209
+
210
+ class PaymentPayload(BaseModel):
211
+ d402_version: int
212
+ scheme: str
213
+ network: str
214
+ payload: SchemePayloads
215
+
216
+ model_config = ConfigDict(
217
+ alias_generator=to_camel,
218
+ populate_by_name=True,
219
+ from_attributes=True,
220
+ )
221
+
222
+
223
+ class D402Headers(BaseModel):
224
+ x_payment: str
225
+
226
+
227
+ class UnsupportedSchemeException(Exception):
228
+ pass
229
+
230
+
231
+ class PaywallConfig(TypedDict, total=False):
232
+ """Configuration for paywall UI customization"""
233
+
234
+ cdp_client_key: str
235
+ app_name: str
236
+ app_logo: str
237
+ session_token_endpoint: str
238
+
239
+
240
+ class DiscoveredResource(BaseModel):
241
+ """A discovery resource represents a discoverable resource in the D402 ecosystem."""
242
+
243
+ resource: str
244
+ type: str = Field(..., pattern="^http$") # Currently only supports 'http'
245
+ d402_version: int = Field(..., alias="d402Version")
246
+ accepts: List["PaymentRequirements"]
247
+ last_updated: datetime = Field(
248
+ ...,
249
+ alias="lastUpdated",
250
+ description="ISO 8601 formatted datetime string with UTC timezone (e.g. 2025-08-09T01:07:04.005Z)",
251
+ )
252
+ metadata: Optional[dict] = None
253
+
254
+ model_config = ConfigDict(
255
+ alias_generator=to_camel,
256
+ populate_by_name=True,
257
+ from_attributes=True,
258
+ )
259
+
260
+
261
+ class ListDiscoveryResourcesRequest(BaseModel):
262
+ """Request parameters for listing discovery resources."""
263
+
264
+ type: Optional[str] = None
265
+ limit: Optional[int] = None
266
+ offset: Optional[int] = None
267
+
268
+ model_config = ConfigDict(
269
+ alias_generator=to_camel,
270
+ populate_by_name=True,
271
+ from_attributes=True,
272
+ )
273
+
274
+
275
+ class DiscoveryResourcesPagination(BaseModel):
276
+ """Pagination information for discovery resources responses."""
277
+
278
+ limit: int
279
+ offset: int
280
+ total: int
281
+
282
+ model_config = ConfigDict(
283
+ alias_generator=to_camel,
284
+ populate_by_name=True,
285
+ from_attributes=True,
286
+ )
287
+
288
+
289
+ class ListDiscoveryResourcesResponse(BaseModel):
290
+ """Response from the discovery resources endpoint."""
291
+
292
+ d402_version: int = Field(..., alias="d402Version")
293
+ items: List[DiscoveredResource]
294
+ pagination: DiscoveryResourcesPagination
295
+
296
+ model_config = ConfigDict(
297
+ alias_generator=to_camel,
298
+ populate_by_name=True,
299
+ from_attributes=True,
300
+ )
@@ -0,0 +1,18 @@
1
+ """MCP (Model Context Protocol) integration module."""
2
+
3
+ from .client import MCPClient
4
+ from .mcp_agent_template import MCPServerConfig, MCPAgentBuilder, run_with_mcp_tools, MCPServerInfo
5
+ from .traia_mcp_adapter import TraiaMCPAdapter, create_mcp_adapter
6
+ from .d402_mcp_tool_adapter import D402MCPToolAdapter, create_d402_mcp_adapter
7
+
8
+ __all__ = [
9
+ "MCPClient",
10
+ "MCPServerConfig",
11
+ "MCPAgentBuilder",
12
+ "run_with_mcp_tools",
13
+ "MCPServerInfo",
14
+ "TraiaMCPAdapter",
15
+ "create_mcp_adapter",
16
+ "D402MCPToolAdapter",
17
+ "create_d402_mcp_adapter",
18
+ ]
@@ -0,0 +1,201 @@
1
+ """MCP client wrapper for connecting to MCP servers with streamable-http support."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any, Dict, Optional, List, AsyncIterator
6
+ import httpx
7
+ import json
8
+
9
+ from ..core.models import MCPServer, MCPServerType
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class MCPClient:
15
+ """Wrapper for MCP client connections with streamable-http support."""
16
+
17
+ def __init__(self, mcp_server: MCPServer):
18
+ self.mcp_server = mcp_server
19
+ self._available_tools: List[Dict[str, Any]] = []
20
+ self._http_client: Optional[httpx.AsyncClient] = None
21
+ self._connected = False
22
+
23
+ async def connect(self) -> None:
24
+ """Connect to the MCP server using streamable-http."""
25
+ try:
26
+ if self.mcp_server.server_type != MCPServerType.STREAMABLE_HTTP:
27
+ raise ValueError(f"Only streamable-http is supported, got: {self.mcp_server.server_type}")
28
+
29
+ await self._connect_streamable_http()
30
+
31
+ except Exception as e:
32
+ logger.error(f"Failed to connect to MCP server {self.mcp_server.name}: {e}")
33
+ raise
34
+
35
+ async def _connect_streamable_http(self) -> None:
36
+ """Connect using streamable-http for real-time updates."""
37
+ url = str(self.mcp_server.url)
38
+
39
+ logger.info(f"Connecting to MCP server {self.mcp_server.name} via streamable-http at {url}")
40
+
41
+ # Initialize HTTP client for persistent connection
42
+ self._http_client = httpx.AsyncClient(
43
+ timeout=httpx.Timeout(30.0, connect=10.0),
44
+ limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
45
+ http2=True # Enable HTTP/2 for better streaming support
46
+ )
47
+
48
+ # Test connection and get available tools
49
+ try:
50
+ response = await self._http_client.get(f"{url}/tools")
51
+ response.raise_for_status()
52
+ tools_data = response.json()
53
+ self._available_tools = tools_data.get("tools", [])
54
+ self._connected = True
55
+ logger.info(f"Connected to {self.mcp_server.name}, found {len(self._available_tools)} tools")
56
+
57
+ # Update capabilities in the MCP server model
58
+ self.mcp_server.capabilities = [tool["name"] for tool in self._available_tools]
59
+ except Exception as e:
60
+ await self._http_client.aclose()
61
+ self._http_client = None
62
+ raise RuntimeError(f"Failed to connect to MCP server: {e}")
63
+
64
+ async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
65
+ """Call a tool on the MCP server."""
66
+ if not self._connected or not self._http_client:
67
+ # Reconnect if needed
68
+ await self.connect()
69
+
70
+ # Find the tool
71
+ tool = next((t for t in self._available_tools if t["name"] == tool_name), None)
72
+ if not tool:
73
+ raise ValueError(f"Tool {tool_name} not found")
74
+
75
+ # Call the tool
76
+ url = str(self.mcp_server.url).rstrip('/') + '/call'
77
+
78
+ response = await self._http_client.post(
79
+ url,
80
+ json={"name": tool_name, "input": arguments}
81
+ )
82
+ response.raise_for_status()
83
+ return response.json()
84
+
85
+ async def call_tool_streaming(self, tool_name: str, arguments: Dict[str, Any]) -> AsyncIterator[Any]:
86
+ """Call a tool with streaming response support."""
87
+ if not self._connected:
88
+ await self.connect()
89
+
90
+ # Stream the tool call response
91
+ async for chunk in self._stream_tool_call(tool_name, arguments):
92
+ yield chunk
93
+
94
+ async def _stream_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> AsyncIterator[Any]:
95
+ """Stream a tool call response for streamable-http connections."""
96
+ if not self._http_client:
97
+ raise RuntimeError("HTTP client not initialized")
98
+
99
+ url = str(self.mcp_server.url).rstrip('/') + '/call'
100
+
101
+ # Make streaming request
102
+ async with self._http_client.stream(
103
+ "POST",
104
+ url,
105
+ json={"name": tool_name, "input": arguments},
106
+ headers={"Accept": "text/event-stream"}
107
+ ) as response:
108
+ async for line in response.aiter_lines():
109
+ if line.startswith("data: "):
110
+ data = line[6:] # Remove "data: " prefix
111
+ if data:
112
+ try:
113
+ yield json.loads(data)
114
+ except json.JSONDecodeError:
115
+ logger.warning(f"Failed to parse streaming data: {data}")
116
+
117
+ async def list_tools(self) -> List[Dict[str, Any]]:
118
+ """List available tools."""
119
+ if not self._connected:
120
+ await self.connect()
121
+
122
+ return self._available_tools
123
+
124
+ async def disconnect(self) -> None:
125
+ """Disconnect from the MCP server."""
126
+ self._connected = False
127
+ self._available_tools = []
128
+
129
+ if self._http_client:
130
+ await self._http_client.aclose()
131
+ self._http_client = None
132
+
133
+ async def health_check(self) -> bool:
134
+ """Check if the MCP server connection is healthy."""
135
+ try:
136
+ if not self._connected or not self._http_client:
137
+ return False
138
+
139
+ # Try to ping the server
140
+ response = await self._http_client.get(f"{self.mcp_server.url}/health")
141
+ return response.status_code == 200
142
+ except Exception as e:
143
+ logger.warning(f"Health check failed for {self.mcp_server.name}: {e}")
144
+ return False
145
+
146
+ async def __aenter__(self):
147
+ """Async context manager entry."""
148
+ await self.connect()
149
+ return self
150
+
151
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
152
+ """Async context manager exit."""
153
+ await self.disconnect()
154
+
155
+
156
+ class MCPToolWrapper:
157
+ """Wrapper to expose MCP tools as CrewAI-compatible tools with connection pooling."""
158
+
159
+ # Class-level connection pool
160
+ _connection_pool: Dict[str, MCPClient] = {}
161
+ _pool_lock = asyncio.Lock()
162
+
163
+ def __init__(self, mcp_server: MCPServer, tool_name: str, tool_description: str):
164
+ self.mcp_server = mcp_server
165
+ self.tool_name = tool_name
166
+ self.description = tool_description
167
+ self.name = tool_name
168
+
169
+ @classmethod
170
+ async def get_or_create_client(cls, mcp_server: MCPServer) -> MCPClient:
171
+ """Get or create a client from the connection pool."""
172
+ server_key = f"{mcp_server.name}:{mcp_server.url}"
173
+
174
+ async with cls._pool_lock:
175
+ if server_key not in cls._connection_pool:
176
+ # Create new client
177
+ client = MCPClient(mcp_server)
178
+ await client.connect()
179
+ cls._connection_pool[server_key] = client
180
+ else:
181
+ # Check if existing client is healthy
182
+ client = cls._connection_pool[server_key]
183
+ if not await client.health_check():
184
+ # Reconnect if unhealthy
185
+ await client.disconnect()
186
+ await client.connect()
187
+
188
+ return cls._connection_pool[server_key]
189
+
190
+ async def __call__(self, **kwargs) -> Any:
191
+ """Execute the MCP tool using pooled connection."""
192
+ client = await self.get_or_create_client(self.mcp_server)
193
+ return await client.call_tool(self.tool_name, kwargs)
194
+
195
+ @classmethod
196
+ async def cleanup_pool(cls):
197
+ """Clean up all connections in the pool."""
198
+ async with cls._pool_lock:
199
+ for client in cls._connection_pool.values():
200
+ await client.disconnect()
201
+ cls._connection_pool.clear()