camel-ai 0.2.72a8__py3-none-any.whl → 0.2.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +140 -345
- camel/memories/agent_memories.py +18 -17
- camel/societies/__init__.py +2 -0
- camel/societies/workforce/prompts.py +36 -10
- camel/societies/workforce/single_agent_worker.py +7 -5
- camel/societies/workforce/workforce.py +6 -4
- camel/storages/key_value_storages/mem0_cloud.py +48 -47
- camel/storages/vectordb_storages/__init__.py +1 -0
- camel/storages/vectordb_storages/surreal.py +100 -150
- camel/toolkits/__init__.py +6 -1
- camel/toolkits/base.py +60 -2
- camel/toolkits/excel_toolkit.py +153 -64
- camel/toolkits/file_write_toolkit.py +67 -0
- camel/toolkits/hybrid_browser_toolkit/config_loader.py +136 -413
- camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit.py +131 -1966
- camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit_ts.py +1177 -0
- camel/toolkits/hybrid_browser_toolkit/ts/package-lock.json +4356 -0
- camel/toolkits/hybrid_browser_toolkit/ts/package.json +33 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/browser-scripts.js +125 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/browser-session.ts +945 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/config-loader.ts +226 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/hybrid-browser-toolkit.ts +522 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/index.ts +7 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/types.ts +110 -0
- camel/toolkits/hybrid_browser_toolkit/ts/tsconfig.json +26 -0
- camel/toolkits/hybrid_browser_toolkit/ts/websocket-server.js +254 -0
- camel/toolkits/hybrid_browser_toolkit/ws_wrapper.py +582 -0
- camel/toolkits/hybrid_browser_toolkit_py/__init__.py +17 -0
- camel/toolkits/hybrid_browser_toolkit_py/config_loader.py +447 -0
- camel/toolkits/hybrid_browser_toolkit_py/hybrid_browser_toolkit.py +2077 -0
- camel/toolkits/mcp_toolkit.py +341 -46
- camel/toolkits/message_integration.py +719 -0
- camel/toolkits/note_taking_toolkit.py +18 -29
- camel/toolkits/notion_mcp_toolkit.py +234 -0
- camel/toolkits/screenshot_toolkit.py +116 -31
- camel/toolkits/search_toolkit.py +20 -2
- camel/toolkits/slack_toolkit.py +43 -48
- camel/toolkits/terminal_toolkit.py +288 -46
- camel/toolkits/video_analysis_toolkit.py +13 -13
- camel/toolkits/video_download_toolkit.py +11 -11
- camel/toolkits/web_deploy_toolkit.py +207 -12
- camel/types/enums.py +6 -0
- {camel_ai-0.2.72a8.dist-info → camel_ai-0.2.73.dist-info}/METADATA +49 -9
- {camel_ai-0.2.72a8.dist-info → camel_ai-0.2.73.dist-info}/RECORD +53 -36
- /camel/toolkits/{hybrid_browser_toolkit → hybrid_browser_toolkit_py}/actions.py +0 -0
- /camel/toolkits/{hybrid_browser_toolkit → hybrid_browser_toolkit_py}/agent.py +0 -0
- /camel/toolkits/{hybrid_browser_toolkit → hybrid_browser_toolkit_py}/browser_session.py +0 -0
- /camel/toolkits/{hybrid_browser_toolkit → hybrid_browser_toolkit_py}/snapshot.py +0 -0
- /camel/toolkits/{hybrid_browser_toolkit → hybrid_browser_toolkit_py}/stealth_script.js +0 -0
- /camel/toolkits/{hybrid_browser_toolkit → hybrid_browser_toolkit_py}/unified_analyzer.js +0 -0
- {camel_ai-0.2.72a8.dist-info → camel_ai-0.2.73.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.72a8.dist-info → camel_ai-0.2.73.dist-info}/licenses/LICENSE +0 -0
|
@@ -68,7 +68,9 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
68
68
|
url: str = "ws://localhost:8000/rpc",
|
|
69
69
|
table: str = "vector_store",
|
|
70
70
|
vector_dim: int = 786,
|
|
71
|
+
vector_type: str = "F64",
|
|
71
72
|
distance: VectorDistance = VectorDistance.COSINE,
|
|
73
|
+
hnsw_effort: int = 40,
|
|
72
74
|
namespace: str = "default",
|
|
73
75
|
database: str = "demo",
|
|
74
76
|
user: str = "root",
|
|
@@ -96,6 +98,8 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
96
98
|
(default: :obj:`"root"`)
|
|
97
99
|
"""
|
|
98
100
|
|
|
101
|
+
from surrealdb import Surreal
|
|
102
|
+
|
|
99
103
|
self.url = url
|
|
100
104
|
self.table = table
|
|
101
105
|
self.ns = namespace
|
|
@@ -103,9 +107,14 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
103
107
|
self.user = user
|
|
104
108
|
self.password = password
|
|
105
109
|
self.vector_dim = vector_dim
|
|
110
|
+
self.vector_type = vector_type
|
|
106
111
|
self.distance = distance
|
|
112
|
+
self._hnsw_effort = hnsw_effort
|
|
113
|
+
self._surreal_client = Surreal(self.url)
|
|
114
|
+
self._surreal_client.signin({"username": user, "password": password})
|
|
115
|
+
self._surreal_client.use(namespace, database)
|
|
116
|
+
|
|
107
117
|
self._check_and_create_table()
|
|
108
|
-
self._surreal_client = None
|
|
109
118
|
|
|
110
119
|
def _table_exists(self) -> bool:
|
|
111
120
|
r"""Check whether the target table exists in the database.
|
|
@@ -113,74 +122,66 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
113
122
|
Returns:
|
|
114
123
|
bool: True if the table exists, False otherwise.
|
|
115
124
|
"""
|
|
116
|
-
|
|
125
|
+
res = self._surreal_client.query("INFO FOR DB;")
|
|
126
|
+
tables = res.get('tables', {})
|
|
127
|
+
logger.debug(f"_table_exists: {res}")
|
|
128
|
+
return self.table in tables
|
|
117
129
|
|
|
118
|
-
|
|
119
|
-
db.signin({"username": self.user, "password": self.password})
|
|
120
|
-
db.use(self.ns, self.db)
|
|
121
|
-
res = db.query_raw("INFO FOR DB;")
|
|
122
|
-
tables = res['result'][0]['result'].get('tables', {})
|
|
123
|
-
return self.table in tables
|
|
124
|
-
|
|
125
|
-
def _get_table_info(self) -> Dict[str, int]:
|
|
130
|
+
def _get_table_info(self) -> dict[str, int | None]:
|
|
126
131
|
r"""Retrieve dimension and record count from the table metadata.
|
|
127
132
|
|
|
128
133
|
Returns:
|
|
129
134
|
Dict[str, int]: A dictionary with 'dim' and 'count' keys.
|
|
130
135
|
"""
|
|
131
|
-
from surrealdb import Surreal # type: ignore[import-not-found]
|
|
132
|
-
|
|
133
136
|
if not self._table_exists():
|
|
134
137
|
return {"dim": self.vector_dim, "count": 0}
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
count = cnt['result'][0]['result'][0]['count']
|
|
151
|
-
except (KeyError, IndexError, TypeError):
|
|
152
|
-
logger.warning(
|
|
153
|
-
"Unexpected result format when counting records: %s", cnt
|
|
154
|
-
)
|
|
155
|
-
count = 0
|
|
156
|
-
|
|
157
|
-
return {"dim": dim, "count": count}
|
|
138
|
+
res = self._surreal_client.query(f"INFO FOR TABLE {self.table};")
|
|
139
|
+
logger.debug(f"_get_table_info: {res}")
|
|
140
|
+
indexes = res.get("indexes", {})
|
|
141
|
+
|
|
142
|
+
dim = self.vector_dim
|
|
143
|
+
idx_def = indexes.get("hnsw_idx")
|
|
144
|
+
if idx_def and isinstance(idx_def, str):
|
|
145
|
+
m = re.search(r"DIMENSION\s+(\d+)", idx_def)
|
|
146
|
+
if m:
|
|
147
|
+
dim = int(m.group(1))
|
|
148
|
+
cnt = self._surreal_client.query(
|
|
149
|
+
f"SELECT COUNT() FROM ONLY {self.table} GROUP ALL LIMIT 1;"
|
|
150
|
+
)
|
|
151
|
+
count = cnt.get("count", 0)
|
|
152
|
+
return {"dim": dim, "count": count}
|
|
158
153
|
|
|
159
154
|
def _create_table(self):
|
|
160
|
-
r"""Define and create the vector storage table with HNSW index.
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
f"
|
|
168
|
-
DEFINE FIELD payload ON {self.table} TYPE object;
|
|
169
|
-
DEFINE FIELD embedding ON {self.table} TYPE array;
|
|
170
|
-
DEFINE INDEX hnsw_idx ON {self.table}
|
|
171
|
-
FIELDS embedding HNSW DIMENSION {self.vector_dim};
|
|
172
|
-
"""
|
|
155
|
+
r"""Define and create the vector storage table with HNSW index.
|
|
156
|
+
|
|
157
|
+
Documentation: https://surrealdb.com/docs/surrealdb/reference-guide/
|
|
158
|
+
vector-search#vector-search-cheat-sheet
|
|
159
|
+
"""
|
|
160
|
+
if self.distance.value not in ["cosine", "euclidean", "manhattan"]:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Unsupported distance metric: {self.distance.value}"
|
|
173
163
|
)
|
|
164
|
+
surql_query = f"""
|
|
165
|
+
DEFINE TABLE {self.table} SCHEMALESS;
|
|
166
|
+
DEFINE FIELD payload ON {self.table} FLEXIBLE TYPE object;
|
|
167
|
+
DEFINE FIELD embedding ON {self.table} TYPE array<float>;
|
|
168
|
+
DEFINE INDEX hnsw_idx ON {self.table}
|
|
169
|
+
FIELDS embedding
|
|
170
|
+
HNSW DIMENSION {self.vector_dim}
|
|
171
|
+
DIST {self.distance.value}
|
|
172
|
+
TYPE {self.vector_type}
|
|
173
|
+
EFC 150 M 12 M0 24;
|
|
174
|
+
"""
|
|
175
|
+
logger.debug(f"_create_table query: {surql_query}")
|
|
176
|
+
res = self._surreal_client.query_raw(surql_query)
|
|
177
|
+
logger.debug(f"_create_table response: {res}")
|
|
178
|
+
if "error" in res:
|
|
179
|
+
raise ValueError(f"Failed to create table: {res['error']}")
|
|
174
180
|
logger.info(f"Table '{self.table}' created successfully.")
|
|
175
181
|
|
|
176
182
|
def _drop_table(self):
|
|
177
183
|
r"""Drop the vector storage table if it exists."""
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
with Surreal(self.url) as db:
|
|
181
|
-
db.signin({"username": self.user, "password": self.password})
|
|
182
|
-
db.use(self.ns, self.db)
|
|
183
|
-
db.query_raw(f"REMOVE TABLE IF EXISTS {self.table};")
|
|
184
|
+
self._surreal_client.query_raw(f"REMOVE TABLE IF EXISTS {self.table};")
|
|
184
185
|
logger.info(f"Table '{self.table}' deleted successfully.")
|
|
185
186
|
|
|
186
187
|
def _check_and_create_table(self):
|
|
@@ -240,62 +241,34 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
240
241
|
List[VectorDBQueryResult]: Ranked list of matching records
|
|
241
242
|
with similarity scores.
|
|
242
243
|
"""
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
VectorDistance.COSINE: "vector::similarity::cosine",
|
|
253
|
-
VectorDistance.EUCLIDEAN: "vector::distance::euclidean",
|
|
254
|
-
VectorDistance.DOT: "vector::dot",
|
|
255
|
-
}.get(self.distance)
|
|
256
|
-
|
|
257
|
-
if not metric_func:
|
|
258
|
-
raise ValueError(f"Unsupported distance metric: {self.distance}")
|
|
259
|
-
|
|
260
|
-
with Surreal(self.url) as db:
|
|
261
|
-
db.signin({"username": self.user, "password": self.password})
|
|
262
|
-
db.use(self.ns, self.db)
|
|
263
|
-
|
|
264
|
-
# Use parameterized query to prevent SQL injection
|
|
265
|
-
sql_query = f"""SELECT payload, embedding,
|
|
266
|
-
{metric_func}(embedding, $query_vec) AS score
|
|
267
|
-
FROM {self.table}
|
|
268
|
-
WHERE embedding <|{query.top_k},{metric}|> $query_vec
|
|
269
|
-
ORDER BY score;
|
|
270
|
-
"""
|
|
271
|
-
|
|
272
|
-
response = db.query_raw(
|
|
273
|
-
sql_query, {"query_vec": query.query_vector}
|
|
274
|
-
)
|
|
244
|
+
surql_query = f"""
|
|
245
|
+
SELECT id, embedding, payload, vector::distance::knn() AS dist
|
|
246
|
+
FROM {self.table}
|
|
247
|
+
WHERE embedding <|{query.top_k},{self._hnsw_effort}|> $vector
|
|
248
|
+
ORDER BY dist;
|
|
249
|
+
"""
|
|
250
|
+
logger.debug(
|
|
251
|
+
f"query surql: {surql_query} with $vector = {query.query_vector}"
|
|
252
|
+
)
|
|
275
253
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
else row["score"]
|
|
295
|
-
),
|
|
296
|
-
)
|
|
297
|
-
for row in results["result"]
|
|
298
|
-
]
|
|
254
|
+
response = self._surreal_client.query(
|
|
255
|
+
surql_query, {"vector": query.query_vector}
|
|
256
|
+
)
|
|
257
|
+
logger.debug(f"query response: {response}")
|
|
258
|
+
|
|
259
|
+
return [
|
|
260
|
+
VectorDBQueryResult(
|
|
261
|
+
record=VectorRecord(
|
|
262
|
+
id=row["id"].id,
|
|
263
|
+
vector=row["embedding"],
|
|
264
|
+
payload=row["payload"],
|
|
265
|
+
),
|
|
266
|
+
similarity=1.0 - row["dist"]
|
|
267
|
+
if self.distance == VectorDistance.COSINE
|
|
268
|
+
else -row["score"],
|
|
269
|
+
)
|
|
270
|
+
for row in response
|
|
271
|
+
]
|
|
299
272
|
|
|
300
273
|
def add(self, records: List[VectorRecord], **kwargs) -> None:
|
|
301
274
|
r"""Insert validated vector records into the SurrealDB table.
|
|
@@ -306,16 +279,10 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
306
279
|
logger.info(
|
|
307
280
|
"Adding %d records to table '%s'.", len(records), self.table
|
|
308
281
|
)
|
|
309
|
-
from surrealdb import Surreal # type: ignore[import-not-found]
|
|
310
|
-
|
|
311
282
|
try:
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
validated_records = self._validate_and_convert_records(records)
|
|
317
|
-
for record in validated_records:
|
|
318
|
-
db.create(self.table, record)
|
|
283
|
+
validated_records = self._validate_and_convert_records(records)
|
|
284
|
+
for record in validated_records:
|
|
285
|
+
self._surreal_client.create(self.table, record)
|
|
319
286
|
|
|
320
287
|
logger.info(
|
|
321
288
|
"Successfully added %d records to table '%s'.",
|
|
@@ -340,32 +307,23 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
340
307
|
ids (Optional[List[str]]): List of record IDs to delete.
|
|
341
308
|
if_all (bool): Whether to delete all records in the table.
|
|
342
309
|
"""
|
|
343
|
-
from surrealdb import
|
|
344
|
-
from surrealdb.data.types.record_id import ( # type: ignore[import-not-found]
|
|
345
|
-
RecordID,
|
|
346
|
-
)
|
|
310
|
+
from surrealdb.data.types.record_id import RecordID
|
|
347
311
|
|
|
348
312
|
try:
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
for id_str in ids:
|
|
366
|
-
rec = RecordID(self.table, id_str)
|
|
367
|
-
db.delete(rec, **kwargs)
|
|
368
|
-
logger.info(f"Deleted record {rec}")
|
|
313
|
+
if if_all:
|
|
314
|
+
self._surreal_client.delete(self.table, **kwargs)
|
|
315
|
+
logger.info(f"Deleted all records from table '{self.table}'")
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
if not ids:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
"Either `ids` must be provided or `if_all=True`"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
for id_str in ids:
|
|
324
|
+
rec = RecordID(self.table, id_str)
|
|
325
|
+
self._surreal_client.delete(rec, **kwargs)
|
|
326
|
+
logger.info(f"Deleted record {rec}")
|
|
369
327
|
|
|
370
328
|
except Exception as e:
|
|
371
329
|
logger.exception("Error deleting records from SurrealDB")
|
|
@@ -404,12 +362,4 @@ class SurrealStorage(BaseVectorStorage):
|
|
|
404
362
|
@property
|
|
405
363
|
def client(self) -> "Surreal":
|
|
406
364
|
r"""Provides access to the underlying SurrealDB client."""
|
|
407
|
-
if self._surreal_client is None:
|
|
408
|
-
from surrealdb import Surreal # type: ignore[import-not-found]
|
|
409
|
-
|
|
410
|
-
self._surreal_client = Surreal(self.url)
|
|
411
|
-
self._surreal_client.signin( # type: ignore[attr-defined]
|
|
412
|
-
{"username": self.user, "password": self.password}
|
|
413
|
-
)
|
|
414
|
-
self._surreal_client.use(self.ns, self.db) # type: ignore[attr-defined]
|
|
415
365
|
return self._surreal_client
|
camel/toolkits/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .meshy_toolkit import MeshyToolkit
|
|
|
31
31
|
from .openbb_toolkit import OpenBBToolkit
|
|
32
32
|
from .bohrium_toolkit import BohriumToolkit
|
|
33
33
|
|
|
34
|
-
from .base import BaseToolkit
|
|
34
|
+
from .base import BaseToolkit, RegisteredAgentToolkit
|
|
35
35
|
from .google_maps_toolkit import GoogleMapsToolkit
|
|
36
36
|
from .code_execution import CodeExecutionToolkit
|
|
37
37
|
from .github_toolkit import GithubToolkit
|
|
@@ -87,6 +87,8 @@ from .note_taking_toolkit import NoteTakingToolkit
|
|
|
87
87
|
from .message_agent_toolkit import AgentCommunicationToolkit
|
|
88
88
|
from .web_deploy_toolkit import WebDeployToolkit
|
|
89
89
|
from .screenshot_toolkit import ScreenshotToolkit
|
|
90
|
+
from .message_integration import ToolkitMessageIntegration
|
|
91
|
+
from .notion_mcp_toolkit import NotionMCPToolkit
|
|
90
92
|
|
|
91
93
|
__all__ = [
|
|
92
94
|
'BaseToolkit',
|
|
@@ -162,4 +164,7 @@ __all__ = [
|
|
|
162
164
|
'AgentCommunicationToolkit',
|
|
163
165
|
'WebDeployToolkit',
|
|
164
166
|
'ScreenshotToolkit',
|
|
167
|
+
'RegisteredAgentToolkit',
|
|
168
|
+
'ToolkitMessageIntegration',
|
|
169
|
+
'NotionMCPToolkit',
|
|
165
170
|
]
|
camel/toolkits/base.py
CHANGED
|
@@ -12,11 +12,18 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
-
from typing import List, Literal, Optional
|
|
15
|
+
from typing import TYPE_CHECKING, List, Literal, Optional
|
|
16
16
|
|
|
17
|
+
from camel.logger import get_logger
|
|
17
18
|
from camel.toolkits import FunctionTool
|
|
18
19
|
from camel.utils import AgentOpsMeta, with_timeout
|
|
19
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from camel.agents import ChatAgent
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
20
27
|
|
|
21
28
|
class BaseToolkit(metaclass=AgentOpsMeta):
|
|
22
29
|
r"""Base class for toolkits.
|
|
@@ -41,7 +48,9 @@ class BaseToolkit(metaclass=AgentOpsMeta):
|
|
|
41
48
|
super().__init_subclass__(**kwargs)
|
|
42
49
|
for attr_name, attr_value in cls.__dict__.items():
|
|
43
50
|
if callable(attr_value) and not attr_name.startswith("__"):
|
|
44
|
-
|
|
51
|
+
# Skip methods that have manual timeout management
|
|
52
|
+
if not getattr(attr_value, '_manual_timeout', False):
|
|
53
|
+
setattr(cls, attr_name, with_timeout(attr_value))
|
|
45
54
|
|
|
46
55
|
def get_tools(self) -> List[FunctionTool]:
|
|
47
56
|
r"""Returns a list of FunctionTool objects representing the
|
|
@@ -63,3 +72,52 @@ class BaseToolkit(metaclass=AgentOpsMeta):
|
|
|
63
72
|
the MCP server in.
|
|
64
73
|
"""
|
|
65
74
|
self.mcp.run(mode)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class RegisteredAgentToolkit:
|
|
78
|
+
r"""Mixin class for toolkits that need to register a ChatAgent.
|
|
79
|
+
|
|
80
|
+
This mixin provides a standard interface for toolkits that require
|
|
81
|
+
a reference to a ChatAgent instance. The ChatAgent will check if a
|
|
82
|
+
toolkit has this mixin and automatically register itself.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self):
|
|
86
|
+
self._agent: Optional["ChatAgent"] = None
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def agent(self) -> Optional["ChatAgent"]:
|
|
90
|
+
r"""Get the registered ChatAgent instance.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Optional[ChatAgent]: The registered agent, or None if not
|
|
94
|
+
registered.
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
If None is returned, it means the toolkit has not been registered
|
|
98
|
+
with a ChatAgent yet. Make sure to pass this toolkit to a ChatAgent
|
|
99
|
+
via the toolkits parameter during initialization.
|
|
100
|
+
"""
|
|
101
|
+
if self._agent is None:
|
|
102
|
+
logger.warning(
|
|
103
|
+
f"{self.__class__.__name__} does not have a "
|
|
104
|
+
f"registered ChatAgent. "
|
|
105
|
+
f"Please ensure this toolkit is passed to a ChatAgent via the "
|
|
106
|
+
f"'toolkits_to_register_agent' parameter during ChatAgent "
|
|
107
|
+
f"initialization if you want to use the tools that require a "
|
|
108
|
+
f"registered agent."
|
|
109
|
+
)
|
|
110
|
+
return self._agent
|
|
111
|
+
|
|
112
|
+
def register_agent(self, agent: "ChatAgent") -> None:
|
|
113
|
+
r"""Register a ChatAgent with this toolkit.
|
|
114
|
+
|
|
115
|
+
This method allows registering an agent after initialization. The
|
|
116
|
+
ChatAgent will automatically call this method if the toolkit to
|
|
117
|
+
register inherits from RegisteredAgentToolkit.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
agent (ChatAgent): The ChatAgent instance to register.
|
|
121
|
+
"""
|
|
122
|
+
self._agent = agent
|
|
123
|
+
logger.info(f"Agent registered with {self.__class__.__name__}")
|