waldiez 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of waldiez might be problematic. Click here for more details.

Files changed (94) hide show
  1. waldiez/__init__.py +15 -0
  2. waldiez/__main__.py +6 -0
  3. waldiez/_version.py +3 -0
  4. waldiez/cli.py +162 -0
  5. waldiez/exporter.py +293 -0
  6. waldiez/exporting/__init__.py +14 -0
  7. waldiez/exporting/agents/__init__.py +5 -0
  8. waldiez/exporting/agents/agent.py +229 -0
  9. waldiez/exporting/agents/agent_skills.py +67 -0
  10. waldiez/exporting/agents/code_execution.py +67 -0
  11. waldiez/exporting/agents/group_manager.py +209 -0
  12. waldiez/exporting/agents/llm_config.py +53 -0
  13. waldiez/exporting/agents/rag_user/__init__.py +5 -0
  14. waldiez/exporting/agents/rag_user/chroma_utils.py +134 -0
  15. waldiez/exporting/agents/rag_user/mongo_utils.py +83 -0
  16. waldiez/exporting/agents/rag_user/pgvector_utils.py +93 -0
  17. waldiez/exporting/agents/rag_user/qdrant_utils.py +112 -0
  18. waldiez/exporting/agents/rag_user/rag_user.py +165 -0
  19. waldiez/exporting/agents/rag_user/vector_db.py +119 -0
  20. waldiez/exporting/agents/teachability.py +37 -0
  21. waldiez/exporting/agents/termination_message.py +45 -0
  22. waldiez/exporting/chats/__init__.py +14 -0
  23. waldiez/exporting/chats/chats.py +46 -0
  24. waldiez/exporting/chats/helpers.py +395 -0
  25. waldiez/exporting/chats/nested.py +264 -0
  26. waldiez/exporting/flow/__init__.py +5 -0
  27. waldiez/exporting/flow/def_main.py +37 -0
  28. waldiez/exporting/flow/flow.py +185 -0
  29. waldiez/exporting/models/__init__.py +193 -0
  30. waldiez/exporting/skills/__init__.py +128 -0
  31. waldiez/exporting/utils/__init__.py +34 -0
  32. waldiez/exporting/utils/comments.py +136 -0
  33. waldiez/exporting/utils/importing.py +267 -0
  34. waldiez/exporting/utils/logging_utils.py +203 -0
  35. waldiez/exporting/utils/method_utils.py +35 -0
  36. waldiez/exporting/utils/naming.py +127 -0
  37. waldiez/exporting/utils/object_string.py +81 -0
  38. waldiez/io_stream.py +181 -0
  39. waldiez/models/__init__.py +107 -0
  40. waldiez/models/agents/__init__.py +65 -0
  41. waldiez/models/agents/agent/__init__.py +21 -0
  42. waldiez/models/agents/agent/agent.py +190 -0
  43. waldiez/models/agents/agent/agent_data.py +162 -0
  44. waldiez/models/agents/agent/code_execution.py +71 -0
  45. waldiez/models/agents/agent/linked_skill.py +30 -0
  46. waldiez/models/agents/agent/nested_chat.py +73 -0
  47. waldiez/models/agents/agent/teachability.py +68 -0
  48. waldiez/models/agents/agent/termination_message.py +167 -0
  49. waldiez/models/agents/agents.py +129 -0
  50. waldiez/models/agents/assistant/__init__.py +6 -0
  51. waldiez/models/agents/assistant/assistant.py +41 -0
  52. waldiez/models/agents/assistant/assistant_data.py +29 -0
  53. waldiez/models/agents/group_manager/__init__.py +19 -0
  54. waldiez/models/agents/group_manager/group_manager.py +87 -0
  55. waldiez/models/agents/group_manager/group_manager_data.py +91 -0
  56. waldiez/models/agents/group_manager/speakers.py +211 -0
  57. waldiez/models/agents/rag_user/__init__.py +26 -0
  58. waldiez/models/agents/rag_user/rag_user.py +58 -0
  59. waldiez/models/agents/rag_user/rag_user_data.py +32 -0
  60. waldiez/models/agents/rag_user/retrieve_config.py +592 -0
  61. waldiez/models/agents/rag_user/vector_db_config.py +162 -0
  62. waldiez/models/agents/user_proxy/__init__.py +6 -0
  63. waldiez/models/agents/user_proxy/user_proxy.py +41 -0
  64. waldiez/models/agents/user_proxy/user_proxy_data.py +30 -0
  65. waldiez/models/chat/__init__.py +22 -0
  66. waldiez/models/chat/chat.py +129 -0
  67. waldiez/models/chat/chat_data.py +326 -0
  68. waldiez/models/chat/chat_message.py +304 -0
  69. waldiez/models/chat/chat_nested.py +160 -0
  70. waldiez/models/chat/chat_summary.py +110 -0
  71. waldiez/models/common/__init__.py +38 -0
  72. waldiez/models/common/base.py +63 -0
  73. waldiez/models/common/method_utils.py +165 -0
  74. waldiez/models/flow/__init__.py +9 -0
  75. waldiez/models/flow/flow.py +302 -0
  76. waldiez/models/flow/flow_data.py +87 -0
  77. waldiez/models/model/__init__.py +11 -0
  78. waldiez/models/model/model.py +169 -0
  79. waldiez/models/model/model_data.py +86 -0
  80. waldiez/models/skill/__init__.py +9 -0
  81. waldiez/models/skill/skill.py +129 -0
  82. waldiez/models/skill/skill_data.py +37 -0
  83. waldiez/models/waldiez.py +301 -0
  84. waldiez/py.typed +0 -0
  85. waldiez/runner.py +304 -0
  86. waldiez/stream/__init__.py +7 -0
  87. waldiez/stream/consumer.py +139 -0
  88. waldiez/stream/provider.py +339 -0
  89. waldiez/stream/server.py +412 -0
  90. waldiez-0.1.0.dist-info/METADATA +181 -0
  91. waldiez-0.1.0.dist-info/RECORD +94 -0
  92. waldiez-0.1.0.dist-info/WHEEL +4 -0
  93. waldiez-0.1.0.dist-info/entry_points.txt +2 -0
  94. waldiez-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,134 @@
1
+ """Get chroma db related imports and content."""
2
+
3
+ from pathlib import Path
4
+ from typing import Set, Tuple
5
+
6
+ from waldiez.models import WaldiezRagUser
7
+
8
+
9
+ def _get_chroma_client_string(agent: WaldiezRagUser) -> Tuple[str, str]:
10
+ """Get the ChromaVectorDB client string.
11
+
12
+ Parameters
13
+ ----------
14
+ agent : WaldiezRagUser
15
+ The agent.
16
+ agent_name : str
17
+ The agent's name.
18
+
19
+ Returns
20
+ -------
21
+ Tuple[str, str]
22
+ The 'client' and what to import.
23
+ """
24
+ # TODO: also check `connection_url` (chroma in client-server mode)
25
+ to_import = "chromadb"
26
+ client_str = "chromadb."
27
+ if (
28
+ agent.retrieve_config.db_config.use_local_storage
29
+ and agent.retrieve_config.db_config.local_storage_path is not None
30
+ ):
31
+ local_path = Path(agent.retrieve_config.db_config.local_storage_path)
32
+ client_str += f'PersistentClient(path="{local_path}")'
33
+ else:
34
+ client_str += "Client()"
35
+ return client_str, to_import
36
+
37
+
38
+ def _get_chroma_embedding_function_string(
39
+ agent: WaldiezRagUser, agent_name: str
40
+ ) -> Tuple[str, str, str]:
41
+ """Get the ChromaVectorDB embedding function string.
42
+
43
+ Parameters
44
+ ----------
45
+ agent : WaldiezRagUser
46
+ The agent.
47
+ agent_name : str
48
+ The agent's name.
49
+
50
+ Returns
51
+ -------
52
+ Tuple[str, str, str]
53
+ The 'embedding_function', the import and the custom embedding function.
54
+ """
55
+ to_import = ""
56
+ embedding_function_arg = ""
57
+ embedding_function_body = ""
58
+ vector_db_model = agent.retrieve_config.db_config.model
59
+ if not agent.retrieve_config.use_custom_embedding:
60
+ to_import = (
61
+ "from chromadb.utils.embedding_functions "
62
+ "import SentenceTransformerEmbeddingFunction"
63
+ )
64
+ embedding_function_arg = "SentenceTransformerEmbeddingFunction("
65
+ embedding_function_arg += f'model_name="{vector_db_model}")'
66
+ else:
67
+ embedding_function_arg = f"custom_embedding_function_{agent_name}"
68
+ embedding_function_body = (
69
+ f"\ndef custom_embedding_function_{agent_name}():\n"
70
+ f"{agent.retrieve_config.embedding_function_string}\n"
71
+ )
72
+
73
+ return embedding_function_arg, to_import, embedding_function_body
74
+
75
+
76
+ def get_chroma_db_args(
77
+ agent: WaldiezRagUser, agent_name: str
78
+ ) -> Tuple[str, Set[str], str, str]:
79
+ """Get the 'kwargs to use for ChromaVectorDB.
80
+
81
+ Parameters
82
+ ----------
83
+ agent : WaldiezRagUser
84
+ The agent.
85
+ agent_name : str
86
+ The agent's name.
87
+
88
+ Returns
89
+ -------
90
+ Tuple[str, Set[str], str]
91
+
92
+ - The 'kwargs' string.
93
+ - What to import.
94
+ - The custom embedding function.
95
+ - Any additional content to be used before the `kwargs` string.
96
+ """
97
+ client_str, to_import_client = _get_chroma_client_string(agent)
98
+ embedding_function_arg, to_import_embedding, embedding_function_body = (
99
+ _get_chroma_embedding_function_string(agent, agent_name)
100
+ )
101
+ to_import = {to_import_client}
102
+ if to_import_embedding:
103
+ to_import.add(to_import_embedding)
104
+ kwarg_string = (
105
+ f" client={client_str},\n"
106
+ f" embedding_function={embedding_function_arg},\n"
107
+ )
108
+ # The RAG example:
109
+ # https://microsoft.github.io/autogen/docs/\
110
+ # notebooks/agentchat_groupchat_RAG
111
+ # raises `InvalidCollectionException`: Collection groupchat does not exist.
112
+ # https://github.com/chroma-core/chroma/issues/861
113
+ # https://github.com/microsoft/autogen/issues/3551#issuecomment-2366930994
114
+ # manually initializing the collection before running the flow,
115
+ # might be a workaround.
116
+ content_before = ""
117
+ collection_name = agent.retrieve_config.collection_name
118
+ get_or_create = agent.retrieve_config.get_or_create
119
+ if collection_name:
120
+ content_before = f"{agent_name}_client = {client_str}\n"
121
+ if get_or_create:
122
+ content_before += (
123
+ f"{agent_name}_client.get_or_create_collection("
124
+ f'"{collection_name}")\n'
125
+ )
126
+ else:
127
+ content_before += (
128
+ "try:\n"
129
+ f' {agent_name}_client.get_collection("{collection_name}")\n'
130
+ "except ValueError:\n"
131
+ f" {agent_name}_client.create_collection("
132
+ f'"{collection_name}")\n'
133
+ )
134
+ return kwarg_string, to_import, embedding_function_body, content_before
@@ -0,0 +1,83 @@
1
+ """Get mongodb related content and imports."""
2
+
3
+ from typing import Set, Tuple
4
+
5
+ from waldiez.models import WaldiezRagUser
6
+
7
+
8
+ def _get_mongodb_embedding_function_string(
9
+ agent: WaldiezRagUser, agent_name: str
10
+ ) -> Tuple[str, str, str]:
11
+ """Get the MongoDBAtlasVectorDB embedding function string.
12
+
13
+ Parameters
14
+ ----------
15
+ agent : WaldiezRagUser
16
+ The agent.
17
+ agent_name : str
18
+ The agent's name.
19
+
20
+ Returns
21
+ -------
22
+ Tuple[str, str, str]
23
+ The 'embedding_function', the import and the custom_embedding_function.
24
+ """
25
+ to_import = ""
26
+ embedding_function_arg = ""
27
+ embedding_function_body = ""
28
+ if not agent.retrieve_config.use_custom_embedding:
29
+ to_import = "from sentence_transformers import SentenceTransformer"
30
+ embedding_function_arg = (
31
+ "SentenceTransformer("
32
+ f'"{agent.retrieve_config.db_config.model}"'
33
+ ").encode"
34
+ )
35
+ else:
36
+ embedding_function_arg = f"custom_embedding_function_{agent_name}"
37
+ embedding_function_body = (
38
+ f"\ndef custom_embedding_function_{agent_name}():\n"
39
+ f"{agent.retrieve_config.embedding_function_string}\n"
40
+ )
41
+ return embedding_function_arg, to_import, embedding_function_body
42
+
43
+
44
+ def get_mongodb_db_args(
45
+ agent: WaldiezRagUser, agent_name: str
46
+ ) -> Tuple[str, Set[str], str]:
47
+ """Get the kwargs to use for MongoDBAtlasVectorDB.
48
+
49
+ Parameters
50
+ ----------
51
+ agent : WaldiezRagUser
52
+ The agent.
53
+ agent_name : str
54
+ The agent's name.
55
+
56
+ Returns
57
+ -------
58
+ Tuple[str, Set[str], str]
59
+ The kwargs to use, what to import and the custom_embedding_function.
60
+ """
61
+ embedding_function_arg, to_import_embedding, embedding_function_body = (
62
+ _get_mongodb_embedding_function_string(agent, agent_name)
63
+ )
64
+ to_import: Set[str] = (
65
+ set() if not to_import_embedding else {to_import_embedding}
66
+ )
67
+ tab = " " * 12
68
+ db_config = agent.retrieve_config.db_config
69
+ kwarg_string = (
70
+ f'{tab}connection_string="{db_config.connection_url}",\n'
71
+ f"{tab}embedding_function={embedding_function_arg},\n"
72
+ )
73
+ wait_until_document_ready = db_config.wait_until_document_ready
74
+ wait_until_index_ready = db_config.wait_until_index_ready
75
+ if wait_until_document_ready is not None:
76
+ kwarg_string += (
77
+ f"{tab}wait_until_document_ready={wait_until_document_ready},\n"
78
+ )
79
+ if wait_until_index_ready is not None:
80
+ kwarg_string += (
81
+ f"{tab}wait_until_index_ready={wait_until_index_ready},\n"
82
+ )
83
+ return kwarg_string, to_import, embedding_function_body
@@ -0,0 +1,93 @@
1
+ """Get pgvector related content and imports."""
2
+
3
+ from typing import Set, Tuple
4
+
5
+ from waldiez.models import WaldiezRagUser
6
+
7
+
8
+ def _get_pgvector_client_string(agent: WaldiezRagUser) -> Tuple[str, str]:
9
+ """Get the PGVectorDB client string.
10
+
11
+ Parameters
12
+ ----------
13
+ agent : WaldiezRagUser
14
+ The agent.
15
+
16
+ Returns
17
+ -------
18
+ Tuple[str, str]
19
+ The 'client' and what to import.
20
+ """
21
+ to_import = "psycopg"
22
+ client_str = "psycopg."
23
+ connection_url = agent.retrieve_config.db_config.connection_url
24
+ client_str += f'connect("{connection_url}")'
25
+ return client_str, to_import
26
+
27
+
28
+ def _get_pgvector_embedding_function_string(
29
+ agent: WaldiezRagUser, agent_name: str
30
+ ) -> Tuple[str, str, str]:
31
+ """Get the PGVectorDB embedding function string.
32
+
33
+ Parameters
34
+ ----------
35
+ agent : WaldiezRagUser
36
+ The agent.
37
+ agent_name : str
38
+ The agent's name.
39
+
40
+ Returns
41
+ -------
42
+ Tuple[str, str, str]
43
+ The 'embedding_function', the import and the custom_embedding_function.
44
+ """
45
+ to_import = ""
46
+ embedding_function_arg = ""
47
+ embedding_function_body = ""
48
+ if agent.retrieve_config.use_custom_embedding:
49
+ embedding_function_arg = f"custom_embedding_function_{agent_name}"
50
+ embedding_function_body = (
51
+ f"\ndef custom_embedding_function_{agent_name}():\n"
52
+ f"{agent.retrieve_config.embedding_function_string}\n"
53
+ )
54
+ else:
55
+ to_import = "from sentence_transformers import SentenceTransformer"
56
+ embedding_function_arg = "SentenceTransformer("
57
+ embedding_function_arg += (
58
+ f'"{agent.retrieve_config.db_config.model}").encode'
59
+ )
60
+ return embedding_function_arg, to_import, embedding_function_body
61
+
62
+
63
+ def get_pgvector_db_args(
64
+ agent: WaldiezRagUser, agent_name: str
65
+ ) -> Tuple[str, Set[str], str]:
66
+ """Get the kwargs to use for PGVectorDB.
67
+
68
+ Parameters
69
+ ----------
70
+ agent : WaldiezRagUser
71
+ The agent.
72
+ agent_name : str
73
+ The agent's name.
74
+
75
+ Returns
76
+ -------
77
+ Tuple[str, Set[str], str]
78
+ The kwargs to use, what to import and the custom_embedding_function.
79
+ """
80
+ client_str, to_import_client = _get_pgvector_client_string(agent)
81
+ embedding_function_arg, to_import_embedding, embedding_function_body = (
82
+ _get_pgvector_embedding_function_string(agent, agent_name)
83
+ )
84
+ to_import = (
85
+ {to_import_client, to_import_embedding}
86
+ if to_import_embedding
87
+ else {to_import_client}
88
+ )
89
+ kwarg_str = (
90
+ f" client={client_str},\n"
91
+ f" embedding_function={embedding_function_arg},\n"
92
+ )
93
+ return kwarg_str, to_import, embedding_function_body
@@ -0,0 +1,112 @@
1
+ """Get qdrant db related imports and content."""
2
+
3
+ from pathlib import Path
4
+ from typing import Set, Tuple
5
+
6
+ from waldiez.models import WaldiezRagUser
7
+
8
+
9
+ def _get_qdrant_client_string(agent: WaldiezRagUser) -> Tuple[str, str]:
10
+ """Get the QdrantVectorDB client string.
11
+
12
+ Parameters
13
+ ----------
14
+ agent : WaldiezRagUser
15
+ The agent.
16
+ agent_name : str
17
+ The agent's name.
18
+
19
+ Returns
20
+ -------
21
+ Tuple[str, str, str]
22
+ The 'client' argument, and the module to import.
23
+ """
24
+ to_import: str = "from qdrant_client import QdrantClient"
25
+ client_str = "QdrantClient("
26
+ if agent.retrieve_config.db_config.use_memory:
27
+ client_str += 'location=":memory:")'
28
+ elif (
29
+ agent.retrieve_config.db_config.use_local_storage
30
+ and agent.retrieve_config.db_config.local_storage_path
31
+ ):
32
+ local_path = Path(agent.retrieve_config.db_config.local_storage_path)
33
+ client_str += f'location="{local_path}")'
34
+ elif agent.retrieve_config.db_config.connection_url:
35
+ client_str += (
36
+ f'location="{agent.retrieve_config.db_config.connection_url}")'
37
+ )
38
+ else:
39
+ # fallback to memory
40
+ client_str += 'location=":memory:")'
41
+ return client_str, to_import
42
+
43
+
44
+ def _get_qdrant_embedding_function_string(
45
+ agent: WaldiezRagUser, agent_name: str
46
+ ) -> Tuple[str, str, str]:
47
+ """Get the QdrantVectorDB embedding function string.
48
+
49
+ Parameters
50
+ ----------
51
+ agent : WaldiezRagUser
52
+ The agent.
53
+ agent_name : str
54
+ The agent's name.
55
+
56
+ Returns
57
+ -------
58
+ Tuple[str, str, str]
59
+ The 'embedding_function', the module to import
60
+ and the custom_embedding_function if used.
61
+ """
62
+ to_import = ""
63
+ embedding_function_arg = ""
64
+ embedding_function_body = ""
65
+ vector_db_model = agent.retrieve_config.db_config.model
66
+ if not agent.retrieve_config.use_custom_embedding:
67
+ to_import = (
68
+ "from autogen.agentchat.contrib.vectordb.qdrant "
69
+ "import FastEmbedEmbeddingFunction"
70
+ )
71
+ embedding_function_arg = "FastEmbedEmbeddingFunction("
72
+ embedding_function_arg += f'model_name="{vector_db_model}")'
73
+ else:
74
+ embedding_function_arg = f"custom_embedding_function_{agent_name}"
75
+ embedding_function_body = (
76
+ f"\ndef custom_embedding_function_{agent_name}():\n"
77
+ f"{agent.retrieve_config.embedding_function_string}\n"
78
+ )
79
+ return embedding_function_arg, to_import, embedding_function_body
80
+
81
+
82
+ def get_qdrant_db_args(
83
+ agent: WaldiezRagUser, agent_name: str
84
+ ) -> Tuple[str, Set[str], str]:
85
+ """Get the kwargs to use for QdrantVectorDB.
86
+
87
+ Parameters
88
+ ----------
89
+ agent : WaldiezRagUser
90
+ The agent.
91
+ agent_name : str
92
+ The agent's name.
93
+
94
+ Returns
95
+ -------
96
+ Tuple[str, Set[str], str]
97
+ The kwargs to use, the imports and the embedding function body if used.
98
+ """
99
+ client_str, to_import_client = _get_qdrant_client_string(agent)
100
+ embedding_function_arg, to_import_embedding, embedding_function_body = (
101
+ _get_qdrant_embedding_function_string(agent, agent_name)
102
+ )
103
+ to_import = (
104
+ {to_import_client, to_import_embedding}
105
+ if to_import_embedding
106
+ else {to_import_client}
107
+ )
108
+ kwarg_string = (
109
+ f" client={client_str},\n"
110
+ f" embedding_function={embedding_function_arg},\n"
111
+ )
112
+ return kwarg_string, to_import, embedding_function_body
@@ -0,0 +1,165 @@
1
+ """RAG User related exporting utils."""
2
+
3
+ from typing import Dict, Set, Tuple
4
+
5
+ from waldiez.models import (
6
+ WaldiezAgent,
7
+ WaldiezRagUser,
8
+ WaldiezRagUserModels,
9
+ WaldiezRagUserRetrieveConfig,
10
+ )
11
+
12
+ from ...utils import get_object_string
13
+ from .vector_db import get_rag_user_vector_db_string
14
+
15
+
16
+ def get_rag_user_retrieve_config_str(
17
+ agent: WaldiezRagUser,
18
+ agent_name: str,
19
+ model_names: Dict[str, str],
20
+ ) -> Tuple[str, str, Set[str]]:
21
+ """Get the RAG user retrieve config string.
22
+
23
+ Parameters
24
+ ----------
25
+ agent : WaldiezRagUser
26
+ The agent.
27
+ agent_name : str
28
+ The agent's name.
29
+ model_names : Dict[str, str]
30
+ A mapping from model id to model name.
31
+ Returns
32
+ -------
33
+ Tuple[str, str, Set[str]]
34
+ The content before the args, the args and the imports.
35
+ """
36
+ # e.g. user_agent = RetrieveUserProxyAgent(
37
+ # ...other common/agent args,
38
+ # retrieve_config={what_this_returns})
39
+ imports: Set[str] = set()
40
+ retrieve_config = agent.retrieve_config
41
+ before_the_args, vector_db_arg, db_imports = get_rag_user_vector_db_string(
42
+ agent=agent,
43
+ agent_name=agent_name,
44
+ )
45
+ imports.update(db_imports)
46
+ args_dict = _get_args_dict(agent, retrieve_config, model_names)
47
+ if retrieve_config.use_custom_token_count:
48
+ token_count_arg_name = f"custom_token_count_function_{agent_name}"
49
+ before_the_args += (
50
+ f"\ndef {token_count_arg_name}():\n"
51
+ f"{retrieve_config.token_count_function_string}"
52
+ "\n\n"
53
+ )
54
+ args_dict["custom_token_count_function"] = token_count_arg_name
55
+ if retrieve_config.use_custom_text_split:
56
+ text_split_arg_name = f"custom_text_split_function_{agent_name}"
57
+ before_the_args += (
58
+ f"\ndef {text_split_arg_name}():\n"
59
+ f"{retrieve_config.text_split_function_string}"
60
+ "\n\n"
61
+ )
62
+ args_dict["custom_text_split_function"] = text_split_arg_name
63
+ args_content = get_object_string(args_dict)
64
+ # get the last line (where the dict ends)
65
+ args_parts = args_content.split("\n")
66
+ before_vector_db = args_parts[:-1]
67
+ closing_arg = args_parts[-1]
68
+ args_content = "\n".join(before_vector_db)
69
+ # add the vector_db arg
70
+ args_content += f',\n "vector_db": {vector_db_arg},\n'
71
+ args_content += closing_arg
72
+ return before_the_args, args_content, imports
73
+
74
+
75
+ def get_rag_user_extras(
76
+ agent: WaldiezAgent,
77
+ agent_name: str,
78
+ model_names: Dict[str, str],
79
+ ) -> Tuple[str, str, Set[str]]:
80
+ """Get the RAG user extra argument, imports and content before the agent.
81
+
82
+ Parameters
83
+ ----------
84
+ agent : WaldiezAgent
85
+ The agent.
86
+ agent_name : str
87
+ The agent's name.
88
+ model_names : Dict[str, str]
89
+ A mapping from model id to model name.
90
+
91
+ Returns
92
+ -------
93
+ Tuple[str, str, Set[str]]
94
+ The content before the agent, the retrieve arg and the db imports.
95
+ """
96
+ before_agent_string = ""
97
+ retrieve_arg = ""
98
+ db_imports: Set[str] = set()
99
+ if agent.agent_type == "rag_user" and isinstance(agent, WaldiezRagUser):
100
+ rag_content_before_agent, retrieve_arg, db_imports = (
101
+ get_rag_user_retrieve_config_str(
102
+ agent=agent, agent_name=agent_name, model_names=model_names
103
+ )
104
+ )
105
+ if retrieve_arg:
106
+ retrieve_arg = f"\n retrieve_config={retrieve_arg},"
107
+ if rag_content_before_agent:
108
+ before_agent_string += rag_content_before_agent
109
+ return before_agent_string, retrieve_arg, db_imports
110
+
111
+
112
+ def _get_model_arg(
113
+ agent: WaldiezRagUser,
114
+ retrieve_config: WaldiezRagUserRetrieveConfig,
115
+ model_names: Dict[str, str],
116
+ ) -> str: # pragma: no cover
117
+ agent_models = agent.data.model_ids
118
+ if agent_models:
119
+ first_model = agent_models[0]
120
+ first_model_name = model_names[first_model]
121
+ new_model_name = f"{first_model_name}"
122
+ return f"{new_model_name}"
123
+ if retrieve_config.model in model_names:
124
+ selected_model = model_names[retrieve_config.model]
125
+ new_model_name = f"{selected_model}"
126
+ return f"{new_model_name}"
127
+ return WaldiezRagUserModels[retrieve_config.vector_db]
128
+
129
+
130
+ def _get_args_dict(
131
+ agent: WaldiezRagUser,
132
+ retrieve_config: WaldiezRagUserRetrieveConfig,
133
+ model_names: Dict[str, str],
134
+ ) -> Dict[str, str]:
135
+ model_arg = _get_model_arg(agent, retrieve_config, model_names)
136
+ args_dict = {
137
+ "task": retrieve_config.task,
138
+ "model": model_arg,
139
+ }
140
+ optional_args = [
141
+ "chunk_token_size",
142
+ "context_max_tokens",
143
+ "customized_prompt",
144
+ "customized_answer_prefix",
145
+ "docs_path",
146
+ ]
147
+ for arg in optional_args:
148
+ arg_value = getattr(retrieve_config, arg)
149
+ if arg_value is not None:
150
+ args_dict[arg] = arg_value
151
+ args_dict[arg] = getattr(retrieve_config, arg)
152
+ non_optional_args = [
153
+ "new_docs",
154
+ "update_context",
155
+ "get_or_create",
156
+ "overwrite",
157
+ "recursive",
158
+ "chunk_mode",
159
+ "must_break_at_empty_line",
160
+ "collection_name",
161
+ "distance_threshold",
162
+ ]
163
+ for arg in non_optional_args:
164
+ args_dict[arg] = getattr(retrieve_config, arg)
165
+ return args_dict