waldiez 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waldiez might be problematic. Click here for more details.
- waldiez/_version.py +1 -1
- waldiez/cli.py +5 -27
- waldiez/exporter.py +0 -13
- waldiez/exporting/agent/exporter.py +38 -0
- waldiez/exporting/agent/extras/__init__.py +2 -0
- waldiez/exporting/agent/extras/doc_agent_extras.py +366 -0
- waldiez/exporting/agent/extras/group_member_extras.py +3 -2
- waldiez/exporting/agent/processor.py +113 -15
- waldiez/exporting/chats/processor.py +2 -21
- waldiez/exporting/chats/utils/common.py +66 -1
- waldiez/exporting/chats/utils/group.py +6 -3
- waldiez/exporting/chats/utils/nested.py +1 -1
- waldiez/exporting/chats/utils/sequential.py +25 -9
- waldiez/exporting/chats/utils/single.py +8 -6
- waldiez/exporting/core/context.py +0 -12
- waldiez/exporting/core/extras/agent_extras/standard_extras.py +3 -1
- waldiez/exporting/core/extras/base.py +20 -17
- waldiez/exporting/core/extras/path_resolver.py +39 -41
- waldiez/exporting/core/extras/serializer.py +16 -1
- waldiez/exporting/core/protocols.py +17 -0
- waldiez/exporting/core/types.py +6 -9
- waldiez/exporting/flow/execution_generator.py +56 -21
- waldiez/exporting/flow/exporter.py +1 -4
- waldiez/exporting/flow/factory.py +0 -9
- waldiez/exporting/flow/file_generator.py +6 -0
- waldiez/exporting/flow/orchestrator.py +27 -21
- waldiez/exporting/flow/utils/__init__.py +0 -2
- waldiez/exporting/flow/utils/common.py +15 -96
- waldiez/exporting/flow/utils/importing.py +4 -0
- waldiez/io/mqtt.py +33 -14
- waldiez/io/redis.py +18 -13
- waldiez/io/structured.py +9 -4
- waldiez/io/utils.py +32 -0
- waldiez/io/ws.py +8 -2
- waldiez/models/__init__.py +6 -0
- waldiez/models/agents/__init__.py +8 -0
- waldiez/models/agents/agent/agent.py +136 -38
- waldiez/models/agents/agent/agent_type.py +3 -2
- waldiez/models/agents/agents.py +10 -0
- waldiez/models/agents/doc_agent/__init__.py +13 -0
- waldiez/models/agents/doc_agent/doc_agent.py +126 -0
- waldiez/models/agents/doc_agent/doc_agent_data.py +149 -0
- waldiez/models/agents/doc_agent/rag_query_engine.py +127 -0
- waldiez/models/chat/chat_message.py +1 -1
- waldiez/models/flow/flow.py +13 -2
- waldiez/models/model/__init__.py +2 -2
- waldiez/models/model/_aws.py +75 -0
- waldiez/models/model/_llm.py +516 -0
- waldiez/models/model/_price.py +30 -0
- waldiez/models/model/model.py +45 -2
- waldiez/models/model/model_data.py +2 -83
- waldiez/models/tool/predefined/_duckduckgo.py +123 -0
- waldiez/models/tool/predefined/_google.py +31 -9
- waldiez/models/tool/predefined/_perplexity.py +161 -0
- waldiez/models/tool/predefined/_searxng.py +152 -0
- waldiez/models/tool/predefined/_tavily.py +46 -9
- waldiez/models/tool/predefined/_wikipedia.py +26 -6
- waldiez/models/tool/predefined/_youtube.py +36 -8
- waldiez/models/tool/predefined/registry.py +6 -0
- waldiez/models/waldiez.py +12 -0
- waldiez/runner.py +184 -382
- waldiez/running/__init__.py +2 -4
- waldiez/running/base_runner.py +136 -118
- waldiez/running/environment.py +61 -17
- waldiez/running/post_run.py +70 -14
- waldiez/running/pre_run.py +42 -0
- waldiez/running/protocol.py +42 -48
- waldiez/running/run_results.py +5 -5
- waldiez/running/standard_runner.py +429 -0
- waldiez/running/timeline_processor.py +1166 -0
- waldiez/utils/version.py +12 -1
- {waldiez-0.5.2.dist-info → waldiez-0.5.4.dist-info}/METADATA +61 -63
- {waldiez-0.5.2.dist-info → waldiez-0.5.4.dist-info}/RECORD +77 -66
- waldiez/running/import_runner.py +0 -424
- waldiez/running/subprocess_runner.py +0 -100
- {waldiez-0.5.2.dist-info → waldiez-0.5.4.dist-info}/WHEEL +0 -0
- {waldiez-0.5.2.dist-info → waldiez-0.5.4.dist-info}/entry_points.txt +0 -0
- {waldiez-0.5.2.dist-info → waldiez-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {waldiez-0.5.2.dist-info → waldiez-0.5.4.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0.
|
|
2
2
|
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
# pylint: disable=too-many-public-methods
|
|
3
4
|
"""Base agent class to be inherited by all agents."""
|
|
4
5
|
|
|
5
6
|
import warnings
|
|
@@ -150,6 +151,28 @@ class WaldiezAgent(WaldiezBase):
|
|
|
150
151
|
)
|
|
151
152
|
_checked_handoffs: Annotated[bool, Field(init=False, default=False)] = False
|
|
152
153
|
|
|
154
|
+
@property
|
|
155
|
+
def args_to_skip(self) -> list[str]:
|
|
156
|
+
"""Get the set of arguments to skip when generating the agent string.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
list[str]
|
|
161
|
+
The list of arguments to skip.
|
|
162
|
+
"""
|
|
163
|
+
if self.is_doc_agent:
|
|
164
|
+
return [
|
|
165
|
+
"description",
|
|
166
|
+
"human_input_mode",
|
|
167
|
+
"max_consecutive_auto_reply",
|
|
168
|
+
"default_auto_reply",
|
|
169
|
+
"code_execution_config",
|
|
170
|
+
"is_termination_msg",
|
|
171
|
+
"functions",
|
|
172
|
+
"update_agent_state_before_reply",
|
|
173
|
+
]
|
|
174
|
+
return []
|
|
175
|
+
|
|
153
176
|
@property
|
|
154
177
|
def handoffs(self) -> list[WaldiezHandoff]:
|
|
155
178
|
"""Get the handoffs for this agent.
|
|
@@ -266,6 +289,17 @@ class WaldiezAgent(WaldiezBase):
|
|
|
266
289
|
"rag_user_proxy",
|
|
267
290
|
)
|
|
268
291
|
|
|
292
|
+
@property
|
|
293
|
+
def is_assistant(self) -> bool:
|
|
294
|
+
"""Check if the agent is an assistant.
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
bool
|
|
299
|
+
True if the agent is an assistant, False otherwise.
|
|
300
|
+
"""
|
|
301
|
+
return self.agent_type == "assistant"
|
|
302
|
+
|
|
269
303
|
@property
|
|
270
304
|
def is_rag_user(self) -> bool:
|
|
271
305
|
"""Check if the agent is a RAG user.
|
|
@@ -277,6 +311,17 @@ class WaldiezAgent(WaldiezBase):
|
|
|
277
311
|
"""
|
|
278
312
|
return self.agent_type in ("rag_user", "rag_user_proxy")
|
|
279
313
|
|
|
314
|
+
@property
|
|
315
|
+
def is_doc_agent(self) -> bool:
|
|
316
|
+
"""Check if the agent is a doc agent.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
bool
|
|
321
|
+
True if the agent is a doc agent, False otherwise.
|
|
322
|
+
"""
|
|
323
|
+
return self.agent_type == "doc_agent"
|
|
324
|
+
|
|
280
325
|
@property
|
|
281
326
|
def is_group_manager(self) -> bool:
|
|
282
327
|
"""Check if the agent is a group manager.
|
|
@@ -293,16 +338,9 @@ class WaldiezAgent(WaldiezBase):
|
|
|
293
338
|
"""Return the AG2 class of the agent."""
|
|
294
339
|
class_name = "ConversableAgent"
|
|
295
340
|
if self.is_group_member:
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
)
|
|
299
|
-
class_name = "MultimodalConversableAgent"
|
|
300
|
-
return class_name
|
|
301
|
-
if self.agent_type == "assistant":
|
|
302
|
-
if getattr(self.data, "is_multimodal", False) is True:
|
|
303
|
-
class_name = "MultimodalConversableAgent"
|
|
304
|
-
else:
|
|
305
|
-
class_name = "AssistantAgent"
|
|
341
|
+
return self.get_group_member_class_name()
|
|
342
|
+
if self.is_assistant: # pragma: no branch
|
|
343
|
+
return self.get_assistant_class_name()
|
|
306
344
|
if self.is_user:
|
|
307
345
|
class_name = "UserProxyAgent"
|
|
308
346
|
if self.is_rag_user:
|
|
@@ -313,40 +351,100 @@ class WaldiezAgent(WaldiezBase):
|
|
|
313
351
|
class_name = "CaptainAgent"
|
|
314
352
|
if self.is_group_manager:
|
|
315
353
|
class_name = "GroupChatManager"
|
|
354
|
+
if self.is_doc_agent:
|
|
355
|
+
class_name = "DocAgent"
|
|
356
|
+
return class_name # pragma: no cover
|
|
357
|
+
|
|
358
|
+
def get_group_member_class_name(self) -> str:
|
|
359
|
+
"""Get the class name for a group member agent.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
str
|
|
364
|
+
The class name for the group member agent.
|
|
365
|
+
"""
|
|
366
|
+
if (
|
|
367
|
+
getattr(self.data, "is_multimodal", False) is True
|
|
368
|
+
): # pragma: no branch
|
|
369
|
+
return "MultimodalConversableAgent"
|
|
370
|
+
if self.is_captain: # pragma: no branch
|
|
371
|
+
return "CaptainAgent"
|
|
372
|
+
if self.is_reasoning: # pragma: no branch
|
|
373
|
+
return "ReasoningAgent"
|
|
374
|
+
if self.is_doc_agent: # pragma: no branch
|
|
375
|
+
return "DocAgent"
|
|
376
|
+
return "ConversableAgent" # pragma: no cover
|
|
377
|
+
|
|
378
|
+
def get_assistant_class_name(self) -> str:
|
|
379
|
+
"""Get the class name for an assistant agent.
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
str
|
|
384
|
+
The class name for the assistant agent.
|
|
385
|
+
"""
|
|
386
|
+
if getattr(self.data, "is_multimodal", False) is True:
|
|
387
|
+
class_name = "MultimodalConversableAgent"
|
|
388
|
+
else:
|
|
389
|
+
class_name = "AssistantAgent"
|
|
316
390
|
return class_name # pragma: no cover
|
|
317
391
|
|
|
318
392
|
@property
|
|
319
393
|
def ag2_imports(self) -> set[str]:
|
|
320
|
-
"""Return the AG2 imports of the agent.
|
|
394
|
+
"""Return the AG2 imports of the agent.
|
|
395
|
+
|
|
396
|
+
Returns
|
|
397
|
+
-------
|
|
398
|
+
set[str]
|
|
399
|
+
A set of import statements required for the agent class.
|
|
400
|
+
|
|
401
|
+
Raises
|
|
402
|
+
------
|
|
403
|
+
ValueError
|
|
404
|
+
If the agent class is unknown and no imports are defined.
|
|
405
|
+
"""
|
|
321
406
|
agent_class = self.ag2_class
|
|
322
407
|
imports = {"import autogen"}
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
408
|
+
match agent_class:
|
|
409
|
+
case "AssistantAgent":
|
|
410
|
+
imports.add("from autogen import AssistantAgent")
|
|
411
|
+
case "UserProxyAgent":
|
|
412
|
+
imports.add("from autogen import UserProxyAgent")
|
|
413
|
+
case "RetrieveUserProxyAgent":
|
|
414
|
+
imports.add(
|
|
415
|
+
"from autogen.agentchat.contrib.retrieve_user_proxy_agent "
|
|
416
|
+
"import RetrieveUserProxyAgent"
|
|
417
|
+
)
|
|
418
|
+
case "MultimodalConversableAgent":
|
|
419
|
+
imports.add(
|
|
420
|
+
"from "
|
|
421
|
+
"autogen.agentchat.contrib.multimodal_conversable_agent "
|
|
422
|
+
"import MultimodalConversableAgent"
|
|
423
|
+
)
|
|
424
|
+
case "ReasoningAgent":
|
|
425
|
+
imports.add(
|
|
426
|
+
"from autogen.agents.experimental import ReasoningAgent"
|
|
427
|
+
)
|
|
428
|
+
case "CaptainAgent":
|
|
429
|
+
imports.add(
|
|
430
|
+
"from autogen.agentchat.contrib.captainagent "
|
|
431
|
+
"import CaptainAgent"
|
|
432
|
+
)
|
|
433
|
+
case "GroupChatManager": # pragma: no branch
|
|
434
|
+
imports.add("from autogen import GroupChat")
|
|
435
|
+
imports.add("from autogen.agentchat import GroupChatManager")
|
|
436
|
+
imports.add(
|
|
437
|
+
"from autogen.agentchat.group import ContextVariables"
|
|
438
|
+
)
|
|
439
|
+
case "DocAgent":
|
|
440
|
+
imports.add("from autogen.agents.experimental import DocAgent")
|
|
441
|
+
case "ConversableAgent":
|
|
442
|
+
imports.add("from autogen import ConversableAgent")
|
|
443
|
+
case _: # pragma: no cover
|
|
444
|
+
raise ValueError(
|
|
445
|
+
f"Unknown agent class: {agent_class}. "
|
|
446
|
+
"Please implement the imports for this class."
|
|
447
|
+
)
|
|
350
448
|
return imports
|
|
351
449
|
|
|
352
450
|
def validate_linked_tools(
|
|
@@ -6,14 +6,15 @@ from typing_extensions import Literal
|
|
|
6
6
|
|
|
7
7
|
# pylint: disable=line-too-long
|
|
8
8
|
# fmt: off
|
|
9
|
-
WaldiezAgentType = Literal["user_proxy", "assistant", "group_manager", "manager", "rag_user", "swarm", "reasoning", "captain", "user", "rag_user_proxy"] # noqa: E501
|
|
9
|
+
WaldiezAgentType = Literal["user_proxy", "assistant", "group_manager", "manager", "rag_user", "swarm", "reasoning", "captain", "user", "rag_user_proxy", "doc_agent"] # noqa: E501
|
|
10
10
|
"""Possible types of a Waldiez Agent:
|
|
11
11
|
- user_proxy,
|
|
12
12
|
- assistant,
|
|
13
13
|
- group_manager,
|
|
14
|
-
- rag_user_proxy,
|
|
14
|
+
- rag_user_proxy (deprecated: use doc_agent),
|
|
15
15
|
- reasoning,
|
|
16
16
|
- captain,
|
|
17
|
+
- doc_agent,
|
|
17
18
|
- swarm (deprecated: do not use it),
|
|
18
19
|
- user (deprecated: use user_proxy)
|
|
19
20
|
- rag_user (deprecated: user rag_user_proxy)
|
waldiez/models/agents/agents.py
CHANGED
|
@@ -11,6 +11,7 @@ from ..common import WaldiezBase
|
|
|
11
11
|
from .agent import WaldiezAgent
|
|
12
12
|
from .assistant import WaldiezAssistant
|
|
13
13
|
from .captain import WaldiezCaptainAgent
|
|
14
|
+
from .doc_agent import WaldiezDocAgent
|
|
14
15
|
from .group_manager import WaldiezGroupManager
|
|
15
16
|
from .rag_user_proxy import WaldiezRagUserProxy
|
|
16
17
|
from .reasoning import WaldiezReasoningAgent
|
|
@@ -84,6 +85,14 @@ class WaldiezAgents(WaldiezBase):
|
|
|
84
85
|
default_factory=list,
|
|
85
86
|
),
|
|
86
87
|
] = []
|
|
88
|
+
docAgents: Annotated[
|
|
89
|
+
list[WaldiezDocAgent],
|
|
90
|
+
Field(
|
|
91
|
+
title="Document Agents.",
|
|
92
|
+
description="The Document agents in the flow.",
|
|
93
|
+
default_factory=list,
|
|
94
|
+
),
|
|
95
|
+
] = []
|
|
87
96
|
|
|
88
97
|
@property
|
|
89
98
|
def members(self) -> Iterator[WaldiezAgent]:
|
|
@@ -100,6 +109,7 @@ class WaldiezAgents(WaldiezBase):
|
|
|
100
109
|
yield from self.reasoningAgents
|
|
101
110
|
yield from self.captainAgents
|
|
102
111
|
yield from self.groupManagerAgents
|
|
112
|
+
yield from self.docAgents
|
|
103
113
|
|
|
104
114
|
@model_validator(mode="after")
|
|
105
115
|
def validate_agents(self) -> Self:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
"""Document agent model."""
|
|
4
|
+
|
|
5
|
+
from .doc_agent import WaldiezDocAgent
|
|
6
|
+
from .doc_agent_data import WaldiezDocAgentData
|
|
7
|
+
from .rag_query_engine import WaldiezDocAgentQueryEngine
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"WaldiezDocAgent",
|
|
11
|
+
"WaldiezDocAgentData",
|
|
12
|
+
"WaldiezDocAgentQueryEngine",
|
|
13
|
+
]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
"""Document agent model."""
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from typing_extensions import Annotated
|
|
9
|
+
|
|
10
|
+
from waldiez.models.agents.doc_agent.rag_query_engine import (
|
|
11
|
+
WaldiezDocAgentQueryEngine,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from ...model import WaldiezModel
|
|
15
|
+
from ..agent import WaldiezAgent
|
|
16
|
+
from .doc_agent_data import WaldiezDocAgentData
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WaldiezDocAgent(WaldiezAgent):
|
|
20
|
+
"""Document agent class.
|
|
21
|
+
|
|
22
|
+
The agent for handling document-related tasks.
|
|
23
|
+
Extends `WaldiezAgent`.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
agent_type: Annotated[ # pyright: ignore
|
|
27
|
+
Literal["doc_agent"],
|
|
28
|
+
Field(
|
|
29
|
+
default="doc_agent",
|
|
30
|
+
title="Agent type",
|
|
31
|
+
description="The agent type: 'doc_agent' for a document agent",
|
|
32
|
+
alias="agentType",
|
|
33
|
+
),
|
|
34
|
+
] = "doc_agent"
|
|
35
|
+
data: Annotated[ # pyright: ignore
|
|
36
|
+
WaldiezDocAgentData,
|
|
37
|
+
Field(
|
|
38
|
+
title="Data",
|
|
39
|
+
description="The document agent's data",
|
|
40
|
+
default_factory=WaldiezDocAgentData, # pyright: ignore
|
|
41
|
+
),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def reset_collection(self) -> bool:
|
|
46
|
+
"""Get whether to reset the collection for the document agent.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
bool
|
|
51
|
+
Whether to reset the collection for the document agent.
|
|
52
|
+
"""
|
|
53
|
+
return self.data.reset_collection
|
|
54
|
+
|
|
55
|
+
def get_collection_name(self) -> str:
|
|
56
|
+
"""Get the collection name for the document agent.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
str
|
|
61
|
+
The collection name for the document agent.
|
|
62
|
+
"""
|
|
63
|
+
return self.data.get_collection_name()
|
|
64
|
+
|
|
65
|
+
def get_query_engine(self) -> WaldiezDocAgentQueryEngine:
|
|
66
|
+
"""Get the query engine for the document agent.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
WaldiezDocAgentQueryEngine
|
|
71
|
+
The query engine for the document agent.
|
|
72
|
+
"""
|
|
73
|
+
return self.data.get_query_engine()
|
|
74
|
+
|
|
75
|
+
def get_db_path(self) -> str:
|
|
76
|
+
"""Get the database path for the query engine.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
str
|
|
81
|
+
The database path for the query engine.
|
|
82
|
+
"""
|
|
83
|
+
return self.data.get_db_path()
|
|
84
|
+
|
|
85
|
+
def get_parsed_docs_path(self) -> str:
|
|
86
|
+
"""Get the parsed documents path for the document agent.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
str
|
|
91
|
+
The parsed documents path for the document agent.
|
|
92
|
+
"""
|
|
93
|
+
return self.data.get_parsed_docs_path()
|
|
94
|
+
|
|
95
|
+
def get_llm_requirements(
|
|
96
|
+
self,
|
|
97
|
+
ag2_version: str,
|
|
98
|
+
all_models: list[WaldiezModel],
|
|
99
|
+
) -> set[str]:
|
|
100
|
+
"""Get the LLM requirements for the document agent.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
ag2_version : str
|
|
105
|
+
The version of AG2 to use for the requirements.
|
|
106
|
+
all_models : list[WaldiezModel]
|
|
107
|
+
All the models in the flow.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
set[str]
|
|
112
|
+
The set of LLM requirements for the document agent.
|
|
113
|
+
"""
|
|
114
|
+
requirements = {
|
|
115
|
+
"llama-index",
|
|
116
|
+
"llama-index-core",
|
|
117
|
+
f"ag2[rag]=={ag2_version}",
|
|
118
|
+
}
|
|
119
|
+
if not self.data.model_ids:
|
|
120
|
+
requirements.add("llama-index-llms-openai")
|
|
121
|
+
else:
|
|
122
|
+
for model_id in self.data.model_ids:
|
|
123
|
+
model = next((m for m in all_models if m.id == model_id), None)
|
|
124
|
+
if model:
|
|
125
|
+
return model.get_llm_requirements(ag2_version=ag2_version)
|
|
126
|
+
return requirements
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
"""Document agent data model."""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
|
|
8
|
+
from platformdirs import user_data_dir
|
|
9
|
+
from pydantic import Field, model_validator
|
|
10
|
+
from typing_extensions import Annotated, Self
|
|
11
|
+
|
|
12
|
+
from ..agent import WaldiezAgentData
|
|
13
|
+
from .rag_query_engine import WaldiezDocAgentQueryEngine
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class WaldiezDocAgentData(WaldiezAgentData):
|
|
17
|
+
"""Document agent data class.
|
|
18
|
+
|
|
19
|
+
The data for a document agent.
|
|
20
|
+
Extends `WaldiezAgentData`.
|
|
21
|
+
Extra attributes:
|
|
22
|
+
- `collection_name`: Optional string, the name of the collection.
|
|
23
|
+
- `reset_collection`: Optional boolean, whether to reset the collection.
|
|
24
|
+
- `parsed_docs_path`: Optional string, the path to the parsed documents.
|
|
25
|
+
- `query_engine`: Optional `RAGQueryEngine`, the query engine to use.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
collection_name: Annotated[
|
|
29
|
+
Optional[str],
|
|
30
|
+
Field(
|
|
31
|
+
title="Collection Name",
|
|
32
|
+
description="The name of the collection for the document agent.",
|
|
33
|
+
default=None,
|
|
34
|
+
alias="collectionName",
|
|
35
|
+
),
|
|
36
|
+
] = None
|
|
37
|
+
reset_collection: Annotated[
|
|
38
|
+
bool,
|
|
39
|
+
Field(
|
|
40
|
+
title="Reset Collection",
|
|
41
|
+
description=(
|
|
42
|
+
"Whether to reset the collection for the document agent."
|
|
43
|
+
),
|
|
44
|
+
default=False,
|
|
45
|
+
alias="resetCollection",
|
|
46
|
+
),
|
|
47
|
+
] = False
|
|
48
|
+
parsed_docs_path: Annotated[
|
|
49
|
+
Optional[Union[str, Path]],
|
|
50
|
+
Field(
|
|
51
|
+
title="Parsed Documents Path",
|
|
52
|
+
description=(
|
|
53
|
+
"The path to the parsed documents for the document agent."
|
|
54
|
+
),
|
|
55
|
+
default=None,
|
|
56
|
+
alias="parsedDocsPath",
|
|
57
|
+
),
|
|
58
|
+
] = None
|
|
59
|
+
query_engine: Annotated[
|
|
60
|
+
Optional[WaldiezDocAgentQueryEngine],
|
|
61
|
+
Field(
|
|
62
|
+
title="Query Engine",
|
|
63
|
+
description="The query engine to use for the document agent.",
|
|
64
|
+
default=None,
|
|
65
|
+
alias="queryEngine",
|
|
66
|
+
),
|
|
67
|
+
] = None
|
|
68
|
+
|
|
69
|
+
@model_validator(mode="after")
|
|
70
|
+
def validate_parsed_docs_path(self) -> Self:
|
|
71
|
+
"""Ensure the parsed documents path is set and is a directory.
|
|
72
|
+
|
|
73
|
+
If not set, create a default path in the user data directory.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
Self
|
|
78
|
+
The instance with validated `parsed_docs_path`.
|
|
79
|
+
"""
|
|
80
|
+
if not self.parsed_docs_path:
|
|
81
|
+
data_dir = user_data_dir(
|
|
82
|
+
appname="waldiez",
|
|
83
|
+
appauthor="waldiez",
|
|
84
|
+
)
|
|
85
|
+
parsed_docs_path = Path(data_dir) / "parsed_docs"
|
|
86
|
+
parsed_docs_path.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
self.parsed_docs_path = str(parsed_docs_path.resolve())
|
|
88
|
+
resolved = Path(self.parsed_docs_path).resolve()
|
|
89
|
+
if not resolved.is_absolute():
|
|
90
|
+
self.parsed_docs_path = str(Path.cwd() / self.parsed_docs_path)
|
|
91
|
+
if not Path(self.parsed_docs_path).is_dir():
|
|
92
|
+
Path(self.parsed_docs_path).mkdir(parents=True, exist_ok=True)
|
|
93
|
+
self.parsed_docs_path = str(Path(self.parsed_docs_path).resolve())
|
|
94
|
+
return self
|
|
95
|
+
|
|
96
|
+
def get_query_engine(self) -> WaldiezDocAgentQueryEngine:
|
|
97
|
+
"""Get the query engine for the document agent.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
WaldiezDocAgentQueryEngine
|
|
102
|
+
The query engine for the document agent.
|
|
103
|
+
"""
|
|
104
|
+
if not self.query_engine:
|
|
105
|
+
self.query_engine = WaldiezDocAgentQueryEngine()
|
|
106
|
+
return self.query_engine
|
|
107
|
+
|
|
108
|
+
def get_db_path(self) -> str:
|
|
109
|
+
"""Get the database path for the query engine.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
str
|
|
114
|
+
The database path for the query engine.
|
|
115
|
+
"""
|
|
116
|
+
return self.get_query_engine().get_db_path()
|
|
117
|
+
|
|
118
|
+
def get_collection_name(self) -> str:
|
|
119
|
+
"""Get the collection name for the document agent.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
str
|
|
124
|
+
The collection name for the document agent.
|
|
125
|
+
"""
|
|
126
|
+
return self.collection_name or "docling-parsed-docs"
|
|
127
|
+
|
|
128
|
+
def get_parsed_docs_path(self) -> str:
|
|
129
|
+
"""Get the parsed documents path for the document agent.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
str
|
|
134
|
+
The parsed documents path for the document agent.
|
|
135
|
+
"""
|
|
136
|
+
if not self.parsed_docs_path:
|
|
137
|
+
data_dir = user_data_dir(
|
|
138
|
+
appname="waldiez",
|
|
139
|
+
appauthor="waldiez",
|
|
140
|
+
)
|
|
141
|
+
parsed_docs_path = Path(data_dir) / "parsed_docs"
|
|
142
|
+
parsed_docs_path.mkdir(parents=True, exist_ok=True)
|
|
143
|
+
self.parsed_docs_path = str(parsed_docs_path.resolve())
|
|
144
|
+
resolved = Path(self.parsed_docs_path).resolve()
|
|
145
|
+
if not resolved.is_absolute():
|
|
146
|
+
self.parsed_docs_path = str(Path.cwd() / self.parsed_docs_path)
|
|
147
|
+
if not Path(self.parsed_docs_path).is_dir():
|
|
148
|
+
Path(self.parsed_docs_path).mkdir(parents=True, exist_ok=True)
|
|
149
|
+
return str(self.parsed_docs_path)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
"""Document agent data model."""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
|
|
8
|
+
from platformdirs import user_data_dir
|
|
9
|
+
from pydantic import Field, model_validator
|
|
10
|
+
from typing_extensions import Annotated, Literal, Self
|
|
11
|
+
|
|
12
|
+
from ...common import WaldiezBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WaldiezDocAgentQueryEngine(WaldiezBase):
|
|
16
|
+
"""RAG Query Engine class.
|
|
17
|
+
|
|
18
|
+
The data for a RAG query engine.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
type: Annotated[
|
|
22
|
+
Optional[
|
|
23
|
+
Literal[
|
|
24
|
+
"VectorChromaQueryEngine",
|
|
25
|
+
"VectorChromaCitationQueryEngine",
|
|
26
|
+
"InMemoryQueryEngine",
|
|
27
|
+
]
|
|
28
|
+
],
|
|
29
|
+
"RAG Query Engine type",
|
|
30
|
+
] = "VectorChromaQueryEngine"
|
|
31
|
+
db_path: Annotated[
|
|
32
|
+
Optional[Union[str, Path]],
|
|
33
|
+
Field(
|
|
34
|
+
title="Database Path",
|
|
35
|
+
description="The path to the database for the query engine.",
|
|
36
|
+
default=None,
|
|
37
|
+
alias="dbPath",
|
|
38
|
+
),
|
|
39
|
+
] = None
|
|
40
|
+
enable_query_citations: Annotated[
|
|
41
|
+
bool,
|
|
42
|
+
Field(
|
|
43
|
+
title="Enable Query Citations",
|
|
44
|
+
description=(
|
|
45
|
+
"Whether to enable query citations for the query engine."
|
|
46
|
+
),
|
|
47
|
+
default=False,
|
|
48
|
+
alias="enableQueryCitations",
|
|
49
|
+
),
|
|
50
|
+
] = False
|
|
51
|
+
citation_chunk_size: Annotated[
|
|
52
|
+
int,
|
|
53
|
+
Field(
|
|
54
|
+
title="Citation Chunk Size",
|
|
55
|
+
description="The size of the citation chunks for the query engine.",
|
|
56
|
+
default=512,
|
|
57
|
+
alias="citationChunkSize",
|
|
58
|
+
),
|
|
59
|
+
] = 512
|
|
60
|
+
|
|
61
|
+
@model_validator(mode="after")
|
|
62
|
+
def validate_db_path(self) -> Self:
|
|
63
|
+
"""Validate the db_path field.
|
|
64
|
+
|
|
65
|
+
Ensure the db_path is set and is a directory.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
Self
|
|
70
|
+
The instance of WaldiezDocAgentQueryEngine with validated db_path.
|
|
71
|
+
"""
|
|
72
|
+
if not self.type:
|
|
73
|
+
self.type = "VectorChromaQueryEngine"
|
|
74
|
+
self.db_path = ensure_db_path(self.db_path)
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
def get_db_path(self) -> str:
|
|
78
|
+
"""Get the database path for the query engine.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
str
|
|
83
|
+
The database path for the query engine.
|
|
84
|
+
"""
|
|
85
|
+
db_path = self.db_path or ensure_db_path(self.db_path)
|
|
86
|
+
return str(db_path)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def ensure_db_path(db_path: str | Path | None) -> str:
|
|
90
|
+
"""Get the database path for the query engine.
|
|
91
|
+
|
|
92
|
+
Ensure the database path is set and is a directory.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
db_path : Optional[str]
|
|
97
|
+
The database path to validate.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
str
|
|
102
|
+
The database path for the query engine.
|
|
103
|
+
|
|
104
|
+
Raises
|
|
105
|
+
------
|
|
106
|
+
ValueError
|
|
107
|
+
If the database path is not set or is not a directory.
|
|
108
|
+
"""
|
|
109
|
+
if not db_path:
|
|
110
|
+
data_dir = user_data_dir(
|
|
111
|
+
appname="waldiez",
|
|
112
|
+
appauthor="waldiez",
|
|
113
|
+
)
|
|
114
|
+
data_dir_path = Path(data_dir) / "rag"
|
|
115
|
+
data_dir_path.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
db_path_to_resolve = data_dir_path / "chroma"
|
|
117
|
+
db_path_to_resolve.mkdir(parents=True, exist_ok=True)
|
|
118
|
+
return str(db_path_to_resolve)
|
|
119
|
+
|
|
120
|
+
resolved = Path(db_path).resolve()
|
|
121
|
+
if not resolved.is_absolute():
|
|
122
|
+
resolved = (Path.cwd() / db_path).resolve()
|
|
123
|
+
if not resolved.exists():
|
|
124
|
+
resolved.mkdir(parents=True, exist_ok=True)
|
|
125
|
+
if not resolved.is_dir():
|
|
126
|
+
raise ValueError(f"The path {resolved} is not a directory.")
|
|
127
|
+
return str(resolved)
|