camel-ai 0.1.5.6__py3-none-any.whl → 0.1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +249 -36
- camel/agents/critic_agent.py +18 -2
- camel/agents/deductive_reasoner_agent.py +16 -4
- camel/agents/embodied_agent.py +20 -6
- camel/agents/knowledge_graph_agent.py +24 -5
- camel/agents/role_assignment_agent.py +13 -1
- camel/agents/search_agent.py +16 -5
- camel/agents/task_agent.py +20 -5
- camel/configs/__init__.py +11 -9
- camel/configs/anthropic_config.py +5 -6
- camel/configs/base_config.py +50 -4
- camel/configs/gemini_config.py +69 -17
- camel/configs/groq_config.py +105 -0
- camel/configs/litellm_config.py +2 -8
- camel/configs/mistral_config.py +78 -0
- camel/configs/ollama_config.py +5 -7
- camel/configs/openai_config.py +12 -23
- camel/configs/vllm_config.py +102 -0
- camel/configs/zhipuai_config.py +5 -11
- camel/embeddings/__init__.py +2 -0
- camel/embeddings/mistral_embedding.py +89 -0
- camel/human.py +1 -1
- camel/interpreters/__init__.py +2 -0
- camel/interpreters/ipython_interpreter.py +167 -0
- camel/loaders/__init__.py +2 -0
- camel/loaders/firecrawl_reader.py +213 -0
- camel/memories/agent_memories.py +1 -4
- camel/memories/blocks/chat_history_block.py +6 -2
- camel/memories/blocks/vectordb_block.py +3 -1
- camel/memories/context_creators/score_based.py +6 -6
- camel/memories/records.py +9 -7
- camel/messages/base.py +1 -0
- camel/models/__init__.py +8 -0
- camel/models/anthropic_model.py +7 -2
- camel/models/azure_openai_model.py +152 -0
- camel/models/base_model.py +9 -2
- camel/models/gemini_model.py +14 -2
- camel/models/groq_model.py +131 -0
- camel/models/litellm_model.py +26 -4
- camel/models/mistral_model.py +169 -0
- camel/models/model_factory.py +30 -3
- camel/models/ollama_model.py +21 -2
- camel/models/open_source_model.py +13 -5
- camel/models/openai_model.py +7 -2
- camel/models/stub_model.py +4 -4
- camel/models/vllm_model.py +138 -0
- camel/models/zhipuai_model.py +7 -4
- camel/prompts/__init__.py +8 -1
- camel/prompts/image_craft.py +34 -0
- camel/prompts/multi_condition_image_craft.py +34 -0
- camel/prompts/task_prompt_template.py +10 -4
- camel/prompts/{descripte_video_prompt.py → video_description_prompt.py} +1 -1
- camel/responses/agent_responses.py +4 -3
- camel/retrievers/auto_retriever.py +2 -2
- camel/societies/babyagi_playing.py +6 -4
- camel/societies/role_playing.py +16 -8
- camel/storages/graph_storages/graph_element.py +10 -14
- camel/storages/graph_storages/neo4j_graph.py +5 -0
- camel/storages/vectordb_storages/base.py +24 -13
- camel/storages/vectordb_storages/milvus.py +1 -1
- camel/storages/vectordb_storages/qdrant.py +2 -3
- camel/tasks/__init__.py +22 -0
- camel/tasks/task.py +408 -0
- camel/tasks/task_prompt.py +65 -0
- camel/toolkits/__init__.py +39 -0
- camel/toolkits/base.py +4 -2
- camel/toolkits/code_execution.py +1 -1
- camel/toolkits/dalle_toolkit.py +146 -0
- camel/toolkits/github_toolkit.py +19 -34
- camel/toolkits/google_maps_toolkit.py +368 -0
- camel/toolkits/math_toolkit.py +79 -0
- camel/toolkits/open_api_toolkit.py +547 -0
- camel/{functions → toolkits}/openai_function.py +2 -7
- camel/toolkits/retrieval_toolkit.py +76 -0
- camel/toolkits/search_toolkit.py +326 -0
- camel/toolkits/slack_toolkit.py +308 -0
- camel/toolkits/twitter_toolkit.py +522 -0
- camel/toolkits/weather_toolkit.py +173 -0
- camel/types/enums.py +154 -35
- camel/utils/__init__.py +14 -2
- camel/utils/async_func.py +1 -1
- camel/utils/commons.py +152 -2
- camel/utils/constants.py +3 -0
- camel/utils/token_counting.py +148 -40
- camel/workforce/__init__.py +23 -0
- camel/workforce/base.py +50 -0
- camel/workforce/manager_node.py +299 -0
- camel/workforce/role_playing_node.py +168 -0
- camel/workforce/single_agent_node.py +77 -0
- camel/workforce/task_channel.py +173 -0
- camel/workforce/utils.py +97 -0
- camel/workforce/worker_node.py +115 -0
- camel/workforce/workforce.py +49 -0
- camel/workforce/workforce_prompt.py +125 -0
- {camel_ai-0.1.5.6.dist-info → camel_ai-0.1.6.1.dist-info}/METADATA +45 -3
- camel_ai-0.1.6.1.dist-info/RECORD +182 -0
- camel/functions/__init__.py +0 -51
- camel/functions/google_maps_function.py +0 -335
- camel/functions/math_functions.py +0 -61
- camel/functions/open_api_function.py +0 -508
- camel/functions/retrieval_functions.py +0 -61
- camel/functions/search_functions.py +0 -298
- camel/functions/slack_functions.py +0 -286
- camel/functions/twitter_function.py +0 -479
- camel/functions/weather_functions.py +0 -144
- camel_ai-0.1.5.6.dist-info/RECORD +0 -157
- /camel/{functions → toolkits}/open_api_specs/biztoc/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/biztoc/ai-plugin.json +0 -0
- /camel/{functions → toolkits}/open_api_specs/biztoc/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/coursera/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/coursera/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/create_qr_code/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/create_qr_code/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/klarna/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/klarna/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/nasa_apod/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/nasa_apod/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/outschool/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/outschool/ai-plugin.json +0 -0
- /camel/{functions → toolkits}/open_api_specs/outschool/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/outschool/paths/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/outschool/paths/get_classes.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/outschool/paths/search_teachers.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/security_config.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/speak/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/speak/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/web_scraper/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/web_scraper/ai-plugin.json +0 -0
- /camel/{functions → toolkits}/open_api_specs/web_scraper/openapi.yaml +0 -0
- /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/__init__.py +0 -0
- /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/scraper.py +0 -0
- {camel_ai-0.1.5.6.dist-info → camel_ai-0.1.6.1.dist-info}/WHEEL +0 -0
camel/utils/token_counting.py
CHANGED
|
@@ -26,6 +26,8 @@ from PIL import Image
|
|
|
26
26
|
from camel.types import ModelType, OpenAIImageType, OpenAIVisionDetailType
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
30
|
+
|
|
29
31
|
from camel.messages import OpenAIMessage
|
|
30
32
|
|
|
31
33
|
LOW_DETAIL_TOKENS = 85
|
|
@@ -37,7 +39,7 @@ EXTRA_TOKENS = 85
|
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
40
|
-
r"""Parse the message list into a single prompt following model-
|
|
42
|
+
r"""Parse the message list into a single prompt following model-specific
|
|
41
43
|
formats.
|
|
42
44
|
|
|
43
45
|
Args:
|
|
@@ -51,7 +53,12 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
|
51
53
|
system_message = messages[0]["content"]
|
|
52
54
|
|
|
53
55
|
ret: str
|
|
54
|
-
if model
|
|
56
|
+
if model in [
|
|
57
|
+
ModelType.LLAMA_2,
|
|
58
|
+
ModelType.LLAMA_3,
|
|
59
|
+
ModelType.GROQ_LLAMA_3_8B,
|
|
60
|
+
ModelType.GROQ_LLAMA_3_70B,
|
|
61
|
+
]:
|
|
55
62
|
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
|
56
63
|
seps = [" ", " </s><s>"]
|
|
57
64
|
role_map = {"user": "[INST]", "assistant": "[/INST]"}
|
|
@@ -74,7 +81,7 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
|
74
81
|
else:
|
|
75
82
|
ret += role
|
|
76
83
|
return ret
|
|
77
|
-
elif model
|
|
84
|
+
elif model in [ModelType.VICUNA, ModelType.VICUNA_16K]:
|
|
78
85
|
seps = [" ", "</s>"]
|
|
79
86
|
role_map = {"user": "USER", "assistant": "ASSISTANT"}
|
|
80
87
|
|
|
@@ -132,6 +139,40 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
|
132
139
|
else:
|
|
133
140
|
ret += '<|im_start|>' + role + '\n'
|
|
134
141
|
return ret
|
|
142
|
+
elif model == ModelType.GROQ_MIXTRAL_8_7B:
|
|
143
|
+
# Mistral/Mixtral format
|
|
144
|
+
system_prompt = f"<s>[INST] {system_message} [/INST]\n"
|
|
145
|
+
ret = system_prompt
|
|
146
|
+
|
|
147
|
+
for msg in messages[1:]:
|
|
148
|
+
if msg["role"] == "user":
|
|
149
|
+
ret += f"[INST] {msg['content']} [/INST]\n"
|
|
150
|
+
elif msg["role"] == "assistant":
|
|
151
|
+
ret += f"{msg['content']}</s>\n"
|
|
152
|
+
|
|
153
|
+
if not isinstance(msg['content'], str):
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"Currently multimodal context is not "
|
|
156
|
+
"supported by the token counter."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return ret.strip()
|
|
160
|
+
elif model in [ModelType.GROQ_GEMMA_7B_IT, ModelType.GROQ_GEMMA_2_9B_IT]:
|
|
161
|
+
# Gemma format
|
|
162
|
+
ret = f"<bos>{system_message}\n"
|
|
163
|
+
for msg in messages:
|
|
164
|
+
if msg["role"] == "user":
|
|
165
|
+
ret += f"Human: {msg['content']}\n"
|
|
166
|
+
elif msg["role"] == "assistant":
|
|
167
|
+
ret += f"Assistant: {msg['content']}\n"
|
|
168
|
+
|
|
169
|
+
if not isinstance(msg['content'], str):
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"Currently multimodal context is not supported by the token counter."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
ret += "<eos>"
|
|
175
|
+
return ret
|
|
135
176
|
else:
|
|
136
177
|
raise ValueError(f"Invalid model type: {model}")
|
|
137
178
|
|
|
@@ -232,6 +273,7 @@ class OpenAITokenCounter(BaseTokenCounter):
|
|
|
232
273
|
model (ModelType): Model type for which tokens will be counted.
|
|
233
274
|
"""
|
|
234
275
|
self.model: str = model.value_for_tiktoken
|
|
276
|
+
self.model_type = model
|
|
235
277
|
|
|
236
278
|
self.tokens_per_message: int
|
|
237
279
|
self.tokens_per_name: int
|
|
@@ -300,7 +342,7 @@ class OpenAITokenCounter(BaseTokenCounter):
|
|
|
300
342
|
base64.b64decode(encoded_image)
|
|
301
343
|
)
|
|
302
344
|
image = Image.open(image_bytes)
|
|
303
|
-
num_tokens +=
|
|
345
|
+
num_tokens += self._count_tokens_from_image(
|
|
304
346
|
image, OpenAIVisionDetailType(detail)
|
|
305
347
|
)
|
|
306
348
|
if key == "name":
|
|
@@ -310,6 +352,45 @@ class OpenAITokenCounter(BaseTokenCounter):
|
|
|
310
352
|
num_tokens += 3
|
|
311
353
|
return num_tokens
|
|
312
354
|
|
|
355
|
+
def _count_tokens_from_image(
|
|
356
|
+
self, image: Image.Image, detail: OpenAIVisionDetailType
|
|
357
|
+
) -> int:
|
|
358
|
+
r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
|
|
359
|
+
resolution model will be treated as :obj:`"high"`. All images with
|
|
360
|
+
:obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
|
|
361
|
+
are first scaled to fit within a 2048 x 2048 square, maintaining their
|
|
362
|
+
aspect ratio. Then, they are scaled such that the shortest side of the
|
|
363
|
+
image is 768px long. Finally, we count how many 512px squares the image
|
|
364
|
+
consists of. Each of those squares costs 170 tokens. Another 85 tokens are
|
|
365
|
+
always added to the final total. For more details please refer to `OpenAI
|
|
366
|
+
vision docs <https://platform.openai.com/docs/guides/vision>`_
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
image (PIL.Image.Image): Image to count number of tokens.
|
|
370
|
+
detail (OpenAIVisionDetailType): Image detail type to count
|
|
371
|
+
number of tokens.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
int: Number of tokens for the image given a detail type.
|
|
375
|
+
"""
|
|
376
|
+
if detail == OpenAIVisionDetailType.LOW:
|
|
377
|
+
return LOW_DETAIL_TOKENS
|
|
378
|
+
|
|
379
|
+
width, height = image.size
|
|
380
|
+
if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
|
|
381
|
+
scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
|
|
382
|
+
width = int(width / scaling_factor)
|
|
383
|
+
height = int(height / scaling_factor)
|
|
384
|
+
|
|
385
|
+
scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
|
|
386
|
+
scaled_width = int(width / scaling_factor)
|
|
387
|
+
scaled_height = int(height / scaling_factor)
|
|
388
|
+
|
|
389
|
+
h = ceil(scaled_height / SQUARE_PIXELS)
|
|
390
|
+
w = ceil(scaled_width / SQUARE_PIXELS)
|
|
391
|
+
total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
|
|
392
|
+
return total
|
|
393
|
+
|
|
313
394
|
|
|
314
395
|
class AnthropicTokenCounter(BaseTokenCounter):
|
|
315
396
|
def __init__(self, model_type: ModelType):
|
|
@@ -428,41 +509,68 @@ class LiteLLMTokenCounter:
|
|
|
428
509
|
return self.completion_cost(completion_response=response)
|
|
429
510
|
|
|
430
511
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
|
|
435
|
-
resolution model will be treated as :obj:`"high"`. All images with
|
|
436
|
-
:obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
|
|
437
|
-
are first scaled to fit within a 2048 x 2048 square, maintaining their
|
|
438
|
-
aspect ratio. Then, they are scaled such that the shortest side of the
|
|
439
|
-
image is 768px long. Finally, we count how many 512px squares the image
|
|
440
|
-
consists of. Each of those squares costs 170 tokens. Another 85 tokens are
|
|
441
|
-
always added to the final total. For more details please refer to `OpenAI
|
|
442
|
-
vision docs <https://platform.openai.com/docs/guides/vision>`_
|
|
512
|
+
class MistralTokenCounter(BaseTokenCounter):
|
|
513
|
+
def __init__(self, model_type: ModelType):
|
|
514
|
+
r"""Constructor for the token counter for Mistral models.
|
|
443
515
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
516
|
+
Args:
|
|
517
|
+
model_type (ModelType): Model type for which tokens will be
|
|
518
|
+
counted.
|
|
519
|
+
"""
|
|
520
|
+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
448
521
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
522
|
+
self.model_type = model_type
|
|
523
|
+
|
|
524
|
+
# Determine the model type and set the tokenizer accordingly
|
|
525
|
+
model_name = (
|
|
526
|
+
"codestral-22b"
|
|
527
|
+
if self.model_type
|
|
528
|
+
in {ModelType.MISTRAL_CODESTRAL, ModelType.MISTRAL_CODESTRAL_MAMBA}
|
|
529
|
+
else self.model_type.value
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
self.tokenizer = MistralTokenizer.from_model(model_name)
|
|
533
|
+
|
|
534
|
+
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
|
|
535
|
+
r"""Count number of tokens in the provided message list using
|
|
536
|
+
loaded tokenizer specific for this type of model.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
messages (List[OpenAIMessage]): Message list with the chat history
|
|
540
|
+
in OpenAI API format.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
int: Total number of tokens in the messages.
|
|
544
|
+
"""
|
|
545
|
+
total_tokens = 0
|
|
546
|
+
for msg in messages:
|
|
547
|
+
tokens = self.tokenizer.encode_chat_completion(
|
|
548
|
+
self._convert_response_from_openai_to_mistral(msg)
|
|
549
|
+
).tokens
|
|
550
|
+
total_tokens += len(tokens)
|
|
551
|
+
return total_tokens
|
|
552
|
+
|
|
553
|
+
def _convert_response_from_openai_to_mistral(
|
|
554
|
+
self, openai_msg: OpenAIMessage
|
|
555
|
+
) -> ChatCompletionRequest:
|
|
556
|
+
r"""Convert an OpenAI message to a Mistral ChatCompletionRequest.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
openai_msg (OpenAIMessage): An individual message with OpenAI
|
|
560
|
+
format.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
ChatCompletionRequest: The converted message in Mistral's request
|
|
564
|
+
format.
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
from mistral_common.protocol.instruct.request import (
|
|
568
|
+
ChatCompletionRequest,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
mistral_request = ChatCompletionRequest( # type: ignore[type-var]
|
|
572
|
+
model=self.model_type.value,
|
|
573
|
+
messages=[openai_msg],
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
return mistral_request
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
from .base import BaseNode
|
|
16
|
+
from .manager_node import ManagerNode
|
|
17
|
+
from .worker_node import WorkerNode
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"BaseNode",
|
|
21
|
+
"WorkerNode",
|
|
22
|
+
"ManagerNode",
|
|
23
|
+
]
|
camel/workforce/base.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from camel.workforce.task_channel import TaskChannel
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseNode(ABC):
|
|
21
|
+
def __init__(self, description: str) -> None:
|
|
22
|
+
self.node_id = str(id(self))
|
|
23
|
+
self.description = description
|
|
24
|
+
# every node is initialized to use its own channel
|
|
25
|
+
self._channel: TaskChannel = TaskChannel()
|
|
26
|
+
self._running = False
|
|
27
|
+
|
|
28
|
+
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
|
29
|
+
"""Resets the node to its initial state."""
|
|
30
|
+
raise NotImplementedError()
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def set_channel(self, channel: TaskChannel):
|
|
34
|
+
r"""Sets the channel for the node."""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def _listen_to_channel(self):
|
|
38
|
+
r"""Listens to the channel and handle tasks. This method should be
|
|
39
|
+
the main loop for the node.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
async def start(self):
|
|
44
|
+
r"""Start the node."""
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def stop(self):
|
|
48
|
+
r"""
|
|
49
|
+
Stop the node.
|
|
50
|
+
"""
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
from collections import deque
|
|
18
|
+
from typing import Deque, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
from colorama import Fore
|
|
21
|
+
|
|
22
|
+
from camel.agents import ChatAgent
|
|
23
|
+
from camel.messages.base import BaseMessage
|
|
24
|
+
from camel.tasks.task import Task, TaskState
|
|
25
|
+
from camel.workforce.base import BaseNode
|
|
26
|
+
from camel.workforce.single_agent_node import SingleAgentNode
|
|
27
|
+
from camel.workforce.task_channel import TaskChannel
|
|
28
|
+
from camel.workforce.utils import (
|
|
29
|
+
check_if_running,
|
|
30
|
+
parse_assign_task_resp,
|
|
31
|
+
parse_create_node_resp,
|
|
32
|
+
)
|
|
33
|
+
from camel.workforce.worker_node import WorkerNode
|
|
34
|
+
from camel.workforce.workforce_prompt import (
|
|
35
|
+
ASSIGN_TASK_PROMPT,
|
|
36
|
+
CREATE_NODE_PROMPT,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ManagerNode(BaseNode):
|
|
41
|
+
r"""A node that manages multiple nodes. It will split the task it
|
|
42
|
+
receives into subtasks and assign them to the child nodes under
|
|
43
|
+
it, and also handles the situation when the task fails.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
description (str): Description of the node.
|
|
47
|
+
coordinator_agent_kwargs (Optional[Dict]): Keyword arguments for the
|
|
48
|
+
coordinator agent, e.g. `model`, `api_key`, `tools`, etc.
|
|
49
|
+
task_agent_kwargs (Optional[Dict]): Keyword arguments for the task
|
|
50
|
+
agent, e.g. `model`, `api_key`, `tools`, etc.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
description: str,
|
|
56
|
+
children: List[BaseNode],
|
|
57
|
+
coordinator_agent_kwargs: Optional[Dict] = None,
|
|
58
|
+
task_agent_kwargs: Optional[Dict] = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
super().__init__(description)
|
|
61
|
+
self._child_listening_tasks: Deque[asyncio.Task] = deque()
|
|
62
|
+
self._children = children
|
|
63
|
+
|
|
64
|
+
coord_agent_sysmsg = BaseMessage.make_assistant_message(
|
|
65
|
+
role_name="Workforce Manager",
|
|
66
|
+
content="You are coordinating a group of workers. A worker can be "
|
|
67
|
+
"a group of agents or a single agent. Each worker is created to"
|
|
68
|
+
" solve a specific kind of task. Your job includes assigning "
|
|
69
|
+
"tasks to a existing worker, creating a new worker for a task, "
|
|
70
|
+
"etc.",
|
|
71
|
+
)
|
|
72
|
+
self.coordinator_agent = ChatAgent(
|
|
73
|
+
coord_agent_sysmsg, **(coordinator_agent_kwargs or {})
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
task_sys_msg = BaseMessage.make_assistant_message(
|
|
77
|
+
role_name="Task Planner",
|
|
78
|
+
content="You are going to compose and decompose tasks.",
|
|
79
|
+
)
|
|
80
|
+
self.task_agent = ChatAgent(task_sys_msg, **(task_agent_kwargs or {}))
|
|
81
|
+
|
|
82
|
+
# if there is one, will set by the workforce class wrapping this
|
|
83
|
+
self._task: Optional[Task] = None
|
|
84
|
+
self._pending_tasks: Deque[Task] = deque()
|
|
85
|
+
|
|
86
|
+
@check_if_running(False)
|
|
87
|
+
def set_main_task(self, task: Task) -> None:
|
|
88
|
+
r"""Set the main task for the node."""
|
|
89
|
+
self._task = task
|
|
90
|
+
|
|
91
|
+
def _get_child_nodes_info(self) -> str:
|
|
92
|
+
r"""Get the information of all the child nodes under this node."""
|
|
93
|
+
return '\n'.join(
|
|
94
|
+
f'{child.node_id}: {child.description}' for child in self._children
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def _find_assignee(
|
|
98
|
+
self,
|
|
99
|
+
task: Task,
|
|
100
|
+
failed_log: Optional[str] = None,
|
|
101
|
+
) -> str:
|
|
102
|
+
r"""Assigns a task to a child node if capable, otherwise create a
|
|
103
|
+
new worker node.
|
|
104
|
+
|
|
105
|
+
Parameters:
|
|
106
|
+
task (Task): The task to be assigned.
|
|
107
|
+
failed_log (Optional[str]): Optional log of a previous failed
|
|
108
|
+
attempt.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
str: ID of the assigned node.
|
|
112
|
+
"""
|
|
113
|
+
prompt = ASSIGN_TASK_PROMPT.format(
|
|
114
|
+
content=task.content,
|
|
115
|
+
child_nodes_info=self._get_child_nodes_info(),
|
|
116
|
+
)
|
|
117
|
+
req = BaseMessage.make_user_message(
|
|
118
|
+
role_name="User",
|
|
119
|
+
content=prompt,
|
|
120
|
+
)
|
|
121
|
+
response = self.coordinator_agent.step(req)
|
|
122
|
+
try:
|
|
123
|
+
print(f"{Fore.YELLOW}{response.msg.content}{Fore.RESET}")
|
|
124
|
+
assignee_id = parse_assign_task_resp(response.msg.content)
|
|
125
|
+
except ValueError:
|
|
126
|
+
assignee_id = self._create_worker_node_for_task(task).node_id
|
|
127
|
+
return assignee_id
|
|
128
|
+
|
|
129
|
+
async def _post_task(self, task: Task, assignee_id: str) -> None:
|
|
130
|
+
await self._channel.post_task(task, self.node_id, assignee_id)
|
|
131
|
+
|
|
132
|
+
async def _post_dependency(self, dependency: Task) -> None:
|
|
133
|
+
await self._channel.post_dependency(dependency, self.node_id)
|
|
134
|
+
|
|
135
|
+
def _create_worker_node_for_task(self, task: Task) -> WorkerNode:
|
|
136
|
+
r"""Creates a new worker node for a given task and add it to the
|
|
137
|
+
children list of this node. This is one of the actions that
|
|
138
|
+
the coordinator can take when a task has failed.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
task (Task): The task for which the worker node is created.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
WorkerNode: The created worker node.
|
|
145
|
+
"""
|
|
146
|
+
prompt = CREATE_NODE_PROMPT.format(
|
|
147
|
+
content=task.content,
|
|
148
|
+
child_nodes_info=self._get_child_nodes_info(),
|
|
149
|
+
)
|
|
150
|
+
req = BaseMessage.make_user_message(
|
|
151
|
+
role_name="User",
|
|
152
|
+
content=prompt,
|
|
153
|
+
)
|
|
154
|
+
response = self.coordinator_agent.step(req)
|
|
155
|
+
new_node_conf = parse_create_node_resp(response.msg.content)
|
|
156
|
+
|
|
157
|
+
worker_sysmsg = BaseMessage.make_assistant_message(
|
|
158
|
+
role_name=new_node_conf.role,
|
|
159
|
+
content=new_node_conf.system,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# TODO: add a default selection of tools for the worker
|
|
163
|
+
worker = ChatAgent(worker_sysmsg)
|
|
164
|
+
|
|
165
|
+
new_node = SingleAgentNode(
|
|
166
|
+
description=new_node_conf.description,
|
|
167
|
+
worker=worker,
|
|
168
|
+
)
|
|
169
|
+
new_node.set_channel(self._channel)
|
|
170
|
+
|
|
171
|
+
print(
|
|
172
|
+
f"{Fore.GREEN}New worker node {new_node.node_id} created."
|
|
173
|
+
f"{Fore.RESET}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self._children.append(new_node)
|
|
177
|
+
self._child_listening_tasks.append(
|
|
178
|
+
asyncio.create_task(new_node.start())
|
|
179
|
+
)
|
|
180
|
+
return new_node
|
|
181
|
+
|
|
182
|
+
async def _get_returned_task(self) -> Task:
|
|
183
|
+
r"""Get the task that's published by this node and just get returned
|
|
184
|
+
from the assignee.
|
|
185
|
+
"""
|
|
186
|
+
return await self._channel.get_returned_task_by_publisher(self.node_id)
|
|
187
|
+
|
|
188
|
+
async def _post_ready_tasks(self) -> None:
|
|
189
|
+
r"""Send all the pending tasks that have all the dependencies met to
|
|
190
|
+
the channel, or directly return if there is none. For now, we will
|
|
191
|
+
directly send the first task in the pending list because all the tasks
|
|
192
|
+
are linearly dependent."""
|
|
193
|
+
|
|
194
|
+
if not self._pending_tasks:
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
ready_task = self._pending_tasks[0]
|
|
198
|
+
|
|
199
|
+
# if the task has failed previously, just compose and send the task
|
|
200
|
+
# to the channel as a dependency
|
|
201
|
+
if ready_task.state == TaskState.FAILED:
|
|
202
|
+
# TODO: the composing of tasks seems not work very well
|
|
203
|
+
ready_task.compose(self.task_agent)
|
|
204
|
+
# remove the subtasks from the channel
|
|
205
|
+
for subtask in ready_task.subtasks:
|
|
206
|
+
await self._channel.remove_task(subtask.id)
|
|
207
|
+
# send the task to the channel as a dependency
|
|
208
|
+
await self._post_dependency(ready_task)
|
|
209
|
+
self._pending_tasks.popleft()
|
|
210
|
+
# try to send the next task in the pending list
|
|
211
|
+
await self._post_ready_tasks()
|
|
212
|
+
else:
|
|
213
|
+
# directly post the task to the channel if it's a new one
|
|
214
|
+
# find a node to assign the task
|
|
215
|
+
assignee_id = self._find_assignee(task=ready_task)
|
|
216
|
+
await self._post_task(ready_task, assignee_id)
|
|
217
|
+
|
|
218
|
+
async def _handle_failed_task(self, task: Task) -> None:
|
|
219
|
+
# remove the failed task from the channel
|
|
220
|
+
await self._channel.remove_task(task.id)
|
|
221
|
+
if task.get_depth() >= 3:
|
|
222
|
+
# create a new WF and reassign
|
|
223
|
+
# TODO: add a state for reassign?
|
|
224
|
+
assignee = self._create_worker_node_for_task(task)
|
|
225
|
+
# print('create_new_assignee:', assignee)
|
|
226
|
+
await self._post_task(task, assignee.node_id)
|
|
227
|
+
else:
|
|
228
|
+
subtasks = task.decompose(self.task_agent)
|
|
229
|
+
# Insert packets at the head of the queue
|
|
230
|
+
self._pending_tasks.extendleft(reversed(subtasks))
|
|
231
|
+
await self._post_ready_tasks()
|
|
232
|
+
|
|
233
|
+
async def _handle_completed_task(self, task: Task) -> None:
|
|
234
|
+
# archive the packet, making it into a dependency
|
|
235
|
+
self._pending_tasks.popleft()
|
|
236
|
+
await self._channel.archive_task(task.id)
|
|
237
|
+
await self._post_ready_tasks()
|
|
238
|
+
|
|
239
|
+
@check_if_running(False)
|
|
240
|
+
def set_channel(self, channel: TaskChannel):
|
|
241
|
+
r"""Set the channel for the node and all the child nodes under it."""
|
|
242
|
+
self._channel = channel
|
|
243
|
+
for child in self._children:
|
|
244
|
+
child.set_channel(channel)
|
|
245
|
+
|
|
246
|
+
@check_if_running(False)
|
|
247
|
+
async def _listen_to_channel(self) -> None:
|
|
248
|
+
r"""Continuously listen to the channel, post task to the channel and
|
|
249
|
+
track the status of posted tasks.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
self._running = True
|
|
253
|
+
print(f"{Fore.GREEN}Manager node {self.node_id} started.{Fore.RESET}")
|
|
254
|
+
|
|
255
|
+
# if this node is at the top level, it will have an initial task
|
|
256
|
+
# the initial task must be decomposed into subtasks first
|
|
257
|
+
if self._task is not None:
|
|
258
|
+
subtasks = self._task.decompose(self.task_agent)
|
|
259
|
+
self._pending_tasks.extend(subtasks)
|
|
260
|
+
self._task.state = TaskState.FAILED
|
|
261
|
+
self._pending_tasks.append(self._task)
|
|
262
|
+
|
|
263
|
+
# before starting the loop, send ready pending tasks to the channel
|
|
264
|
+
await self._post_ready_tasks()
|
|
265
|
+
|
|
266
|
+
await self._channel.print_channel()
|
|
267
|
+
|
|
268
|
+
while self._task is None or self._pending_tasks:
|
|
269
|
+
returned_task = await self._get_returned_task()
|
|
270
|
+
if returned_task.state == TaskState.DONE:
|
|
271
|
+
await self._handle_completed_task(returned_task)
|
|
272
|
+
elif returned_task.state == TaskState.FAILED:
|
|
273
|
+
await self._handle_failed_task(returned_task)
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Task {returned_task.id} has an unexpected state."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# shut down the whole workforce tree
|
|
280
|
+
self.stop()
|
|
281
|
+
|
|
282
|
+
@check_if_running(False)
|
|
283
|
+
async def start(self) -> None:
|
|
284
|
+
r"""Start itself and all the child nodes under it."""
|
|
285
|
+
for child in self._children:
|
|
286
|
+
child_listening_task = asyncio.create_task(child.start())
|
|
287
|
+
self._child_listening_tasks.append(child_listening_task)
|
|
288
|
+
await self._listen_to_channel()
|
|
289
|
+
|
|
290
|
+
@check_if_running(True)
|
|
291
|
+
def stop(self) -> None:
|
|
292
|
+
r"""Stop all the child nodes under it. The node itself will be stopped
|
|
293
|
+
by its parent node.
|
|
294
|
+
"""
|
|
295
|
+
for child in self._children:
|
|
296
|
+
child.stop()
|
|
297
|
+
for child_task in self._child_listening_tasks:
|
|
298
|
+
child_task.cancel()
|
|
299
|
+
self._running = False
|