camel-ai 0.2.69a6__py3-none-any.whl → 0.2.69a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +220 -7
- camel/memories/context_creators/score_based.py +11 -6
- camel/messages/base.py +2 -2
- camel/societies/workforce/workforce.py +214 -22
- camel/storages/__init__.py +2 -0
- camel/storages/vectordb_storages/__init__.py +2 -0
- camel/storages/vectordb_storages/chroma.py +731 -0
- camel/tasks/task.py +30 -2
- camel/toolkits/__init__.py +2 -1
- camel/toolkits/excel_toolkit.py +814 -69
- camel/toolkits/google_drive_mcp_toolkit.py +73 -0
- camel/toolkits/mcp_toolkit.py +31 -1
- camel/types/enums.py +6 -6
- {camel_ai-0.2.69a6.dist-info → camel_ai-0.2.69a7.dist-info}/METADATA +4 -1
- {camel_ai-0.2.69a6.dist-info → camel_ai-0.2.69a7.dist-info}/RECORD +18 -16
- {camel_ai-0.2.69a6.dist-info → camel_ai-0.2.69a7.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.69a6.dist-info → camel_ai-0.2.69a7.dist-info}/licenses/LICENSE +0 -0
camel/__init__.py
CHANGED
camel/agents/chat_agent.py
CHANGED
|
@@ -406,7 +406,10 @@ class ChatAgent(BaseAgent):
|
|
|
406
406
|
# List of tuples (platform, type)
|
|
407
407
|
resolved_models_list = []
|
|
408
408
|
for model_spec in model_list:
|
|
409
|
-
platform, type_ =
|
|
409
|
+
platform, type_ = ( # type: ignore[index]
|
|
410
|
+
model_spec[0],
|
|
411
|
+
model_spec[1],
|
|
412
|
+
)
|
|
410
413
|
resolved_models_list.append(
|
|
411
414
|
ModelFactory.create(
|
|
412
415
|
model_platform=platform, model_type=type_
|
|
@@ -846,6 +849,185 @@ class ChatAgent(BaseAgent):
|
|
|
846
849
|
except ValidationError:
|
|
847
850
|
return False
|
|
848
851
|
|
|
852
|
+
def _check_tools_strict_compatibility(self) -> bool:
|
|
853
|
+
r"""Check if all tools are compatible with OpenAI strict mode.
|
|
854
|
+
|
|
855
|
+
Returns:
|
|
856
|
+
bool: True if all tools are strict mode compatible,
|
|
857
|
+
False otherwise.
|
|
858
|
+
"""
|
|
859
|
+
tool_schemas = self._get_full_tool_schemas()
|
|
860
|
+
for schema in tool_schemas:
|
|
861
|
+
if not schema.get("function", {}).get("strict", True):
|
|
862
|
+
return False
|
|
863
|
+
return True
|
|
864
|
+
|
|
865
|
+
def _convert_response_format_to_prompt(
|
|
866
|
+
self, response_format: Type[BaseModel]
|
|
867
|
+
) -> str:
|
|
868
|
+
r"""Convert a Pydantic response format to a prompt instruction.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
response_format (Type[BaseModel]): The Pydantic model class.
|
|
872
|
+
|
|
873
|
+
Returns:
|
|
874
|
+
str: A prompt instruction requesting the specific format.
|
|
875
|
+
"""
|
|
876
|
+
try:
|
|
877
|
+
# Get the JSON schema from the Pydantic model
|
|
878
|
+
schema = response_format.model_json_schema()
|
|
879
|
+
|
|
880
|
+
# Create a prompt based on the schema
|
|
881
|
+
format_instruction = (
|
|
882
|
+
"\n\nPlease respond in the following JSON format:\n" "{\n"
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
properties = schema.get("properties", {})
|
|
886
|
+
for field_name, field_info in properties.items():
|
|
887
|
+
field_type = field_info.get("type", "string")
|
|
888
|
+
description = field_info.get("description", "")
|
|
889
|
+
|
|
890
|
+
if field_type == "array":
|
|
891
|
+
format_instruction += (
|
|
892
|
+
f' "{field_name}": ["array of values"]'
|
|
893
|
+
)
|
|
894
|
+
elif field_type == "object":
|
|
895
|
+
format_instruction += f' "{field_name}": {{"object"}}'
|
|
896
|
+
elif field_type == "boolean":
|
|
897
|
+
format_instruction += f' "{field_name}": true'
|
|
898
|
+
elif field_type == "number":
|
|
899
|
+
format_instruction += f' "{field_name}": 0'
|
|
900
|
+
else:
|
|
901
|
+
format_instruction += f' "{field_name}": "string value"'
|
|
902
|
+
|
|
903
|
+
if description:
|
|
904
|
+
format_instruction += f' // {description}'
|
|
905
|
+
|
|
906
|
+
# Add comma if not the last item
|
|
907
|
+
if field_name != list(properties.keys())[-1]:
|
|
908
|
+
format_instruction += ","
|
|
909
|
+
format_instruction += "\n"
|
|
910
|
+
|
|
911
|
+
format_instruction += "}"
|
|
912
|
+
return format_instruction
|
|
913
|
+
|
|
914
|
+
except Exception as e:
|
|
915
|
+
logger.warning(
|
|
916
|
+
f"Failed to convert response_format to prompt: {e}. "
|
|
917
|
+
f"Using generic format instruction."
|
|
918
|
+
)
|
|
919
|
+
return (
|
|
920
|
+
"\n\nPlease respond in a structured JSON format "
|
|
921
|
+
"that matches the requested schema."
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
def _handle_response_format_with_non_strict_tools(
|
|
925
|
+
self,
|
|
926
|
+
input_message: Union[BaseMessage, str],
|
|
927
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
928
|
+
) -> Tuple[Union[BaseMessage, str], Optional[Type[BaseModel]], bool]:
|
|
929
|
+
r"""Handle response format when tools are not strict mode compatible.
|
|
930
|
+
|
|
931
|
+
Args:
|
|
932
|
+
input_message: The original input message.
|
|
933
|
+
response_format: The requested response format.
|
|
934
|
+
|
|
935
|
+
Returns:
|
|
936
|
+
Tuple: (modified_message, modified_response_format,
|
|
937
|
+
used_prompt_formatting)
|
|
938
|
+
"""
|
|
939
|
+
if response_format is None:
|
|
940
|
+
return input_message, response_format, False
|
|
941
|
+
|
|
942
|
+
# Check if tools are strict mode compatible
|
|
943
|
+
if self._check_tools_strict_compatibility():
|
|
944
|
+
return input_message, response_format, False
|
|
945
|
+
|
|
946
|
+
# Tools are not strict compatible, convert to prompt
|
|
947
|
+
logger.info(
|
|
948
|
+
"Non-strict tools detected. Converting response_format to "
|
|
949
|
+
"prompt-based formatting."
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
format_prompt = self._convert_response_format_to_prompt(
|
|
953
|
+
response_format
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
# Modify the message to include format instruction
|
|
957
|
+
modified_message: Union[BaseMessage, str]
|
|
958
|
+
if isinstance(input_message, str):
|
|
959
|
+
modified_message = input_message + format_prompt
|
|
960
|
+
else:
|
|
961
|
+
modified_message = input_message.create_new_instance(
|
|
962
|
+
input_message.content + format_prompt
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
# Return None for response_format to avoid strict mode conflicts
|
|
966
|
+
# and True to indicate we used prompt formatting
|
|
967
|
+
return modified_message, None, True
|
|
968
|
+
|
|
969
|
+
def _apply_prompt_based_parsing(
|
|
970
|
+
self,
|
|
971
|
+
response: ModelResponse,
|
|
972
|
+
original_response_format: Type[BaseModel],
|
|
973
|
+
) -> None:
|
|
974
|
+
r"""Apply manual parsing when using prompt-based formatting.
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
response: The model response to parse.
|
|
978
|
+
original_response_format: The original response format class.
|
|
979
|
+
"""
|
|
980
|
+
for message in response.output_messages:
|
|
981
|
+
if message.content:
|
|
982
|
+
try:
|
|
983
|
+
# Try to extract JSON from the response content
|
|
984
|
+
import json
|
|
985
|
+
import re
|
|
986
|
+
|
|
987
|
+
from pydantic import ValidationError
|
|
988
|
+
|
|
989
|
+
# Try to find JSON in the content
|
|
990
|
+
content = message.content.strip()
|
|
991
|
+
|
|
992
|
+
# Try direct parsing first
|
|
993
|
+
try:
|
|
994
|
+
parsed_json = json.loads(content)
|
|
995
|
+
message.parsed = (
|
|
996
|
+
original_response_format.model_validate(
|
|
997
|
+
parsed_json
|
|
998
|
+
)
|
|
999
|
+
)
|
|
1000
|
+
continue
|
|
1001
|
+
except (json.JSONDecodeError, ValidationError):
|
|
1002
|
+
pass
|
|
1003
|
+
|
|
1004
|
+
# Try to extract JSON from text
|
|
1005
|
+
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
|
1006
|
+
json_matches = re.findall(json_pattern, content, re.DOTALL)
|
|
1007
|
+
|
|
1008
|
+
for json_str in json_matches:
|
|
1009
|
+
try:
|
|
1010
|
+
parsed_json = json.loads(json_str)
|
|
1011
|
+
message.parsed = (
|
|
1012
|
+
original_response_format.model_validate(
|
|
1013
|
+
parsed_json
|
|
1014
|
+
)
|
|
1015
|
+
)
|
|
1016
|
+
# Update content to just the JSON for consistency
|
|
1017
|
+
message.content = json.dumps(parsed_json)
|
|
1018
|
+
break
|
|
1019
|
+
except (json.JSONDecodeError, ValidationError):
|
|
1020
|
+
continue
|
|
1021
|
+
|
|
1022
|
+
if not message.parsed:
|
|
1023
|
+
logger.warning(
|
|
1024
|
+
f"Failed to parse JSON from response: "
|
|
1025
|
+
f"{content[:100]}..."
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
except Exception as e:
|
|
1029
|
+
logger.warning(f"Error during prompt-based parsing: {e}")
|
|
1030
|
+
|
|
849
1031
|
def _format_response_if_needed(
|
|
850
1032
|
self,
|
|
851
1033
|
response: ModelResponse,
|
|
@@ -932,6 +1114,14 @@ class ChatAgent(BaseAgent):
|
|
|
932
1114
|
except ImportError:
|
|
933
1115
|
pass # Langfuse not available
|
|
934
1116
|
|
|
1117
|
+
# Handle response format compatibility with non-strict tools
|
|
1118
|
+
original_response_format = response_format
|
|
1119
|
+
input_message, response_format, used_prompt_formatting = (
|
|
1120
|
+
self._handle_response_format_with_non_strict_tools(
|
|
1121
|
+
input_message, response_format
|
|
1122
|
+
)
|
|
1123
|
+
)
|
|
1124
|
+
|
|
935
1125
|
# Convert input message to BaseMessage if necessary
|
|
936
1126
|
if isinstance(input_message, str):
|
|
937
1127
|
input_message = BaseMessage.make_user_message(
|
|
@@ -1014,6 +1204,13 @@ class ChatAgent(BaseAgent):
|
|
|
1014
1204
|
break
|
|
1015
1205
|
|
|
1016
1206
|
self._format_response_if_needed(response, response_format)
|
|
1207
|
+
|
|
1208
|
+
# Apply manual parsing if we used prompt-based formatting
|
|
1209
|
+
if used_prompt_formatting and original_response_format:
|
|
1210
|
+
self._apply_prompt_based_parsing(
|
|
1211
|
+
response, original_response_format
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1017
1214
|
self._record_final_output(response.output_messages)
|
|
1018
1215
|
|
|
1019
1216
|
return self._convert_to_chatagent_response(
|
|
@@ -1065,6 +1262,14 @@ class ChatAgent(BaseAgent):
|
|
|
1065
1262
|
except ImportError:
|
|
1066
1263
|
pass # Langfuse not available
|
|
1067
1264
|
|
|
1265
|
+
# Handle response format compatibility with non-strict tools
|
|
1266
|
+
original_response_format = response_format
|
|
1267
|
+
input_message, response_format, used_prompt_formatting = (
|
|
1268
|
+
self._handle_response_format_with_non_strict_tools(
|
|
1269
|
+
input_message, response_format
|
|
1270
|
+
)
|
|
1271
|
+
)
|
|
1272
|
+
|
|
1068
1273
|
if isinstance(input_message, str):
|
|
1069
1274
|
input_message = BaseMessage.make_user_message(
|
|
1070
1275
|
role_name="User", content=input_message
|
|
@@ -1098,6 +1303,11 @@ class ChatAgent(BaseAgent):
|
|
|
1098
1303
|
)
|
|
1099
1304
|
iteration_count += 1
|
|
1100
1305
|
|
|
1306
|
+
# Accumulate API token usage
|
|
1307
|
+
self._update_token_usage_tracker(
|
|
1308
|
+
step_token_usage, response.usage_dict
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1101
1311
|
# Terminate Agent if stop_event is set
|
|
1102
1312
|
if self.stop_event and self.stop_event.is_set():
|
|
1103
1313
|
# Use the _step_terminate to terminate the agent with reason
|
|
@@ -1139,13 +1349,14 @@ class ChatAgent(BaseAgent):
|
|
|
1139
1349
|
break
|
|
1140
1350
|
|
|
1141
1351
|
await self._aformat_response_if_needed(response, response_format)
|
|
1142
|
-
self._record_final_output(response.output_messages)
|
|
1143
1352
|
|
|
1144
|
-
#
|
|
1145
|
-
|
|
1353
|
+
# Apply manual parsing if we used prompt-based formatting
|
|
1354
|
+
if used_prompt_formatting and original_response_format:
|
|
1355
|
+
self._apply_prompt_based_parsing(
|
|
1356
|
+
response, original_response_format
|
|
1357
|
+
)
|
|
1146
1358
|
|
|
1147
|
-
|
|
1148
|
-
self._update_token_usage_tracker(step_token_usage, response.usage_dict)
|
|
1359
|
+
self._record_final_output(response.output_messages)
|
|
1149
1360
|
|
|
1150
1361
|
return self._convert_to_chatagent_response(
|
|
1151
1362
|
response,
|
|
@@ -1924,7 +2135,9 @@ class ChatAgent(BaseAgent):
|
|
|
1924
2135
|
schema for schema in self._external_tool_schemas.values()
|
|
1925
2136
|
],
|
|
1926
2137
|
response_terminators=self.response_terminators,
|
|
1927
|
-
scheduling_strategy=
|
|
2138
|
+
scheduling_strategy=(
|
|
2139
|
+
self.model_backend.scheduling_strategy.__name__
|
|
2140
|
+
),
|
|
1928
2141
|
max_iteration=self.max_iteration,
|
|
1929
2142
|
stop_event=self.stop_event,
|
|
1930
2143
|
)
|
|
@@ -155,16 +155,21 @@ class ScoreBasedContextCreator(BaseContextCreator):
|
|
|
155
155
|
# ======================
|
|
156
156
|
# 6. Truncation Logic with Tool Call Awareness
|
|
157
157
|
# ======================
|
|
158
|
-
logger.warning(
|
|
159
|
-
f"Context truncation required "
|
|
160
|
-
f"({total_tokens} > {self.token_limit}), "
|
|
161
|
-
f"pruning low-score messages."
|
|
162
|
-
)
|
|
163
|
-
|
|
164
158
|
remaining_units = self._truncate_with_tool_call_awareness(
|
|
165
159
|
regular_units, tool_call_groups, system_tokens
|
|
166
160
|
)
|
|
167
161
|
|
|
162
|
+
# Log only after truncation is actually performed so that both
|
|
163
|
+
# the original and the final token counts are visible.
|
|
164
|
+
tokens_after = system_tokens + sum(
|
|
165
|
+
u.num_tokens for u in remaining_units
|
|
166
|
+
)
|
|
167
|
+
logger.warning(
|
|
168
|
+
"Context truncation performed: "
|
|
169
|
+
f"before={total_tokens}, after={tokens_after}, "
|
|
170
|
+
f"limit={self.token_limit}"
|
|
171
|
+
)
|
|
172
|
+
|
|
168
173
|
# ======================
|
|
169
174
|
# 7. Output Assembly
|
|
170
175
|
# ======================
|
camel/messages/base.py
CHANGED
|
@@ -69,7 +69,7 @@ class BaseMessage:
|
|
|
69
69
|
image_detail (Literal["auto", "low", "high"]): Detail level of the
|
|
70
70
|
images associated with the message. (default: :obj:`auto`)
|
|
71
71
|
video_detail (Literal["auto", "low", "high"]): Detail level of the
|
|
72
|
-
videos associated with the message. (default: :obj:`
|
|
72
|
+
videos associated with the message. (default: :obj:`auto`)
|
|
73
73
|
parsed: Optional[Union[Type[BaseModel], dict]]: Optional object which
|
|
74
74
|
is parsed from the content. (default: :obj:`None`)
|
|
75
75
|
"""
|
|
@@ -82,7 +82,7 @@ class BaseMessage:
|
|
|
82
82
|
video_bytes: Optional[bytes] = None
|
|
83
83
|
image_list: Optional[List[Image.Image]] = None
|
|
84
84
|
image_detail: Literal["auto", "low", "high"] = "auto"
|
|
85
|
-
video_detail: Literal["auto", "low", "high"] = "
|
|
85
|
+
video_detail: Literal["auto", "low", "high"] = "auto"
|
|
86
86
|
parsed: Optional[Union[BaseModel, dict]] = None
|
|
87
87
|
|
|
88
88
|
@classmethod
|
|
@@ -19,7 +19,7 @@ import time
|
|
|
19
19
|
import uuid
|
|
20
20
|
from collections import deque
|
|
21
21
|
from enum import Enum
|
|
22
|
-
from typing import Any, Coroutine, Deque, Dict, List, Optional
|
|
22
|
+
from typing import Any, Coroutine, Deque, Dict, List, Optional, Set, Tuple
|
|
23
23
|
|
|
24
24
|
from colorama import Fore
|
|
25
25
|
|
|
@@ -37,6 +37,7 @@ from camel.societies.workforce.role_playing_worker import RolePlayingWorker
|
|
|
37
37
|
from camel.societies.workforce.single_agent_worker import SingleAgentWorker
|
|
38
38
|
from camel.societies.workforce.task_channel import TaskChannel
|
|
39
39
|
from camel.societies.workforce.utils import (
|
|
40
|
+
TaskAssignment,
|
|
40
41
|
TaskAssignResult,
|
|
41
42
|
WorkerConf,
|
|
42
43
|
check_if_running,
|
|
@@ -1164,22 +1165,30 @@ class Workforce(BaseNode):
|
|
|
1164
1165
|
)
|
|
1165
1166
|
return info
|
|
1166
1167
|
|
|
1167
|
-
def
|
|
1168
|
-
|
|
1169
|
-
|
|
1168
|
+
def _get_valid_worker_ids(self) -> set:
|
|
1169
|
+
r"""Get all valid worker IDs from child nodes.
|
|
1170
|
+
|
|
1171
|
+
Returns:
|
|
1172
|
+
set: Set of valid worker IDs that can be assigned tasks.
|
|
1173
|
+
"""
|
|
1174
|
+
valid_worker_ids = {child.node_id for child in self._children}
|
|
1175
|
+
return valid_worker_ids
|
|
1176
|
+
|
|
1177
|
+
def _call_coordinator_for_assignment(
|
|
1178
|
+
self, tasks: List[Task], invalid_ids: Optional[List[str]] = None
|
|
1170
1179
|
) -> TaskAssignResult:
|
|
1171
|
-
r"""
|
|
1180
|
+
r"""Call coordinator agent to assign tasks with optional validation
|
|
1181
|
+
feedback in the case of invalid worker IDs.
|
|
1172
1182
|
|
|
1173
|
-
|
|
1174
|
-
tasks (List[Task]):
|
|
1183
|
+
Args:
|
|
1184
|
+
tasks (List[Task]): Tasks to assign.
|
|
1185
|
+
invalid_ids (List[str], optional): Invalid worker IDs from previous
|
|
1186
|
+
attempt (if any).
|
|
1175
1187
|
|
|
1176
1188
|
Returns:
|
|
1177
|
-
TaskAssignResult: Assignment result
|
|
1178
|
-
with their dependencies.
|
|
1189
|
+
TaskAssignResult: Assignment result from coordinator.
|
|
1179
1190
|
"""
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
# Format tasks information for the prompt
|
|
1191
|
+
# format tasks information for the prompt
|
|
1183
1192
|
tasks_info = ""
|
|
1184
1193
|
for task in tasks:
|
|
1185
1194
|
tasks_info += f"Task ID: {task.id}\n"
|
|
@@ -1188,29 +1197,212 @@ class Workforce(BaseNode):
|
|
|
1188
1197
|
tasks_info += f"Additional Info: {task.additional_info}\n"
|
|
1189
1198
|
tasks_info += "---\n"
|
|
1190
1199
|
|
|
1191
|
-
prompt =
|
|
1192
|
-
|
|
1193
|
-
|
|
1200
|
+
prompt = str(
|
|
1201
|
+
ASSIGN_TASK_PROMPT.format(
|
|
1202
|
+
tasks_info=tasks_info,
|
|
1203
|
+
child_nodes_info=self._get_child_nodes_info(),
|
|
1204
|
+
)
|
|
1194
1205
|
)
|
|
1195
1206
|
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1207
|
+
# add feedback if this is a retry
|
|
1208
|
+
if invalid_ids:
|
|
1209
|
+
valid_worker_ids = list(self._get_valid_worker_ids())
|
|
1210
|
+
feedback = (
|
|
1211
|
+
f"VALIDATION ERROR: The following worker IDs are invalid: "
|
|
1212
|
+
f"{invalid_ids}. "
|
|
1213
|
+
f"VALID WORKER IDS: {valid_worker_ids}. "
|
|
1214
|
+
f"Please reassign ONLY the above tasks using these valid IDs."
|
|
1215
|
+
)
|
|
1216
|
+
prompt = prompt + f"\n\n{feedback}"
|
|
1200
1217
|
|
|
1201
1218
|
response = self.coordinator_agent.step(
|
|
1202
1219
|
prompt, response_format=TaskAssignResult
|
|
1203
1220
|
)
|
|
1221
|
+
|
|
1204
1222
|
if response.msg is None or response.msg.content is None:
|
|
1205
1223
|
logger.error(
|
|
1206
1224
|
"Coordinator agent returned empty response for task assignment"
|
|
1207
1225
|
)
|
|
1208
|
-
# Return empty result as fallback
|
|
1209
1226
|
return TaskAssignResult(assignments=[])
|
|
1210
1227
|
|
|
1211
1228
|
result_dict = json.loads(response.msg.content, parse_int=str)
|
|
1212
|
-
|
|
1213
|
-
|
|
1229
|
+
return TaskAssignResult(**result_dict)
|
|
1230
|
+
|
|
1231
|
+
def _validate_assignments(
|
|
1232
|
+
self, assignments: List[TaskAssignment], valid_ids: Set[str]
|
|
1233
|
+
) -> Tuple[List[TaskAssignment], List[TaskAssignment]]:
|
|
1234
|
+
r"""Validate task assignments against valid worker IDs.
|
|
1235
|
+
|
|
1236
|
+
Args:
|
|
1237
|
+
assignments (List[TaskAssignment]): Assignments to validate.
|
|
1238
|
+
valid_ids (Set[str]): Set of valid worker IDs.
|
|
1239
|
+
|
|
1240
|
+
Returns:
|
|
1241
|
+
Tuple[List[TaskAssignment], List[TaskAssignment]]:
|
|
1242
|
+
(valid_assignments, invalid_assignments)
|
|
1243
|
+
"""
|
|
1244
|
+
valid_assignments: List[TaskAssignment] = []
|
|
1245
|
+
invalid_assignments: List[TaskAssignment] = []
|
|
1246
|
+
|
|
1247
|
+
for assignment in assignments:
|
|
1248
|
+
if assignment.assignee_id in valid_ids:
|
|
1249
|
+
valid_assignments.append(assignment)
|
|
1250
|
+
else:
|
|
1251
|
+
invalid_assignments.append(assignment)
|
|
1252
|
+
|
|
1253
|
+
return valid_assignments, invalid_assignments
|
|
1254
|
+
|
|
1255
|
+
def _handle_task_assignment_fallbacks(self, tasks: List[Task]) -> List:
|
|
1256
|
+
r"""Create new workers for unassigned tasks as fallback.
|
|
1257
|
+
|
|
1258
|
+
Args:
|
|
1259
|
+
tasks (List[Task]): Tasks that need new workers.
|
|
1260
|
+
|
|
1261
|
+
Returns:
|
|
1262
|
+
List[TaskAssignment]: Assignments for newly created workers.
|
|
1263
|
+
"""
|
|
1264
|
+
fallback_assignments = []
|
|
1265
|
+
|
|
1266
|
+
for task in tasks:
|
|
1267
|
+
logger.info(f"Creating new worker for unassigned task {task.id}")
|
|
1268
|
+
new_worker = self._create_worker_node_for_task(task)
|
|
1269
|
+
|
|
1270
|
+
assignment = TaskAssignment(
|
|
1271
|
+
task_id=task.id,
|
|
1272
|
+
assignee_id=new_worker.node_id,
|
|
1273
|
+
dependencies=[],
|
|
1274
|
+
)
|
|
1275
|
+
fallback_assignments.append(assignment)
|
|
1276
|
+
|
|
1277
|
+
return fallback_assignments
|
|
1278
|
+
|
|
1279
|
+
def _handle_assignment_retry_and_fallback(
|
|
1280
|
+
self,
|
|
1281
|
+
invalid_assignments: List[TaskAssignment],
|
|
1282
|
+
tasks: List[Task],
|
|
1283
|
+
valid_worker_ids: Set[str],
|
|
1284
|
+
) -> List[TaskAssignment]:
|
|
1285
|
+
r"""Called if Coordinator agent fails to assign tasks to valid worker
|
|
1286
|
+
IDs. Handles retry assignment and fallback worker creation for invalid
|
|
1287
|
+
assignments.
|
|
1288
|
+
|
|
1289
|
+
Args:
|
|
1290
|
+
invalid_assignments (List[TaskAssignment]): Invalid assignments to
|
|
1291
|
+
retry.
|
|
1292
|
+
tasks (List[Task]): Original tasks list for task lookup.
|
|
1293
|
+
valid_worker_ids (set): Set of valid worker IDs.
|
|
1294
|
+
|
|
1295
|
+
Returns:
|
|
1296
|
+
List[TaskAssignment]: Final assignments for the invalid tasks.
|
|
1297
|
+
"""
|
|
1298
|
+
invalid_ids = [a.assignee_id for a in invalid_assignments]
|
|
1299
|
+
invalid_tasks = [
|
|
1300
|
+
task
|
|
1301
|
+
for task in tasks
|
|
1302
|
+
if any(a.task_id == task.id for a in invalid_assignments)
|
|
1303
|
+
]
|
|
1304
|
+
|
|
1305
|
+
# handle cases where coordinator returned no assignments at all
|
|
1306
|
+
if not invalid_assignments:
|
|
1307
|
+
invalid_tasks = tasks # all tasks need assignment
|
|
1308
|
+
logger.warning(
|
|
1309
|
+
f"Coordinator returned no assignments. "
|
|
1310
|
+
f"Retrying assignment for all {len(invalid_tasks)} tasks."
|
|
1311
|
+
)
|
|
1312
|
+
else:
|
|
1313
|
+
logger.warning(
|
|
1314
|
+
f"Invalid worker IDs detected: {invalid_ids}. "
|
|
1315
|
+
f"Retrying assignment for {len(invalid_tasks)} tasks."
|
|
1316
|
+
)
|
|
1317
|
+
|
|
1318
|
+
# retry assignment with feedback
|
|
1319
|
+
retry_result = self._call_coordinator_for_assignment(
|
|
1320
|
+
invalid_tasks, invalid_ids
|
|
1321
|
+
)
|
|
1322
|
+
final_assignments = []
|
|
1323
|
+
|
|
1324
|
+
if retry_result.assignments:
|
|
1325
|
+
retry_valid, retry_invalid = self._validate_assignments(
|
|
1326
|
+
retry_result.assignments, valid_worker_ids
|
|
1327
|
+
)
|
|
1328
|
+
final_assignments.extend(retry_valid)
|
|
1329
|
+
|
|
1330
|
+
# collect tasks that are still unassigned for fallback
|
|
1331
|
+
if retry_invalid:
|
|
1332
|
+
unassigned_tasks = [
|
|
1333
|
+
task
|
|
1334
|
+
for task in invalid_tasks
|
|
1335
|
+
if any(a.task_id == task.id for a in retry_invalid)
|
|
1336
|
+
]
|
|
1337
|
+
else:
|
|
1338
|
+
unassigned_tasks = []
|
|
1339
|
+
else:
|
|
1340
|
+
# retry failed completely, all invalid tasks need fallback
|
|
1341
|
+
logger.warning("Retry assignment failed")
|
|
1342
|
+
unassigned_tasks = invalid_tasks
|
|
1343
|
+
|
|
1344
|
+
# handle fallback for any remaining unassigned tasks
|
|
1345
|
+
if unassigned_tasks:
|
|
1346
|
+
logger.warning(
|
|
1347
|
+
f"Creating fallback workers for {len(unassigned_tasks)} "
|
|
1348
|
+
f"unassigned tasks"
|
|
1349
|
+
)
|
|
1350
|
+
fallback_assignments = self._handle_task_assignment_fallbacks(
|
|
1351
|
+
unassigned_tasks
|
|
1352
|
+
)
|
|
1353
|
+
final_assignments.extend(fallback_assignments)
|
|
1354
|
+
|
|
1355
|
+
return final_assignments
|
|
1356
|
+
|
|
1357
|
+
def _find_assignee(
|
|
1358
|
+
self,
|
|
1359
|
+
tasks: List[Task],
|
|
1360
|
+
) -> TaskAssignResult:
|
|
1361
|
+
r"""Assigns multiple tasks to worker nodes with the best capabilities.
|
|
1362
|
+
|
|
1363
|
+
Parameters:
|
|
1364
|
+
tasks (List[Task]): The tasks to be assigned.
|
|
1365
|
+
|
|
1366
|
+
Returns:
|
|
1367
|
+
TaskAssignResult: Assignment result containing task assignments
|
|
1368
|
+
with their dependencies.
|
|
1369
|
+
"""
|
|
1370
|
+
self.coordinator_agent.reset()
|
|
1371
|
+
valid_worker_ids = self._get_valid_worker_ids()
|
|
1372
|
+
|
|
1373
|
+
logger.debug(
|
|
1374
|
+
f"Sending batch assignment request to coordinator "
|
|
1375
|
+
f"for {len(tasks)} tasks."
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
assignment_result = self._call_coordinator_for_assignment(tasks)
|
|
1379
|
+
|
|
1380
|
+
# validate assignments
|
|
1381
|
+
valid_assignments, invalid_assignments = self._validate_assignments(
|
|
1382
|
+
assignment_result.assignments, valid_worker_ids
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
# check if we have assignments for all tasks
|
|
1386
|
+
assigned_task_ids = {
|
|
1387
|
+
a.task_id for a in valid_assignments + invalid_assignments
|
|
1388
|
+
}
|
|
1389
|
+
unassigned_tasks = [t for t in tasks if t.id not in assigned_task_ids]
|
|
1390
|
+
|
|
1391
|
+
# if all assignments are valid and all tasks are assigned, return early
|
|
1392
|
+
if not invalid_assignments and not unassigned_tasks:
|
|
1393
|
+
return TaskAssignResult(assignments=valid_assignments)
|
|
1394
|
+
|
|
1395
|
+
# handle retry and fallback for
|
|
1396
|
+
# invalid assignments and unassigned tasks
|
|
1397
|
+
all_problem_assignments = invalid_assignments
|
|
1398
|
+
retry_and_fallback_assignments = (
|
|
1399
|
+
self._handle_assignment_retry_and_fallback(
|
|
1400
|
+
all_problem_assignments, tasks, valid_worker_ids
|
|
1401
|
+
)
|
|
1402
|
+
)
|
|
1403
|
+
valid_assignments.extend(retry_and_fallback_assignments)
|
|
1404
|
+
|
|
1405
|
+
return TaskAssignResult(assignments=valid_assignments)
|
|
1214
1406
|
|
|
1215
1407
|
async def _post_task(self, task: Task, assignee_id: str) -> None:
|
|
1216
1408
|
# Record the start time when a task is posted
|
camel/storages/__init__.py
CHANGED
|
@@ -26,6 +26,7 @@ from .vectordb_storages.base import (
|
|
|
26
26
|
VectorDBQueryResult,
|
|
27
27
|
VectorRecord,
|
|
28
28
|
)
|
|
29
|
+
from .vectordb_storages.chroma import ChromaStorage
|
|
29
30
|
from .vectordb_storages.faiss import FaissStorage
|
|
30
31
|
from .vectordb_storages.milvus import MilvusStorage
|
|
31
32
|
from .vectordb_storages.oceanbase import OceanBaseStorage
|
|
@@ -52,4 +53,5 @@ __all__ = [
|
|
|
52
53
|
'Mem0Storage',
|
|
53
54
|
'OceanBaseStorage',
|
|
54
55
|
'WeaviateStorage',
|
|
56
|
+
'ChromaStorage',
|
|
55
57
|
]
|
|
@@ -19,6 +19,7 @@ from .base import (
|
|
|
19
19
|
VectorDBStatus,
|
|
20
20
|
VectorRecord,
|
|
21
21
|
)
|
|
22
|
+
from .chroma import ChromaStorage
|
|
22
23
|
from .faiss import FaissStorage
|
|
23
24
|
from .milvus import MilvusStorage
|
|
24
25
|
from .oceanbase import OceanBaseStorage
|
|
@@ -30,6 +31,7 @@ __all__ = [
|
|
|
30
31
|
'BaseVectorStorage',
|
|
31
32
|
'VectorDBQuery',
|
|
32
33
|
'VectorDBQueryResult',
|
|
34
|
+
'ChromaStorage',
|
|
33
35
|
'QdrantStorage',
|
|
34
36
|
'MilvusStorage',
|
|
35
37
|
"TiDBStorage",
|