camel-ai 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +114 -23
- camel/configs/__init__.py +6 -4
- camel/configs/base_config.py +21 -0
- camel/configs/gemini_config.py +17 -9
- camel/configs/qwen_config.py +91 -0
- camel/configs/samba_config.py +1 -38
- camel/configs/yi_config.py +58 -0
- camel/generators.py +93 -0
- camel/interpreters/docker_interpreter.py +5 -0
- camel/interpreters/ipython_interpreter.py +2 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/apify_reader.py +223 -0
- camel/memories/agent_memories.py +24 -1
- camel/messages/base.py +38 -0
- camel/models/__init__.py +4 -0
- camel/models/model_factory.py +6 -0
- camel/models/qwen_model.py +139 -0
- camel/models/samba_model.py +1 -1
- camel/models/yi_model.py +138 -0
- camel/prompts/image_craft.py +8 -0
- camel/prompts/video_description_prompt.py +8 -0
- camel/retrievers/vector_retriever.py +5 -1
- camel/societies/role_playing.py +29 -18
- camel/societies/workforce/base.py +7 -1
- camel/societies/workforce/task_channel.py +10 -0
- camel/societies/workforce/utils.py +6 -0
- camel/societies/workforce/worker.py +2 -0
- camel/storages/vectordb_storages/qdrant.py +147 -24
- camel/tasks/task.py +15 -0
- camel/terminators/base.py +4 -0
- camel/terminators/response_terminator.py +1 -0
- camel/terminators/token_limit_terminator.py +1 -0
- camel/toolkits/__init__.py +4 -1
- camel/toolkits/base.py +9 -0
- camel/toolkits/data_commons_toolkit.py +360 -0
- camel/toolkits/function_tool.py +174 -7
- camel/toolkits/github_toolkit.py +175 -176
- camel/toolkits/google_scholar_toolkit.py +36 -7
- camel/toolkits/notion_toolkit.py +279 -0
- camel/toolkits/search_toolkit.py +164 -36
- camel/types/enums.py +88 -0
- camel/types/unified_model_type.py +10 -0
- camel/utils/commons.py +2 -1
- camel/utils/constants.py +2 -0
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/METADATA +129 -79
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/RECORD +49 -42
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/LICENSE +0 -0
- {camel_ai-0.2.5.dist-info → camel_ai-0.2.7.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
from openai import OpenAI, Stream
|
|
19
|
+
|
|
20
|
+
from camel.configs import QWEN_API_PARAMS, QwenConfig
|
|
21
|
+
from camel.messages import OpenAIMessage
|
|
22
|
+
from camel.models import BaseModelBackend
|
|
23
|
+
from camel.types import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionChunk,
|
|
26
|
+
ModelType,
|
|
27
|
+
)
|
|
28
|
+
from camel.utils import (
|
|
29
|
+
BaseTokenCounter,
|
|
30
|
+
OpenAITokenCounter,
|
|
31
|
+
api_keys_required,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class QwenModel(BaseModelBackend):
|
|
36
|
+
r"""Qwen API in a unified BaseModelBackend interface.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model_type (Union[ModelType, str]): Model for which a backend is
|
|
40
|
+
created, one of Qwen series.
|
|
41
|
+
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
|
42
|
+
that will be fed into:obj:`openai.ChatCompletion.create()`. If
|
|
43
|
+
:obj:`None`, :obj:`QwenConfig().as_dict()` will be used.
|
|
44
|
+
(default: :obj:`None`)
|
|
45
|
+
api_key (Optional[str], optional): The API key for authenticating with
|
|
46
|
+
the Qwen service. (default: :obj:`None`)
|
|
47
|
+
url (Optional[str], optional): The url to the Qwen service.
|
|
48
|
+
(default: :obj:`https://dashscope.aliyuncs.com/compatible-mode/v1`)
|
|
49
|
+
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
|
50
|
+
use for the model. If not provided, :obj:`OpenAITokenCounter(
|
|
51
|
+
ModelType.GPT_4O_MINI)` will be used.
|
|
52
|
+
(default: :obj:`None`)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model_type: Union[ModelType, str],
|
|
58
|
+
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
59
|
+
api_key: Optional[str] = None,
|
|
60
|
+
url: Optional[str] = None,
|
|
61
|
+
token_counter: Optional[BaseTokenCounter] = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
if model_config_dict is None:
|
|
64
|
+
model_config_dict = QwenConfig().as_dict()
|
|
65
|
+
api_key = api_key or os.environ.get("QWEN_API_KEY")
|
|
66
|
+
url = url or os.environ.get(
|
|
67
|
+
"QWEN_API_BASE_URL",
|
|
68
|
+
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
69
|
+
)
|
|
70
|
+
super().__init__(
|
|
71
|
+
model_type, model_config_dict, api_key, url, token_counter
|
|
72
|
+
)
|
|
73
|
+
self._client = OpenAI(
|
|
74
|
+
timeout=60,
|
|
75
|
+
max_retries=3,
|
|
76
|
+
api_key=self._api_key,
|
|
77
|
+
base_url=self._url,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@api_keys_required("QWEN_API_KEY")
|
|
81
|
+
def run(
|
|
82
|
+
self,
|
|
83
|
+
messages: List[OpenAIMessage],
|
|
84
|
+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
85
|
+
r"""Runs inference of Qwen chat completion.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
messages (List[OpenAIMessage]): Message list with the chat history
|
|
89
|
+
in OpenAI API format.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
93
|
+
`ChatCompletion` in the non-stream mode, or
|
|
94
|
+
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
95
|
+
"""
|
|
96
|
+
response = self._client.chat.completions.create(
|
|
97
|
+
messages=messages,
|
|
98
|
+
model=self.model_type,
|
|
99
|
+
**self.model_config_dict,
|
|
100
|
+
)
|
|
101
|
+
return response
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def token_counter(self) -> BaseTokenCounter:
|
|
105
|
+
r"""Initialize the token counter for the model backend.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
OpenAITokenCounter: The token counter following the model's
|
|
109
|
+
tokenization style.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
if not self._token_counter:
|
|
113
|
+
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
|
|
114
|
+
return self._token_counter
|
|
115
|
+
|
|
116
|
+
def check_model_config(self):
|
|
117
|
+
r"""Check whether the model configuration contains any
|
|
118
|
+
unexpected arguments to Qwen API.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: If the model configuration dictionary contains any
|
|
122
|
+
unexpected arguments to Qwen API.
|
|
123
|
+
"""
|
|
124
|
+
for param in self.model_config_dict:
|
|
125
|
+
if param not in QWEN_API_PARAMS:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Unexpected argument `{param}` is "
|
|
128
|
+
"input into Qwen model backend."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def stream(self) -> bool:
|
|
133
|
+
r"""Returns whether the model is in stream mode, which sends partial
|
|
134
|
+
results each time.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
bool: Whether the model is in stream mode.
|
|
138
|
+
"""
|
|
139
|
+
return self.model_config_dict.get('stream', False)
|
camel/models/samba_model.py
CHANGED
|
@@ -147,7 +147,7 @@ class SambaModel(BaseModelBackend):
|
|
|
147
147
|
def run( # type: ignore[misc]
|
|
148
148
|
self, messages: List[OpenAIMessage]
|
|
149
149
|
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
150
|
-
r"""Runs SambaNova's
|
|
150
|
+
r"""Runs SambaNova's service.
|
|
151
151
|
|
|
152
152
|
Args:
|
|
153
153
|
messages (List[OpenAIMessage]): Message list with the chat history
|
camel/models/yi_model.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
from openai import OpenAI, Stream
|
|
19
|
+
|
|
20
|
+
from camel.configs import YI_API_PARAMS, YiConfig
|
|
21
|
+
from camel.messages import OpenAIMessage
|
|
22
|
+
from camel.models import BaseModelBackend
|
|
23
|
+
from camel.types import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionChunk,
|
|
26
|
+
ModelType,
|
|
27
|
+
)
|
|
28
|
+
from camel.utils import (
|
|
29
|
+
BaseTokenCounter,
|
|
30
|
+
OpenAITokenCounter,
|
|
31
|
+
api_keys_required,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class YiModel(BaseModelBackend):
|
|
36
|
+
r"""Yi API in a unified BaseModelBackend interface.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model_type (Union[ModelType, str]): Model for which a backend is
|
|
40
|
+
created, one of Yi series.
|
|
41
|
+
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
|
42
|
+
that will be fed into:obj:`openai.ChatCompletion.create()`. If
|
|
43
|
+
:obj:`None`, :obj:`YiConfig().as_dict()` will be used.
|
|
44
|
+
(default: :obj:`None`)
|
|
45
|
+
api_key (Optional[str], optional): The API key for authenticating with
|
|
46
|
+
the Yi service. (default: :obj:`None`)
|
|
47
|
+
url (Optional[str], optional): The url to the Yi service.
|
|
48
|
+
(default: :obj:`https://api.lingyiwanwu.com/v1`)
|
|
49
|
+
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
|
50
|
+
use for the model. If not provided, :obj:`OpenAITokenCounter(
|
|
51
|
+
ModelType.GPT_4O_MINI)` will be used.
|
|
52
|
+
(default: :obj:`None`)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model_type: Union[ModelType, str],
|
|
58
|
+
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
59
|
+
api_key: Optional[str] = None,
|
|
60
|
+
url: Optional[str] = None,
|
|
61
|
+
token_counter: Optional[BaseTokenCounter] = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
if model_config_dict is None:
|
|
64
|
+
model_config_dict = YiConfig().as_dict()
|
|
65
|
+
api_key = api_key or os.environ.get("YI_API_KEY")
|
|
66
|
+
url = url or os.environ.get(
|
|
67
|
+
"YI_API_BASE_URL", "https://api.lingyiwanwu.com/v1"
|
|
68
|
+
)
|
|
69
|
+
super().__init__(
|
|
70
|
+
model_type, model_config_dict, api_key, url, token_counter
|
|
71
|
+
)
|
|
72
|
+
self._client = OpenAI(
|
|
73
|
+
timeout=60,
|
|
74
|
+
max_retries=3,
|
|
75
|
+
api_key=self._api_key,
|
|
76
|
+
base_url=self._url,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@api_keys_required("YI_API_KEY")
|
|
80
|
+
def run(
|
|
81
|
+
self,
|
|
82
|
+
messages: List[OpenAIMessage],
|
|
83
|
+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
84
|
+
r"""Runs inference of Yi chat completion.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
messages (List[OpenAIMessage]): Message list with the chat history
|
|
88
|
+
in OpenAI API format.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
92
|
+
`ChatCompletion` in the non-stream mode, or
|
|
93
|
+
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
94
|
+
"""
|
|
95
|
+
response = self._client.chat.completions.create(
|
|
96
|
+
messages=messages,
|
|
97
|
+
model=self.model_type,
|
|
98
|
+
**self.model_config_dict,
|
|
99
|
+
)
|
|
100
|
+
return response
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def token_counter(self) -> BaseTokenCounter:
|
|
104
|
+
r"""Initialize the token counter for the model backend.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
OpenAITokenCounter: The token counter following the model's
|
|
108
|
+
tokenization style.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
if not self._token_counter:
|
|
112
|
+
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
|
|
113
|
+
return self._token_counter
|
|
114
|
+
|
|
115
|
+
def check_model_config(self):
|
|
116
|
+
r"""Check whether the model configuration contains any
|
|
117
|
+
unexpected arguments to Yi API.
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
ValueError: If the model configuration dictionary contains any
|
|
121
|
+
unexpected arguments to Yi API.
|
|
122
|
+
"""
|
|
123
|
+
for param in self.model_config_dict:
|
|
124
|
+
if param not in YI_API_PARAMS:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Unexpected argument `{param}` is "
|
|
127
|
+
"input into Yi model backend."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def stream(self) -> bool:
|
|
132
|
+
r"""Returns whether the model is in stream mode, which sends partial
|
|
133
|
+
results each time.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
bool: Whether the model is in stream mode.
|
|
137
|
+
"""
|
|
138
|
+
return self.model_config_dict.get('stream', False)
|
camel/prompts/image_craft.py
CHANGED
|
@@ -18,6 +18,14 @@ from camel.types import RoleType
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class ImageCraftPromptTemplateDict(TextPromptDict):
|
|
21
|
+
r"""A dictionary containing :obj:`TextPrompt` used in the `ImageCraft`
|
|
22
|
+
task.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
ASSISTANT_PROMPT (TextPrompt): A prompt for the AI assistant to create
|
|
26
|
+
an original image based on the provided descriptive captions.
|
|
27
|
+
"""
|
|
28
|
+
|
|
21
29
|
ASSISTANT_PROMPT = TextPrompt(
|
|
22
30
|
"""You are tasked with creating an original image based on
|
|
23
31
|
the provided descriptive captions. Use your imagination
|
|
@@ -19,6 +19,14 @@ from camel.types import RoleType
|
|
|
19
19
|
|
|
20
20
|
# flake8: noqa :E501
|
|
21
21
|
class VideoDescriptionPromptTemplateDict(TextPromptDict):
|
|
22
|
+
r"""A dictionary containing :obj:`TextPrompt` used in the `VideoDescription`
|
|
23
|
+
task.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
ASSISTANT_PROMPT (TextPrompt): A prompt for the AI assistant to
|
|
27
|
+
provide a shot description of the content of the current video.
|
|
28
|
+
"""
|
|
29
|
+
|
|
22
30
|
ASSISTANT_PROMPT = TextPrompt(
|
|
23
31
|
"""You are a master of video analysis.
|
|
24
32
|
Please provide a shot description of the content of the current video."""
|
|
@@ -76,6 +76,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
76
76
|
max_characters: int = 500,
|
|
77
77
|
embed_batch: int = 50,
|
|
78
78
|
should_chunk: bool = True,
|
|
79
|
+
extra_info: Optional[dict] = None,
|
|
79
80
|
**kwargs: Any,
|
|
80
81
|
) -> None:
|
|
81
82
|
r"""Processes content from local file path, remote URL, string
|
|
@@ -93,6 +94,8 @@ class VectorRetriever(BaseRetriever):
|
|
|
93
94
|
embed_batch (int): Size of batch for embeddings. Defaults to `50`.
|
|
94
95
|
should_chunk (bool): If True, divide the content into chunks,
|
|
95
96
|
otherwise skip chunking. Defaults to True.
|
|
97
|
+
extra_info (Optional[dict]): Extra information to be added
|
|
98
|
+
to the payload. Defaults to None.
|
|
96
99
|
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
97
100
|
"""
|
|
98
101
|
from unstructured.documents.elements import Element
|
|
@@ -153,12 +156,13 @@ class VectorRetriever(BaseRetriever):
|
|
|
153
156
|
chunk_metadata = {"metadata": chunk.metadata.to_dict()}
|
|
154
157
|
# Remove the 'orig_elements' key if it exists
|
|
155
158
|
chunk_metadata["metadata"].pop("orig_elements", "")
|
|
156
|
-
|
|
159
|
+
extra_info = extra_info or {}
|
|
157
160
|
chunk_text = {"text": str(chunk)}
|
|
158
161
|
combined_dict = {
|
|
159
162
|
**content_path_info,
|
|
160
163
|
**chunk_metadata,
|
|
161
164
|
**chunk_text,
|
|
165
|
+
**extra_info,
|
|
162
166
|
}
|
|
163
167
|
|
|
164
168
|
records.append(
|
camel/societies/role_playing.py
CHANGED
|
@@ -23,10 +23,10 @@ from camel.agents import (
|
|
|
23
23
|
from camel.generators import SystemMessageGenerator
|
|
24
24
|
from camel.human import Human
|
|
25
25
|
from camel.messages import BaseMessage
|
|
26
|
-
from camel.models import BaseModelBackend
|
|
26
|
+
from camel.models import BaseModelBackend, ModelFactory
|
|
27
27
|
from camel.prompts import TextPrompt
|
|
28
28
|
from camel.responses import ChatAgentResponse
|
|
29
|
-
from camel.types import RoleType, TaskType
|
|
29
|
+
from camel.types import ModelPlatformType, ModelType, RoleType, TaskType
|
|
30
30
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
32
32
|
logger.setLevel(logging.WARNING)
|
|
@@ -55,7 +55,8 @@ class RolePlaying:
|
|
|
55
55
|
If not specified, set the criteria to improve task performance.
|
|
56
56
|
model (BaseModelBackend, optional): The model backend to use for
|
|
57
57
|
generating responses. If specified, it will override the model in
|
|
58
|
-
all agents. (default:
|
|
58
|
+
all agents if not specified in agent-specific kwargs. (default:
|
|
59
|
+
:obj:`OpenAIModel` with `GPT_4O_MINI`)
|
|
59
60
|
task_type (TaskType, optional): The type of task to perform.
|
|
60
61
|
(default: :obj:`TaskType.AI_SOCIETY`)
|
|
61
62
|
assistant_agent_kwargs (Dict, optional): Additional arguments to pass
|
|
@@ -103,16 +104,21 @@ class RolePlaying:
|
|
|
103
104
|
) -> None:
|
|
104
105
|
if model is not None:
|
|
105
106
|
logger.warning(
|
|
106
|
-
"
|
|
107
|
-
"
|
|
108
|
-
"through assistant_agent_kwargs, user_agent_kwargs, and "
|
|
109
|
-
"other agent-specific kwargs."
|
|
107
|
+
"Model provided globally is set for all agents if not"
|
|
108
|
+
" already specified in agent_kwargs."
|
|
110
109
|
)
|
|
111
110
|
|
|
112
111
|
self.with_task_specify = with_task_specify
|
|
113
112
|
self.with_task_planner = with_task_planner
|
|
114
113
|
self.with_critic_in_the_loop = with_critic_in_the_loop
|
|
115
|
-
self.model =
|
|
114
|
+
self.model: BaseModelBackend = (
|
|
115
|
+
model
|
|
116
|
+
if model is not None
|
|
117
|
+
else ModelFactory.create(
|
|
118
|
+
model_platform=ModelPlatformType.DEFAULT,
|
|
119
|
+
model_type=ModelType.DEFAULT,
|
|
120
|
+
)
|
|
121
|
+
)
|
|
116
122
|
self.task_type = task_type
|
|
117
123
|
self.task_prompt = task_prompt
|
|
118
124
|
|
|
@@ -204,8 +210,9 @@ class RolePlaying:
|
|
|
204
210
|
task_specify_meta_dict.update(extend_task_specify_meta_dict or {})
|
|
205
211
|
if self.model is not None:
|
|
206
212
|
if task_specify_agent_kwargs is None:
|
|
207
|
-
task_specify_agent_kwargs = {}
|
|
208
|
-
task_specify_agent_kwargs
|
|
213
|
+
task_specify_agent_kwargs = {'model': self.model}
|
|
214
|
+
elif 'model' not in task_specify_agent_kwargs:
|
|
215
|
+
task_specify_agent_kwargs.update(dict(model=self.model))
|
|
209
216
|
task_specify_agent = TaskSpecifyAgent(
|
|
210
217
|
task_type=self.task_type,
|
|
211
218
|
output_language=output_language,
|
|
@@ -237,8 +244,9 @@ class RolePlaying:
|
|
|
237
244
|
if self.with_task_planner:
|
|
238
245
|
if self.model is not None:
|
|
239
246
|
if task_planner_agent_kwargs is None:
|
|
240
|
-
task_planner_agent_kwargs = {}
|
|
241
|
-
task_planner_agent_kwargs
|
|
247
|
+
task_planner_agent_kwargs = {'model': self.model}
|
|
248
|
+
elif 'model' not in task_planner_agent_kwargs:
|
|
249
|
+
task_planner_agent_kwargs.update(dict(model=self.model))
|
|
242
250
|
task_planner_agent = TaskPlannerAgent(
|
|
243
251
|
output_language=output_language,
|
|
244
252
|
**(task_planner_agent_kwargs or {}),
|
|
@@ -332,11 +340,13 @@ class RolePlaying:
|
|
|
332
340
|
"""
|
|
333
341
|
if self.model is not None:
|
|
334
342
|
if assistant_agent_kwargs is None:
|
|
335
|
-
assistant_agent_kwargs = {}
|
|
336
|
-
assistant_agent_kwargs
|
|
343
|
+
assistant_agent_kwargs = {'model': self.model}
|
|
344
|
+
elif 'model' not in assistant_agent_kwargs:
|
|
345
|
+
assistant_agent_kwargs.update(dict(model=self.model))
|
|
337
346
|
if user_agent_kwargs is None:
|
|
338
|
-
user_agent_kwargs = {}
|
|
339
|
-
user_agent_kwargs
|
|
347
|
+
user_agent_kwargs = {'model': self.model}
|
|
348
|
+
elif 'model' not in user_agent_kwargs:
|
|
349
|
+
user_agent_kwargs.update(dict(model=self.model))
|
|
340
350
|
|
|
341
351
|
self.assistant_agent = ChatAgent(
|
|
342
352
|
init_assistant_sys_msg,
|
|
@@ -394,8 +404,9 @@ class RolePlaying:
|
|
|
394
404
|
)
|
|
395
405
|
if self.model is not None:
|
|
396
406
|
if critic_kwargs is None:
|
|
397
|
-
critic_kwargs = {}
|
|
398
|
-
critic_kwargs
|
|
407
|
+
critic_kwargs = {'model': self.model}
|
|
408
|
+
elif 'model' not in critic_kwargs:
|
|
409
|
+
critic_kwargs.update(dict(model=self.model))
|
|
399
410
|
self.critic = CriticAgent(
|
|
400
411
|
self.critic_sys_msg,
|
|
401
412
|
**(critic_kwargs or {}),
|
|
@@ -19,6 +19,12 @@ from camel.societies.workforce.utils import check_if_running
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class BaseNode(ABC):
|
|
22
|
+
r"""Base class for all nodes in the workforce.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
description (str): Description of the node.
|
|
26
|
+
"""
|
|
27
|
+
|
|
22
28
|
def __init__(self, description: str) -> None:
|
|
23
29
|
self.node_id = str(id(self))
|
|
24
30
|
self.description = description
|
|
@@ -27,7 +33,7 @@ class BaseNode(ABC):
|
|
|
27
33
|
|
|
28
34
|
@check_if_running(False)
|
|
29
35
|
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
|
30
|
-
"""Resets the node to its initial state."""
|
|
36
|
+
r"""Resets the node to its initial state."""
|
|
31
37
|
self._channel = TaskChannel()
|
|
32
38
|
self._running = False
|
|
33
39
|
|
|
@@ -84,6 +84,9 @@ class TaskChannel:
|
|
|
84
84
|
self._task_dict: Dict[str, Packet] = {}
|
|
85
85
|
|
|
86
86
|
async def get_returned_task_by_publisher(self, publisher_id: str) -> Task:
|
|
87
|
+
r"""Get a task from the channel that has been returned by the
|
|
88
|
+
publisher.
|
|
89
|
+
"""
|
|
87
90
|
async with self._condition:
|
|
88
91
|
while True:
|
|
89
92
|
for task_id in self._task_id_list:
|
|
@@ -96,6 +99,9 @@ class TaskChannel:
|
|
|
96
99
|
await self._condition.wait()
|
|
97
100
|
|
|
98
101
|
async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
|
|
102
|
+
r"""Get a task from the channel that has been assigned to the
|
|
103
|
+
assignee.
|
|
104
|
+
"""
|
|
99
105
|
async with self._condition:
|
|
100
106
|
while True:
|
|
101
107
|
for task_id in self._task_id_list:
|
|
@@ -147,12 +153,14 @@ class TaskChannel:
|
|
|
147
153
|
self._condition.notify_all()
|
|
148
154
|
|
|
149
155
|
async def remove_task(self, task_id: str) -> None:
|
|
156
|
+
r"""Remove a task from the channel."""
|
|
150
157
|
async with self._condition:
|
|
151
158
|
self._task_id_list.remove(task_id)
|
|
152
159
|
self._task_dict.pop(task_id)
|
|
153
160
|
self._condition.notify_all()
|
|
154
161
|
|
|
155
162
|
async def get_dependency_ids(self) -> List[str]:
|
|
163
|
+
r"""Get the IDs of all dependencies in the channel."""
|
|
156
164
|
async with self._condition:
|
|
157
165
|
dependency_ids = []
|
|
158
166
|
for task_id in self._task_id_list:
|
|
@@ -162,11 +170,13 @@ class TaskChannel:
|
|
|
162
170
|
return dependency_ids
|
|
163
171
|
|
|
164
172
|
async def get_task_by_id(self, task_id: str) -> Task:
|
|
173
|
+
r"""Get a task from the channel by its ID."""
|
|
165
174
|
async with self._condition:
|
|
166
175
|
if task_id not in self._task_id_list:
|
|
167
176
|
raise ValueError(f"Task {task_id} not found.")
|
|
168
177
|
return self._task_dict[task_id].task
|
|
169
178
|
|
|
170
179
|
async def get_channel_debug_info(self) -> str:
|
|
180
|
+
r"""Get the debug information of the channel."""
|
|
171
181
|
async with self._condition:
|
|
172
182
|
return str(self._task_dict) + '\n' + str(self._task_id_list)
|
|
@@ -18,6 +18,8 @@ from pydantic import BaseModel, Field
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class WorkerConf(BaseModel):
|
|
21
|
+
r"""The configuration of a worker."""
|
|
22
|
+
|
|
21
23
|
role: str = Field(
|
|
22
24
|
description="The role of the agent working in the work node."
|
|
23
25
|
)
|
|
@@ -31,6 +33,8 @@ class WorkerConf(BaseModel):
|
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
class TaskResult(BaseModel):
|
|
36
|
+
r"""The result of a task."""
|
|
37
|
+
|
|
34
38
|
content: str = Field(description="The result of the task.")
|
|
35
39
|
failed: bool = Field(
|
|
36
40
|
description="Flag indicating whether the task processing failed."
|
|
@@ -38,6 +42,8 @@ class TaskResult(BaseModel):
|
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
class TaskAssignResult(BaseModel):
|
|
45
|
+
r"""The result of task assignment."""
|
|
46
|
+
|
|
41
47
|
assignee_id: str = Field(
|
|
42
48
|
description="The ID of the workforce that is assigned to the task."
|
|
43
49
|
)
|
|
@@ -110,9 +110,11 @@ class Worker(BaseNode, ABC):
|
|
|
110
110
|
|
|
111
111
|
@check_if_running(False)
|
|
112
112
|
async def start(self):
|
|
113
|
+
r"""Start the worker."""
|
|
113
114
|
await self._listen_to_channel()
|
|
114
115
|
|
|
115
116
|
@check_if_running(True)
|
|
116
117
|
def stop(self):
|
|
118
|
+
r"""Stop the worker."""
|
|
117
119
|
self._running = False
|
|
118
120
|
return
|