camel-ai 0.1.5.6__py3-none-any.whl → 0.1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +249 -36
  3. camel/agents/critic_agent.py +18 -2
  4. camel/agents/deductive_reasoner_agent.py +16 -4
  5. camel/agents/embodied_agent.py +20 -6
  6. camel/agents/knowledge_graph_agent.py +24 -5
  7. camel/agents/role_assignment_agent.py +13 -1
  8. camel/agents/search_agent.py +16 -5
  9. camel/agents/task_agent.py +20 -5
  10. camel/configs/__init__.py +11 -9
  11. camel/configs/anthropic_config.py +5 -6
  12. camel/configs/base_config.py +50 -4
  13. camel/configs/gemini_config.py +69 -17
  14. camel/configs/groq_config.py +105 -0
  15. camel/configs/litellm_config.py +2 -8
  16. camel/configs/mistral_config.py +78 -0
  17. camel/configs/ollama_config.py +5 -7
  18. camel/configs/openai_config.py +12 -23
  19. camel/configs/vllm_config.py +102 -0
  20. camel/configs/zhipuai_config.py +5 -11
  21. camel/embeddings/__init__.py +2 -0
  22. camel/embeddings/mistral_embedding.py +89 -0
  23. camel/human.py +1 -1
  24. camel/interpreters/__init__.py +2 -0
  25. camel/interpreters/ipython_interpreter.py +167 -0
  26. camel/loaders/__init__.py +2 -0
  27. camel/loaders/firecrawl_reader.py +213 -0
  28. camel/memories/agent_memories.py +1 -4
  29. camel/memories/blocks/chat_history_block.py +6 -2
  30. camel/memories/blocks/vectordb_block.py +3 -1
  31. camel/memories/context_creators/score_based.py +6 -6
  32. camel/memories/records.py +9 -7
  33. camel/messages/base.py +1 -0
  34. camel/models/__init__.py +8 -0
  35. camel/models/anthropic_model.py +7 -2
  36. camel/models/azure_openai_model.py +152 -0
  37. camel/models/base_model.py +9 -2
  38. camel/models/gemini_model.py +14 -2
  39. camel/models/groq_model.py +131 -0
  40. camel/models/litellm_model.py +26 -4
  41. camel/models/mistral_model.py +169 -0
  42. camel/models/model_factory.py +30 -3
  43. camel/models/ollama_model.py +21 -2
  44. camel/models/open_source_model.py +13 -5
  45. camel/models/openai_model.py +7 -2
  46. camel/models/stub_model.py +4 -4
  47. camel/models/vllm_model.py +138 -0
  48. camel/models/zhipuai_model.py +7 -4
  49. camel/prompts/__init__.py +8 -1
  50. camel/prompts/image_craft.py +34 -0
  51. camel/prompts/multi_condition_image_craft.py +34 -0
  52. camel/prompts/task_prompt_template.py +10 -4
  53. camel/prompts/{descripte_video_prompt.py → video_description_prompt.py} +1 -1
  54. camel/responses/agent_responses.py +4 -3
  55. camel/retrievers/auto_retriever.py +2 -2
  56. camel/societies/babyagi_playing.py +6 -4
  57. camel/societies/role_playing.py +16 -8
  58. camel/storages/graph_storages/graph_element.py +10 -14
  59. camel/storages/graph_storages/neo4j_graph.py +5 -0
  60. camel/storages/vectordb_storages/base.py +24 -13
  61. camel/storages/vectordb_storages/milvus.py +1 -1
  62. camel/storages/vectordb_storages/qdrant.py +2 -3
  63. camel/tasks/__init__.py +22 -0
  64. camel/tasks/task.py +408 -0
  65. camel/tasks/task_prompt.py +65 -0
  66. camel/toolkits/__init__.py +39 -0
  67. camel/toolkits/base.py +4 -2
  68. camel/toolkits/code_execution.py +1 -1
  69. camel/toolkits/dalle_toolkit.py +146 -0
  70. camel/toolkits/github_toolkit.py +19 -34
  71. camel/toolkits/google_maps_toolkit.py +368 -0
  72. camel/toolkits/math_toolkit.py +79 -0
  73. camel/toolkits/open_api_toolkit.py +547 -0
  74. camel/{functions → toolkits}/openai_function.py +2 -7
  75. camel/toolkits/retrieval_toolkit.py +76 -0
  76. camel/toolkits/search_toolkit.py +326 -0
  77. camel/toolkits/slack_toolkit.py +308 -0
  78. camel/toolkits/twitter_toolkit.py +522 -0
  79. camel/toolkits/weather_toolkit.py +173 -0
  80. camel/types/enums.py +154 -35
  81. camel/utils/__init__.py +14 -2
  82. camel/utils/async_func.py +1 -1
  83. camel/utils/commons.py +152 -2
  84. camel/utils/constants.py +3 -0
  85. camel/utils/token_counting.py +148 -40
  86. camel/workforce/__init__.py +23 -0
  87. camel/workforce/base.py +50 -0
  88. camel/workforce/manager_node.py +299 -0
  89. camel/workforce/role_playing_node.py +168 -0
  90. camel/workforce/single_agent_node.py +77 -0
  91. camel/workforce/task_channel.py +173 -0
  92. camel/workforce/utils.py +97 -0
  93. camel/workforce/worker_node.py +115 -0
  94. camel/workforce/workforce.py +49 -0
  95. camel/workforce/workforce_prompt.py +125 -0
  96. {camel_ai-0.1.5.6.dist-info → camel_ai-0.1.6.1.dist-info}/METADATA +45 -3
  97. camel_ai-0.1.6.1.dist-info/RECORD +182 -0
  98. camel/functions/__init__.py +0 -51
  99. camel/functions/google_maps_function.py +0 -335
  100. camel/functions/math_functions.py +0 -61
  101. camel/functions/open_api_function.py +0 -508
  102. camel/functions/retrieval_functions.py +0 -61
  103. camel/functions/search_functions.py +0 -298
  104. camel/functions/slack_functions.py +0 -286
  105. camel/functions/twitter_function.py +0 -479
  106. camel/functions/weather_functions.py +0 -144
  107. camel_ai-0.1.5.6.dist-info/RECORD +0 -157
  108. /camel/{functions → toolkits}/open_api_specs/biztoc/__init__.py +0 -0
  109. /camel/{functions → toolkits}/open_api_specs/biztoc/ai-plugin.json +0 -0
  110. /camel/{functions → toolkits}/open_api_specs/biztoc/openapi.yaml +0 -0
  111. /camel/{functions → toolkits}/open_api_specs/coursera/__init__.py +0 -0
  112. /camel/{functions → toolkits}/open_api_specs/coursera/openapi.yaml +0 -0
  113. /camel/{functions → toolkits}/open_api_specs/create_qr_code/__init__.py +0 -0
  114. /camel/{functions → toolkits}/open_api_specs/create_qr_code/openapi.yaml +0 -0
  115. /camel/{functions → toolkits}/open_api_specs/klarna/__init__.py +0 -0
  116. /camel/{functions → toolkits}/open_api_specs/klarna/openapi.yaml +0 -0
  117. /camel/{functions → toolkits}/open_api_specs/nasa_apod/__init__.py +0 -0
  118. /camel/{functions → toolkits}/open_api_specs/nasa_apod/openapi.yaml +0 -0
  119. /camel/{functions → toolkits}/open_api_specs/outschool/__init__.py +0 -0
  120. /camel/{functions → toolkits}/open_api_specs/outschool/ai-plugin.json +0 -0
  121. /camel/{functions → toolkits}/open_api_specs/outschool/openapi.yaml +0 -0
  122. /camel/{functions → toolkits}/open_api_specs/outschool/paths/__init__.py +0 -0
  123. /camel/{functions → toolkits}/open_api_specs/outschool/paths/get_classes.py +0 -0
  124. /camel/{functions → toolkits}/open_api_specs/outschool/paths/search_teachers.py +0 -0
  125. /camel/{functions → toolkits}/open_api_specs/security_config.py +0 -0
  126. /camel/{functions → toolkits}/open_api_specs/speak/__init__.py +0 -0
  127. /camel/{functions → toolkits}/open_api_specs/speak/openapi.yaml +0 -0
  128. /camel/{functions → toolkits}/open_api_specs/web_scraper/__init__.py +0 -0
  129. /camel/{functions → toolkits}/open_api_specs/web_scraper/ai-plugin.json +0 -0
  130. /camel/{functions → toolkits}/open_api_specs/web_scraper/openapi.yaml +0 -0
  131. /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/__init__.py +0 -0
  132. /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/scraper.py +0 -0
  133. {camel_ai-0.1.5.6.dist-info → camel_ai-0.1.6.1.dist-info}/WHEEL +0 -0
@@ -26,6 +26,8 @@ from PIL import Image
26
26
  from camel.types import ModelType, OpenAIImageType, OpenAIVisionDetailType
27
27
 
28
28
  if TYPE_CHECKING:
29
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
30
+
29
31
  from camel.messages import OpenAIMessage
30
32
 
31
33
  LOW_DETAIL_TOKENS = 85
@@ -37,7 +39,7 @@ EXTRA_TOKENS = 85
37
39
 
38
40
 
39
41
  def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
40
- r"""Parse the message list into a single prompt following model-specifc
42
+ r"""Parse the message list into a single prompt following model-specific
41
43
  formats.
42
44
 
43
45
  Args:
@@ -51,7 +53,12 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
51
53
  system_message = messages[0]["content"]
52
54
 
53
55
  ret: str
54
- if model == ModelType.LLAMA_2 or model == ModelType.LLAMA_3:
56
+ if model in [
57
+ ModelType.LLAMA_2,
58
+ ModelType.LLAMA_3,
59
+ ModelType.GROQ_LLAMA_3_8B,
60
+ ModelType.GROQ_LLAMA_3_70B,
61
+ ]:
55
62
  # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
56
63
  seps = [" ", " </s><s>"]
57
64
  role_map = {"user": "[INST]", "assistant": "[/INST]"}
@@ -74,7 +81,7 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
74
81
  else:
75
82
  ret += role
76
83
  return ret
77
- elif model == ModelType.VICUNA or model == ModelType.VICUNA_16K:
84
+ elif model in [ModelType.VICUNA, ModelType.VICUNA_16K]:
78
85
  seps = [" ", "</s>"]
79
86
  role_map = {"user": "USER", "assistant": "ASSISTANT"}
80
87
 
@@ -132,6 +139,40 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
132
139
  else:
133
140
  ret += '<|im_start|>' + role + '\n'
134
141
  return ret
142
+ elif model == ModelType.GROQ_MIXTRAL_8_7B:
143
+ # Mistral/Mixtral format
144
+ system_prompt = f"<s>[INST] {system_message} [/INST]\n"
145
+ ret = system_prompt
146
+
147
+ for msg in messages[1:]:
148
+ if msg["role"] == "user":
149
+ ret += f"[INST] {msg['content']} [/INST]\n"
150
+ elif msg["role"] == "assistant":
151
+ ret += f"{msg['content']}</s>\n"
152
+
153
+ if not isinstance(msg['content'], str):
154
+ raise ValueError(
155
+ "Currently multimodal context is not "
156
+ "supported by the token counter."
157
+ )
158
+
159
+ return ret.strip()
160
+ elif model in [ModelType.GROQ_GEMMA_7B_IT, ModelType.GROQ_GEMMA_2_9B_IT]:
161
+ # Gemma format
162
+ ret = f"<bos>{system_message}\n"
163
+ for msg in messages:
164
+ if msg["role"] == "user":
165
+ ret += f"Human: {msg['content']}\n"
166
+ elif msg["role"] == "assistant":
167
+ ret += f"Assistant: {msg['content']}\n"
168
+
169
+ if not isinstance(msg['content'], str):
170
+ raise ValueError(
171
+ "Currently multimodal context is not supported by the token counter."
172
+ )
173
+
174
+ ret += "<eos>"
175
+ return ret
135
176
  else:
136
177
  raise ValueError(f"Invalid model type: {model}")
137
178
 
@@ -232,6 +273,7 @@ class OpenAITokenCounter(BaseTokenCounter):
232
273
  model (ModelType): Model type for which tokens will be counted.
233
274
  """
234
275
  self.model: str = model.value_for_tiktoken
276
+ self.model_type = model
235
277
 
236
278
  self.tokens_per_message: int
237
279
  self.tokens_per_name: int
@@ -300,7 +342,7 @@ class OpenAITokenCounter(BaseTokenCounter):
300
342
  base64.b64decode(encoded_image)
301
343
  )
302
344
  image = Image.open(image_bytes)
303
- num_tokens += count_tokens_from_image(
345
+ num_tokens += self._count_tokens_from_image(
304
346
  image, OpenAIVisionDetailType(detail)
305
347
  )
306
348
  if key == "name":
@@ -310,6 +352,45 @@ class OpenAITokenCounter(BaseTokenCounter):
310
352
  num_tokens += 3
311
353
  return num_tokens
312
354
 
355
+ def _count_tokens_from_image(
356
+ self, image: Image.Image, detail: OpenAIVisionDetailType
357
+ ) -> int:
358
+ r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
359
+ resolution model will be treated as :obj:`"high"`. All images with
360
+ :obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
361
+ are first scaled to fit within a 2048 x 2048 square, maintaining their
362
+ aspect ratio. Then, they are scaled such that the shortest side of the
363
+ image is 768px long. Finally, we count how many 512px squares the image
364
+ consists of. Each of those squares costs 170 tokens. Another 85 tokens are
365
+ always added to the final total. For more details please refer to `OpenAI
366
+ vision docs <https://platform.openai.com/docs/guides/vision>`_
367
+
368
+ Args:
369
+ image (PIL.Image.Image): Image to count number of tokens.
370
+ detail (OpenAIVisionDetailType): Image detail type to count
371
+ number of tokens.
372
+
373
+ Returns:
374
+ int: Number of tokens for the image given a detail type.
375
+ """
376
+ if detail == OpenAIVisionDetailType.LOW:
377
+ return LOW_DETAIL_TOKENS
378
+
379
+ width, height = image.size
380
+ if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
381
+ scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
382
+ width = int(width / scaling_factor)
383
+ height = int(height / scaling_factor)
384
+
385
+ scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
386
+ scaled_width = int(width / scaling_factor)
387
+ scaled_height = int(height / scaling_factor)
388
+
389
+ h = ceil(scaled_height / SQUARE_PIXELS)
390
+ w = ceil(scaled_width / SQUARE_PIXELS)
391
+ total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
392
+ return total
393
+
313
394
 
314
395
  class AnthropicTokenCounter(BaseTokenCounter):
315
396
  def __init__(self, model_type: ModelType):
@@ -428,41 +509,68 @@ class LiteLLMTokenCounter:
428
509
  return self.completion_cost(completion_response=response)
429
510
 
430
511
 
431
- def count_tokens_from_image(
432
- image: Image.Image, detail: OpenAIVisionDetailType
433
- ) -> int:
434
- r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
435
- resolution model will be treated as :obj:`"high"`. All images with
436
- :obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
437
- are first scaled to fit within a 2048 x 2048 square, maintaining their
438
- aspect ratio. Then, they are scaled such that the shortest side of the
439
- image is 768px long. Finally, we count how many 512px squares the image
440
- consists of. Each of those squares costs 170 tokens. Another 85 tokens are
441
- always added to the final total. For more details please refer to `OpenAI
442
- vision docs <https://platform.openai.com/docs/guides/vision>`_
512
+ class MistralTokenCounter(BaseTokenCounter):
513
+ def __init__(self, model_type: ModelType):
514
+ r"""Constructor for the token counter for Mistral models.
443
515
 
444
- Args:
445
- image (PIL.Image.Image): Image to count number of tokens.
446
- detail (OpenAIVisionDetailType): Image detail type to count
447
- number of tokens.
516
+ Args:
517
+ model_type (ModelType): Model type for which tokens will be
518
+ counted.
519
+ """
520
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
448
521
 
449
- Returns:
450
- int: Number of tokens for the image given a detail type.
451
- """
452
- if detail == OpenAIVisionDetailType.LOW:
453
- return LOW_DETAIL_TOKENS
454
-
455
- width, height = image.size
456
- if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
457
- scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
458
- width = int(width / scaling_factor)
459
- height = int(height / scaling_factor)
460
-
461
- scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
462
- scaled_width = int(width / scaling_factor)
463
- scaled_height = int(height / scaling_factor)
464
-
465
- h = ceil(scaled_height / SQUARE_PIXELS)
466
- w = ceil(scaled_width / SQUARE_PIXELS)
467
- total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
468
- return total
522
+ self.model_type = model_type
523
+
524
+ # Determine the model type and set the tokenizer accordingly
525
+ model_name = (
526
+ "codestral-22b"
527
+ if self.model_type
528
+ in {ModelType.MISTRAL_CODESTRAL, ModelType.MISTRAL_CODESTRAL_MAMBA}
529
+ else self.model_type.value
530
+ )
531
+
532
+ self.tokenizer = MistralTokenizer.from_model(model_name)
533
+
534
+ def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
535
+ r"""Count number of tokens in the provided message list using
536
+ loaded tokenizer specific for this type of model.
537
+
538
+ Args:
539
+ messages (List[OpenAIMessage]): Message list with the chat history
540
+ in OpenAI API format.
541
+
542
+ Returns:
543
+ int: Total number of tokens in the messages.
544
+ """
545
+ total_tokens = 0
546
+ for msg in messages:
547
+ tokens = self.tokenizer.encode_chat_completion(
548
+ self._convert_response_from_openai_to_mistral(msg)
549
+ ).tokens
550
+ total_tokens += len(tokens)
551
+ return total_tokens
552
+
553
+ def _convert_response_from_openai_to_mistral(
554
+ self, openai_msg: OpenAIMessage
555
+ ) -> ChatCompletionRequest:
556
+ r"""Convert an OpenAI message to a Mistral ChatCompletionRequest.
557
+
558
+ Args:
559
+ openai_msg (OpenAIMessage): An individual message with OpenAI
560
+ format.
561
+
562
+ Returns:
563
+ ChatCompletionRequest: The converted message in Mistral's request
564
+ format.
565
+ """
566
+
567
+ from mistral_common.protocol.instruct.request import (
568
+ ChatCompletionRequest,
569
+ )
570
+
571
+ mistral_request = ChatCompletionRequest( # type: ignore[type-var]
572
+ model=self.model_type.value,
573
+ messages=[openai_msg],
574
+ )
575
+
576
+ return mistral_request
@@ -0,0 +1,23 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+
15
+ from .base import BaseNode
16
+ from .manager_node import ManagerNode
17
+ from .worker_node import WorkerNode
18
+
19
+ __all__ = [
20
+ "BaseNode",
21
+ "WorkerNode",
22
+ "ManagerNode",
23
+ ]
@@ -0,0 +1,50 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any
16
+
17
+ from camel.workforce.task_channel import TaskChannel
18
+
19
+
20
+ class BaseNode(ABC):
21
+ def __init__(self, description: str) -> None:
22
+ self.node_id = str(id(self))
23
+ self.description = description
24
+ # every node is initialized to use its own channel
25
+ self._channel: TaskChannel = TaskChannel()
26
+ self._running = False
27
+
28
+ def reset(self, *args: Any, **kwargs: Any) -> Any:
29
+ """Resets the node to its initial state."""
30
+ raise NotImplementedError()
31
+
32
+ @abstractmethod
33
+ def set_channel(self, channel: TaskChannel):
34
+ r"""Sets the channel for the node."""
35
+
36
+ @abstractmethod
37
+ async def _listen_to_channel(self):
38
+ r"""Listens to the channel and handle tasks. This method should be
39
+ the main loop for the node.
40
+ """
41
+
42
+ @abstractmethod
43
+ async def start(self):
44
+ r"""Start the node."""
45
+
46
+ @abstractmethod
47
+ def stop(self):
48
+ r"""
49
+ Stop the node.
50
+ """
@@ -0,0 +1,299 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ from collections import deque
18
+ from typing import Deque, Dict, List, Optional
19
+
20
+ from colorama import Fore
21
+
22
+ from camel.agents import ChatAgent
23
+ from camel.messages.base import BaseMessage
24
+ from camel.tasks.task import Task, TaskState
25
+ from camel.workforce.base import BaseNode
26
+ from camel.workforce.single_agent_node import SingleAgentNode
27
+ from camel.workforce.task_channel import TaskChannel
28
+ from camel.workforce.utils import (
29
+ check_if_running,
30
+ parse_assign_task_resp,
31
+ parse_create_node_resp,
32
+ )
33
+ from camel.workforce.worker_node import WorkerNode
34
+ from camel.workforce.workforce_prompt import (
35
+ ASSIGN_TASK_PROMPT,
36
+ CREATE_NODE_PROMPT,
37
+ )
38
+
39
+
40
+ class ManagerNode(BaseNode):
41
+ r"""A node that manages multiple nodes. It will split the task it
42
+ receives into subtasks and assign them to the child nodes under
43
+ it, and also handles the situation when the task fails.
44
+
45
+ Args:
46
+ description (str): Description of the node.
47
+ coordinator_agent_kwargs (Optional[Dict]): Keyword arguments for the
48
+ coordinator agent, e.g. `model`, `api_key`, `tools`, etc.
49
+ task_agent_kwargs (Optional[Dict]): Keyword arguments for the task
50
+ agent, e.g. `model`, `api_key`, `tools`, etc.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ description: str,
56
+ children: List[BaseNode],
57
+ coordinator_agent_kwargs: Optional[Dict] = None,
58
+ task_agent_kwargs: Optional[Dict] = None,
59
+ ) -> None:
60
+ super().__init__(description)
61
+ self._child_listening_tasks: Deque[asyncio.Task] = deque()
62
+ self._children = children
63
+
64
+ coord_agent_sysmsg = BaseMessage.make_assistant_message(
65
+ role_name="Workforce Manager",
66
+ content="You are coordinating a group of workers. A worker can be "
67
+ "a group of agents or a single agent. Each worker is created to"
68
+ " solve a specific kind of task. Your job includes assigning "
69
+ "tasks to a existing worker, creating a new worker for a task, "
70
+ "etc.",
71
+ )
72
+ self.coordinator_agent = ChatAgent(
73
+ coord_agent_sysmsg, **(coordinator_agent_kwargs or {})
74
+ )
75
+
76
+ task_sys_msg = BaseMessage.make_assistant_message(
77
+ role_name="Task Planner",
78
+ content="You are going to compose and decompose tasks.",
79
+ )
80
+ self.task_agent = ChatAgent(task_sys_msg, **(task_agent_kwargs or {}))
81
+
82
+ # if there is one, will set by the workforce class wrapping this
83
+ self._task: Optional[Task] = None
84
+ self._pending_tasks: Deque[Task] = deque()
85
+
86
+ @check_if_running(False)
87
+ def set_main_task(self, task: Task) -> None:
88
+ r"""Set the main task for the node."""
89
+ self._task = task
90
+
91
+ def _get_child_nodes_info(self) -> str:
92
+ r"""Get the information of all the child nodes under this node."""
93
+ return '\n'.join(
94
+ f'{child.node_id}: {child.description}' for child in self._children
95
+ )
96
+
97
+ def _find_assignee(
98
+ self,
99
+ task: Task,
100
+ failed_log: Optional[str] = None,
101
+ ) -> str:
102
+ r"""Assigns a task to a child node if capable, otherwise create a
103
+ new worker node.
104
+
105
+ Parameters:
106
+ task (Task): The task to be assigned.
107
+ failed_log (Optional[str]): Optional log of a previous failed
108
+ attempt.
109
+
110
+ Returns:
111
+ str: ID of the assigned node.
112
+ """
113
+ prompt = ASSIGN_TASK_PROMPT.format(
114
+ content=task.content,
115
+ child_nodes_info=self._get_child_nodes_info(),
116
+ )
117
+ req = BaseMessage.make_user_message(
118
+ role_name="User",
119
+ content=prompt,
120
+ )
121
+ response = self.coordinator_agent.step(req)
122
+ try:
123
+ print(f"{Fore.YELLOW}{response.msg.content}{Fore.RESET}")
124
+ assignee_id = parse_assign_task_resp(response.msg.content)
125
+ except ValueError:
126
+ assignee_id = self._create_worker_node_for_task(task).node_id
127
+ return assignee_id
128
+
129
+ async def _post_task(self, task: Task, assignee_id: str) -> None:
130
+ await self._channel.post_task(task, self.node_id, assignee_id)
131
+
132
+ async def _post_dependency(self, dependency: Task) -> None:
133
+ await self._channel.post_dependency(dependency, self.node_id)
134
+
135
+ def _create_worker_node_for_task(self, task: Task) -> WorkerNode:
136
+ r"""Creates a new worker node for a given task and add it to the
137
+ children list of this node. This is one of the actions that
138
+ the coordinator can take when a task has failed.
139
+
140
+ Args:
141
+ task (Task): The task for which the worker node is created.
142
+
143
+ Returns:
144
+ WorkerNode: The created worker node.
145
+ """
146
+ prompt = CREATE_NODE_PROMPT.format(
147
+ content=task.content,
148
+ child_nodes_info=self._get_child_nodes_info(),
149
+ )
150
+ req = BaseMessage.make_user_message(
151
+ role_name="User",
152
+ content=prompt,
153
+ )
154
+ response = self.coordinator_agent.step(req)
155
+ new_node_conf = parse_create_node_resp(response.msg.content)
156
+
157
+ worker_sysmsg = BaseMessage.make_assistant_message(
158
+ role_name=new_node_conf.role,
159
+ content=new_node_conf.system,
160
+ )
161
+
162
+ # TODO: add a default selection of tools for the worker
163
+ worker = ChatAgent(worker_sysmsg)
164
+
165
+ new_node = SingleAgentNode(
166
+ description=new_node_conf.description,
167
+ worker=worker,
168
+ )
169
+ new_node.set_channel(self._channel)
170
+
171
+ print(
172
+ f"{Fore.GREEN}New worker node {new_node.node_id} created."
173
+ f"{Fore.RESET}"
174
+ )
175
+
176
+ self._children.append(new_node)
177
+ self._child_listening_tasks.append(
178
+ asyncio.create_task(new_node.start())
179
+ )
180
+ return new_node
181
+
182
+ async def _get_returned_task(self) -> Task:
183
+ r"""Get the task that's published by this node and just get returned
184
+ from the assignee.
185
+ """
186
+ return await self._channel.get_returned_task_by_publisher(self.node_id)
187
+
188
+ async def _post_ready_tasks(self) -> None:
189
+ r"""Send all the pending tasks that have all the dependencies met to
190
+ the channel, or directly return if there is none. For now, we will
191
+ directly send the first task in the pending list because all the tasks
192
+ are linearly dependent."""
193
+
194
+ if not self._pending_tasks:
195
+ return
196
+
197
+ ready_task = self._pending_tasks[0]
198
+
199
+ # if the task has failed previously, just compose and send the task
200
+ # to the channel as a dependency
201
+ if ready_task.state == TaskState.FAILED:
202
+ # TODO: the composing of tasks seems not work very well
203
+ ready_task.compose(self.task_agent)
204
+ # remove the subtasks from the channel
205
+ for subtask in ready_task.subtasks:
206
+ await self._channel.remove_task(subtask.id)
207
+ # send the task to the channel as a dependency
208
+ await self._post_dependency(ready_task)
209
+ self._pending_tasks.popleft()
210
+ # try to send the next task in the pending list
211
+ await self._post_ready_tasks()
212
+ else:
213
+ # directly post the task to the channel if it's a new one
214
+ # find a node to assign the task
215
+ assignee_id = self._find_assignee(task=ready_task)
216
+ await self._post_task(ready_task, assignee_id)
217
+
218
+ async def _handle_failed_task(self, task: Task) -> None:
219
+ # remove the failed task from the channel
220
+ await self._channel.remove_task(task.id)
221
+ if task.get_depth() >= 3:
222
+ # create a new WF and reassign
223
+ # TODO: add a state for reassign?
224
+ assignee = self._create_worker_node_for_task(task)
225
+ # print('create_new_assignee:', assignee)
226
+ await self._post_task(task, assignee.node_id)
227
+ else:
228
+ subtasks = task.decompose(self.task_agent)
229
+ # Insert packets at the head of the queue
230
+ self._pending_tasks.extendleft(reversed(subtasks))
231
+ await self._post_ready_tasks()
232
+
233
+ async def _handle_completed_task(self, task: Task) -> None:
234
+ # archive the packet, making it into a dependency
235
+ self._pending_tasks.popleft()
236
+ await self._channel.archive_task(task.id)
237
+ await self._post_ready_tasks()
238
+
239
+ @check_if_running(False)
240
+ def set_channel(self, channel: TaskChannel):
241
+ r"""Set the channel for the node and all the child nodes under it."""
242
+ self._channel = channel
243
+ for child in self._children:
244
+ child.set_channel(channel)
245
+
246
+ @check_if_running(False)
247
+ async def _listen_to_channel(self) -> None:
248
+ r"""Continuously listen to the channel, post task to the channel and
249
+ track the status of posted tasks.
250
+ """
251
+
252
+ self._running = True
253
+ print(f"{Fore.GREEN}Manager node {self.node_id} started.{Fore.RESET}")
254
+
255
+ # if this node is at the top level, it will have an initial task
256
+ # the initial task must be decomposed into subtasks first
257
+ if self._task is not None:
258
+ subtasks = self._task.decompose(self.task_agent)
259
+ self._pending_tasks.extend(subtasks)
260
+ self._task.state = TaskState.FAILED
261
+ self._pending_tasks.append(self._task)
262
+
263
+ # before starting the loop, send ready pending tasks to the channel
264
+ await self._post_ready_tasks()
265
+
266
+ await self._channel.print_channel()
267
+
268
+ while self._task is None or self._pending_tasks:
269
+ returned_task = await self._get_returned_task()
270
+ if returned_task.state == TaskState.DONE:
271
+ await self._handle_completed_task(returned_task)
272
+ elif returned_task.state == TaskState.FAILED:
273
+ await self._handle_failed_task(returned_task)
274
+ else:
275
+ raise ValueError(
276
+ f"Task {returned_task.id} has an unexpected state."
277
+ )
278
+
279
+ # shut down the whole workforce tree
280
+ self.stop()
281
+
282
+ @check_if_running(False)
283
+ async def start(self) -> None:
284
+ r"""Start itself and all the child nodes under it."""
285
+ for child in self._children:
286
+ child_listening_task = asyncio.create_task(child.start())
287
+ self._child_listening_tasks.append(child_listening_task)
288
+ await self._listen_to_channel()
289
+
290
+ @check_if_running(True)
291
+ def stop(self) -> None:
292
+ r"""Stop all the child nodes under it. The node itself will be stopped
293
+ by its parent node.
294
+ """
295
+ for child in self._children:
296
+ child.stop()
297
+ for child_task in self._child_listening_tasks:
298
+ child_task.cancel()
299
+ self._running = False