camel-ai 0.2.59__py3-none-any.whl → 0.2.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +158 -7
- camel/configs/anthropic_config.py +6 -5
- camel/configs/cohere_config.py +1 -1
- camel/configs/mistral_config.py +1 -1
- camel/configs/openai_config.py +3 -0
- camel/configs/reka_config.py +1 -1
- camel/configs/samba_config.py +2 -2
- camel/datagen/cot_datagen.py +29 -34
- camel/datagen/evol_instruct/scorer.py +22 -23
- camel/datagen/evol_instruct/templates.py +46 -46
- camel/datasets/static_dataset.py +144 -0
- camel/embeddings/jina_embedding.py +8 -1
- camel/embeddings/sentence_transformers_embeddings.py +2 -2
- camel/embeddings/vlm_embedding.py +9 -2
- camel/loaders/__init__.py +5 -2
- camel/loaders/chunkr_reader.py +117 -91
- camel/loaders/mistral_reader.py +148 -0
- camel/memories/blocks/chat_history_block.py +1 -2
- camel/memories/records.py +3 -0
- camel/messages/base.py +15 -3
- camel/models/azure_openai_model.py +1 -0
- camel/models/model_factory.py +2 -2
- camel/models/model_manager.py +7 -3
- camel/retrievers/bm25_retriever.py +1 -2
- camel/retrievers/hybrid_retrival.py +2 -2
- camel/societies/workforce/workforce.py +65 -24
- camel/storages/__init__.py +2 -0
- camel/storages/vectordb_storages/__init__.py +2 -0
- camel/storages/vectordb_storages/faiss.py +712 -0
- camel/storages/vectordb_storages/oceanbase.py +1 -2
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/async_browser_toolkit.py +80 -524
- camel/toolkits/bohrium_toolkit.py +318 -0
- camel/toolkits/browser_toolkit.py +221 -541
- camel/toolkits/browser_toolkit_commons.py +568 -0
- camel/toolkits/dalle_toolkit.py +4 -0
- camel/toolkits/excel_toolkit.py +8 -2
- camel/toolkits/file_write_toolkit.py +76 -29
- camel/toolkits/github_toolkit.py +43 -25
- camel/toolkits/image_analysis_toolkit.py +3 -0
- camel/toolkits/jina_reranker_toolkit.py +194 -77
- camel/toolkits/mcp_toolkit.py +134 -16
- camel/toolkits/page_script.js +40 -28
- camel/toolkits/twitter_toolkit.py +6 -1
- camel/toolkits/video_analysis_toolkit.py +3 -0
- camel/toolkits/video_download_toolkit.py +3 -0
- camel/toolkits/wolfram_alpha_toolkit.py +51 -23
- camel/types/enums.py +27 -6
- camel/utils/__init__.py +2 -0
- camel/utils/commons.py +27 -0
- {camel_ai-0.2.59.dist-info → camel_ai-0.2.61.dist-info}/METADATA +17 -9
- {camel_ai-0.2.59.dist-info → camel_ai-0.2.61.dist-info}/RECORD +55 -51
- {camel_ai-0.2.59.dist-info → camel_ai-0.2.61.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.59.dist-info → camel_ai-0.2.61.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,7 +11,11 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
-
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
from typing import Any, Dict, List, Optional
|
|
17
|
+
|
|
18
|
+
import requests
|
|
15
19
|
|
|
16
20
|
from camel.toolkits import FunctionTool
|
|
17
21
|
from camel.toolkits.base import BaseToolkit
|
|
@@ -30,7 +34,9 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
30
34
|
def __init__(
|
|
31
35
|
self,
|
|
32
36
|
timeout: Optional[float] = None,
|
|
37
|
+
model_name: Optional[str] = "jinaai/jina-reranker-m0",
|
|
33
38
|
device: Optional[str] = None,
|
|
39
|
+
use_api: bool = True,
|
|
34
40
|
) -> None:
|
|
35
41
|
r"""Initializes a new instance of the JinaRerankerToolkit class.
|
|
36
42
|
|
|
@@ -38,31 +44,57 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
38
44
|
timeout (Optional[float]): The timeout value for API requests
|
|
39
45
|
in seconds. If None, no timeout is applied.
|
|
40
46
|
(default: :obj:`None`)
|
|
47
|
+
model_name (Optional[str]): The reranker model name. If None,
|
|
48
|
+
will use the default model.
|
|
49
|
+
(default: :obj:`None`)
|
|
41
50
|
device (Optional[str]): Device to load the model on. If None,
|
|
42
51
|
will use CUDA if available, otherwise CPU.
|
|
52
|
+
Only effective when use_api=False.
|
|
43
53
|
(default: :obj:`None`)
|
|
54
|
+
use_api (bool): A flag to switch between local model and API.
|
|
55
|
+
(default: :obj:`True`)
|
|
44
56
|
"""
|
|
45
|
-
import torch
|
|
46
|
-
from transformers import AutoModel
|
|
47
57
|
|
|
48
58
|
super().__init__(timeout=timeout)
|
|
49
59
|
|
|
50
|
-
self.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
self.use_api = use_api
|
|
61
|
+
self.model_name = model_name
|
|
62
|
+
|
|
63
|
+
if self.use_api:
|
|
64
|
+
self.model = None
|
|
65
|
+
self._api_key = os.environ.get("JINA_API_KEY", "None")
|
|
66
|
+
if self._api_key == "None":
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"Missing or empty required API keys in "
|
|
69
|
+
"environment variables\n"
|
|
70
|
+
"You can obtain the API key from https://jina.ai/reranker/"
|
|
71
|
+
)
|
|
72
|
+
self.url = 'https://api.jina.ai/v1/rerank'
|
|
73
|
+
self.headers = {
|
|
74
|
+
'Content-Type': 'application/json',
|
|
75
|
+
'Accept': 'application/json',
|
|
76
|
+
'Authorization': f'Bearer {self._api_key}',
|
|
77
|
+
}
|
|
78
|
+
else:
|
|
79
|
+
import torch
|
|
80
|
+
from transformers import AutoModel
|
|
81
|
+
|
|
82
|
+
self.model = AutoModel.from_pretrained(
|
|
83
|
+
self.model_name,
|
|
84
|
+
torch_dtype="auto",
|
|
85
|
+
trust_remote_code=True,
|
|
86
|
+
)
|
|
87
|
+
self.device = (
|
|
88
|
+
device
|
|
89
|
+
if device is not None
|
|
90
|
+
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
91
|
+
)
|
|
92
|
+
self.model.to(self.device)
|
|
93
|
+
self.model.eval()
|
|
62
94
|
|
|
63
95
|
def _sort_documents(
|
|
64
96
|
self, documents: List[str], scores: List[float]
|
|
65
|
-
) -> List[
|
|
97
|
+
) -> List[Dict[str, object]]:
|
|
66
98
|
r"""Sort documents by their scores in descending order.
|
|
67
99
|
|
|
68
100
|
Args:
|
|
@@ -70,7 +102,7 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
70
102
|
scores (List[float]): Corresponding scores for each document.
|
|
71
103
|
|
|
72
104
|
Returns:
|
|
73
|
-
List[
|
|
105
|
+
List[Dict[str, object]]: Sorted list of (document, score) pairs.
|
|
74
106
|
|
|
75
107
|
Raises:
|
|
76
108
|
ValueError: If documents and scores have different lengths.
|
|
@@ -80,14 +112,45 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
80
112
|
doc_score_pairs = list(zip(documents, scores))
|
|
81
113
|
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
|
82
114
|
|
|
83
|
-
|
|
115
|
+
results = [
|
|
116
|
+
{'document': {'text': doc}, 'relevance_score': score}
|
|
117
|
+
for doc, score in doc_score_pairs
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
return results
|
|
121
|
+
|
|
122
|
+
def _call_jina_api(self, data: Dict[str, Any]) -> List[Dict[str, object]]:
|
|
123
|
+
r"""Makes a call to the JINA API for reranking.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
data (Dict[str]): The data to be passed into the api body.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
List[Dict[str, object]]: A list of dictionary containing
|
|
130
|
+
the reranked documents and their relevance scores.
|
|
131
|
+
"""
|
|
132
|
+
try:
|
|
133
|
+
response = requests.post(
|
|
134
|
+
self.url,
|
|
135
|
+
headers=self.headers,
|
|
136
|
+
data=json.dumps(data),
|
|
137
|
+
timeout=self.timeout,
|
|
138
|
+
)
|
|
139
|
+
response.raise_for_status()
|
|
140
|
+
results = [
|
|
141
|
+
{key: value for key, value in _res.items() if key != 'index'}
|
|
142
|
+
for _res in response.json()['results']
|
|
143
|
+
]
|
|
144
|
+
return results
|
|
145
|
+
except requests.exceptions.RequestException as e:
|
|
146
|
+
raise RuntimeError(f"Failed to get response from Jina AI: {e}")
|
|
84
147
|
|
|
85
148
|
def rerank_text_documents(
|
|
86
149
|
self,
|
|
87
150
|
query: str,
|
|
88
151
|
documents: List[str],
|
|
89
152
|
max_length: int = 1024,
|
|
90
|
-
) -> List[
|
|
153
|
+
) -> List[Dict[str, object]]:
|
|
91
154
|
r"""Reranks text documents based on their relevance to a text query.
|
|
92
155
|
|
|
93
156
|
Args:
|
|
@@ -97,21 +160,34 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
97
160
|
(default: :obj:`1024`)
|
|
98
161
|
|
|
99
162
|
Returns:
|
|
100
|
-
List[
|
|
163
|
+
List[Dict[str, object]]: A list of dictionary containing
|
|
101
164
|
the reranked documents and their relevance scores.
|
|
102
165
|
"""
|
|
103
|
-
import torch
|
|
104
166
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
)
|
|
167
|
+
data = {
|
|
168
|
+
'model': self.model_name,
|
|
169
|
+
'query': query,
|
|
170
|
+
'top_n': len(documents),
|
|
171
|
+
'documents': documents,
|
|
172
|
+
'return_documents': True,
|
|
173
|
+
}
|
|
109
174
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
175
|
+
if self.use_api:
|
|
176
|
+
return self._call_jina_api(data)
|
|
177
|
+
|
|
178
|
+
else:
|
|
179
|
+
import torch
|
|
180
|
+
|
|
181
|
+
if self.model is None:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"Model has not been initialized or failed to initialize."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
with torch.inference_mode():
|
|
187
|
+
text_pairs = [[query, doc] for doc in documents]
|
|
188
|
+
scores = self.model.compute_score(
|
|
189
|
+
text_pairs, max_length=max_length, doc_type="text"
|
|
190
|
+
)
|
|
115
191
|
|
|
116
192
|
return self._sort_documents(documents, scores)
|
|
117
193
|
|
|
@@ -120,7 +196,7 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
120
196
|
query: str,
|
|
121
197
|
documents: List[str],
|
|
122
198
|
max_length: int = 2048,
|
|
123
|
-
) -> List[
|
|
199
|
+
) -> List[Dict[str, object]]:
|
|
124
200
|
r"""Reranks image documents based on their relevance to a text query.
|
|
125
201
|
|
|
126
202
|
Args:
|
|
@@ -130,21 +206,33 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
130
206
|
(default: :obj:`2048`)
|
|
131
207
|
|
|
132
208
|
Returns:
|
|
133
|
-
List[
|
|
209
|
+
List[Dict[str, object]]: A list of dictionary containing
|
|
134
210
|
the reranked image URLs/paths and their relevance scores.
|
|
135
211
|
"""
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
212
|
+
data = {
|
|
213
|
+
'model': self.model_name,
|
|
214
|
+
'query': query,
|
|
215
|
+
'top_n': len(documents),
|
|
216
|
+
'documents': documents,
|
|
217
|
+
'return_documents': True,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
if self.use_api:
|
|
221
|
+
return self._call_jina_api(data)
|
|
222
|
+
|
|
223
|
+
else:
|
|
224
|
+
import torch
|
|
225
|
+
|
|
226
|
+
if self.model is None:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"Model has not been initialized or failed to initialize."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
with torch.inference_mode():
|
|
232
|
+
image_pairs = [[query, doc] for doc in documents]
|
|
233
|
+
scores = self.model.compute_score(
|
|
234
|
+
image_pairs, max_length=max_length, doc_type="image"
|
|
235
|
+
)
|
|
148
236
|
|
|
149
237
|
return self._sort_documents(documents, scores)
|
|
150
238
|
|
|
@@ -153,7 +241,7 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
153
241
|
image_query: str,
|
|
154
242
|
documents: List[str],
|
|
155
243
|
max_length: int = 2048,
|
|
156
|
-
) -> List[
|
|
244
|
+
) -> List[Dict[str, object]]:
|
|
157
245
|
r"""Reranks text documents based on their relevance to an image query.
|
|
158
246
|
|
|
159
247
|
Args:
|
|
@@ -163,30 +251,45 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
163
251
|
(default: :obj:`2048`)
|
|
164
252
|
|
|
165
253
|
Returns:
|
|
166
|
-
List[
|
|
254
|
+
List[Dict[str, object]]: A list of dictionary containing
|
|
167
255
|
the reranked documents and their relevance scores.
|
|
168
256
|
"""
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
257
|
+
data = {
|
|
258
|
+
'model': self.model_name,
|
|
259
|
+
'query': image_query,
|
|
260
|
+
'top_n': len(documents),
|
|
261
|
+
'documents': documents,
|
|
262
|
+
'return_documents': True,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if self.use_api:
|
|
266
|
+
return self._call_jina_api(data)
|
|
267
|
+
|
|
268
|
+
else:
|
|
269
|
+
import torch
|
|
270
|
+
|
|
271
|
+
if self.model is None:
|
|
272
|
+
raise ValueError(
|
|
273
|
+
"Model has not been initialized or failed to initialize."
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
with torch.inference_mode():
|
|
277
|
+
image_pairs = [[image_query, doc] for doc in documents]
|
|
278
|
+
scores = self.model.compute_score(
|
|
279
|
+
image_pairs,
|
|
280
|
+
max_length=max_length,
|
|
281
|
+
query_type="image",
|
|
282
|
+
doc_type="text",
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return self._sort_documents(documents, scores)
|
|
183
286
|
|
|
184
287
|
def image_query_image_documents(
|
|
185
288
|
self,
|
|
186
289
|
image_query: str,
|
|
187
290
|
documents: List[str],
|
|
188
291
|
max_length: int = 2048,
|
|
189
|
-
) -> List[
|
|
292
|
+
) -> List[Dict[str, object]]:
|
|
190
293
|
r"""Reranks image documents based on their relevance to an image query.
|
|
191
294
|
|
|
192
295
|
Args:
|
|
@@ -196,24 +299,38 @@ class JinaRerankerToolkit(BaseToolkit):
|
|
|
196
299
|
(default: :obj:`2048`)
|
|
197
300
|
|
|
198
301
|
Returns:
|
|
199
|
-
List[
|
|
302
|
+
List[Dict[str, object]]: A list of dictionary containing
|
|
200
303
|
the reranked image URLs/paths and their relevance scores.
|
|
201
304
|
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
305
|
+
data = {
|
|
306
|
+
'model': self.model_name,
|
|
307
|
+
'query': image_query,
|
|
308
|
+
'top_n': len(documents),
|
|
309
|
+
'documents': documents,
|
|
310
|
+
'return_documents': True,
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
if self.use_api:
|
|
314
|
+
return self._call_jina_api(data)
|
|
315
|
+
|
|
316
|
+
else:
|
|
317
|
+
import torch
|
|
318
|
+
|
|
319
|
+
if self.model is None:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
"Model has not been initialized or failed to initialize."
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
with torch.inference_mode():
|
|
325
|
+
image_pairs = [[image_query, doc] for doc in documents]
|
|
326
|
+
scores = self.model.compute_score(
|
|
327
|
+
image_pairs,
|
|
328
|
+
max_length=max_length,
|
|
329
|
+
query_type="image",
|
|
330
|
+
doc_type="image",
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return self._sort_documents(documents, scores)
|
|
217
334
|
|
|
218
335
|
def get_tools(self) -> List[FunctionTool]:
|
|
219
336
|
r"""Returns a list of FunctionTool objects representing the
|
camel/toolkits/mcp_toolkit.py
CHANGED
|
@@ -34,8 +34,10 @@ from urllib.parse import urlparse
|
|
|
34
34
|
if TYPE_CHECKING:
|
|
35
35
|
from mcp import ClientSession, ListToolsResult, Tool
|
|
36
36
|
|
|
37
|
+
|
|
37
38
|
from camel.logger import get_logger
|
|
38
39
|
from camel.toolkits import BaseToolkit, FunctionTool
|
|
40
|
+
from camel.utils.commons import run_async
|
|
39
41
|
|
|
40
42
|
logger = get_logger(__name__)
|
|
41
43
|
|
|
@@ -84,7 +86,6 @@ class MCPClient(BaseToolkit):
|
|
|
84
86
|
await client.disconnect()
|
|
85
87
|
```
|
|
86
88
|
|
|
87
|
-
|
|
88
89
|
Attributes:
|
|
89
90
|
command_or_url (str): URL for SSE mode or command executable for stdio
|
|
90
91
|
mode. (default: :obj:`None`)
|
|
@@ -96,6 +97,10 @@ class MCPClient(BaseToolkit):
|
|
|
96
97
|
(default: :obj:`None`)
|
|
97
98
|
headers (Dict[str, str]): Headers for the HTTP request.
|
|
98
99
|
(default: :obj:`None`)
|
|
100
|
+
mode (Optional[str]): Connection mode. Can be "sse" for Server-Sent
|
|
101
|
+
Events, "streamable-http" for streaming HTTP,
|
|
102
|
+
or None for stdio mode.
|
|
103
|
+
(default: :obj:`None`)
|
|
99
104
|
strict (Optional[bool]): Whether to enforce strict mode for the
|
|
100
105
|
function call. (default: :obj:`False`)
|
|
101
106
|
"""
|
|
@@ -107,6 +112,7 @@ class MCPClient(BaseToolkit):
|
|
|
107
112
|
env: Optional[Dict[str, str]] = None,
|
|
108
113
|
timeout: Optional[float] = None,
|
|
109
114
|
headers: Optional[Dict[str, str]] = None,
|
|
115
|
+
mode: Optional[str] = None,
|
|
110
116
|
strict: Optional[bool] = False,
|
|
111
117
|
):
|
|
112
118
|
from mcp import Tool
|
|
@@ -118,6 +124,7 @@ class MCPClient(BaseToolkit):
|
|
|
118
124
|
self.env = env or {}
|
|
119
125
|
self.headers = headers or {}
|
|
120
126
|
self.strict = strict
|
|
127
|
+
self.mode = mode
|
|
121
128
|
|
|
122
129
|
self._mcp_tools: List[Tool] = []
|
|
123
130
|
self._session: Optional['ClientSession'] = None
|
|
@@ -133,6 +140,7 @@ class MCPClient(BaseToolkit):
|
|
|
133
140
|
from mcp.client.session import ClientSession
|
|
134
141
|
from mcp.client.sse import sse_client
|
|
135
142
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
143
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
136
144
|
|
|
137
145
|
if self._is_connected:
|
|
138
146
|
logger.warning("Server is already connected")
|
|
@@ -140,16 +148,37 @@ class MCPClient(BaseToolkit):
|
|
|
140
148
|
|
|
141
149
|
try:
|
|
142
150
|
if urlparse(self.command_or_url).scheme in ("http", "https"):
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
+
if self.mode == "sse" or self.mode is None:
|
|
152
|
+
(
|
|
153
|
+
read_stream,
|
|
154
|
+
write_stream,
|
|
155
|
+
) = await self._exit_stack.enter_async_context(
|
|
156
|
+
sse_client(
|
|
157
|
+
self.command_or_url,
|
|
158
|
+
headers=self.headers,
|
|
159
|
+
timeout=self.timeout,
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
elif self.mode == "streamable-http":
|
|
163
|
+
try:
|
|
164
|
+
(
|
|
165
|
+
read_stream,
|
|
166
|
+
write_stream,
|
|
167
|
+
_,
|
|
168
|
+
) = await self._exit_stack.enter_async_context(
|
|
169
|
+
streamablehttp_client(
|
|
170
|
+
self.command_or_url,
|
|
171
|
+
headers=self.headers,
|
|
172
|
+
timeout=timedelta(seconds=self.timeout),
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
# Handle anyio task group errors
|
|
177
|
+
logger.error(f"Streamable HTTP client error: {e}")
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"Invalid mode '{self.mode}' for HTTP URL"
|
|
151
181
|
)
|
|
152
|
-
)
|
|
153
182
|
else:
|
|
154
183
|
command = self.command_or_url
|
|
155
184
|
arguments = self.args
|
|
@@ -192,16 +221,28 @@ class MCPClient(BaseToolkit):
|
|
|
192
221
|
logger.error(f"Failed to connect to MCP server: {e}")
|
|
193
222
|
raise e
|
|
194
223
|
|
|
224
|
+
def connect_sync(self):
|
|
225
|
+
r"""Synchronously connect to the MCP server."""
|
|
226
|
+
return run_async(self.connect)()
|
|
227
|
+
|
|
195
228
|
async def disconnect(self):
|
|
196
229
|
r"""Explicitly disconnect from the MCP server."""
|
|
197
230
|
# If the server is not connected, do nothing
|
|
198
231
|
if not self._is_connected:
|
|
199
232
|
return
|
|
200
233
|
self._is_connected = False
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
await self._exit_stack.aclose()
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logger.warning(f"{e}")
|
|
239
|
+
finally:
|
|
240
|
+
self._exit_stack = AsyncExitStack()
|
|
241
|
+
self._session = None
|
|
242
|
+
|
|
243
|
+
def disconnect_sync(self):
|
|
244
|
+
r"""Synchronously disconnect from the MCP server."""
|
|
245
|
+
return run_async(self.disconnect)()
|
|
205
246
|
|
|
206
247
|
@asynccontextmanager
|
|
207
248
|
async def connection(self):
|
|
@@ -217,7 +258,14 @@ class MCPClient(BaseToolkit):
|
|
|
217
258
|
await self.connect()
|
|
218
259
|
yield self
|
|
219
260
|
finally:
|
|
220
|
-
|
|
261
|
+
try:
|
|
262
|
+
await self.disconnect()
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.warning(f"Error: {e}")
|
|
265
|
+
|
|
266
|
+
def connection_sync(self):
|
|
267
|
+
r"""Synchronously connect to the MCP server."""
|
|
268
|
+
return run_async(self.connection)()
|
|
221
269
|
|
|
222
270
|
async def list_mcp_tools(self) -> Union[str, "ListToolsResult"]:
|
|
223
271
|
r"""Retrieves the list of available tools from the connected MCP
|
|
@@ -234,6 +282,11 @@ class MCPClient(BaseToolkit):
|
|
|
234
282
|
logger.exception("Failed to list MCP tools")
|
|
235
283
|
raise e
|
|
236
284
|
|
|
285
|
+
def list_mcp_tools_sync(self) -> Union[str, "ListToolsResult"]:
|
|
286
|
+
r"""Synchronously list the available tools from the connected MCP
|
|
287
|
+
server."""
|
|
288
|
+
return run_async(self.list_mcp_tools)()
|
|
289
|
+
|
|
237
290
|
def generate_function_from_mcp_tool(self, mcp_tool: "Tool") -> Callable:
|
|
238
291
|
r"""Dynamically generates a Python callable function corresponding to
|
|
239
292
|
a given MCP tool.
|
|
@@ -355,6 +408,10 @@ class MCPClient(BaseToolkit):
|
|
|
355
408
|
|
|
356
409
|
return dynamic_function
|
|
357
410
|
|
|
411
|
+
def generate_function_from_mcp_tool_sync(self, mcp_tool: "Tool") -> Any:
|
|
412
|
+
r"""Synchronously generate a function from an MCP tool."""
|
|
413
|
+
return run_async(self.generate_function_from_mcp_tool)(mcp_tool)
|
|
414
|
+
|
|
358
415
|
def _build_tool_schema(self, mcp_tool: "Tool") -> Dict[str, Any]:
|
|
359
416
|
input_schema = mcp_tool.inputSchema
|
|
360
417
|
properties = input_schema.get("properties", {})
|
|
@@ -428,6 +485,10 @@ class MCPClient(BaseToolkit):
|
|
|
428
485
|
|
|
429
486
|
return await self._session.call_tool(tool_name, tool_args)
|
|
430
487
|
|
|
488
|
+
def call_tool_sync(self, tool_name: str, tool_args: Dict[str, Any]) -> Any:
|
|
489
|
+
r"""Synchronously call a tool."""
|
|
490
|
+
return run_async(self.call_tool)(tool_name, tool_args)
|
|
491
|
+
|
|
431
492
|
@property
|
|
432
493
|
def session(self) -> Optional["ClientSession"]:
|
|
433
494
|
return self._session
|
|
@@ -440,6 +501,7 @@ class MCPClient(BaseToolkit):
|
|
|
440
501
|
env: Optional[Dict[str, str]] = None,
|
|
441
502
|
timeout: Optional[float] = None,
|
|
442
503
|
headers: Optional[Dict[str, str]] = None,
|
|
504
|
+
mode: Optional[str] = None,
|
|
443
505
|
) -> "MCPClient":
|
|
444
506
|
r"""Factory method that creates and connects to the MCP server.
|
|
445
507
|
|
|
@@ -457,6 +519,10 @@ class MCPClient(BaseToolkit):
|
|
|
457
519
|
(default: :obj:`None`)
|
|
458
520
|
headers (Optional[Dict[str, str]]): Headers for the HTTP request.
|
|
459
521
|
(default: :obj:`None`)
|
|
522
|
+
mode (Optional[str]): Connection mode. Can be "sse" for
|
|
523
|
+
Server-Sent Events, "streamable-http" for
|
|
524
|
+
streaming HTTP, or None for stdio mode.
|
|
525
|
+
(default: :obj:`None`)
|
|
460
526
|
|
|
461
527
|
Returns:
|
|
462
528
|
MCPClient: A fully initialized and connected MCPClient instance.
|
|
@@ -470,6 +536,7 @@ class MCPClient(BaseToolkit):
|
|
|
470
536
|
env=env,
|
|
471
537
|
timeout=timeout,
|
|
472
538
|
headers=headers,
|
|
539
|
+
mode=mode,
|
|
473
540
|
)
|
|
474
541
|
try:
|
|
475
542
|
await client.connect()
|
|
@@ -480,6 +547,21 @@ class MCPClient(BaseToolkit):
|
|
|
480
547
|
logger.error(f"Failed to initialize MCPClient: {e}")
|
|
481
548
|
raise RuntimeError(f"Failed to initialize MCPClient: {e}") from e
|
|
482
549
|
|
|
550
|
+
@classmethod
|
|
551
|
+
def create_sync(
|
|
552
|
+
self,
|
|
553
|
+
command_or_url: str,
|
|
554
|
+
args: Optional[List[str]] = None,
|
|
555
|
+
env: Optional[Dict[str, str]] = None,
|
|
556
|
+
timeout: Optional[float] = None,
|
|
557
|
+
headers: Optional[Dict[str, str]] = None,
|
|
558
|
+
mode: Optional[str] = None,
|
|
559
|
+
) -> "MCPClient":
|
|
560
|
+
r"""Synchronously create and connect to the MCP server."""
|
|
561
|
+
return run_async(self.create)(
|
|
562
|
+
command_or_url, args, env, timeout, headers, mode
|
|
563
|
+
)
|
|
564
|
+
|
|
483
565
|
async def __aenter__(self) -> "MCPClient":
|
|
484
566
|
r"""Async context manager entry point. Automatically connects to the
|
|
485
567
|
MCP server when used in an async with statement.
|
|
@@ -490,12 +572,35 @@ class MCPClient(BaseToolkit):
|
|
|
490
572
|
await self.connect()
|
|
491
573
|
return self
|
|
492
574
|
|
|
493
|
-
|
|
575
|
+
def __enter__(self) -> "MCPClient":
|
|
576
|
+
r"""Synchronously enter the async context manager."""
|
|
577
|
+
return run_async(self.__aenter__)()
|
|
578
|
+
|
|
579
|
+
async def __aexit__(self) -> None:
|
|
494
580
|
r"""Async context manager exit point. Automatically disconnects from
|
|
495
581
|
the MCP server when exiting an async with statement.
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
None
|
|
496
585
|
"""
|
|
497
586
|
await self.disconnect()
|
|
498
587
|
|
|
588
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
589
|
+
r"""Synchronously exit the async context manager.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
exc_type (Optional[Type[Exception]]): The type of exception that
|
|
593
|
+
occurred during the execution of the with statement.
|
|
594
|
+
exc_val (Optional[Exception]): The exception that occurred during
|
|
595
|
+
the execution of the with statement.
|
|
596
|
+
exc_tb (Optional[TracebackType]): The traceback of the exception
|
|
597
|
+
that occurred during the execution of the with statement.
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
None
|
|
601
|
+
"""
|
|
602
|
+
return run_async(self.__aexit__)()
|
|
603
|
+
|
|
499
604
|
|
|
500
605
|
class MCPToolkit(BaseToolkit):
|
|
501
606
|
r"""MCPToolkit provides a unified interface for managing multiple
|
|
@@ -679,6 +784,7 @@ class MCPToolkit(BaseToolkit):
|
|
|
679
784
|
env={**os.environ, **cfg.get("env", {})},
|
|
680
785
|
timeout=cfg.get("timeout", None),
|
|
681
786
|
headers=headers,
|
|
787
|
+
mode=cfg.get("mode", None),
|
|
682
788
|
strict=strict,
|
|
683
789
|
)
|
|
684
790
|
all_servers.append(server)
|
|
@@ -707,6 +813,10 @@ class MCPToolkit(BaseToolkit):
|
|
|
707
813
|
logger.error(f"Failed to connect to one or more MCP servers: {e}")
|
|
708
814
|
raise e
|
|
709
815
|
|
|
816
|
+
def connect_sync(self):
|
|
817
|
+
r"""Synchronously connect to all MCP servers."""
|
|
818
|
+
return run_async(self.connect)()
|
|
819
|
+
|
|
710
820
|
async def disconnect(self):
|
|
711
821
|
r"""Explicitly disconnect from all MCP servers."""
|
|
712
822
|
if not self._connected:
|
|
@@ -716,6 +826,10 @@ class MCPToolkit(BaseToolkit):
|
|
|
716
826
|
await server.disconnect()
|
|
717
827
|
self._connected = False
|
|
718
828
|
|
|
829
|
+
def disconnect_sync(self):
|
|
830
|
+
r"""Synchronously disconnect from all MCP servers."""
|
|
831
|
+
return run_async(self.disconnect)()
|
|
832
|
+
|
|
719
833
|
@asynccontextmanager
|
|
720
834
|
async def connection(self) -> AsyncGenerator["MCPToolkit", None]:
|
|
721
835
|
r"""Async context manager that simultaneously establishes connections
|
|
@@ -730,6 +844,10 @@ class MCPToolkit(BaseToolkit):
|
|
|
730
844
|
finally:
|
|
731
845
|
await self.disconnect()
|
|
732
846
|
|
|
847
|
+
def connection_sync(self):
|
|
848
|
+
r"""Synchronously connect to all MCP servers."""
|
|
849
|
+
return run_async(self.connection)()
|
|
850
|
+
|
|
733
851
|
def is_connected(self) -> bool:
|
|
734
852
|
r"""Checks if all the managed servers are connected.
|
|
735
853
|
|