waldiez 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of waldiez might be problematic. Click here for more details.

Files changed (76) hide show
  1. waldiez/_version.py +1 -1
  2. waldiez/cli.py +3 -27
  3. waldiez/exporter.py +0 -13
  4. waldiez/exporting/agent/exporter.py +38 -0
  5. waldiez/exporting/agent/extras/__init__.py +2 -0
  6. waldiez/exporting/agent/extras/doc_agent_extras.py +366 -0
  7. waldiez/exporting/agent/extras/group_member_extras.py +3 -2
  8. waldiez/exporting/agent/processor.py +113 -15
  9. waldiez/exporting/chats/processor.py +2 -21
  10. waldiez/exporting/chats/utils/common.py +66 -1
  11. waldiez/exporting/chats/utils/group.py +6 -3
  12. waldiez/exporting/chats/utils/nested.py +1 -1
  13. waldiez/exporting/chats/utils/sequential.py +25 -9
  14. waldiez/exporting/chats/utils/single.py +8 -6
  15. waldiez/exporting/core/context.py +0 -12
  16. waldiez/exporting/core/extras/agent_extras/standard_extras.py +3 -1
  17. waldiez/exporting/core/extras/base.py +20 -17
  18. waldiez/exporting/core/extras/path_resolver.py +39 -41
  19. waldiez/exporting/core/extras/serializer.py +16 -1
  20. waldiez/exporting/core/protocols.py +17 -0
  21. waldiez/exporting/core/types.py +6 -9
  22. waldiez/exporting/flow/execution_generator.py +56 -21
  23. waldiez/exporting/flow/exporter.py +1 -4
  24. waldiez/exporting/flow/factory.py +0 -9
  25. waldiez/exporting/flow/file_generator.py +6 -0
  26. waldiez/exporting/flow/orchestrator.py +27 -21
  27. waldiez/exporting/flow/utils/__init__.py +0 -2
  28. waldiez/exporting/flow/utils/common.py +15 -96
  29. waldiez/exporting/flow/utils/importing.py +4 -0
  30. waldiez/io/mqtt.py +33 -14
  31. waldiez/io/redis.py +18 -13
  32. waldiez/io/structured.py +9 -4
  33. waldiez/io/utils.py +32 -0
  34. waldiez/io/ws.py +8 -2
  35. waldiez/models/__init__.py +6 -0
  36. waldiez/models/agents/__init__.py +8 -0
  37. waldiez/models/agents/agent/agent.py +136 -38
  38. waldiez/models/agents/agent/agent_type.py +3 -2
  39. waldiez/models/agents/agents.py +10 -0
  40. waldiez/models/agents/doc_agent/__init__.py +13 -0
  41. waldiez/models/agents/doc_agent/doc_agent.py +126 -0
  42. waldiez/models/agents/doc_agent/doc_agent_data.py +149 -0
  43. waldiez/models/agents/doc_agent/rag_query_engine.py +127 -0
  44. waldiez/models/flow/flow.py +13 -2
  45. waldiez/models/model/__init__.py +2 -2
  46. waldiez/models/model/_aws.py +75 -0
  47. waldiez/models/model/_llm.py +516 -0
  48. waldiez/models/model/_price.py +30 -0
  49. waldiez/models/model/model.py +45 -2
  50. waldiez/models/model/model_data.py +2 -83
  51. waldiez/models/tool/predefined/_duckduckgo.py +123 -0
  52. waldiez/models/tool/predefined/_google.py +31 -9
  53. waldiez/models/tool/predefined/_perplexity.py +161 -0
  54. waldiez/models/tool/predefined/_searxng.py +152 -0
  55. waldiez/models/tool/predefined/_tavily.py +46 -9
  56. waldiez/models/tool/predefined/_wikipedia.py +26 -6
  57. waldiez/models/tool/predefined/_youtube.py +36 -8
  58. waldiez/models/tool/predefined/registry.py +6 -0
  59. waldiez/models/waldiez.py +12 -0
  60. waldiez/runner.py +177 -408
  61. waldiez/running/__init__.py +2 -4
  62. waldiez/running/base_runner.py +100 -112
  63. waldiez/running/environment.py +29 -4
  64. waldiez/running/post_run.py +0 -1
  65. waldiez/running/protocol.py +36 -48
  66. waldiez/running/run_results.py +5 -5
  67. waldiez/running/standard_runner.py +429 -0
  68. waldiez/running/timeline_processor.py +0 -82
  69. {waldiez-0.5.3.dist-info → waldiez-0.5.4.dist-info}/METADATA +58 -61
  70. {waldiez-0.5.3.dist-info → waldiez-0.5.4.dist-info}/RECORD +74 -64
  71. waldiez/running/import_runner.py +0 -437
  72. waldiez/running/subprocess_runner.py +0 -104
  73. {waldiez-0.5.3.dist-info → waldiez-0.5.4.dist-info}/WHEEL +0 -0
  74. {waldiez-0.5.3.dist-info → waldiez-0.5.4.dist-info}/entry_points.txt +0 -0
  75. {waldiez-0.5.3.dist-info → waldiez-0.5.4.dist-info}/licenses/LICENSE +0 -0
  76. {waldiez-0.5.3.dist-info → waldiez-0.5.4.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,5 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0.
2
2
  # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ # pylint: disable=too-many-public-methods
3
4
  """Base agent class to be inherited by all agents."""
4
5
 
5
6
  import warnings
@@ -150,6 +151,28 @@ class WaldiezAgent(WaldiezBase):
150
151
  )
151
152
  _checked_handoffs: Annotated[bool, Field(init=False, default=False)] = False
152
153
 
154
+ @property
155
+ def args_to_skip(self) -> list[str]:
156
+ """Get the set of arguments to skip when generating the agent string.
157
+
158
+ Returns
159
+ -------
160
+ list[str]
161
+ The list of arguments to skip.
162
+ """
163
+ if self.is_doc_agent:
164
+ return [
165
+ "description",
166
+ "human_input_mode",
167
+ "max_consecutive_auto_reply",
168
+ "default_auto_reply",
169
+ "code_execution_config",
170
+ "is_termination_msg",
171
+ "functions",
172
+ "update_agent_state_before_reply",
173
+ ]
174
+ return []
175
+
153
176
  @property
154
177
  def handoffs(self) -> list[WaldiezHandoff]:
155
178
  """Get the handoffs for this agent.
@@ -266,6 +289,17 @@ class WaldiezAgent(WaldiezBase):
266
289
  "rag_user_proxy",
267
290
  )
268
291
 
292
+ @property
293
+ def is_assistant(self) -> bool:
294
+ """Check if the agent is an assistant.
295
+
296
+ Returns
297
+ -------
298
+ bool
299
+ True if the agent is an assistant, False otherwise.
300
+ """
301
+ return self.agent_type == "assistant"
302
+
269
303
  @property
270
304
  def is_rag_user(self) -> bool:
271
305
  """Check if the agent is a RAG user.
@@ -277,6 +311,17 @@ class WaldiezAgent(WaldiezBase):
277
311
  """
278
312
  return self.agent_type in ("rag_user", "rag_user_proxy")
279
313
 
314
+ @property
315
+ def is_doc_agent(self) -> bool:
316
+ """Check if the agent is a doc agent.
317
+
318
+ Returns
319
+ -------
320
+ bool
321
+ True if the agent is a doc agent, False otherwise.
322
+ """
323
+ return self.agent_type == "doc_agent"
324
+
280
325
  @property
281
326
  def is_group_manager(self) -> bool:
282
327
  """Check if the agent is a group manager.
@@ -293,16 +338,9 @@ class WaldiezAgent(WaldiezBase):
293
338
  """Return the AG2 class of the agent."""
294
339
  class_name = "ConversableAgent"
295
340
  if self.is_group_member:
296
- if (
297
- getattr(self.data, "is_multimodal", False) is True
298
- ): # pragma: no branch
299
- class_name = "MultimodalConversableAgent"
300
- return class_name
301
- if self.agent_type == "assistant":
302
- if getattr(self.data, "is_multimodal", False) is True:
303
- class_name = "MultimodalConversableAgent"
304
- else:
305
- class_name = "AssistantAgent"
341
+ return self.get_group_member_class_name()
342
+ if self.is_assistant: # pragma: no branch
343
+ return self.get_assistant_class_name()
306
344
  if self.is_user:
307
345
  class_name = "UserProxyAgent"
308
346
  if self.is_rag_user:
@@ -313,40 +351,100 @@ class WaldiezAgent(WaldiezBase):
313
351
  class_name = "CaptainAgent"
314
352
  if self.is_group_manager:
315
353
  class_name = "GroupChatManager"
354
+ if self.is_doc_agent:
355
+ class_name = "DocAgent"
356
+ return class_name # pragma: no cover
357
+
358
+ def get_group_member_class_name(self) -> str:
359
+ """Get the class name for a group member agent.
360
+
361
+ Returns
362
+ -------
363
+ str
364
+ The class name for the group member agent.
365
+ """
366
+ if (
367
+ getattr(self.data, "is_multimodal", False) is True
368
+ ): # pragma: no branch
369
+ return "MultimodalConversableAgent"
370
+ if self.is_captain: # pragma: no branch
371
+ return "CaptainAgent"
372
+ if self.is_reasoning: # pragma: no branch
373
+ return "ReasoningAgent"
374
+ if self.is_doc_agent: # pragma: no branch
375
+ return "DocAgent"
376
+ return "ConversableAgent" # pragma: no cover
377
+
378
+ def get_assistant_class_name(self) -> str:
379
+ """Get the class name for an assistant agent.
380
+
381
+ Returns
382
+ -------
383
+ str
384
+ The class name for the assistant agent.
385
+ """
386
+ if getattr(self.data, "is_multimodal", False) is True:
387
+ class_name = "MultimodalConversableAgent"
388
+ else:
389
+ class_name = "AssistantAgent"
316
390
  return class_name # pragma: no cover
317
391
 
318
392
  @property
319
393
  def ag2_imports(self) -> set[str]:
320
- """Return the AG2 imports of the agent."""
394
+ """Return the AG2 imports of the agent.
395
+
396
+ Returns
397
+ -------
398
+ set[str]
399
+ A set of import statements required for the agent class.
400
+
401
+ Raises
402
+ ------
403
+ ValueError
404
+ If the agent class is unknown and no imports are defined.
405
+ """
321
406
  agent_class = self.ag2_class
322
407
  imports = {"import autogen"}
323
- if agent_class == "AssistantAgent":
324
- imports.add("from autogen import AssistantAgent")
325
- elif agent_class == "UserProxyAgent":
326
- imports.add("from autogen import UserProxyAgent")
327
- elif agent_class == "RetrieveUserProxyAgent":
328
- imports.add(
329
- "from autogen.agentchat.contrib.retrieve_user_proxy_agent "
330
- "import RetrieveUserProxyAgent"
331
- )
332
- elif agent_class == "MultimodalConversableAgent":
333
- imports.add(
334
- "from autogen.agentchat.contrib.multimodal_conversable_agent "
335
- "import MultimodalConversableAgent"
336
- )
337
- elif agent_class == "ReasoningAgent":
338
- imports.add(
339
- "from autogen.agents.experimental import ReasoningAgent"
340
- )
341
- elif agent_class == "CaptainAgent":
342
- imports.add(
343
- "from autogen.agentchat.contrib.captainagent "
344
- "import CaptainAgent"
345
- )
346
- elif agent_class == "GroupChatManager": # pragma: no branch
347
- imports.add("from autogen import GroupChat")
348
- imports.add("from autogen.agentchat import GroupChatManager")
349
- imports.add("from autogen.agentchat.group import ContextVariables")
408
+ match agent_class:
409
+ case "AssistantAgent":
410
+ imports.add("from autogen import AssistantAgent")
411
+ case "UserProxyAgent":
412
+ imports.add("from autogen import UserProxyAgent")
413
+ case "RetrieveUserProxyAgent":
414
+ imports.add(
415
+ "from autogen.agentchat.contrib.retrieve_user_proxy_agent "
416
+ "import RetrieveUserProxyAgent"
417
+ )
418
+ case "MultimodalConversableAgent":
419
+ imports.add(
420
+ "from "
421
+ "autogen.agentchat.contrib.multimodal_conversable_agent "
422
+ "import MultimodalConversableAgent"
423
+ )
424
+ case "ReasoningAgent":
425
+ imports.add(
426
+ "from autogen.agents.experimental import ReasoningAgent"
427
+ )
428
+ case "CaptainAgent":
429
+ imports.add(
430
+ "from autogen.agentchat.contrib.captainagent "
431
+ "import CaptainAgent"
432
+ )
433
+ case "GroupChatManager": # pragma: no branch
434
+ imports.add("from autogen import GroupChat")
435
+ imports.add("from autogen.agentchat import GroupChatManager")
436
+ imports.add(
437
+ "from autogen.agentchat.group import ContextVariables"
438
+ )
439
+ case "DocAgent":
440
+ imports.add("from autogen.agents.experimental import DocAgent")
441
+ case "ConversableAgent":
442
+ imports.add("from autogen import ConversableAgent")
443
+ case _: # pragma: no cover
444
+ raise ValueError(
445
+ f"Unknown agent class: {agent_class}. "
446
+ "Please implement the imports for this class."
447
+ )
350
448
  return imports
351
449
 
352
450
  def validate_linked_tools(
@@ -6,14 +6,15 @@ from typing_extensions import Literal
6
6
 
7
7
  # pylint: disable=line-too-long
8
8
  # fmt: off
9
- WaldiezAgentType = Literal["user_proxy", "assistant", "group_manager", "manager", "rag_user", "swarm", "reasoning", "captain", "user", "rag_user_proxy"] # noqa: E501
9
+ WaldiezAgentType = Literal["user_proxy", "assistant", "group_manager", "manager", "rag_user", "swarm", "reasoning", "captain", "user", "rag_user_proxy", "doc_agent"] # noqa: E501
10
10
  """Possible types of a Waldiez Agent:
11
11
  - user_proxy,
12
12
  - assistant,
13
13
  - group_manager,
14
- - rag_user_proxy,
14
+ - rag_user_proxy (deprecated: use doc_agent),
15
15
  - reasoning,
16
16
  - captain,
17
+ - doc_agent,
17
18
  - swarm (deprecated: do not use it),
18
19
  - user (deprecated: use user_proxy)
19
20
  - rag_user (deprecated: user rag_user_proxy)
@@ -11,6 +11,7 @@ from ..common import WaldiezBase
11
11
  from .agent import WaldiezAgent
12
12
  from .assistant import WaldiezAssistant
13
13
  from .captain import WaldiezCaptainAgent
14
+ from .doc_agent import WaldiezDocAgent
14
15
  from .group_manager import WaldiezGroupManager
15
16
  from .rag_user_proxy import WaldiezRagUserProxy
16
17
  from .reasoning import WaldiezReasoningAgent
@@ -84,6 +85,14 @@ class WaldiezAgents(WaldiezBase):
84
85
  default_factory=list,
85
86
  ),
86
87
  ] = []
88
+ docAgents: Annotated[
89
+ list[WaldiezDocAgent],
90
+ Field(
91
+ title="Document Agents.",
92
+ description="The Document agents in the flow.",
93
+ default_factory=list,
94
+ ),
95
+ ] = []
87
96
 
88
97
  @property
89
98
  def members(self) -> Iterator[WaldiezAgent]:
@@ -100,6 +109,7 @@ class WaldiezAgents(WaldiezBase):
100
109
  yield from self.reasoningAgents
101
110
  yield from self.captainAgents
102
111
  yield from self.groupManagerAgents
112
+ yield from self.docAgents
103
113
 
104
114
  @model_validator(mode="after")
105
115
  def validate_agents(self) -> Self:
@@ -0,0 +1,13 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Document agent model."""
4
+
5
+ from .doc_agent import WaldiezDocAgent
6
+ from .doc_agent_data import WaldiezDocAgentData
7
+ from .rag_query_engine import WaldiezDocAgentQueryEngine
8
+
9
+ __all__ = [
10
+ "WaldiezDocAgent",
11
+ "WaldiezDocAgentData",
12
+ "WaldiezDocAgentQueryEngine",
13
+ ]
@@ -0,0 +1,126 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Document agent model."""
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import Field
8
+ from typing_extensions import Annotated
9
+
10
+ from waldiez.models.agents.doc_agent.rag_query_engine import (
11
+ WaldiezDocAgentQueryEngine,
12
+ )
13
+
14
+ from ...model import WaldiezModel
15
+ from ..agent import WaldiezAgent
16
+ from .doc_agent_data import WaldiezDocAgentData
17
+
18
+
19
+ class WaldiezDocAgent(WaldiezAgent):
20
+ """Document agent class.
21
+
22
+ The agent for handling document-related tasks.
23
+ Extends `WaldiezAgent`.
24
+ """
25
+
26
+ agent_type: Annotated[ # pyright: ignore
27
+ Literal["doc_agent"],
28
+ Field(
29
+ default="doc_agent",
30
+ title="Agent type",
31
+ description="The agent type: 'doc_agent' for a document agent",
32
+ alias="agentType",
33
+ ),
34
+ ] = "doc_agent"
35
+ data: Annotated[ # pyright: ignore
36
+ WaldiezDocAgentData,
37
+ Field(
38
+ title="Data",
39
+ description="The document agent's data",
40
+ default_factory=WaldiezDocAgentData, # pyright: ignore
41
+ ),
42
+ ]
43
+
44
+ @property
45
+ def reset_collection(self) -> bool:
46
+ """Get whether to reset the collection for the document agent.
47
+
48
+ Returns
49
+ -------
50
+ bool
51
+ Whether to reset the collection for the document agent.
52
+ """
53
+ return self.data.reset_collection
54
+
55
+ def get_collection_name(self) -> str:
56
+ """Get the collection name for the document agent.
57
+
58
+ Returns
59
+ -------
60
+ str
61
+ The collection name for the document agent.
62
+ """
63
+ return self.data.get_collection_name()
64
+
65
+ def get_query_engine(self) -> WaldiezDocAgentQueryEngine:
66
+ """Get the query engine for the document agent.
67
+
68
+ Returns
69
+ -------
70
+ WaldiezDocAgentQueryEngine
71
+ The query engine for the document agent.
72
+ """
73
+ return self.data.get_query_engine()
74
+
75
+ def get_db_path(self) -> str:
76
+ """Get the database path for the query engine.
77
+
78
+ Returns
79
+ -------
80
+ str
81
+ The database path for the query engine.
82
+ """
83
+ return self.data.get_db_path()
84
+
85
+ def get_parsed_docs_path(self) -> str:
86
+ """Get the parsed documents path for the document agent.
87
+
88
+ Returns
89
+ -------
90
+ str
91
+ The parsed documents path for the document agent.
92
+ """
93
+ return self.data.get_parsed_docs_path()
94
+
95
+ def get_llm_requirements(
96
+ self,
97
+ ag2_version: str,
98
+ all_models: list[WaldiezModel],
99
+ ) -> set[str]:
100
+ """Get the LLM requirements for the document agent.
101
+
102
+ Parameters
103
+ ----------
104
+ ag2_version : str
105
+ The version of AG2 to use for the requirements.
106
+ all_models : list[WaldiezModel]
107
+ All the models in the flow.
108
+
109
+ Returns
110
+ -------
111
+ set[str]
112
+ The set of LLM requirements for the document agent.
113
+ """
114
+ requirements = {
115
+ "llama-index",
116
+ "llama-index-core",
117
+ f"ag2[rag]=={ag2_version}",
118
+ }
119
+ if not self.data.model_ids:
120
+ requirements.add("llama-index-llms-openai")
121
+ else:
122
+ for model_id in self.data.model_ids:
123
+ model = next((m for m in all_models if m.id == model_id), None)
124
+ if model:
125
+ return model.get_llm_requirements(ag2_version=ag2_version)
126
+ return requirements
@@ -0,0 +1,149 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Document agent data model."""
4
+
5
+ from pathlib import Path
6
+ from typing import Optional, Union
7
+
8
+ from platformdirs import user_data_dir
9
+ from pydantic import Field, model_validator
10
+ from typing_extensions import Annotated, Self
11
+
12
+ from ..agent import WaldiezAgentData
13
+ from .rag_query_engine import WaldiezDocAgentQueryEngine
14
+
15
+
16
+ class WaldiezDocAgentData(WaldiezAgentData):
17
+ """Document agent data class.
18
+
19
+ The data for a document agent.
20
+ Extends `WaldiezAgentData`.
21
+ Extra attributes:
22
+ - `collection_name`: Optional string, the name of the collection.
23
+ - `reset_collection`: Optional boolean, whether to reset the collection.
24
+ - `parsed_docs_path`: Optional string, the path to the parsed documents.
25
+ - `query_engine`: Optional `RAGQueryEngine`, the query engine to use.
26
+ """
27
+
28
+ collection_name: Annotated[
29
+ Optional[str],
30
+ Field(
31
+ title="Collection Name",
32
+ description="The name of the collection for the document agent.",
33
+ default=None,
34
+ alias="collectionName",
35
+ ),
36
+ ] = None
37
+ reset_collection: Annotated[
38
+ bool,
39
+ Field(
40
+ title="Reset Collection",
41
+ description=(
42
+ "Whether to reset the collection for the document agent."
43
+ ),
44
+ default=False,
45
+ alias="resetCollection",
46
+ ),
47
+ ] = False
48
+ parsed_docs_path: Annotated[
49
+ Optional[Union[str, Path]],
50
+ Field(
51
+ title="Parsed Documents Path",
52
+ description=(
53
+ "The path to the parsed documents for the document agent."
54
+ ),
55
+ default=None,
56
+ alias="parsedDocsPath",
57
+ ),
58
+ ] = None
59
+ query_engine: Annotated[
60
+ Optional[WaldiezDocAgentQueryEngine],
61
+ Field(
62
+ title="Query Engine",
63
+ description="The query engine to use for the document agent.",
64
+ default=None,
65
+ alias="queryEngine",
66
+ ),
67
+ ] = None
68
+
69
+ @model_validator(mode="after")
70
+ def validate_parsed_docs_path(self) -> Self:
71
+ """Ensure the parsed documents path is set and is a directory.
72
+
73
+ If not set, create a default path in the user data directory.
74
+
75
+ Returns
76
+ -------
77
+ Self
78
+ The instance with validated `parsed_docs_path`.
79
+ """
80
+ if not self.parsed_docs_path:
81
+ data_dir = user_data_dir(
82
+ appname="waldiez",
83
+ appauthor="waldiez",
84
+ )
85
+ parsed_docs_path = Path(data_dir) / "parsed_docs"
86
+ parsed_docs_path.mkdir(parents=True, exist_ok=True)
87
+ self.parsed_docs_path = str(parsed_docs_path.resolve())
88
+ resolved = Path(self.parsed_docs_path).resolve()
89
+ if not resolved.is_absolute():
90
+ self.parsed_docs_path = str(Path.cwd() / self.parsed_docs_path)
91
+ if not Path(self.parsed_docs_path).is_dir():
92
+ Path(self.parsed_docs_path).mkdir(parents=True, exist_ok=True)
93
+ self.parsed_docs_path = str(Path(self.parsed_docs_path).resolve())
94
+ return self
95
+
96
+ def get_query_engine(self) -> WaldiezDocAgentQueryEngine:
97
+ """Get the query engine for the document agent.
98
+
99
+ Returns
100
+ -------
101
+ WaldiezDocAgentQueryEngine
102
+ The query engine for the document agent.
103
+ """
104
+ if not self.query_engine:
105
+ self.query_engine = WaldiezDocAgentQueryEngine()
106
+ return self.query_engine
107
+
108
+ def get_db_path(self) -> str:
109
+ """Get the database path for the query engine.
110
+
111
+ Returns
112
+ -------
113
+ str
114
+ The database path for the query engine.
115
+ """
116
+ return self.get_query_engine().get_db_path()
117
+
118
+ def get_collection_name(self) -> str:
119
+ """Get the collection name for the document agent.
120
+
121
+ Returns
122
+ -------
123
+ str
124
+ The collection name for the document agent.
125
+ """
126
+ return self.collection_name or "docling-parsed-docs"
127
+
128
+ def get_parsed_docs_path(self) -> str:
129
+ """Get the parsed documents path for the document agent.
130
+
131
+ Returns
132
+ -------
133
+ str
134
+ The parsed documents path for the document agent.
135
+ """
136
+ if not self.parsed_docs_path:
137
+ data_dir = user_data_dir(
138
+ appname="waldiez",
139
+ appauthor="waldiez",
140
+ )
141
+ parsed_docs_path = Path(data_dir) / "parsed_docs"
142
+ parsed_docs_path.mkdir(parents=True, exist_ok=True)
143
+ self.parsed_docs_path = str(parsed_docs_path.resolve())
144
+ resolved = Path(self.parsed_docs_path).resolve()
145
+ if not resolved.is_absolute():
146
+ self.parsed_docs_path = str(Path.cwd() / self.parsed_docs_path)
147
+ if not Path(self.parsed_docs_path).is_dir():
148
+ Path(self.parsed_docs_path).mkdir(parents=True, exist_ok=True)
149
+ return str(self.parsed_docs_path)
@@ -0,0 +1,127 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Document agent data model."""
4
+
5
+ from pathlib import Path
6
+ from typing import Optional, Union
7
+
8
+ from platformdirs import user_data_dir
9
+ from pydantic import Field, model_validator
10
+ from typing_extensions import Annotated, Literal, Self
11
+
12
+ from ...common import WaldiezBase
13
+
14
+
15
+ class WaldiezDocAgentQueryEngine(WaldiezBase):
16
+ """RAG Query Engine class.
17
+
18
+ The data for a RAG query engine.
19
+ """
20
+
21
+ type: Annotated[
22
+ Optional[
23
+ Literal[
24
+ "VectorChromaQueryEngine",
25
+ "VectorChromaCitationQueryEngine",
26
+ "InMemoryQueryEngine",
27
+ ]
28
+ ],
29
+ "RAG Query Engine type",
30
+ ] = "VectorChromaQueryEngine"
31
+ db_path: Annotated[
32
+ Optional[Union[str, Path]],
33
+ Field(
34
+ title="Database Path",
35
+ description="The path to the database for the query engine.",
36
+ default=None,
37
+ alias="dbPath",
38
+ ),
39
+ ] = None
40
+ enable_query_citations: Annotated[
41
+ bool,
42
+ Field(
43
+ title="Enable Query Citations",
44
+ description=(
45
+ "Whether to enable query citations for the query engine."
46
+ ),
47
+ default=False,
48
+ alias="enableQueryCitations",
49
+ ),
50
+ ] = False
51
+ citation_chunk_size: Annotated[
52
+ int,
53
+ Field(
54
+ title="Citation Chunk Size",
55
+ description="The size of the citation chunks for the query engine.",
56
+ default=512,
57
+ alias="citationChunkSize",
58
+ ),
59
+ ] = 512
60
+
61
+ @model_validator(mode="after")
62
+ def validate_db_path(self) -> Self:
63
+ """Validate the db_path field.
64
+
65
+ Ensure the db_path is set and is a directory.
66
+
67
+ Returns
68
+ -------
69
+ Self
70
+ The instance of WaldiezDocAgentQueryEngine with validated db_path.
71
+ """
72
+ if not self.type:
73
+ self.type = "VectorChromaQueryEngine"
74
+ self.db_path = ensure_db_path(self.db_path)
75
+ return self
76
+
77
+ def get_db_path(self) -> str:
78
+ """Get the database path for the query engine.
79
+
80
+ Returns
81
+ -------
82
+ str
83
+ The database path for the query engine.
84
+ """
85
+ db_path = self.db_path or ensure_db_path(self.db_path)
86
+ return str(db_path)
87
+
88
+
89
+ def ensure_db_path(db_path: str | Path | None) -> str:
90
+ """Get the database path for the query engine.
91
+
92
+ Ensure the database path is set and is a directory.
93
+
94
+ Parameters
95
+ ----------
96
+ db_path : Optional[str]
97
+ The database path to validate.
98
+
99
+ Returns
100
+ -------
101
+ str
102
+ The database path for the query engine.
103
+
104
+ Raises
105
+ ------
106
+ ValueError
107
+ If the database path is not set or is not a directory.
108
+ """
109
+ if not db_path:
110
+ data_dir = user_data_dir(
111
+ appname="waldiez",
112
+ appauthor="waldiez",
113
+ )
114
+ data_dir_path = Path(data_dir) / "rag"
115
+ data_dir_path.mkdir(parents=True, exist_ok=True)
116
+ db_path_to_resolve = data_dir_path / "chroma"
117
+ db_path_to_resolve.mkdir(parents=True, exist_ok=True)
118
+ return str(db_path_to_resolve)
119
+
120
+ resolved = Path(db_path).resolve()
121
+ if not resolved.is_absolute():
122
+ resolved = (Path.cwd() / db_path).resolve()
123
+ if not resolved.exists():
124
+ resolved.mkdir(parents=True, exist_ok=True)
125
+ if not resolved.is_dir():
126
+ raise ValueError(f"The path {resolved} is not a directory.")
127
+ return str(resolved)