camel-ai 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +107 -22
- camel/configs/__init__.py +6 -0
- camel/configs/base_config.py +21 -0
- camel/configs/gemini_config.py +17 -9
- camel/configs/qwen_config.py +91 -0
- camel/configs/yi_config.py +58 -0
- camel/generators.py +93 -0
- camel/interpreters/docker_interpreter.py +5 -0
- camel/interpreters/ipython_interpreter.py +2 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/apify_reader.py +223 -0
- camel/memories/agent_memories.py +24 -1
- camel/messages/base.py +38 -0
- camel/models/__init__.py +4 -0
- camel/models/model_factory.py +6 -0
- camel/models/qwen_model.py +139 -0
- camel/models/yi_model.py +138 -0
- camel/prompts/image_craft.py +8 -0
- camel/prompts/video_description_prompt.py +8 -0
- camel/retrievers/vector_retriever.py +5 -1
- camel/societies/role_playing.py +29 -18
- camel/societies/workforce/base.py +7 -1
- camel/societies/workforce/task_channel.py +10 -0
- camel/societies/workforce/utils.py +6 -0
- camel/societies/workforce/worker.py +2 -0
- camel/storages/vectordb_storages/qdrant.py +147 -24
- camel/tasks/task.py +15 -0
- camel/terminators/base.py +4 -0
- camel/terminators/response_terminator.py +1 -0
- camel/terminators/token_limit_terminator.py +1 -0
- camel/toolkits/__init__.py +4 -1
- camel/toolkits/base.py +9 -0
- camel/toolkits/data_commons_toolkit.py +360 -0
- camel/toolkits/function_tool.py +174 -7
- camel/toolkits/github_toolkit.py +175 -176
- camel/toolkits/google_scholar_toolkit.py +36 -7
- camel/toolkits/notion_toolkit.py +279 -0
- camel/toolkits/search_toolkit.py +164 -36
- camel/types/enums.py +88 -0
- camel/types/unified_model_type.py +10 -0
- camel/utils/commons.py +2 -1
- camel/utils/constants.py +2 -0
- {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/METADATA +129 -79
- {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/RECORD +47 -40
- {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/LICENSE +0 -0
- {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from camel.toolkits.base import BaseToolkit
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataCommonsToolkit(BaseToolkit):
|
|
23
|
+
r"""A class representing a toolkit for Data Commons.
|
|
24
|
+
|
|
25
|
+
This class provides methods for querying and retrieving data from the
|
|
26
|
+
Data Commons knowledge graph. It includes functionality for:
|
|
27
|
+
- Executing SPARQL queries
|
|
28
|
+
- Retrieving triples associated with nodes
|
|
29
|
+
- Fetching statistical time series data
|
|
30
|
+
- Analyzing property labels and values
|
|
31
|
+
- Retrieving places within a given place type
|
|
32
|
+
- Obtaining statistical values for specific variables and locations
|
|
33
|
+
|
|
34
|
+
All the data are grabbed from the knowledge graph of Data Commons.
|
|
35
|
+
Refer to https://datacommons.org/browser/ for more details.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def query_data_commons(
|
|
40
|
+
query_string: str,
|
|
41
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
42
|
+
r"""Query the Data Commons knowledge graph using SPARQL.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
query_string (str): A SPARQL query string.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Optional[List[Dict[str, Any]]]: A list of dictionaries, each
|
|
49
|
+
representing a node matching the query conditions if success,
|
|
50
|
+
(default: :obj:`None`) otherwise.
|
|
51
|
+
|
|
52
|
+
Note:
|
|
53
|
+
- Only supports a limited subset of SPARQL functionality (ORDER BY,
|
|
54
|
+
DISTINCT, LIMIT).
|
|
55
|
+
- Each variable in the query should have a 'typeOf' condition.
|
|
56
|
+
- The Python SPARQL library currently only supports the V1 version
|
|
57
|
+
of the API.
|
|
58
|
+
|
|
59
|
+
Reference:
|
|
60
|
+
https://docs.datacommons.org/api/python/query.html
|
|
61
|
+
"""
|
|
62
|
+
import datacommons
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
results = datacommons.query(query_string)
|
|
66
|
+
|
|
67
|
+
processed_results = [
|
|
68
|
+
{key: value for key, value in row.items()} for row in results
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
return processed_results
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(
|
|
75
|
+
f"An error occurred while querying Data Commons: {e!s}"
|
|
76
|
+
)
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def get_triples(
|
|
81
|
+
dcids: Union[str, List[str]], limit: int = 500
|
|
82
|
+
) -> Optional[Dict[str, List[tuple]]]:
|
|
83
|
+
r"""Retrieve triples associated with nodes.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
dcids (Union[str, List[str]]): A single DCID or a list of DCIDs
|
|
87
|
+
to query.
|
|
88
|
+
limit (int): The maximum number of triples per
|
|
89
|
+
combination of property and type. (default: :obj:`500`)
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Optional[Dict[str, List[tuple]]]: A dictionary where keys are
|
|
93
|
+
DCIDs and values are lists of associated triples if success,
|
|
94
|
+
(default: :obj:`None`) otherwise.
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
- The function will raise a ValueError if any of the required
|
|
98
|
+
arguments are missing.
|
|
99
|
+
- The function will raise a TypeError if the dcids are not a string
|
|
100
|
+
or a list of strings.
|
|
101
|
+
- The function will raise a ValueError if the limit is not between
|
|
102
|
+
1 and 500.
|
|
103
|
+
- The function will raise a KeyError if one or more of the provided
|
|
104
|
+
DCIDs do not exist in the Data Commons knowledge graph.
|
|
105
|
+
- The function will raise an Exception if an unexpected error occurs.
|
|
106
|
+
|
|
107
|
+
Reference:
|
|
108
|
+
https://docs.datacommons.org/api/python/triple.html
|
|
109
|
+
"""
|
|
110
|
+
import datacommons
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
result = datacommons.get_triples(dcids, limit)
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f"An error occurred: {e!s}")
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def get_stat_time_series(
|
|
122
|
+
place: str,
|
|
123
|
+
stat_var: str,
|
|
124
|
+
measurement_method: Optional[str] = None,
|
|
125
|
+
observation_period: Optional[str] = None,
|
|
126
|
+
unit: Optional[str] = None,
|
|
127
|
+
scaling_factor: Optional[str] = None,
|
|
128
|
+
) -> Optional[Dict[str, Any]]:
|
|
129
|
+
r"""Retrieve statistical time series for a place.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
place (str): The dcid of the Place to query for.
|
|
133
|
+
stat_var (str): The dcid of the StatisticalVariable.
|
|
134
|
+
measurement_method (str, optional): The technique used for
|
|
135
|
+
measuring a statistical variable. (default: :obj:`None`)
|
|
136
|
+
observation_period (str, optional): The time period over which an
|
|
137
|
+
observation is made. (default: :obj:`None`)
|
|
138
|
+
scaling_factor (str, optional): Property of statistical variables
|
|
139
|
+
indicating factor by which a measurement is multiplied to fit
|
|
140
|
+
a certain format. (default: :obj:`None`)
|
|
141
|
+
unit (str, optional): The unit of measurement. (default:
|
|
142
|
+
:obj:`None`)
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Optional[Dict[str, Any]]: A dictionary containing the statistical
|
|
146
|
+
time series data if success, (default: :obj:`None`) otherwise.
|
|
147
|
+
|
|
148
|
+
Reference:
|
|
149
|
+
https://docs.datacommons.org/api/python/stat_series.html
|
|
150
|
+
"""
|
|
151
|
+
import datacommons_pandas
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
result = datacommons_pandas.get_stat_series(
|
|
155
|
+
place,
|
|
156
|
+
stat_var,
|
|
157
|
+
measurement_method,
|
|
158
|
+
observation_period,
|
|
159
|
+
unit,
|
|
160
|
+
scaling_factor,
|
|
161
|
+
)
|
|
162
|
+
return result
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error(
|
|
165
|
+
f"An error occurred while querying Data Commons: {e!s}"
|
|
166
|
+
)
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def get_property_labels(
|
|
171
|
+
dcids: Union[str, List[str]], out: bool = True
|
|
172
|
+
) -> Optional[Dict[str, List[str]]]:
|
|
173
|
+
r"""Retrieves and analyzes property labels for given DCIDs.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
dcids (list): A list of Data Commons IDs (DCIDs) to analyze.
|
|
177
|
+
out (bool): Direction of properties to retrieve. (default:
|
|
178
|
+
:obj:`True`)
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Optional[Dict[str, List[str]]]: Analysis results for each DCID if
|
|
182
|
+
success, (default: :obj:`None`) otherwise.
|
|
183
|
+
|
|
184
|
+
Reference:
|
|
185
|
+
https://docs.datacommons.org/api/python/property_label.html
|
|
186
|
+
"""
|
|
187
|
+
import datacommons
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
result = datacommons.get_property_labels(dcids, out=out)
|
|
191
|
+
return result
|
|
192
|
+
except Exception as e:
|
|
193
|
+
logger.error(
|
|
194
|
+
f"An error occurred while analyzing property labels: {e!s}"
|
|
195
|
+
)
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def get_property_values(
|
|
200
|
+
dcids: Union[str, List[str]],
|
|
201
|
+
prop: str,
|
|
202
|
+
out: Optional[bool] = True,
|
|
203
|
+
value_type: Optional[str] = None,
|
|
204
|
+
limit: Optional[int] = None,
|
|
205
|
+
) -> Optional[Dict[str, Any]]:
|
|
206
|
+
r"""Retrieves and analyzes property values for given DCIDs.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
dcids (list): A list of Data Commons IDs (DCIDs) to analyze.
|
|
210
|
+
prop (str): The property to analyze.
|
|
211
|
+
value_type (str, optional): The type of the property value to
|
|
212
|
+
filter by. Defaults to NONE. Only applicable if the value
|
|
213
|
+
refers to a node.
|
|
214
|
+
out (bool, optional): The label's direction. (default: :obj:`True`)
|
|
215
|
+
(only returning response nodes directed towards the requested
|
|
216
|
+
node). If set to False, will only return response nodes
|
|
217
|
+
directed away from the request node. (default: :obj:`None`)
|
|
218
|
+
limit (int, optional): (≤ 500) Maximum number of values returned
|
|
219
|
+
per node. (default: :obj:`datacommons.utils._MAX_LIMIT`)
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Optional[Dict[str, Any]]: Analysis results for each DCID if
|
|
223
|
+
success, (default: :obj:`None`) otherwise.
|
|
224
|
+
|
|
225
|
+
Reference:
|
|
226
|
+
https://docs.datacommons.org/api/python/property_value.html
|
|
227
|
+
"""
|
|
228
|
+
import datacommons
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
result = datacommons.get_property_values(
|
|
232
|
+
dcids, prop, out, value_type, limit
|
|
233
|
+
)
|
|
234
|
+
return result
|
|
235
|
+
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.error(
|
|
238
|
+
f"An error occurred while analyzing property values: {e!s}"
|
|
239
|
+
)
|
|
240
|
+
return None
|
|
241
|
+
|
|
242
|
+
@staticmethod
|
|
243
|
+
def get_places_in(
|
|
244
|
+
dcids: list, place_type: str
|
|
245
|
+
) -> Optional[Dict[str, Any]]:
|
|
246
|
+
r"""Retrieves places within a given place type.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
dcids (list): A list of Data Commons IDs (DCIDs) to analyze.
|
|
250
|
+
place_type (str): The type of the place to filter by.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Optional[Dict[str, Any]]: Analysis results for each DCID if
|
|
254
|
+
success, (default: :obj:`None`) otherwise.
|
|
255
|
+
|
|
256
|
+
Reference:
|
|
257
|
+
https://docs.datacommons.org/api/python/place_in.html
|
|
258
|
+
"""
|
|
259
|
+
import datacommons
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
result = datacommons.get_places_in(dcids, place_type)
|
|
263
|
+
return result
|
|
264
|
+
|
|
265
|
+
except Exception as e:
|
|
266
|
+
logger.error(
|
|
267
|
+
"An error occurred while retrieving places in a given place "
|
|
268
|
+
f"type: {e!s}"
|
|
269
|
+
)
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def get_stat_value(
|
|
274
|
+
place: str,
|
|
275
|
+
stat_var: str,
|
|
276
|
+
date: Optional[str] = None,
|
|
277
|
+
measurement_method: Optional[str] = None,
|
|
278
|
+
observation_period: Optional[str] = None,
|
|
279
|
+
unit: Optional[str] = None,
|
|
280
|
+
scaling_factor: Optional[str] = None,
|
|
281
|
+
) -> Optional[float]:
|
|
282
|
+
r"""Retrieves the value of a statistical variable for a given place
|
|
283
|
+
and date.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
place (str): The DCID of the Place to query for.
|
|
287
|
+
stat_var (str): The DCID of the StatisticalVariable.
|
|
288
|
+
date (str, optional): The preferred date of observation in ISO
|
|
289
|
+
8601 format. If not specified, returns the latest observation.
|
|
290
|
+
(default: :obj:`None`)
|
|
291
|
+
measurement_method (str, optional): The DCID of the preferred
|
|
292
|
+
measurementMethod value. (default: :obj:`None`)
|
|
293
|
+
observation_period (str, optional): The preferred observationPeriod
|
|
294
|
+
value. (default: :obj:`None`)
|
|
295
|
+
unit (str, optional): The DCID of the preferred unit value.
|
|
296
|
+
(default: :obj:`None`)
|
|
297
|
+
scaling_factor (str, optional): The preferred scalingFactor value.
|
|
298
|
+
(default: :obj:`None`)
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Optional[float]: The value of the statistical variable for the
|
|
302
|
+
given place and date if success, (default: :obj:`None`)
|
|
303
|
+
otherwise.
|
|
304
|
+
|
|
305
|
+
Reference:
|
|
306
|
+
https://docs.datacommons.org/api/python/stat_value.html
|
|
307
|
+
"""
|
|
308
|
+
import datacommons
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
result = datacommons.get_stat_value(
|
|
312
|
+
place,
|
|
313
|
+
stat_var,
|
|
314
|
+
date,
|
|
315
|
+
measurement_method,
|
|
316
|
+
observation_period,
|
|
317
|
+
unit,
|
|
318
|
+
scaling_factor,
|
|
319
|
+
)
|
|
320
|
+
return result
|
|
321
|
+
|
|
322
|
+
except Exception as e:
|
|
323
|
+
logger.error(
|
|
324
|
+
"An error occurred while retrieving the value of a "
|
|
325
|
+
f"statistical variable: {e!s}"
|
|
326
|
+
)
|
|
327
|
+
return None
|
|
328
|
+
|
|
329
|
+
@staticmethod
|
|
330
|
+
def get_stat_all(places: str, stat_vars: str) -> Optional[dict]:
|
|
331
|
+
r"""Retrieves the value of a statistical variable for a given place
|
|
332
|
+
and date.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
places (str): The DCID IDs of the Place objects to query for.
|
|
336
|
+
(Here DCID stands for Data Commons ID, the unique identifier
|
|
337
|
+
assigned to all entities in Data Commons.)
|
|
338
|
+
stat_vars (str): The dcids of the StatisticalVariables at
|
|
339
|
+
https://datacommons.org/browser/StatisticalVariable
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Optional[dict]: A dictionary with the DCID of the place as the key
|
|
343
|
+
and a list of tuples as the value if success, (default:
|
|
344
|
+
:obj:`None`) otherwise.
|
|
345
|
+
|
|
346
|
+
Reference:
|
|
347
|
+
https://docs.datacommons.org/api/python/stat_all.html
|
|
348
|
+
"""
|
|
349
|
+
import datacommons
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
result = datacommons.get_stat_all(places, stat_vars)
|
|
353
|
+
return result
|
|
354
|
+
|
|
355
|
+
except Exception as e:
|
|
356
|
+
logger.error(
|
|
357
|
+
"An error occurred while retrieving the value of a "
|
|
358
|
+
f"statistical variable: {e!s}"
|
|
359
|
+
)
|
|
360
|
+
return None
|
camel/toolkits/function_tool.py
CHANGED
|
@@ -11,8 +11,9 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import logging
|
|
14
15
|
import warnings
|
|
15
|
-
from inspect import Parameter, signature
|
|
16
|
+
from inspect import Parameter, getsource, signature
|
|
16
17
|
from typing import Any, Callable, Dict, Mapping, Optional, Tuple
|
|
17
18
|
|
|
18
19
|
from docstring_parser import parse
|
|
@@ -21,8 +22,13 @@ from jsonschema.validators import Draft202012Validator as JSONValidator
|
|
|
21
22
|
from pydantic import create_model
|
|
22
23
|
from pydantic.fields import FieldInfo
|
|
23
24
|
|
|
25
|
+
from camel.agents import ChatAgent
|
|
26
|
+
from camel.models import BaseModelBackend
|
|
27
|
+
from camel.types import ModelType
|
|
24
28
|
from camel.utils import get_pydantic_object_schema, to_pascal
|
|
25
29
|
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
26
32
|
|
|
27
33
|
def _remove_a_key(d: Dict, remove_key: Any) -> None:
|
|
28
34
|
r"""Remove a key from a dictionary recursively."""
|
|
@@ -143,6 +149,71 @@ def get_openai_tool_schema(func: Callable) -> Dict[str, Any]:
|
|
|
143
149
|
return openai_tool_schema
|
|
144
150
|
|
|
145
151
|
|
|
152
|
+
def generate_docstring(
|
|
153
|
+
code: str,
|
|
154
|
+
model: Optional[BaseModelBackend] = None,
|
|
155
|
+
) -> str:
|
|
156
|
+
r"""Generates a docstring for a given function code using LLM.
|
|
157
|
+
|
|
158
|
+
This function leverages a language model to generate a
|
|
159
|
+
PEP 8/PEP 257-compliant docstring for a provided Python function.
|
|
160
|
+
If no model is supplied, a default gpt-4o-mini is used.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
code (str): The source code of the function.
|
|
164
|
+
model (Optional[BaseModelBackend]): An optional language model backend
|
|
165
|
+
instance. If not provided, a default gpt-4o-mini is used.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
str: The generated docstring.
|
|
169
|
+
"""
|
|
170
|
+
# Create the docstring prompt
|
|
171
|
+
docstring_prompt = '''
|
|
172
|
+
**Role**: Generate professional Python docstrings conforming to
|
|
173
|
+
PEP 8/PEP 257.
|
|
174
|
+
|
|
175
|
+
**Requirements**:
|
|
176
|
+
- Use appropriate format: reST, Google, or NumPy, as needed.
|
|
177
|
+
- Include parameters, return values, and exceptions.
|
|
178
|
+
- Reference any existing docstring in the function and
|
|
179
|
+
retain useful information.
|
|
180
|
+
|
|
181
|
+
**Input**: Python function.
|
|
182
|
+
|
|
183
|
+
**Output**: Docstring content (plain text, no code markers).
|
|
184
|
+
|
|
185
|
+
**Example:**
|
|
186
|
+
|
|
187
|
+
Input:
|
|
188
|
+
```python
|
|
189
|
+
def add(a: int, b: int) -> int:
|
|
190
|
+
return a + b
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
Output:
|
|
194
|
+
Adds two numbers.
|
|
195
|
+
Args:
|
|
196
|
+
a (int): The first number.
|
|
197
|
+
b (int): The second number.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
int: The sum of the two numbers.
|
|
201
|
+
|
|
202
|
+
**Task**: Generate a docstring for the function below.
|
|
203
|
+
|
|
204
|
+
'''
|
|
205
|
+
# Initialize assistant with system message and model
|
|
206
|
+
assistant_sys_msg = "You are a helpful assistant."
|
|
207
|
+
docstring_assistant = ChatAgent(assistant_sys_msg, model=model)
|
|
208
|
+
|
|
209
|
+
# Create user message to prompt the assistant
|
|
210
|
+
user_msg = docstring_prompt + code
|
|
211
|
+
|
|
212
|
+
# Get the response containing the generated docstring
|
|
213
|
+
response = docstring_assistant.step(user_msg)
|
|
214
|
+
return response.msg.content
|
|
215
|
+
|
|
216
|
+
|
|
146
217
|
class FunctionTool:
|
|
147
218
|
r"""An abstraction of a function that OpenAI chat models can call. See
|
|
148
219
|
https://platform.openai.com/docs/api-reference/chat/create.
|
|
@@ -151,23 +222,52 @@ class FunctionTool:
|
|
|
151
222
|
provide a user-defined tool schema to override.
|
|
152
223
|
|
|
153
224
|
Args:
|
|
154
|
-
func (Callable): The function to call.The tool schema is parsed from
|
|
155
|
-
the signature and docstring by default.
|
|
156
|
-
openai_tool_schema (Optional[Dict[str, Any]], optional): A
|
|
157
|
-
|
|
225
|
+
func (Callable): The function to call. The tool schema is parsed from
|
|
226
|
+
the function signature and docstring by default.
|
|
227
|
+
openai_tool_schema (Optional[Dict[str, Any]], optional): A
|
|
228
|
+
user-defined OpenAI tool schema to override the default result.
|
|
158
229
|
(default: :obj:`None`)
|
|
230
|
+
use_schema_assistant (Optional[bool], optional): Whether to enable the
|
|
231
|
+
use of a schema assistant model to automatically generate the
|
|
232
|
+
schema if validation fails or no valid schema is provided.
|
|
233
|
+
(default: :obj:`False`)
|
|
234
|
+
schema_assistant_model (Optional[BaseModelBackend], optional): An
|
|
235
|
+
assistant model (e.g., an LLM model) used to generate the schema
|
|
236
|
+
if `use_schema_assistant` is enabled and no valid schema is
|
|
237
|
+
provided. (default: :obj:`None`)
|
|
238
|
+
schema_generation_max_retries (int, optional): The maximum
|
|
239
|
+
number of attempts to retry schema generation using the schema
|
|
240
|
+
assistant model if the previous attempts fail. (default: 2)
|
|
159
241
|
"""
|
|
160
242
|
|
|
161
243
|
def __init__(
|
|
162
244
|
self,
|
|
163
245
|
func: Callable,
|
|
164
246
|
openai_tool_schema: Optional[Dict[str, Any]] = None,
|
|
247
|
+
use_schema_assistant: Optional[bool] = False,
|
|
248
|
+
schema_assistant_model: Optional[BaseModelBackend] = None,
|
|
249
|
+
schema_generation_max_retries: int = 2,
|
|
165
250
|
) -> None:
|
|
166
251
|
self.func = func
|
|
167
252
|
self.openai_tool_schema = openai_tool_schema or get_openai_tool_schema(
|
|
168
253
|
func
|
|
169
254
|
)
|
|
170
255
|
|
|
256
|
+
if use_schema_assistant:
|
|
257
|
+
if openai_tool_schema:
|
|
258
|
+
logger.warning("""The user-defined OpenAI tool schema will be
|
|
259
|
+
overridden by the schema assistant model.""")
|
|
260
|
+
schema = self.generate_openai_tool_schema(
|
|
261
|
+
schema_generation_max_retries, schema_assistant_model
|
|
262
|
+
)
|
|
263
|
+
if schema:
|
|
264
|
+
self.openai_tool_schema = schema
|
|
265
|
+
else:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Failed to generate valid schema for "
|
|
268
|
+
f"{self.func.__name__}."
|
|
269
|
+
)
|
|
270
|
+
|
|
171
271
|
@staticmethod
|
|
172
272
|
def validate_openai_tool_schema(
|
|
173
273
|
openai_tool_schema: Dict[str, Any],
|
|
@@ -262,8 +362,8 @@ class FunctionTool:
|
|
|
262
362
|
r"""Sets the schema of the function within the OpenAI tool schema.
|
|
263
363
|
|
|
264
364
|
Args:
|
|
265
|
-
openai_function_schema (Dict[str, Any]): The function schema to
|
|
266
|
-
within the OpenAI tool schema.
|
|
365
|
+
openai_function_schema (Dict[str, Any]): The function schema to
|
|
366
|
+
set within the OpenAI tool schema.
|
|
267
367
|
"""
|
|
268
368
|
self.openai_tool_schema["function"] = openai_function_schema
|
|
269
369
|
|
|
@@ -364,6 +464,71 @@ class FunctionTool:
|
|
|
364
464
|
param_name
|
|
365
465
|
] = value
|
|
366
466
|
|
|
467
|
+
def generate_openai_tool_schema(
|
|
468
|
+
self,
|
|
469
|
+
max_retries: Optional[int] = None,
|
|
470
|
+
schema_assistant_model: Optional[BaseModelBackend] = None,
|
|
471
|
+
) -> Dict[str, Any]:
|
|
472
|
+
r"""Generates an OpenAI tool schema for the specified function.
|
|
473
|
+
|
|
474
|
+
This method uses a language model (LLM) to generate the OpenAI tool
|
|
475
|
+
schema for the specified function by first generating a docstring and
|
|
476
|
+
then creating a schema based on the function's source code. If no LLM
|
|
477
|
+
is provided, it defaults to initializing a gpt-4o-mini model. The
|
|
478
|
+
schema generation and validation process is retried up to
|
|
479
|
+
`max_retries` times in case of failure.
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
max_retries (Optional[int], optional): The maximum number of
|
|
484
|
+
retries for schema generation and validation if the process
|
|
485
|
+
fails. (default: :obj:`None`)
|
|
486
|
+
schema_assistant_model (Optional[BaseModelBackend], optional): An
|
|
487
|
+
optional LLM backend model used for generating the docstring
|
|
488
|
+
and schema. If not provided, a gpt-4o-mini model
|
|
489
|
+
will be created. (default: :obj:`None`)
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Dict[str, Any]: The generated OpenAI tool schema for the function.
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
ValueError: If schema generation or validation fails after the
|
|
496
|
+
maximum number of retries, a ValueError is raised, prompting
|
|
497
|
+
manual schema setting.
|
|
498
|
+
"""
|
|
499
|
+
if not schema_assistant_model:
|
|
500
|
+
logger.warning(
|
|
501
|
+
"Warning: No model provided. "
|
|
502
|
+
f"Use `{ModelType.GPT_4O_MINI.value}` to generate the schema."
|
|
503
|
+
)
|
|
504
|
+
code = getsource(self.func)
|
|
505
|
+
retries = 0
|
|
506
|
+
if max_retries is None:
|
|
507
|
+
max_retries = 0
|
|
508
|
+
# Retry loop to handle schema generation and validation
|
|
509
|
+
while retries <= max_retries:
|
|
510
|
+
try:
|
|
511
|
+
# Generate the docstring and the schema
|
|
512
|
+
docstring = generate_docstring(code, schema_assistant_model)
|
|
513
|
+
self.func.__doc__ = docstring
|
|
514
|
+
schema = get_openai_tool_schema(self.func)
|
|
515
|
+
# Validate the schema
|
|
516
|
+
self.validate_openai_tool_schema(schema)
|
|
517
|
+
return schema
|
|
518
|
+
|
|
519
|
+
except Exception as e:
|
|
520
|
+
retries += 1
|
|
521
|
+
if retries == max_retries:
|
|
522
|
+
raise ValueError(
|
|
523
|
+
f"Failed to generate the OpenAI tool Schema after "
|
|
524
|
+
f"{max_retries} retries. "
|
|
525
|
+
f"Please set the OpenAI tool schema for "
|
|
526
|
+
f"function {self.func.__name__} manually."
|
|
527
|
+
) from e
|
|
528
|
+
logger.warning("Schema validation failed. Retrying...")
|
|
529
|
+
|
|
530
|
+
return {}
|
|
531
|
+
|
|
367
532
|
@property
|
|
368
533
|
def parameters(self) -> Dict[str, Any]:
|
|
369
534
|
r"""Getter method for the property :obj:`parameters`.
|
|
@@ -397,6 +562,8 @@ warnings.simplefilter('always', DeprecationWarning)
|
|
|
397
562
|
|
|
398
563
|
# Alias for backwards compatibility
|
|
399
564
|
class OpenAIFunction(FunctionTool):
|
|
565
|
+
r"""Alias for backwards compatibility."""
|
|
566
|
+
|
|
400
567
|
def __init__(self, *args, **kwargs):
|
|
401
568
|
PURPLE = '\033[95m'
|
|
402
569
|
RESET = '\033[0m'
|