waldiez 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of waldiez might be problematic. Click here for more details.

Files changed (138) hide show
  1. waldiez/__init__.py +2 -0
  2. waldiez/__main__.py +2 -0
  3. waldiez/_version.py +3 -1
  4. waldiez/cli.py +13 -3
  5. waldiez/cli_extras.py +4 -3
  6. waldiez/conflict_checker.py +4 -3
  7. waldiez/exporter.py +28 -105
  8. waldiez/exporting/__init__.py +8 -9
  9. waldiez/exporting/agent/__init__.py +7 -0
  10. waldiez/exporting/agent/agent_exporter.py +279 -0
  11. waldiez/exporting/agent/utils/__init__.py +23 -0
  12. waldiez/exporting/agent/utils/agent_class_name.py +34 -0
  13. waldiez/exporting/agent/utils/agent_imports.py +50 -0
  14. waldiez/exporting/{agents → agent/utils}/code_execution.py +9 -11
  15. waldiez/exporting/{agents → agent/utils}/group_manager.py +47 -35
  16. waldiez/exporting/{agents → agent/utils}/rag_user/__init__.py +2 -0
  17. waldiez/exporting/{agents → agent/utils}/rag_user/chroma_utils.py +22 -17
  18. waldiez/exporting/{agents → agent/utils}/rag_user/mongo_utils.py +14 -10
  19. waldiez/exporting/{agents → agent/utils}/rag_user/pgvector_utils.py +12 -8
  20. waldiez/exporting/{agents → agent/utils}/rag_user/qdrant_utils.py +11 -8
  21. waldiez/exporting/{agents → agent/utils}/rag_user/rag_user.py +78 -55
  22. waldiez/exporting/{agents → agent/utils}/rag_user/vector_db.py +10 -8
  23. waldiez/exporting/agent/utils/swarm_agent.py +463 -0
  24. waldiez/exporting/{agents → agent/utils}/teachability.py +10 -6
  25. waldiez/exporting/{agents → agent/utils}/termination_message.py +7 -8
  26. waldiez/exporting/base/__init__.py +25 -0
  27. waldiez/exporting/base/agent_position.py +75 -0
  28. waldiez/exporting/base/base_exporter.py +118 -0
  29. waldiez/exporting/base/export_position.py +48 -0
  30. waldiez/exporting/base/import_position.py +23 -0
  31. waldiez/exporting/base/mixin.py +134 -0
  32. waldiez/exporting/base/utils/__init__.py +18 -0
  33. waldiez/exporting/{utils → base/utils}/comments.py +12 -55
  34. waldiez/exporting/{utils → base/utils}/naming.py +14 -4
  35. waldiez/exporting/base/utils/path_check.py +68 -0
  36. waldiez/exporting/{utils/object_string.py → base/utils/to_string.py} +21 -20
  37. waldiez/exporting/chats/__init__.py +5 -12
  38. waldiez/exporting/chats/chats_exporter.py +240 -0
  39. waldiez/exporting/chats/utils/__init__.py +15 -0
  40. waldiez/exporting/chats/utils/common.py +81 -0
  41. waldiez/exporting/chats/{nested.py → utils/nested.py} +125 -86
  42. waldiez/exporting/chats/utils/sequential.py +244 -0
  43. waldiez/exporting/chats/utils/single_chat.py +313 -0
  44. waldiez/exporting/chats/utils/swarm.py +207 -0
  45. waldiez/exporting/flow/__init__.py +5 -3
  46. waldiez/exporting/flow/flow_exporter.py +503 -0
  47. waldiez/exporting/flow/utils/__init__.py +47 -0
  48. waldiez/exporting/flow/utils/agent_utils.py +204 -0
  49. waldiez/exporting/flow/utils/chat_utils.py +71 -0
  50. waldiez/exporting/flow/utils/def_main.py +62 -0
  51. waldiez/exporting/flow/utils/flow_content.py +112 -0
  52. waldiez/exporting/flow/utils/flow_names.py +115 -0
  53. waldiez/exporting/flow/utils/importing_utils.py +182 -0
  54. waldiez/exporting/{utils → flow/utils}/logging_utils.py +34 -31
  55. waldiez/exporting/models/__init__.py +7 -242
  56. waldiez/exporting/models/models_exporter.py +192 -0
  57. waldiez/exporting/models/utils.py +166 -0
  58. waldiez/exporting/skills/__init__.py +7 -161
  59. waldiez/exporting/skills/skills_exporter.py +169 -0
  60. waldiez/exporting/skills/utils.py +281 -0
  61. waldiez/models/__init__.py +25 -7
  62. waldiez/models/agents/__init__.py +70 -0
  63. waldiez/models/agents/agent/__init__.py +11 -1
  64. waldiez/models/agents/agent/agent.py +9 -4
  65. waldiez/models/agents/agent/agent_data.py +3 -1
  66. waldiez/models/agents/agent/code_execution.py +2 -0
  67. waldiez/models/agents/agent/linked_skill.py +2 -0
  68. waldiez/models/agents/agent/nested_chat.py +2 -0
  69. waldiez/models/agents/agent/teachability.py +2 -0
  70. waldiez/models/agents/agent/termination_message.py +49 -13
  71. waldiez/models/agents/agents.py +15 -3
  72. waldiez/models/agents/assistant/__init__.py +2 -0
  73. waldiez/models/agents/assistant/assistant.py +2 -0
  74. waldiez/models/agents/assistant/assistant_data.py +2 -0
  75. waldiez/models/agents/group_manager/__init__.py +9 -1
  76. waldiez/models/agents/group_manager/group_manager.py +2 -0
  77. waldiez/models/agents/group_manager/group_manager_data.py +2 -0
  78. waldiez/models/agents/group_manager/speakers.py +49 -13
  79. waldiez/models/agents/rag_user/__init__.py +21 -4
  80. waldiez/models/agents/rag_user/rag_user.py +3 -1
  81. waldiez/models/agents/rag_user/rag_user_data.py +2 -0
  82. waldiez/models/agents/rag_user/retrieve_config.py +268 -17
  83. waldiez/models/agents/rag_user/vector_db_config.py +5 -3
  84. waldiez/models/agents/swarm_agent/__init__.py +49 -0
  85. waldiez/models/agents/swarm_agent/after_work.py +178 -0
  86. waldiez/models/agents/swarm_agent/on_condition.py +103 -0
  87. waldiez/models/agents/swarm_agent/on_condition_available.py +140 -0
  88. waldiez/models/agents/swarm_agent/on_condition_target.py +40 -0
  89. waldiez/models/agents/swarm_agent/swarm_agent.py +107 -0
  90. waldiez/models/agents/swarm_agent/swarm_agent_data.py +125 -0
  91. waldiez/models/agents/swarm_agent/update_system_message.py +144 -0
  92. waldiez/models/agents/user_proxy/__init__.py +2 -0
  93. waldiez/models/agents/user_proxy/user_proxy.py +2 -0
  94. waldiez/models/agents/user_proxy/user_proxy_data.py +2 -0
  95. waldiez/models/chat/__init__.py +21 -3
  96. waldiez/models/chat/chat.py +241 -7
  97. waldiez/models/chat/chat_data.py +192 -48
  98. waldiez/models/chat/chat_message.py +153 -144
  99. waldiez/models/chat/chat_nested.py +33 -53
  100. waldiez/models/chat/chat_summary.py +2 -0
  101. waldiez/models/common/__init__.py +6 -6
  102. waldiez/models/common/base.py +4 -1
  103. waldiez/models/common/method_utils.py +163 -83
  104. waldiez/models/flow/__init__.py +2 -0
  105. waldiez/models/flow/flow.py +176 -40
  106. waldiez/models/flow/flow_data.py +63 -2
  107. waldiez/models/flow/utils.py +172 -0
  108. waldiez/models/model/__init__.py +2 -0
  109. waldiez/models/model/model.py +30 -9
  110. waldiez/models/model/model_data.py +3 -1
  111. waldiez/models/skill/__init__.py +4 -1
  112. waldiez/models/skill/skill.py +30 -2
  113. waldiez/models/skill/skill_data.py +2 -0
  114. waldiez/models/waldiez.py +28 -4
  115. waldiez/runner.py +142 -228
  116. waldiez/running/__init__.py +33 -0
  117. waldiez/running/environment.py +83 -0
  118. waldiez/running/gen_seq_diagram.py +185 -0
  119. waldiez/running/running.py +300 -0
  120. {waldiez-0.2.2.dist-info → waldiez-0.3.1.dist-info}/METADATA +35 -28
  121. waldiez-0.3.1.dist-info/RECORD +125 -0
  122. waldiez-0.3.1.dist-info/licenses/LICENSE +201 -0
  123. waldiez/exporting/agents/__init__.py +0 -5
  124. waldiez/exporting/agents/agent.py +0 -236
  125. waldiez/exporting/agents/agent_skills.py +0 -67
  126. waldiez/exporting/agents/llm_config.py +0 -53
  127. waldiez/exporting/chats/chats.py +0 -46
  128. waldiez/exporting/chats/helpers.py +0 -420
  129. waldiez/exporting/flow/def_main.py +0 -32
  130. waldiez/exporting/flow/flow.py +0 -189
  131. waldiez/exporting/utils/__init__.py +0 -36
  132. waldiez/exporting/utils/importing.py +0 -265
  133. waldiez/exporting/utils/method_utils.py +0 -35
  134. waldiez/exporting/utils/path_check.py +0 -51
  135. waldiez-0.2.2.dist-info/RECORD +0 -92
  136. waldiez-0.2.2.dist-info/licenses/LICENSE +0 -21
  137. {waldiez-0.2.2.dist-info → waldiez-0.3.1.dist-info}/WHEEL +0 -0
  138. {waldiez-0.2.2.dist-info → waldiez-0.3.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,50 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Get the imports needed for the agent."""
4
+
5
+ from typing import Set
6
+
7
+
8
+ def get_agent_imports(agent_class: str) -> Set[str]:
9
+ """Get the imports needed for the agent.
10
+
11
+ Parameters
12
+ ----------
13
+ agent_class : str
14
+ The agent class name.
15
+
16
+ Returns
17
+ -------
18
+ Set[str]
19
+ The imports needed for the agent.
20
+ """
21
+ imports = set(["import autogen"])
22
+ if agent_class == "AssistantAgent":
23
+ imports.add("from autogen import AssistantAgent")
24
+ elif agent_class == "UserProxyAgent":
25
+ imports.add("from autogen import UserProxyAgent")
26
+ elif agent_class == "GroupChatManager":
27
+ imports.add("from autogen import GroupChatManager")
28
+ elif agent_class == "RetrieveUserProxyAgent":
29
+ imports.add(
30
+ "from autogen.agentchat.contrib.retrieve_user_proxy_agent "
31
+ "import RetrieveUserProxyAgent"
32
+ )
33
+ elif agent_class == "MultimodalConversableAgent":
34
+ imports.add(
35
+ "from autogen.agentchat.contrib.multimodal_conversable_agent "
36
+ "import MultimodalConversableAgent"
37
+ )
38
+ elif agent_class == "SwarmAgent":
39
+ imports.add(
40
+ "from autogen import "
41
+ "AFTER_WORK, "
42
+ "ON_CONDITION, "
43
+ "UPDATE_SYSTEM_MESSAGE, "
44
+ "AfterWorkOption, "
45
+ "SwarmAgent, "
46
+ "SwarmResult"
47
+ )
48
+ else:
49
+ imports.add("from autogen import ConversableAgent")
50
+ return imports
@@ -1,3 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
1
3
  """Code execution related functions for exporting agents."""
2
4
 
3
5
  from typing import Dict, Tuple
@@ -36,15 +38,15 @@ def get_agent_code_execution_config(
36
38
  if use_docker
37
39
  else "LocalCommandLineCodeExecutor"
38
40
  )
39
- executor_content = f"{agent_name}_executor = {executor_class_name}(\n"
41
+ executor_content = f"{agent_name}_executor = {executor_class_name}(" + "\n"
40
42
  if agent.data.code_execution_config.work_dir:
41
43
  wok_dir = agent.data.code_execution_config.work_dir.replace(
42
44
  '"', '\\"'
43
45
  ).replace("\n", "\\n")
44
- executor_content += f' work_dir="{wok_dir}",\n'
46
+ executor_content += f' work_dir="{wok_dir}",' + "\n"
45
47
  if agent.data.code_execution_config.timeout:
46
48
  executor_content += (
47
- f" timeout={agent.data.code_execution_config.timeout},\n"
49
+ f" timeout={agent.data.code_execution_config.timeout}," + "\n"
48
50
  )
49
51
  if use_docker is False and agent.data.code_execution_config.functions:
50
52
  function_names = []
@@ -53,15 +55,11 @@ def get_agent_code_execution_config(
53
55
  function_names.append(skill_name)
54
56
  if function_names:
55
57
  # pylint: disable=inconsistent-quotes
58
+ function_names_string = ", ".join(function_names)
56
59
  executor_content += (
57
- f" functions=[{', '.join(function_names)}],\n"
60
+ f" functions=[{function_names_string}]," + "\n"
58
61
  )
59
62
  executor_content += ")\n\n"
60
- # if (
61
- # executor_content
62
- # == f"{agent_name}_executor = {executor_class_name}(\n)\n\n"
63
- # ):
64
- # # empty executor?
65
- # return "", "False", ""
66
63
  executor_arg = f'{{"executor": {agent_name}_executor}}'
67
- return executor_content, executor_arg, executor_class_name
64
+ the_import = f"from autogen.coding import {executor_class_name}"
65
+ return executor_content, executor_arg, the_import
@@ -1,16 +1,17 @@
1
- """Export group manger and group chat to string."""
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Export group manager and group chat to string."""
2
4
 
3
- from typing import Dict, List, Optional, Tuple
5
+ from typing import Callable, Dict, List, Optional, Tuple
4
6
 
5
7
  from waldiez.models import WaldiezAgent, WaldiezGroupManager
6
8
 
7
- from ..utils import get_method_string, get_object_string
8
-
9
9
 
10
10
  def get_group_manager_extras(
11
11
  agent: WaldiezAgent,
12
12
  group_chat_members: List[WaldiezAgent],
13
13
  agent_names: Dict[str, str],
14
+ serializer: Callable[..., str],
14
15
  ) -> Tuple[str, str]:
15
16
  """Get the group manager extra string and custom selection method if any.
16
17
 
@@ -22,6 +23,8 @@ def get_group_manager_extras(
22
23
  The group members.
23
24
  agent_names : Dict[str, str]
24
25
  The agent names.
26
+ serializer : Callable[..., str]
27
+ The serializer function.
25
28
 
26
29
  Returns
27
30
  -------
@@ -34,12 +37,17 @@ def get_group_manager_extras(
34
37
  custom_speaker_selection: Optional[str] = None
35
38
  if agent.agent_type == "manager" and isinstance(agent, WaldiezGroupManager):
36
39
  group_chat_string, group_chat_name, custom_speaker_selection = (
37
- _get_group_manager_extras(agent, group_chat_members, agent_names)
40
+ _get_group_manager_extras(
41
+ agent=agent,
42
+ group_members=group_chat_members,
43
+ agent_names=agent_names,
44
+ serializer=serializer,
45
+ )
38
46
  )
39
47
  if group_chat_name:
40
- group_chat_arg = f"\n groupchat={group_chat_name},"
48
+ group_chat_arg = "\n" + f" groupchat={group_chat_name},"
41
49
  if custom_speaker_selection:
42
- before_agent_string += f"{custom_speaker_selection}\n\n"
50
+ before_agent_string += f"{custom_speaker_selection}" + "\n"
43
51
  if group_chat_string:
44
52
  before_agent_string += group_chat_string
45
53
  return before_agent_string, group_chat_arg
@@ -49,6 +57,7 @@ def _get_group_manager_extras(
49
57
  agent: WaldiezGroupManager,
50
58
  group_members: List[WaldiezAgent],
51
59
  agent_names: Dict[str, str],
60
+ serializer: Callable[..., str],
52
61
  ) -> Tuple[str, str, Optional[str]]:
53
62
  """Get the group manager extra string and custom selection method if any.
54
63
 
@@ -60,6 +69,8 @@ def _get_group_manager_extras(
60
69
  The group members.
61
70
  agent_names : Dict[str, str]
62
71
  The agent names.
72
+ serializer : Callable[..., str]
73
+ The serializer function.
63
74
 
64
75
  Returns
65
76
  -------
@@ -88,25 +99,19 @@ def _get_group_manager_extras(
88
99
  group_chat_string += f" max_round={agent.data.max_round}," + "\n"
89
100
  if agent.data.admin_name:
90
101
  group_chat_string += f' admin_name="{agent.data.admin_name}",' + "\n"
91
- extra_group_chat_string, method_name_and_content = (
92
- _get_group_chat_speakers_string(agent, agent_names)
102
+ extra_group_chat_string, custom_selection_method = (
103
+ _get_group_chat_speakers_string(agent, agent_names, serializer)
93
104
  )
94
- custom_selection_method: Optional[str] = None
95
105
  group_chat_string += extra_group_chat_string
96
106
  group_chat_string += ")\n\n"
97
- if method_name_and_content:
98
- method_name, method_content = method_name_and_content
99
- custom_selection_method = get_method_string(
100
- "custom_speaker_selection",
101
- method_name,
102
- method_content,
103
- )
104
107
  return group_chat_string, group_chat_name, custom_selection_method
105
108
 
106
109
 
107
110
  def _get_group_chat_speakers_string(
108
- agent: WaldiezGroupManager, agent_names: Dict[str, str]
109
- ) -> Tuple[str, Optional[Tuple[str, str]]]:
111
+ agent: WaldiezGroupManager,
112
+ agent_names: Dict[str, str],
113
+ serializer: Callable[..., str],
114
+ ) -> Tuple[str, Optional[str]]:
110
115
  """Get the group chat speakers string.
111
116
 
112
117
  Parameters
@@ -115,16 +120,18 @@ def _get_group_chat_speakers_string(
115
120
  The agent.
116
121
  agent_names : Dict[str, str]
117
122
  The agent names.
123
+ serializer : Callable[..., str]
124
+ The serializer function.
118
125
 
119
126
  Returns
120
127
  -------
121
128
  str
122
129
  The group chat speakers string.
123
- Optional[Tuple[str, str]]
124
- The custom selection method name and content if any.
130
+ Optional[str]
131
+ The custom custom for speaker selection if any.
125
132
  """
126
133
  speakers_string = ""
127
- method_name_and_content: Optional[Tuple[str, str]] = None
134
+ function_content: Optional[str] = None
128
135
  if agent.data.speakers.max_retries_for_selecting is not None:
129
136
  speakers_string += (
130
137
  " max_retries_for_selecting_speaker="
@@ -139,12 +146,14 @@ def _get_group_chat_speakers_string(
139
146
  )
140
147
  else:
141
148
  agent_name = agent_names[agent.id]
142
- method_name = f"custom_speaker_selection_method_{agent_name}"
143
- method_name_and_content = (
144
- method_name,
145
- agent.data.speakers.custom_method_string or "",
149
+ function_content, function_name = (
150
+ agent.data.speakers.get_custom_method_function(
151
+ name_suffix=agent_name
152
+ )
153
+ )
154
+ speakers_string += (
155
+ f" speaker_selection_method={function_name}," + "\n"
146
156
  )
147
- speakers_string += f" speaker_selection_method={method_name}," "\n"
148
157
  # selection_mode == "repeat":
149
158
  if agent.data.speakers.selection_mode == "repeat":
150
159
  speakers_string += _get_speakers_selection_repeat_string(
@@ -156,10 +165,12 @@ def _get_group_chat_speakers_string(
156
165
  and agent.data.speakers.allowed_or_disallowed_transitions
157
166
  ):
158
167
  speakers_string += _get_speakers_selection_transition_string(
159
- agent, agent_names
168
+ agent=agent,
169
+ agent_names=agent_names,
170
+ serializer=serializer,
160
171
  )
161
172
  speakers_string = speakers_string.replace('"None"', "None")
162
- return speakers_string, method_name_and_content
173
+ return speakers_string, function_content
163
174
 
164
175
 
165
176
  def _get_speakers_selection_repeat_string(
@@ -168,9 +179,8 @@ def _get_speakers_selection_repeat_string(
168
179
  speakers_string = ""
169
180
  if isinstance(agent.data.speakers.allow_repeat, bool):
170
181
  speakers_string += (
171
- " allow_repeat_speaker="
172
- f"{agent.data.speakers.allow_repeat},"
173
- "\n"
182
+ f" allow_repeat_speaker={agent.data.speakers.allow_repeat},"
183
+ + "\n"
174
184
  )
175
185
  elif isinstance(agent.data.speakers.allow_repeat, list):
176
186
  # get the names of the agents
@@ -178,12 +188,14 @@ def _get_speakers_selection_repeat_string(
178
188
  agent_names[agent_id]
179
189
  for agent_id in agent.data.speakers.allow_repeat
180
190
  )
181
- speakers_string += f" allow_repeat=[{allow_repeat}]," "\n"
191
+ speakers_string += f" allow_repeat=[{allow_repeat}]," + "\n"
182
192
  return speakers_string
183
193
 
184
194
 
185
195
  def _get_speakers_selection_transition_string(
186
- agent: WaldiezGroupManager, agent_names: Dict[str, str]
196
+ agent: WaldiezGroupManager,
197
+ agent_names: Dict[str, str],
198
+ serializer: Callable[..., str],
187
199
  ) -> str:
188
200
  speakers_string = ""
189
201
  allowed_or_disallowed_speaker_transitions = {}
@@ -194,7 +206,7 @@ def _get_speakers_selection_transition_string(
194
206
  allowed_or_disallowed_speaker_transitions[agent_names[agent_id]] = [
195
207
  agent_names[transition] for transition in transitions
196
208
  ]
197
- transitions_string = get_object_string(
209
+ transitions_string = serializer(
198
210
  allowed_or_disallowed_speaker_transitions, 1
199
211
  )
200
212
  transitions_string = transitions_string.replace('"', "").replace("'", "")
@@ -1,3 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
1
3
  """RAG User Agent related string generation."""
2
4
 
3
5
  from .rag_user import get_rag_user_extras, get_rag_user_retrieve_config_str
@@ -1,3 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
1
3
  """Get chroma db related imports and content."""
2
4
 
3
5
  from pathlib import Path
@@ -22,19 +24,21 @@ def _get_chroma_client_string(agent: WaldiezRagUser) -> Tuple[str, str]:
22
24
  The 'client' and what to import.
23
25
  """
24
26
  # TODO: also check `connection_url` (chroma in client-server mode)
25
- to_import = "chromadb"
27
+ to_import = "import chromadb"
26
28
  client_str = "chromadb."
27
29
  if (
28
30
  agent.retrieve_config.db_config.use_local_storage
29
31
  and agent.retrieve_config.db_config.local_storage_path is not None
30
32
  ):
31
- # on windows, we get:
33
+ # on windows, we might get:
32
34
  # SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes
33
35
  # in position 2-3: truncated \UXXXXXXXX escape
34
36
  local_path = Path(agent.retrieve_config.db_config.local_storage_path)
35
37
  client_str += (
36
- f'PersistentClient(path=r"{local_path}", '
37
- "settings=Settings(anonymized_telemetry=False))"
38
+ "PersistentClient(\n"
39
+ f' path=r"{local_path}",' + "\n"
40
+ " settings=Settings(anonymized_telemetry=False),\n"
41
+ ")"
38
42
  )
39
43
  else:
40
44
  client_str += "Client(Settings(anonymized_telemetry=False))"
@@ -60,7 +64,7 @@ def _get_chroma_embedding_function_string(
60
64
  """
61
65
  to_import = ""
62
66
  embedding_function_arg = ""
63
- embedding_function_body = ""
67
+ embedding_function_content = ""
64
68
  vector_db_model = agent.retrieve_config.db_config.model
65
69
  if not agent.retrieve_config.use_custom_embedding:
66
70
  to_import = (
@@ -70,13 +74,13 @@ def _get_chroma_embedding_function_string(
70
74
  embedding_function_arg = "SentenceTransformerEmbeddingFunction("
71
75
  embedding_function_arg += f'model_name="{vector_db_model}")'
72
76
  else:
73
- embedding_function_arg = f"custom_embedding_function_{agent_name}"
74
- embedding_function_body = (
75
- f"\ndef custom_embedding_function_{agent_name}():\n"
76
- f"{agent.retrieve_config.embedding_function_string}\n"
77
+ embedding_function_content, embedding_function_arg = (
78
+ agent.retrieve_config.get_custom_embedding_function(
79
+ name_suffix=agent_name
80
+ )
77
81
  )
78
-
79
- return embedding_function_arg, to_import, embedding_function_body
82
+ embedding_function_content = "\n" + embedding_function_content
83
+ return embedding_function_arg, to_import, embedding_function_content
80
84
 
81
85
 
82
86
  def get_chroma_db_args(
@@ -108,8 +112,8 @@ def get_chroma_db_args(
108
112
  if to_import_embedding:
109
113
  to_import.add(to_import_embedding)
110
114
  kwarg_string = (
111
- f" client={agent_name}_client,\n"
112
- f" embedding_function={embedding_function_arg},\n"
115
+ f" client={agent_name}_client," + "\n"
116
+ f" embedding_function={embedding_function_arg}," + "\n"
113
117
  )
114
118
  # The RAG example:
115
119
  # https://ag2ai.github.io/ag2/docs/notebooks/agentchat_groupchat_RAG/
@@ -118,21 +122,22 @@ def get_chroma_db_args(
118
122
  # https://github.com/microsoft/autogen/issues/3551#issuecomment-2366930994
119
123
  # manually initializing the collection before running the flow,
120
124
  # might be a workaround.
121
- content_before = f"{agent_name}_client = {client_str}\n"
125
+ content_before = f"{agent_name}_client = {client_str}" + "\n"
122
126
  collection_name = agent.retrieve_config.collection_name
123
127
  get_or_create = agent.retrieve_config.get_or_create
124
128
  if collection_name:
125
129
  if get_or_create:
126
130
  content_before += (
127
131
  f"{agent_name}_client.get_or_create_collection("
128
- f'"{collection_name}")\n'
132
+ f'"{collection_name}")' + "\n"
129
133
  )
130
134
  else:
131
135
  content_before += (
132
136
  "try:\n"
133
- f' {agent_name}_client.get_collection("{collection_name}")\n'
137
+ f' {agent_name}_client.get_collection("{collection_name}")'
138
+ + "\n"
134
139
  "except ValueError:\n"
135
140
  f" {agent_name}_client.create_collection("
136
- f'"{collection_name}")\n'
141
+ f'"{collection_name}")' + "\n"
137
142
  )
138
143
  return kwarg_string, to_import, embedding_function_body, content_before
@@ -1,3 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
1
3
  """Get mongodb related content and imports."""
2
4
 
3
5
  from typing import Set, Tuple
@@ -24,7 +26,7 @@ def _get_mongodb_embedding_function_string(
24
26
  """
25
27
  to_import = ""
26
28
  embedding_function_arg = ""
27
- embedding_function_body = ""
29
+ embedding_function_content = ""
28
30
  if not agent.retrieve_config.use_custom_embedding:
29
31
  to_import = "from sentence_transformers import SentenceTransformer"
30
32
  embedding_function_arg = (
@@ -33,12 +35,13 @@ def _get_mongodb_embedding_function_string(
33
35
  ").encode"
34
36
  )
35
37
  else:
36
- embedding_function_arg = f"custom_embedding_function_{agent_name}"
37
- embedding_function_body = (
38
- f"\ndef custom_embedding_function_{agent_name}():\n"
39
- f"{agent.retrieve_config.embedding_function_string}\n"
38
+ embedding_function_content, embedding_function_arg = (
39
+ agent.retrieve_config.get_custom_embedding_function(
40
+ name_suffix=agent_name
41
+ )
40
42
  )
41
- return embedding_function_arg, to_import, embedding_function_body
43
+ embedding_function_content = "\n" + embedding_function_content
44
+ return embedding_function_arg, to_import, embedding_function_content
42
45
 
43
46
 
44
47
  def get_mongodb_db_args(
@@ -67,17 +70,18 @@ def get_mongodb_db_args(
67
70
  tab = " " * 12
68
71
  db_config = agent.retrieve_config.db_config
69
72
  kwarg_string = (
70
- f'{tab}connection_string="{db_config.connection_url}",\n'
71
- f"{tab}embedding_function={embedding_function_arg},\n"
73
+ f'{tab}connection_string="{db_config.connection_url}",' + "\n"
74
+ f"{tab}embedding_function={embedding_function_arg}," + "\n"
72
75
  )
73
76
  wait_until_document_ready = db_config.wait_until_document_ready
74
77
  wait_until_index_ready = db_config.wait_until_index_ready
75
78
  if wait_until_document_ready is not None:
76
79
  kwarg_string += (
77
- f"{tab}wait_until_document_ready={wait_until_document_ready},\n"
80
+ f"{tab}wait_until_document_ready={wait_until_document_ready},"
81
+ + "\n"
78
82
  )
79
83
  if wait_until_index_ready is not None:
80
84
  kwarg_string += (
81
- f"{tab}wait_until_index_ready={wait_until_index_ready},\n"
85
+ f"{tab}wait_until_index_ready={wait_until_index_ready}," + "\n"
82
86
  )
83
87
  return kwarg_string, to_import, embedding_function_body
@@ -1,3 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
1
3
  """Get pgvector related content and imports."""
2
4
 
3
5
  from typing import Set, Tuple
@@ -18,7 +20,7 @@ def _get_pgvector_client_string(agent: WaldiezRagUser) -> Tuple[str, str]:
18
20
  Tuple[str, str]
19
21
  The 'client' and what to import.
20
22
  """
21
- to_import = "psycopg"
23
+ to_import = "import psycopg"
22
24
  client_str = "psycopg."
23
25
  connection_url = agent.retrieve_config.db_config.connection_url
24
26
  client_str += f'connect("{connection_url}")'
@@ -44,20 +46,22 @@ def _get_pgvector_embedding_function_string(
44
46
  """
45
47
  to_import = ""
46
48
  embedding_function_arg = ""
47
- embedding_function_body = ""
49
+ embedding_function_content = ""
48
50
  if agent.retrieve_config.use_custom_embedding:
49
51
  embedding_function_arg = f"custom_embedding_function_{agent_name}"
50
- embedding_function_body = (
51
- f"\ndef custom_embedding_function_{agent_name}():\n"
52
- f"{agent.retrieve_config.embedding_function_string}\n"
52
+ embedding_function_content, embedding_function_arg = (
53
+ agent.retrieve_config.get_custom_embedding_function(
54
+ name_suffix=agent_name
55
+ )
53
56
  )
57
+ embedding_function_content = "\n" + embedding_function_content
54
58
  else:
55
59
  to_import = "from sentence_transformers import SentenceTransformer"
56
60
  embedding_function_arg = "SentenceTransformer("
57
61
  embedding_function_arg += (
58
62
  f'"{agent.retrieve_config.db_config.model}").encode'
59
63
  )
60
- return embedding_function_arg, to_import, embedding_function_body
64
+ return embedding_function_arg, to_import, embedding_function_content
61
65
 
62
66
 
63
67
  def get_pgvector_db_args(
@@ -87,7 +91,7 @@ def get_pgvector_db_args(
87
91
  else {to_import_client}
88
92
  )
89
93
  kwarg_str = (
90
- f" client={client_str},\n"
91
- f" embedding_function={embedding_function_arg},\n"
94
+ f" client={client_str}," + "\n"
95
+ f" embedding_function={embedding_function_arg}," + "\n"
92
96
  )
93
97
  return kwarg_str, to_import, embedding_function_body
@@ -1,3 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
1
3
  """Get qdrant db related imports and content."""
2
4
 
3
5
  from pathlib import Path
@@ -61,7 +63,7 @@ def _get_qdrant_embedding_function_string(
61
63
  """
62
64
  to_import = ""
63
65
  embedding_function_arg = ""
64
- embedding_function_body = ""
66
+ embedding_function_content = ""
65
67
  vector_db_model = agent.retrieve_config.db_config.model
66
68
  if not agent.retrieve_config.use_custom_embedding:
67
69
  to_import = (
@@ -71,12 +73,13 @@ def _get_qdrant_embedding_function_string(
71
73
  embedding_function_arg = "FastEmbedEmbeddingFunction("
72
74
  embedding_function_arg += f'model_name="{vector_db_model}")'
73
75
  else:
74
- embedding_function_arg = f"custom_embedding_function_{agent_name}"
75
- embedding_function_body = (
76
- f"\ndef custom_embedding_function_{agent_name}():\n"
77
- f"{agent.retrieve_config.embedding_function_string}\n"
76
+ embedding_function_content, embedding_function_arg = (
77
+ agent.retrieve_config.get_custom_embedding_function(
78
+ name_suffix=agent_name
79
+ )
78
80
  )
79
- return embedding_function_arg, to_import, embedding_function_body
81
+ embedding_function_content = "\n" + embedding_function_content
82
+ return embedding_function_arg, to_import, embedding_function_content
80
83
 
81
84
 
82
85
  def get_qdrant_db_args(
@@ -106,7 +109,7 @@ def get_qdrant_db_args(
106
109
  else {to_import_client}
107
110
  )
108
111
  kwarg_string = (
109
- f" client={client_str},\n"
110
- f" embedding_function={embedding_function_arg},\n"
112
+ f" client={client_str}," + "\n"
113
+ f" embedding_function={embedding_function_arg}," + "\n"
111
114
  )
112
115
  return kwarg_string, to_import, embedding_function_body