camel-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (99) hide show
  1. camel/__init__.py +1 -11
  2. camel/agents/__init__.py +5 -5
  3. camel/agents/chat_agent.py +124 -63
  4. camel/agents/critic_agent.py +28 -17
  5. camel/agents/deductive_reasoner_agent.py +235 -0
  6. camel/agents/embodied_agent.py +92 -40
  7. camel/agents/role_assignment_agent.py +27 -17
  8. camel/agents/task_agent.py +60 -34
  9. camel/agents/tool_agents/base.py +0 -1
  10. camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
  11. camel/configs.py +119 -7
  12. camel/embeddings/__init__.py +2 -0
  13. camel/embeddings/base.py +3 -2
  14. camel/embeddings/openai_embedding.py +3 -3
  15. camel/embeddings/sentence_transformers_embeddings.py +65 -0
  16. camel/functions/__init__.py +13 -3
  17. camel/functions/google_maps_function.py +335 -0
  18. camel/functions/math_functions.py +7 -7
  19. camel/functions/openai_function.py +344 -42
  20. camel/functions/search_functions.py +100 -35
  21. camel/functions/twitter_function.py +484 -0
  22. camel/functions/weather_functions.py +36 -23
  23. camel/generators.py +65 -46
  24. camel/human.py +17 -11
  25. camel/interpreters/__init__.py +25 -0
  26. camel/interpreters/base.py +49 -0
  27. camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
  28. camel/interpreters/interpreter_error.py +19 -0
  29. camel/interpreters/subprocess_interpreter.py +190 -0
  30. camel/loaders/__init__.py +22 -0
  31. camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
  32. camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
  33. camel/memories/__init__.py +17 -7
  34. camel/memories/agent_memories.py +156 -0
  35. camel/memories/base.py +97 -32
  36. camel/memories/blocks/__init__.py +21 -0
  37. camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
  38. camel/memories/blocks/vectordb_block.py +101 -0
  39. camel/memories/context_creators/__init__.py +3 -2
  40. camel/memories/context_creators/score_based.py +32 -20
  41. camel/memories/records.py +6 -5
  42. camel/messages/__init__.py +2 -2
  43. camel/messages/base.py +99 -16
  44. camel/messages/func_message.py +7 -4
  45. camel/models/__init__.py +4 -2
  46. camel/models/anthropic_model.py +132 -0
  47. camel/models/base_model.py +3 -2
  48. camel/models/model_factory.py +10 -8
  49. camel/models/open_source_model.py +25 -13
  50. camel/models/openai_model.py +9 -10
  51. camel/models/stub_model.py +6 -5
  52. camel/prompts/__init__.py +7 -5
  53. camel/prompts/ai_society.py +21 -14
  54. camel/prompts/base.py +54 -47
  55. camel/prompts/code.py +22 -14
  56. camel/prompts/evaluation.py +8 -5
  57. camel/prompts/misalignment.py +26 -19
  58. camel/prompts/object_recognition.py +35 -0
  59. camel/prompts/prompt_templates.py +14 -8
  60. camel/prompts/role_description_prompt_template.py +16 -10
  61. camel/prompts/solution_extraction.py +9 -5
  62. camel/prompts/task_prompt_template.py +24 -21
  63. camel/prompts/translation.py +9 -5
  64. camel/responses/agent_responses.py +5 -2
  65. camel/retrievers/__init__.py +24 -0
  66. camel/retrievers/auto_retriever.py +319 -0
  67. camel/retrievers/base.py +64 -0
  68. camel/retrievers/bm25_retriever.py +149 -0
  69. camel/retrievers/vector_retriever.py +166 -0
  70. camel/societies/__init__.py +1 -1
  71. camel/societies/babyagi_playing.py +56 -32
  72. camel/societies/role_playing.py +188 -133
  73. camel/storages/__init__.py +18 -0
  74. camel/storages/graph_storages/__init__.py +23 -0
  75. camel/storages/graph_storages/base.py +82 -0
  76. camel/storages/graph_storages/graph_element.py +74 -0
  77. camel/storages/graph_storages/neo4j_graph.py +582 -0
  78. camel/storages/key_value_storages/base.py +1 -2
  79. camel/storages/key_value_storages/in_memory.py +1 -2
  80. camel/storages/key_value_storages/json.py +8 -13
  81. camel/storages/vectordb_storages/__init__.py +33 -0
  82. camel/storages/vectordb_storages/base.py +202 -0
  83. camel/storages/vectordb_storages/milvus.py +396 -0
  84. camel/storages/vectordb_storages/qdrant.py +371 -0
  85. camel/terminators/__init__.py +1 -1
  86. camel/terminators/base.py +2 -3
  87. camel/terminators/response_terminator.py +21 -12
  88. camel/terminators/token_limit_terminator.py +5 -3
  89. camel/types/__init__.py +12 -6
  90. camel/types/enums.py +86 -13
  91. camel/types/openai_types.py +10 -5
  92. camel/utils/__init__.py +18 -13
  93. camel/utils/commons.py +242 -81
  94. camel/utils/token_counting.py +135 -15
  95. {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/METADATA +116 -74
  96. camel_ai-0.1.3.dist-info/RECORD +101 -0
  97. {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/WHEEL +1 -1
  98. camel/memories/context_creators/base.py +0 -72
  99. camel_ai-0.1.1.dist-info/RECORD +0 -75
camel/__init__.py CHANGED
@@ -11,18 +11,8 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
- import camel.agents
15
- import camel.configs
16
- import camel.generators
17
- import camel.messages
18
- import camel.prompts
19
- import camel.types
20
- import camel.utils
21
- import camel.functions
22
- import camel.memories
23
- import camel.storages
24
14
 
25
- __version__ = '0.1.1'
15
+ __version__ = '0.1.3'
26
16
 
27
17
  __all__ = [
28
18
  '__version__',
camel/agents/__init__.py CHANGED
@@ -13,17 +13,17 @@
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
14
  from .base import BaseAgent
15
15
  from .chat_agent import ChatAgent
16
+ from .critic_agent import CriticAgent
17
+ from .embodied_agent import EmbodiedAgent
18
+ from .role_assignment_agent import RoleAssignmentAgent
16
19
  from .task_agent import (
17
- TaskSpecifyAgent,
18
- TaskPlannerAgent,
19
20
  TaskCreationAgent,
21
+ TaskPlannerAgent,
20
22
  TaskPrioritizationAgent,
23
+ TaskSpecifyAgent,
21
24
  )
22
- from .critic_agent import CriticAgent
23
25
  from .tool_agents.base import BaseToolAgent
24
26
  from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
25
- from .embodied_agent import EmbodiedAgent
26
- from .role_assignment_agent import RoleAssignmentAgent
27
27
 
28
28
  __all__ = [
29
29
  'BaseAgent',
@@ -11,18 +11,17 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from __future__ import annotations
15
+
14
16
  import json
15
17
  from collections import defaultdict
16
18
  from dataclasses import dataclass
17
- from typing import Any, Callable, Dict, List, Optional, Tuple
18
-
19
- from openai import Stream
19
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
20
20
 
21
- from camel.agents import BaseAgent
22
- from camel.configs import BaseConfig, ChatGPTConfig
23
- from camel.functions import OpenAIFunction
21
+ from camel.agents.base import BaseAgent
22
+ from camel.configs import ChatGPTConfig, ChatGPTVisionConfig
24
23
  from camel.memories import (
25
- BaseMemory,
24
+ AgentMemory,
26
25
  ChatHistoryMemory,
27
26
  MemoryRecord,
28
27
  ScoreBasedContextCreator,
@@ -30,7 +29,6 @@ from camel.memories import (
30
29
  from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
31
30
  from camel.models import BaseModelBackend, ModelFactory
32
31
  from camel.responses import ChatAgentResponse
33
- from camel.terminators import ResponseTerminator
34
32
  from camel.types import (
35
33
  ChatCompletion,
36
34
  ChatCompletionChunk,
@@ -40,6 +38,13 @@ from camel.types import (
40
38
  )
41
39
  from camel.utils import get_model_encoding
42
40
 
41
+ if TYPE_CHECKING:
42
+ from openai import Stream
43
+
44
+ from camel.configs import BaseConfig
45
+ from camel.functions import OpenAIFunction
46
+ from camel.terminators import ResponseTerminator
47
+
43
48
 
44
49
  @dataclass(frozen=True)
45
50
  class FunctionCallingRecord:
@@ -51,6 +56,7 @@ class FunctionCallingRecord:
51
56
  the function.
52
57
  result (Any): The execution result of calling this function.
53
58
  """
59
+
54
60
  func_name: str
55
61
  args: Dict[str, Any]
56
62
  result: Any
@@ -62,9 +68,11 @@ class FunctionCallingRecord:
62
68
  str: Modified string to represent the function calling.
63
69
  """
64
70
 
65
- return (f"Function Execution: {self.func_name}\n"
66
- f"\tArgs: {self.args}\n"
67
- f"\tResult: {self.result}")
71
+ return (
72
+ f"Function Execution: {self.func_name}\n"
73
+ f"\tArgs: {self.args}\n"
74
+ f"\tResult: {self.result}"
75
+ )
68
76
 
69
77
 
70
78
  class ChatAgent(BaseAgent):
@@ -76,13 +84,13 @@ class ChatAgent(BaseAgent):
76
84
  responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
77
85
  model_config (BaseConfig, optional): Configuration options for the
78
86
  LLM model. (default: :obj:`None`)
79
- memory (BaseMemory, optional): The agent memory for managing chat
87
+ memory (AgentMemory, optional): The agent memory for managing chat
80
88
  messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
81
89
  (default: :obj:`None`)
82
90
  message_window_size (int, optional): The maximum number of previous
83
91
  messages to include in the context window. If `None`, no windowing
84
92
  is performed. (default: :obj:`None`)
85
- token_limit (int, optional): The maxinum number of tokens in a context.
93
+ token_limit (int, optional): The maximum number of tokens in a context.
86
94
  The context will be automatically pruned to fulfill the limitation.
87
95
  If `None`, it will be set according to the backend model.
88
96
  (default: :obj:`None`)
@@ -100,14 +108,13 @@ class ChatAgent(BaseAgent):
100
108
  system_message: BaseMessage,
101
109
  model_type: Optional[ModelType] = None,
102
110
  model_config: Optional[BaseConfig] = None,
103
- memory: Optional[BaseMemory] = None,
111
+ memory: Optional[AgentMemory] = None,
104
112
  message_window_size: Optional[int] = None,
105
113
  token_limit: Optional[int] = None,
106
114
  output_language: Optional[str] = None,
107
115
  function_list: Optional[List[OpenAIFunction]] = None,
108
116
  response_terminators: Optional[List[ResponseTerminator]] = None,
109
117
  ) -> None:
110
-
111
118
  self.orig_sys_message: BaseMessage = system_message
112
119
  self.system_message = system_message
113
120
  self.role_name: str = system_message.role_name
@@ -116,24 +123,48 @@ class ChatAgent(BaseAgent):
116
123
  if self.output_language is not None:
117
124
  self.set_output_language(self.output_language)
118
125
 
119
- self.model_type: ModelType = (model_type if model_type is not None else
120
- ModelType.GPT_3_5_TURBO)
126
+ self.model_type: ModelType = (
127
+ model_type if model_type is not None else ModelType.GPT_3_5_TURBO
128
+ )
121
129
 
122
130
  self.func_dict: Dict[str, Callable] = {}
123
131
  if function_list is not None:
124
132
  for func in function_list:
125
- self.func_dict[func.name] = func.func
126
- self.model_config = model_config or ChatGPTConfig()
133
+ self.func_dict[func.get_function_name()] = func.func
134
+
135
+ self.model_config: BaseConfig
136
+ if self.model_type == ModelType.GPT_4_TURBO_VISION:
137
+ if model_config is not None and not isinstance(
138
+ model_config, ChatGPTVisionConfig
139
+ ):
140
+ raise ValueError(
141
+ "Please use `ChatGPTVisionConfig` as "
142
+ "the `model_config` when `model_type` "
143
+ "is `GPT_4_TURBO_VISION`"
144
+ )
145
+ self.model_config = model_config or ChatGPTVisionConfig()
146
+ else:
147
+ if model_config is not None and isinstance(
148
+ model_config, ChatGPTVisionConfig
149
+ ):
150
+ raise ValueError(
151
+ "Please don't use `ChatGPTVisionConfig` as "
152
+ "the `model_config` when `model_type` "
153
+ "is not `GPT_4_TURBO_VISION`"
154
+ )
155
+ self.model_config = model_config or ChatGPTConfig()
127
156
 
128
157
  self.model_backend: BaseModelBackend = ModelFactory.create(
129
- self.model_type, self.model_config.__dict__)
158
+ self.model_type, self.model_config.__dict__
159
+ )
130
160
  self.model_token_limit = token_limit or self.model_backend.token_limit
131
161
  context_creator = ScoreBasedContextCreator(
132
162
  self.model_backend.token_counter,
133
163
  self.model_token_limit,
134
164
  )
135
- self.memory: BaseMemory = memory or ChatHistoryMemory(
136
- context_creator, window_size=message_window_size)
165
+ self.memory: AgentMemory = memory or ChatHistoryMemory(
166
+ context_creator, window_size=message_window_size
167
+ )
137
168
 
138
169
  self.terminated: bool = False
139
170
  self.response_terminators = response_terminators or []
@@ -180,8 +211,9 @@ class ChatAgent(BaseAgent):
180
211
  """
181
212
  return len(self.func_dict) > 0
182
213
 
183
- def update_memory(self, message: BaseMessage,
184
- role: OpenAIBackendRole) -> None:
214
+ def update_memory(
215
+ self, message: BaseMessage, role: OpenAIBackendRole
216
+ ) -> None:
185
217
  r"""Updates the agent memory with a new message.
186
218
 
187
219
  Args:
@@ -204,15 +236,21 @@ class ChatAgent(BaseAgent):
204
236
  BaseMessage: The updated system message object.
205
237
  """
206
238
  self.output_language = output_language
207
- content = (self.orig_sys_message.content +
208
- ("\nRegardless of the input language, "
209
- f"you must output text in {output_language}."))
239
+ content = self.orig_sys_message.content + (
240
+ "\nRegardless of the input language, "
241
+ f"you must output text in {output_language}."
242
+ )
210
243
  self.system_message = self.system_message.create_new_instance(content)
211
244
  return self.system_message
212
245
 
213
- def get_info(self, id: Optional[str], usage: Optional[Dict[str, int]],
214
- termination_reasons: List[str], num_tokens: int,
215
- called_funcs: List[FunctionCallingRecord]) -> Dict[str, Any]:
246
+ def get_info(
247
+ self,
248
+ id: Optional[str],
249
+ usage: Optional[Dict[str, int]],
250
+ termination_reasons: List[str],
251
+ num_tokens: int,
252
+ called_funcs: List[FunctionCallingRecord],
253
+ ) -> Dict[str, Any]:
216
254
  r"""Returns a dictionary containing information about the chat session.
217
255
 
218
256
  Args:
@@ -241,8 +279,9 @@ class ChatAgent(BaseAgent):
241
279
  r"""Initializes the stored messages list with the initial system
242
280
  message.
243
281
  """
244
- system_record = MemoryRecord(self.system_message,
245
- OpenAIBackendRole.SYSTEM)
282
+ system_record = MemoryRecord(
283
+ self.system_message, OpenAIBackendRole.SYSTEM
284
+ )
246
285
  self.memory.clear()
247
286
  self.memory.write_record(system_record)
248
287
 
@@ -287,29 +326,36 @@ class ChatAgent(BaseAgent):
287
326
  try:
288
327
  openai_messages, num_tokens = self.memory.get_context()
289
328
  except RuntimeError as e:
290
- return self.step_token_exceed(e.args[1], called_funcs,
291
- "max_tokens_exceeded")
329
+ return self.step_token_exceed(
330
+ e.args[1], called_funcs, "max_tokens_exceeded"
331
+ )
292
332
 
293
333
  # Obtain the model's response
294
334
  response = self.model_backend.run(openai_messages)
295
335
 
296
336
  if isinstance(response, ChatCompletion):
297
337
  output_messages, finish_reasons, usage_dict, response_id = (
298
- self.handle_batch_response(response))
338
+ self.handle_batch_response(response)
339
+ )
299
340
  else:
300
341
  output_messages, finish_reasons, usage_dict, response_id = (
301
- self.handle_stream_response(response, num_tokens))
342
+ self.handle_stream_response(response, num_tokens)
343
+ )
302
344
 
303
- if (self.is_function_calling_enabled()
304
- and finish_reasons[0] == 'function_call'
305
- and isinstance(response, ChatCompletion)):
345
+ if (
346
+ self.is_function_calling_enabled()
347
+ and finish_reasons[0] == 'function_call'
348
+ and isinstance(response, ChatCompletion)
349
+ ):
306
350
  # Do function calling
307
351
  func_assistant_msg, func_result_msg, func_record = (
308
- self.step_function_call(response))
352
+ self.step_function_call(response)
353
+ )
309
354
 
310
355
  # Update the messages
311
- self.update_memory(func_assistant_msg,
312
- OpenAIBackendRole.ASSISTANT)
356
+ self.update_memory(
357
+ func_assistant_msg, OpenAIBackendRole.ASSISTANT
358
+ )
313
359
  self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
314
360
 
315
361
  # Record the function calling
@@ -326,9 +372,13 @@ class ChatAgent(BaseAgent):
326
372
  ]
327
373
  # Terminate the agent if any of the terminator terminates
328
374
  self.terminated, termination_reason = next(
329
- ((terminated, termination_reason)
330
- for terminated, termination_reason in termination
331
- if terminated), (False, None))
375
+ (
376
+ (terminated, termination_reason)
377
+ for terminated, termination_reason in termination
378
+ if terminated
379
+ ),
380
+ (False, None),
381
+ )
332
382
  # For now only retain the first termination reason
333
383
  if self.terminated and termination_reason is not None:
334
384
  finish_reasons = [termination_reason] * len(finish_reasons)
@@ -368,8 +418,9 @@ class ChatAgent(BaseAgent):
368
418
  finish_reasons = [
369
419
  str(choice.finish_reason) for choice in response.choices
370
420
  ]
371
- usage = (response.usage.model_dump()
372
- if response.usage is not None else {})
421
+ usage = (
422
+ response.usage.model_dump() if response.usage is not None else {}
423
+ )
373
424
  return (
374
425
  output_messages,
375
426
  finish_reasons,
@@ -408,10 +459,12 @@ class ChatAgent(BaseAgent):
408
459
  content_dict[index] += delta.content
409
460
  else:
410
461
  finish_reasons_dict[index] = choice.finish_reason
411
- chat_message = BaseMessage(role_name=self.role_name,
412
- role_type=self.role_type,
413
- meta_dict=dict(),
414
- content=content_dict[index])
462
+ chat_message = BaseMessage(
463
+ role_name=self.role_name,
464
+ role_type=self.role_type,
465
+ meta_dict=dict(),
466
+ content=content_dict[index],
467
+ )
415
468
  output_messages.append(chat_message)
416
469
  finish_reasons = [
417
470
  finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
@@ -419,9 +472,12 @@ class ChatAgent(BaseAgent):
419
472
  usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
420
473
  return output_messages, finish_reasons, usage_dict, response_id
421
474
 
422
- def step_token_exceed(self, num_tokens: int,
423
- called_funcs: List[FunctionCallingRecord],
424
- termination_reason: str) -> ChatAgentResponse:
475
+ def step_token_exceed(
476
+ self,
477
+ num_tokens: int,
478
+ called_funcs: List[FunctionCallingRecord],
479
+ termination_reason: str,
480
+ ) -> ChatAgentResponse:
425
481
  r"""Return trivial response containing number of tokens and information
426
482
  of called functions when the number of tokens exceeds.
427
483
 
@@ -455,8 +511,9 @@ class ChatAgent(BaseAgent):
455
511
  def step_function_call(
456
512
  self,
457
513
  response: ChatCompletion,
458
- ) -> Tuple[FunctionCallingMessage, FunctionCallingMessage,
459
- FunctionCallingRecord]:
514
+ ) -> Tuple[
515
+ FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
516
+ ]:
460
517
  r"""Execute the function with arguments following the model's response.
461
518
 
462
519
  Args:
@@ -477,7 +534,7 @@ class ChatAgent(BaseAgent):
477
534
  func = self.func_dict[func_name]
478
535
 
479
536
  args_str: str = choice.message.function_call.arguments
480
- args = json.loads(args_str.replace("\'", "\""))
537
+ args = json.loads(args_str.replace("'", "\""))
481
538
 
482
539
  # Pass the extracted arguments to the indicated function
483
540
  try:
@@ -485,7 +542,8 @@ class ChatAgent(BaseAgent):
485
542
  except Exception:
486
543
  raise ValueError(
487
544
  f"Execution of function {func.__name__} failed with "
488
- f"arguments being {args}.")
545
+ f"arguments being {args}."
546
+ )
489
547
 
490
548
  assist_msg = FunctionCallingMessage(
491
549
  role_name=self.role_name,
@@ -508,8 +566,9 @@ class ChatAgent(BaseAgent):
508
566
  func_record = FunctionCallingRecord(func_name, args, result)
509
567
  return assist_msg, func_msg, func_record
510
568
 
511
- def get_usage_dict(self, output_messages: List[BaseMessage],
512
- prompt_tokens: int) -> Dict[str, int]:
569
+ def get_usage_dict(
570
+ self, output_messages: List[BaseMessage], prompt_tokens: int
571
+ ) -> Dict[str, int]:
513
572
  r"""Get usage dictionary when using the stream mode.
514
573
 
515
574
  Args:
@@ -523,9 +582,11 @@ class ChatAgent(BaseAgent):
523
582
  completion_tokens = 0
524
583
  for message in output_messages:
525
584
  completion_tokens += len(encoding.encode(message.content))
526
- usage_dict = dict(completion_tokens=completion_tokens,
527
- prompt_tokens=prompt_tokens,
528
- total_tokens=completion_tokens + prompt_tokens)
585
+ usage_dict = dict(
586
+ completion_tokens=completion_tokens,
587
+ prompt_tokens=prompt_tokens,
588
+ total_tokens=completion_tokens + prompt_tokens,
589
+ )
529
590
  return usage_dict
530
591
 
531
592
  def __repr__(self) -> str:
@@ -17,8 +17,8 @@ from typing import Any, Dict, Optional, Sequence
17
17
 
18
18
  from colorama import Fore
19
19
 
20
- from camel.agents import ChatAgent
21
- from camel.memories import BaseMemory
20
+ from camel.agents.chat_agent import ChatAgent
21
+ from camel.memories import AgentMemory
22
22
  from camel.messages import BaseMessage
23
23
  from camel.responses import ChatAgentResponse
24
24
  from camel.types import ModelType
@@ -50,15 +50,19 @@ class CriticAgent(ChatAgent):
50
50
  system_message: BaseMessage,
51
51
  model_type: ModelType = ModelType.GPT_3_5_TURBO,
52
52
  model_config: Optional[Any] = None,
53
- memory: Optional[BaseMemory] = None,
53
+ memory: Optional[AgentMemory] = None,
54
54
  message_window_size: int = 6,
55
55
  retry_attempts: int = 2,
56
56
  verbose: bool = False,
57
57
  logger_color: Any = Fore.MAGENTA,
58
58
  ) -> None:
59
- super().__init__(system_message, model_type=model_type,
60
- model_config=model_config, memory=memory,
61
- message_window_size=message_window_size)
59
+ super().__init__(
60
+ system_message,
61
+ model_type=model_type,
62
+ model_config=model_config,
63
+ memory=memory,
64
+ message_window_size=message_window_size,
65
+ )
62
66
  self.options_dict: Dict[str, str] = dict()
63
67
  self.retry_attempts = retry_attempts
64
68
  self.verbose = verbose
@@ -77,13 +81,15 @@ class CriticAgent(ChatAgent):
77
81
  flatten_options = (
78
82
  f"> Proposals from "
79
83
  f"{messages[0].role_name} ({messages[0].role_type}). "
80
- "Please choose an option:\n")
84
+ "Please choose an option:\n"
85
+ )
81
86
  for index, option in enumerate(options):
82
87
  flatten_options += f"Option {index + 1}:\n{option}\n\n"
83
88
  self.options_dict[str(index + 1)] = option
84
89
  format = (
85
90
  f"Please first enter your choice ([1-{len(self.options_dict)}]) "
86
- "and then your explanation and comparison: ")
91
+ "and then your explanation and comparison: "
92
+ )
87
93
  return flatten_options + format
88
94
 
89
95
  def get_option(self, input_message: BaseMessage) -> str:
@@ -110,8 +116,10 @@ class CriticAgent(ChatAgent):
110
116
  critic_msg = critic_response.msg
111
117
  self.record_message(critic_msg)
112
118
  if self.verbose:
113
- print_text_animated(self.logger_color + "\n> Critic response: "
114
- f"\x1b[3m{critic_msg.content}\x1b[0m\n")
119
+ print_text_animated(
120
+ self.logger_color + "\n> Critic response: "
121
+ f"\x1b[3m{critic_msg.content}\x1b[0m\n"
122
+ )
115
123
  choice = self.parse_critic(critic_msg)
116
124
 
117
125
  if choice in self.options_dict:
@@ -121,13 +129,15 @@ class CriticAgent(ChatAgent):
121
129
  role_name=input_message.role_name,
122
130
  role_type=input_message.role_type,
123
131
  meta_dict=input_message.meta_dict,
124
- content="> Invalid choice. Please choose again.\n" +
125
- msg_content,
132
+ content="> Invalid choice. Please choose again.\n"
133
+ + msg_content,
126
134
  )
127
135
  i += 1
128
- warnings.warn("Critic failed to get a valid option. "
129
- f"After {self.retry_attempts} attempts. "
130
- "Returning a random option.")
136
+ warnings.warn(
137
+ "Critic failed to get a valid option. "
138
+ f"After {self.retry_attempts} attempts. "
139
+ "Returning a random option."
140
+ )
131
141
  return random.choice(list(self.options_dict.values()))
132
142
 
133
143
  def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
@@ -168,8 +178,9 @@ class CriticAgent(ChatAgent):
168
178
 
169
179
  flatten_options = self.flatten_options(input_messages)
170
180
  if self.verbose:
171
- print_text_animated(self.logger_color +
172
- f"\x1b[3m{flatten_options}\x1b[0m\n")
181
+ print_text_animated(
182
+ self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
183
+ )
173
184
  input_msg = meta_chat_message.create_new_instance(flatten_options)
174
185
 
175
186
  option = self.get_option(input_msg)