waldiez 0.3.11__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of waldiez might be problematic. Click here for more details.

Files changed (66) hide show
  1. waldiez/_version.py +1 -1
  2. waldiez/cli.py +1 -3
  3. waldiez/exporting/agent/agent_exporter.py +26 -15
  4. waldiez/exporting/agent/utils/__init__.py +2 -4
  5. waldiez/exporting/agent/utils/captain_agent.py +250 -0
  6. waldiez/exporting/agent/utils/swarm_agent.py +12 -7
  7. waldiez/exporting/base/utils/comments.py +1 -0
  8. waldiez/exporting/chats/utils/swarm.py +1 -1
  9. waldiez/exporting/flow/flow_exporter.py +5 -6
  10. waldiez/exporting/flow/utils/__init__.py +3 -6
  11. waldiez/exporting/flow/utils/def_main.py +5 -4
  12. waldiez/exporting/flow/utils/flow_content.py +38 -0
  13. waldiez/exporting/flow/utils/importing_utils.py +64 -29
  14. waldiez/exporting/skills/skills_exporter.py +13 -6
  15. waldiez/exporting/skills/utils.py +92 -6
  16. waldiez/models/__init__.py +6 -0
  17. waldiez/models/agents/__init__.py +14 -0
  18. waldiez/models/agents/agent/__init__.py +2 -1
  19. waldiez/models/agents/agent/agent.py +71 -11
  20. waldiez/models/agents/agent/agent_type.py +11 -0
  21. waldiez/models/agents/agents.py +11 -1
  22. waldiez/models/agents/captain_agent/__init__.py +15 -0
  23. waldiez/models/agents/captain_agent/captain_agent.py +45 -0
  24. waldiez/models/agents/captain_agent/captain_agent_data.py +62 -0
  25. waldiez/models/agents/captain_agent/captain_agent_lib_entry.py +38 -0
  26. waldiez/models/agents/extra_requirements.py +88 -0
  27. waldiez/models/agents/group_manager/speakers.py +3 -0
  28. waldiez/models/agents/rag_user/retrieve_config.py +3 -0
  29. waldiez/models/agents/reasoning/reasoning_agent_reason_config.py +1 -0
  30. waldiez/models/agents/swarm_agent/after_work.py +13 -11
  31. waldiez/models/agents/swarm_agent/on_condition.py +3 -2
  32. waldiez/models/agents/swarm_agent/on_condition_available.py +1 -0
  33. waldiez/models/agents/swarm_agent/swarm_agent_data.py +3 -3
  34. waldiez/models/agents/swarm_agent/update_system_message.py +1 -0
  35. waldiez/models/chat/chat_message.py +1 -0
  36. waldiez/models/chat/chat_summary.py +1 -0
  37. waldiez/models/common/__init__.py +4 -0
  38. waldiez/models/common/ag2_version.py +30 -0
  39. waldiez/models/common/base.py +1 -1
  40. waldiez/models/common/date_utils.py +2 -0
  41. waldiez/models/common/dict_utils.py +2 -0
  42. waldiez/models/common/method_utils.py +98 -0
  43. waldiez/models/flow/__init__.py +2 -0
  44. waldiez/models/flow/utils.py +61 -1
  45. waldiez/models/model/__init__.py +2 -0
  46. waldiez/models/model/extra_requirements.py +57 -0
  47. waldiez/models/model/model.py +5 -2
  48. waldiez/models/model/model_data.py +3 -1
  49. waldiez/models/skill/__init__.py +4 -0
  50. waldiez/models/skill/extra_requirements.py +39 -0
  51. waldiez/models/skill/skill.py +157 -13
  52. waldiez/models/skill/skill_data.py +14 -0
  53. waldiez/models/skill/skill_type.py +8 -0
  54. waldiez/models/waldiez.py +47 -76
  55. waldiez/runner.py +19 -7
  56. waldiez/running/environment.py +30 -1
  57. waldiez/running/running.py +0 -6
  58. waldiez/utils/pysqlite3_checker.py +18 -5
  59. {waldiez-0.3.11.dist-info → waldiez-0.4.0.dist-info}/METADATA +42 -30
  60. {waldiez-0.3.11.dist-info → waldiez-0.4.0.dist-info}/RECORD +64 -55
  61. waldiez/exporting/agent/utils/agent_class_name.py +0 -36
  62. waldiez/exporting/agent/utils/agent_imports.py +0 -55
  63. {waldiez-0.3.11.dist-info → waldiez-0.4.0.dist-info}/WHEEL +0 -0
  64. {waldiez-0.3.11.dist-info → waldiez-0.4.0.dist-info}/entry_points.txt +0 -0
  65. {waldiez-0.3.11.dist-info → waldiez-0.4.0.dist-info}/licenses/LICENSE +0 -0
  66. {waldiez-0.3.11.dist-info → waldiez-0.4.0.dist-info}/licenses/NOTICE.md +0 -0
@@ -36,21 +36,6 @@ COMMON_AUTOGEN_IMPORTS = [
36
36
  ]
37
37
 
38
38
 
39
- def get_standard_imports() -> str:
40
- """Get the standard imports.
41
-
42
- Returns
43
- -------
44
- str
45
- The standard imports.
46
- """
47
- builtin_imports = BUILTIN_IMPORTS.copy()
48
- imports_string = "\n".join(builtin_imports) + "\n"
49
- typing_imports = "from typing import " + ", ".join(TYPING_IMPORTS)
50
- imports_string += typing_imports
51
- return imports_string
52
-
53
-
54
39
  def sort_imports(
55
40
  all_imports: List[Tuple[str, ImportPosition]],
56
41
  ) -> Tuple[List[str], List[str], List[str], List[str], bool]:
@@ -66,10 +51,10 @@ def sort_imports(
66
51
  Tuple[List[str], List[str], List[str], List[str], bool]
67
52
  The sorted imports and a flag if we got `import autogen`.
68
53
  """
69
- builtin_imports = []
70
- third_party_imports = []
71
- local_imports = []
72
- autogen_imports = COMMON_AUTOGEN_IMPORTS.copy()
54
+ builtin_imports: List[str] = []
55
+ third_party_imports: List[str] = []
56
+ local_imports: List[str] = []
57
+ autogen_imports: List[str] = COMMON_AUTOGEN_IMPORTS.copy()
73
58
  got_import_autogen = False
74
59
  for import_string, position in all_imports:
75
60
  if "import autogen" in import_string:
@@ -85,11 +70,22 @@ def sort_imports(
85
70
  elif position == ImportPosition.LOCAL:
86
71
  local_imports.append(import_string)
87
72
  autogen_imports = list(set(autogen_imports))
73
+ third_party_imports = ensure_np_import(third_party_imports)
74
+ sorted_builtins = sorted(
75
+ [imp for imp in builtin_imports if imp.startswith("import ")]
76
+ ) + sorted([imp for imp in builtin_imports if imp.startswith("from ")])
77
+ sorted_third_party = sorted(
78
+ [imp for imp in third_party_imports if imp.startswith("import ")]
79
+ ) + sorted([imp for imp in third_party_imports if imp.startswith("from ")])
80
+ sorted_locals = sorted(
81
+ [imp for imp in local_imports if imp.startswith("import ")]
82
+ ) + sorted([imp for imp in local_imports if imp.startswith("from ")])
83
+
88
84
  return (
89
- sorted(builtin_imports),
85
+ sorted_builtins,
90
86
  sorted(autogen_imports),
91
- sorted(third_party_imports),
92
- sorted(local_imports),
87
+ sorted_third_party,
88
+ sorted_locals,
93
89
  got_import_autogen,
94
90
  )
95
91
 
@@ -150,6 +146,27 @@ def get_the_imports_string(
150
146
  return final_string.replace("\n\n\n", "\n\n") # avoid too many newlines
151
147
 
152
148
 
149
+ def ensure_np_import(third_party_imports: List[str]) -> List[str]:
150
+ """Ensure numpy is imported.
151
+
152
+ Parameters
153
+ ----------
154
+ third_party_imports : List[str]
155
+ The third party imports.
156
+
157
+ Returns
158
+ -------
159
+ List[str]
160
+ The third party imports with numpy.
161
+ """
162
+ if (
163
+ not third_party_imports
164
+ or "import numpy as np" not in third_party_imports
165
+ ):
166
+ third_party_imports.append("import numpy as np")
167
+ return third_party_imports
168
+
169
+
153
170
  def gather_imports(
154
171
  model_imports: Optional[List[Tuple[str, ImportPosition]]],
155
172
  skill_imports: Optional[List[Tuple[str, ImportPosition]]],
@@ -174,13 +191,14 @@ def gather_imports(
174
191
  Tuple[str, ImportPosition]
175
192
  The gathered imports.
176
193
  """
177
- imports_string = get_standard_imports()
178
- all_imports: List[Tuple[str, ImportPosition]] = [
179
- (
180
- imports_string,
181
- ImportPosition.BUILTINS,
194
+ all_imports: List[Tuple[str, ImportPosition]] = []
195
+ for import_statement in BUILTIN_IMPORTS:
196
+ all_imports.append(
197
+ (
198
+ import_statement,
199
+ ImportPosition.BUILTINS,
200
+ )
182
201
  )
183
- ]
184
202
  if model_imports:
185
203
  all_imports.extend(model_imports)
186
204
  if skill_imports:
@@ -189,4 +207,21 @@ def gather_imports(
189
207
  all_imports.extend(chat_imports)
190
208
  if agent_imports:
191
209
  all_imports.extend(agent_imports)
192
- return list(set(all_imports))
210
+ # let's try to avoid this:
211
+ # from typing import Annotated
212
+ # from typing import Annotated, Any, Callable, Dict, ...Union
213
+ all_typing_imports = TYPING_IMPORTS.copy()
214
+ final_imports: List[Tuple[str, ImportPosition]] = []
215
+ for import_statement, import_position in all_imports:
216
+ if import_statement.startswith("from typing"):
217
+ to_import = import_statement.split("import")[1].strip()
218
+ if to_import:
219
+ all_typing_imports.append(to_import)
220
+ else:
221
+ final_imports.append((import_statement, import_position))
222
+ unique_typing_imports = list(set(all_typing_imports))
223
+ one_typing_import = "from typing import " + ", ".join(
224
+ sorted(unique_typing_imports)
225
+ )
226
+ final_imports.insert(1, (one_typing_import, ImportPosition.BUILTINS))
227
+ return list(set(final_imports))
@@ -89,12 +89,18 @@ class SkillsExporter(BaseExporter, ExporterMixin):
89
89
  Tuple[str, int]
90
90
  The exported imports and the position of the imports.
91
91
  """
92
- if not self.skill_imports:
93
- return []
94
92
  imports: List[Tuple[str, ImportPosition]] = []
95
- for skill_import in self.skill_imports:
96
- if (skill_import, ImportPosition.LOCAL) not in imports:
97
- imports.append((skill_import, ImportPosition.LOCAL))
93
+ if not self.skill_imports:
94
+ return imports
95
+ # standard imports
96
+ for import_statement in self.skill_imports[0]:
97
+ imports.append((import_statement, ImportPosition.BUILTINS))
98
+ # third party imports
99
+ for import_statement in self.skill_imports[1]:
100
+ imports.append((import_statement, ImportPosition.THIRD_PARTY))
101
+ # secrets/local imports
102
+ for import_statement in self.skill_imports[2]:
103
+ imports.append((import_statement, ImportPosition.LOCAL))
98
104
  return imports
99
105
 
100
106
  def get_before_export(
@@ -156,11 +162,12 @@ class SkillsExporter(BaseExporter, ExporterMixin):
156
162
  the before export strings, the after export strings,
157
163
  and the environment variables.
158
164
  """
165
+ content = self.generate()
159
166
  imports = self.get_imports()
160
167
  after_export = self.get_after_export()
161
168
  environment_variables = self.get_environment_variables()
162
169
  result: ExporterReturnType = {
163
- "content": self.generate(),
170
+ "content": content,
164
171
  "imports": imports,
165
172
  "before_export": None,
166
173
  "after_export": after_export,
@@ -94,6 +94,7 @@ def _write_skill_secrets(
94
94
  return
95
95
  if not isinstance(output_dir, Path):
96
96
  output_dir = Path(output_dir)
97
+ output_dir.mkdir(parents=True, exist_ok=True)
97
98
  secrets_file = output_dir / f"{flow_name}_{skill_name}_secrets.py"
98
99
  first_line = f'"""Secrets for the skill: {skill_name}."""' + "\n"
99
100
  with secrets_file.open("w", encoding="utf-8", newline="\n") as f:
@@ -108,7 +109,7 @@ def export_skills(
108
109
  skills: List[WaldiezSkill],
109
110
  skill_names: Dict[str, str],
110
111
  output_dir: Optional[Union[str, Path]] = None,
111
- ) -> Tuple[List[str], List[Tuple[str, str]], str]:
112
+ ) -> Tuple[Tuple[List[str], List[str], List[str]], List[Tuple[str, str]], str]:
112
113
  """Get the skills' contents and secrets.
113
114
 
114
115
  If `output_dir` is provided, the contents are saved to that directory.
@@ -126,19 +127,26 @@ def export_skills(
126
127
 
127
128
  Returns
128
129
  -------
129
- Tuple[Set[str], Set[Tuple[str, str]], str]
130
+ Tuple[Tuple[List[str], List[str], List[str]], List[Tuple[str, str]], str]
130
131
  - The skill imports to use in the main file.
131
132
  - The skill secrets to set as environment variables.
132
133
  - The skills contents.
133
134
  """
134
- skill_imports: List[str] = []
135
+ skill_imports: Tuple[List[str], List[str], List[str]] = ([], [], [])
135
136
  skill_secrets: List[Tuple[str, str]] = []
136
137
  skill_contents: str = ""
137
138
  # if the skill.is_shared,
138
139
  # its contents must be first (before the other skills)
139
140
  shared_skill_contents = ""
140
141
  for skill in skills:
141
- skill_imports.append(get_skill_imports(flow_name, skill))
142
+ standard_skill_imports, third_party_skill_imports = skill.get_imports()
143
+ if standard_skill_imports:
144
+ skill_imports[0].extend(standard_skill_imports)
145
+ if third_party_skill_imports:
146
+ skill_imports[1].extend(third_party_skill_imports)
147
+ secrets_import = get_skill_secrets_import(flow_name, skill)
148
+ if secrets_import:
149
+ skill_imports[2].append(secrets_import)
142
150
  for key, value in skill.secrets.items():
143
151
  skill_secrets.append((key, value))
144
152
  _write_skill_secrets(
@@ -153,8 +161,14 @@ def export_skills(
153
161
  if skill.is_shared:
154
162
  shared_skill_contents += skill_content + "\n\n"
155
163
  else:
164
+ if skill.is_interop:
165
+ skill_content += _add_interop_extras(
166
+ skill=skill, skill_names=skill_names
167
+ )
156
168
  skill_contents += skill_content + "\n\n"
157
169
  skill_contents = shared_skill_contents + skill_contents
170
+ # remove dupes from imports if any and sort them
171
+ skill_imports = _sort_imports(skill_imports)
158
172
  return (
159
173
  skill_imports,
160
174
  skill_secrets,
@@ -162,8 +176,77 @@ def export_skills(
162
176
  )
163
177
 
164
178
 
165
- def get_skill_imports(flow_name: str, skill: WaldiezSkill) -> str:
166
- """Get the skill imports string.
179
+ def _add_interop_extras(
180
+ skill: WaldiezSkill,
181
+ skill_names: Dict[str, str],
182
+ ) -> str:
183
+ """Add the interop conversion.
184
+
185
+ Parameters
186
+ ----------
187
+ skill : WaldiezSkill
188
+ The skill
189
+ skill_names : Dict[str, str]
190
+ The skill names.
191
+
192
+ Returns
193
+ -------
194
+ str
195
+ The extra content to convert the tool.
196
+ """
197
+ skill_name = skill_names[skill.id]
198
+ interop_instance = f"ag2_{skill_name}_interop = Interoperability()" + "\n"
199
+ extra_content = (
200
+ f"ag2_{skill_name} = "
201
+ f"ag2_{skill_name}_interop.convert_tool("
202
+ f"tool={skill_name}, "
203
+ f'type="{skill.skill_type}")'
204
+ )
205
+ return "\n" + interop_instance + extra_content
206
+
207
+
208
+ def _sort_imports(
209
+ skill_imports: Tuple[List[str], List[str], List[str]],
210
+ ) -> Tuple[List[str], List[str], List[str]]:
211
+ """Sort the imports.
212
+
213
+ Parameters
214
+ ----------
215
+ skill_imports : Tuple[List[str], List[str], List[str]]
216
+ The skill imports.
217
+
218
+ Returns
219
+ -------
220
+ Tuple[List[str], List[str], List[str]]
221
+ The sorted skill imports.
222
+ """
223
+
224
+ # "from x import y" and "import z"
225
+ # the "import a" should be first (and sorted)
226
+ # then the "from b import c" (and sorted)
227
+ standard_lib_imports = skill_imports[0]
228
+ third_party_imports = skill_imports[1]
229
+ secrets_imports = skill_imports[2]
230
+
231
+ sorted_standard_lib_imports = sorted(
232
+ [imp for imp in standard_lib_imports if imp.startswith("import ")]
233
+ ) + sorted([imp for imp in standard_lib_imports if imp.startswith("from ")])
234
+
235
+ sorted_third_party_imports = sorted(
236
+ [imp for imp in third_party_imports if imp.startswith("import ")]
237
+ ) + sorted([imp for imp in third_party_imports if imp.startswith("from ")])
238
+
239
+ sorted_secrets_imports = sorted(secrets_imports)
240
+
241
+ return (
242
+ sorted_standard_lib_imports,
243
+ sorted_third_party_imports,
244
+ sorted_secrets_imports,
245
+ )
246
+
247
+
248
+ def get_skill_secrets_import(flow_name: str, skill: WaldiezSkill) -> str:
249
+ """Get the skill secrets import string.
167
250
 
168
251
  Parameters
169
252
  ----------
@@ -263,6 +346,9 @@ def get_agent_skill_registrations(
263
346
  skill for skill in all_skills if skill.id == linked_skill.id
264
347
  )
265
348
  skill_name = skill_names[linked_skill.id]
349
+ if waldiez_skill.is_interop:
350
+ # the name of the the converted to ag2 tool
351
+ skill_name = f"ag2_{skill_name}"
266
352
  skill_description = (
267
353
  waldiez_skill.description or f"Description of {skill_name}"
268
354
  )
@@ -23,6 +23,9 @@ from .agents import (
23
23
  WaldiezAgentType,
24
24
  WaldiezAssistant,
25
25
  WaldiezAssistantData,
26
+ WaldiezCaptainAgent,
27
+ WaldiezCaptainAgentData,
28
+ WaldiezCaptainAgentLibEntry,
26
29
  WaldiezGroupManager,
27
30
  WaldiezGroupManagerData,
28
31
  WaldiezGroupManagerSpeakers,
@@ -87,6 +90,9 @@ __all__ = [
87
90
  "WaldiezAgentType",
88
91
  "WaldiezAssistant",
89
92
  "WaldiezAssistantData",
93
+ "WaldiezCaptainAgent",
94
+ "WaldiezCaptainAgentData",
95
+ "WaldiezCaptainAgentLibEntry",
90
96
  "WaldiezChat",
91
97
  "WaldiezChatData",
92
98
  "WaldiezChatSummary",
@@ -18,6 +18,15 @@ from .agent import (
18
18
  )
19
19
  from .agents import WaldiezAgents
20
20
  from .assistant import WaldiezAssistant, WaldiezAssistantData
21
+ from .captain_agent import (
22
+ WaldiezCaptainAgent,
23
+ WaldiezCaptainAgentData,
24
+ WaldiezCaptainAgentLibEntry,
25
+ )
26
+ from .extra_requirements import (
27
+ get_captain_agent_extra_requirements,
28
+ get_retrievechat_extra_requirements,
29
+ )
21
30
  from .group_manager import (
22
31
  CUSTOM_SPEAKER_SELECTION,
23
32
  CUSTOM_SPEAKER_SELECTION_ARGS,
@@ -77,6 +86,8 @@ from .swarm_agent import (
77
86
  from .user_proxy import WaldiezUserProxy, WaldiezUserProxyData
78
87
 
79
88
  __all__ = [
89
+ "get_retrievechat_extra_requirements",
90
+ "get_captain_agent_extra_requirements",
80
91
  "IS_TERMINATION_MESSAGE",
81
92
  "IS_TERMINATION_MESSAGE_ARGS",
82
93
  "IS_TERMINATION_MESSAGE_TYPES",
@@ -113,6 +124,9 @@ __all__ = [
113
124
  "WaldiezAgentNestedChatMessage",
114
125
  "WaldiezAgentTeachability",
115
126
  "WaldiezAgentTerminationMessage",
127
+ "WaldiezCaptainAgent",
128
+ "WaldiezCaptainAgentData",
129
+ "WaldiezCaptainAgentLibEntry",
116
130
  "WaldiezGroupManager",
117
131
  "WaldiezGroupManagerData",
118
132
  "WaldiezGroupManagerSpeakers",
@@ -2,8 +2,9 @@
2
2
  # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
3
  """Base agent class to be inherited by all other agents."""
4
4
 
5
- from .agent import WaldiezAgent, WaldiezAgentType
5
+ from .agent import WaldiezAgent
6
6
  from .agent_data import WaldiezAgentData
7
+ from .agent_type import WaldiezAgentType
7
8
  from .code_execution import WaldiezAgentCodeExecutionConfig
8
9
  from .linked_skill import WaldiezAgentLinkedSkill
9
10
  from .nested_chat import WaldiezAgentNestedChat, WaldiezAgentNestedChatMessage
@@ -2,19 +2,16 @@
2
2
  # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
3
  """Base agent class to be inherited by all agents."""
4
4
 
5
- from typing import List
5
+ from typing import List, Set
6
6
 
7
7
  from pydantic import Field
8
8
  from typing_extensions import Annotated, Literal
9
9
 
10
10
  from ...common import WaldiezBase, now
11
11
  from .agent_data import WaldiezAgentData
12
+ from .agent_type import WaldiezAgentType
12
13
  from .code_execution import WaldiezAgentCodeExecutionConfig
13
14
 
14
- WaldiezAgentType = Literal[
15
- "user", "assistant", "manager", "rag_user", "swarm", "reasoning"
16
- ]
17
-
18
15
 
19
16
  class WaldiezAgent(WaldiezBase):
20
17
  """Waldiez Agent.
@@ -25,9 +22,7 @@ class WaldiezAgent(WaldiezBase):
25
22
  The ID of the agent.
26
23
  type : Literal["agent"]
27
24
  The type of the "node" in a graph: "agent"
28
- agent_type : Literal[
29
- "user", "assistant", "manager", "rag_user", "swarm", "reasoning"
30
- ]
25
+ agent_type : WaldiezAgentType
31
26
  The type of the agent
32
27
  name: str
33
28
  The name of the agent.
@@ -65,9 +60,7 @@ class WaldiezAgent(WaldiezBase):
65
60
  ),
66
61
  ]
67
62
  agent_type: Annotated[
68
- Literal[
69
- "user", "assistant", "manager", "rag_user", "swarm", "reasoning"
70
- ],
63
+ WaldiezAgentType,
71
64
  Field(
72
65
  ...,
73
66
  title="Agent type",
@@ -129,6 +122,73 @@ class WaldiezAgent(WaldiezBase):
129
122
  ),
130
123
  ]
131
124
 
125
+ @property
126
+ def ag2_class(self) -> str:
127
+ """Return the AG2 class of the agent."""
128
+ class_name = "ConversableAgent"
129
+ if self.data.is_multimodal:
130
+ return "MultimodalConversableAgent"
131
+ if self.agent_type == "assistant":
132
+ class_name = "AssistantAgent"
133
+ if self.agent_type == "user":
134
+ class_name = "UserProxyAgent"
135
+ if self.agent_type == "manager":
136
+ class_name = "GroupChatManager"
137
+ if self.agent_type == "rag_user":
138
+ class_name = "RetrieveUserProxyAgent"
139
+ if self.agent_type == "swarm":
140
+ class_name = "SwarmAgent"
141
+ if self.agent_type == "reasoning":
142
+ class_name = "ReasoningAgent"
143
+ if self.agent_type == "captain":
144
+ class_name = "CaptainAgent"
145
+ return class_name
146
+
147
+ @property
148
+ def ag2_imports(self) -> Set[str]:
149
+ """Return the AG2 imports of the agent."""
150
+ agent_class = self.ag2_class
151
+ imports = set(["import autogen"])
152
+ if agent_class == "AssistantAgent":
153
+ imports.add("from autogen import AssistantAgent")
154
+ elif agent_class == "UserProxyAgent":
155
+ imports.add("from autogen import UserProxyAgent")
156
+ elif agent_class == "GroupChatManager":
157
+ imports.add("from autogen import GroupChatManager")
158
+ elif agent_class == "RetrieveUserProxyAgent":
159
+ imports.add(
160
+ "from autogen.agentchat.contrib.retrieve_user_proxy_agent "
161
+ "import RetrieveUserProxyAgent"
162
+ )
163
+ elif agent_class == "MultimodalConversableAgent":
164
+ imports.add(
165
+ "from autogen.agentchat.contrib.multimodal_conversable_agent "
166
+ "import MultimodalConversableAgent"
167
+ )
168
+ elif agent_class == "SwarmAgent":
169
+ imports.add(
170
+ "from autogen import "
171
+ "register_hand_off, "
172
+ "AfterWork, "
173
+ "OnCondition, "
174
+ "UpdateSystemMessage, "
175
+ "AfterWorkOption, "
176
+ "SwarmResult"
177
+ )
178
+ elif agent_class == "ReasoningAgent":
179
+ imports.add(
180
+ "from autogen.agentchat.contrib.reasoning_agent "
181
+ "import ReasoningAgent, visualize_tree"
182
+ )
183
+ elif agent_class == "CaptainAgent":
184
+ imports.add(
185
+ "from autogen.agentchat.contrib.captainagent "
186
+ "import CaptainAgent"
187
+ )
188
+ else: # pragma: no cover
189
+ imports.add("import ConversableAgent")
190
+ return imports
191
+
132
192
  def validate_linked_skills(
133
193
  self, skill_ids: List[str], agent_ids: List[str]
134
194
  ) -> None:
@@ -0,0 +1,11 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Waldiez Agent types."""
4
+
5
+ from typing_extensions import Literal
6
+
7
+ # pylint: disable=line-too-long
8
+ # fmt: off
9
+ WaldiezAgentType = Literal["user", "assistant", "manager", "rag_user", "swarm", "reasoning", "captain"] # noqa: E501
10
+ """Possible types of a Waldiez Agent: user, assistant, manager, rag_user, swarm, reasoning, captain.""" # noqa: E501
11
+ # fmt: on
@@ -10,7 +10,8 @@ from typing_extensions import Annotated, Self
10
10
  from ..common import WaldiezBase
11
11
  from .agent import WaldiezAgent
12
12
  from .assistant import WaldiezAssistant
13
- from .group_manager.group_manager import WaldiezGroupManager
13
+ from .captain_agent import WaldiezCaptainAgent
14
+ from .group_manager import WaldiezGroupManager
14
15
  from .rag_user import WaldiezRagUser
15
16
  from .reasoning import WaldiezReasoningAgent
16
17
  from .swarm_agent import WaldiezSwarmAgent
@@ -80,6 +81,14 @@ class WaldiezAgents(WaldiezBase):
80
81
  default_factory=list,
81
82
  ),
82
83
  ]
84
+ captain_agents: Annotated[
85
+ List[WaldiezCaptainAgent],
86
+ Field(
87
+ title="Captain Agents.",
88
+ description="Captain agents",
89
+ default_factory=list,
90
+ ),
91
+ ]
83
92
 
84
93
  @property
85
94
  def members(self) -> Iterator[WaldiezAgent]:
@@ -96,6 +105,7 @@ class WaldiezAgents(WaldiezBase):
96
105
  yield from self.reasoning_agents
97
106
  yield from self.swarm_agents
98
107
  yield from self.managers
108
+ yield from self.captain_agents
99
109
 
100
110
  @model_validator(mode="after")
101
111
  def validate_agents(self) -> Self:
@@ -0,0 +1,15 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Captain agent model."""
4
+
5
+ from .captain_agent import WaldiezCaptainAgent
6
+ from .captain_agent_data import (
7
+ WaldiezCaptainAgentData,
8
+ )
9
+ from .captain_agent_lib_entry import WaldiezCaptainAgentLibEntry
10
+
11
+ __all__ = [
12
+ "WaldiezCaptainAgentData",
13
+ "WaldiezCaptainAgent",
14
+ "WaldiezCaptainAgentLibEntry",
15
+ ]
@@ -0,0 +1,45 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Waldiez captain agent model."""
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import Field
8
+ from typing_extensions import Annotated
9
+
10
+ from ..agent import WaldiezAgent
11
+ from .captain_agent_data import WaldiezCaptainAgentData
12
+
13
+
14
+ class WaldiezCaptainAgent(WaldiezAgent):
15
+ """Captain agent.
16
+
17
+ A `WaldiezAgent` with agent_type `captain` and
18
+ captain agent's related config for the agent.
19
+ Also see `WaldiezAgent`, `WaldiezCaptainData`, `WaldiezAgentData`
20
+
21
+ Attributes
22
+ ----------
23
+ agent_type : Literal["captain"]
24
+ The agent type: 'captain' for a captain agent
25
+ data : WaldiezCaptainAgentData
26
+ The captain agent's data.
27
+ """
28
+
29
+ agent_type: Annotated[
30
+ Literal["captain"],
31
+ Field(
32
+ "captain",
33
+ title="Agent type",
34
+ description="The agent type: 'captain' for a captain agent",
35
+ alias="agentType",
36
+ ),
37
+ ]
38
+ data: Annotated[
39
+ WaldiezCaptainAgentData,
40
+ Field(
41
+ title="Data",
42
+ description="The captain agent's data",
43
+ default_factory=WaldiezCaptainAgentData,
44
+ ),
45
+ ]
@@ -0,0 +1,62 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Waldiez captain agent data."""
4
+
5
+ from typing import List, Optional
6
+
7
+ from pydantic import Field
8
+ from typing_extensions import Annotated, Literal
9
+
10
+ from ..agent import WaldiezAgentData
11
+ from .captain_agent_lib_entry import WaldiezCaptainAgentLibEntry
12
+
13
+
14
+ class WaldiezCaptainAgentData(WaldiezAgentData):
15
+ """Captain agent data class.
16
+
17
+ The data for a captain agent.
18
+ Extends `WaldiezAgentData`.
19
+ Extra attributes:
20
+ - `agent_lib`: Optional list of agent lib entries
21
+ - `tool_lib`:
22
+ - `max_round`: The maximum number of rounds in a group chat
23
+ - `max_turns`: The maximum number of turns for a chat
24
+ See the parent's docs (`WaldiezAgentData`) for the rest of the properties.
25
+ """
26
+
27
+ agent_lib: Annotated[
28
+ List[WaldiezCaptainAgentLibEntry],
29
+ Field(
30
+ default_factory=list,
31
+ title="Agent lib",
32
+ description="The agent lib",
33
+ alias="agentLib",
34
+ ),
35
+ ] = []
36
+ tool_lib: Annotated[
37
+ Optional[Literal["default"]],
38
+ Field(
39
+ None,
40
+ title="Tool lib",
41
+ description="Whether to use the default tool lib",
42
+ alias="toolLib",
43
+ ),
44
+ ] = None
45
+ max_round: Annotated[
46
+ int,
47
+ Field(
48
+ 10,
49
+ title="Max round",
50
+ description="The maximum number of rounds in a group chat",
51
+ alias="maxRound",
52
+ ),
53
+ ] = 10
54
+ max_turns: Annotated[
55
+ int,
56
+ Field(
57
+ 5,
58
+ title="Max turns",
59
+ description="The maximum number of turns for a chat",
60
+ alias="maxTurns",
61
+ ),
62
+ ] = 5