waldiez 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waldiez might be problematic. Click here for more details.
- waldiez/__init__.py +2 -0
- waldiez/__main__.py +2 -0
- waldiez/_version.py +3 -1
- waldiez/cli.py +13 -3
- waldiez/cli_extras.py +4 -3
- waldiez/conflict_checker.py +4 -3
- waldiez/exporter.py +28 -105
- waldiez/exporting/__init__.py +8 -9
- waldiez/exporting/agent/__init__.py +7 -0
- waldiez/exporting/agent/agent_exporter.py +279 -0
- waldiez/exporting/agent/utils/__init__.py +23 -0
- waldiez/exporting/agent/utils/agent_class_name.py +34 -0
- waldiez/exporting/agent/utils/agent_imports.py +50 -0
- waldiez/exporting/{agents → agent/utils}/code_execution.py +9 -11
- waldiez/exporting/{agents → agent/utils}/group_manager.py +47 -35
- waldiez/exporting/{agents → agent/utils}/rag_user/__init__.py +2 -0
- waldiez/exporting/{agents → agent/utils}/rag_user/chroma_utils.py +22 -17
- waldiez/exporting/{agents → agent/utils}/rag_user/mongo_utils.py +14 -10
- waldiez/exporting/{agents → agent/utils}/rag_user/pgvector_utils.py +12 -8
- waldiez/exporting/{agents → agent/utils}/rag_user/qdrant_utils.py +11 -8
- waldiez/exporting/{agents → agent/utils}/rag_user/rag_user.py +78 -55
- waldiez/exporting/{agents → agent/utils}/rag_user/vector_db.py +10 -8
- waldiez/exporting/agent/utils/swarm_agent.py +463 -0
- waldiez/exporting/{agents → agent/utils}/teachability.py +10 -6
- waldiez/exporting/{agents → agent/utils}/termination_message.py +7 -8
- waldiez/exporting/base/__init__.py +25 -0
- waldiez/exporting/base/agent_position.py +75 -0
- waldiez/exporting/base/base_exporter.py +118 -0
- waldiez/exporting/base/export_position.py +48 -0
- waldiez/exporting/base/import_position.py +23 -0
- waldiez/exporting/base/mixin.py +134 -0
- waldiez/exporting/base/utils/__init__.py +18 -0
- waldiez/exporting/{utils → base/utils}/comments.py +12 -55
- waldiez/exporting/{utils → base/utils}/naming.py +14 -4
- waldiez/exporting/base/utils/path_check.py +68 -0
- waldiez/exporting/{utils/object_string.py → base/utils/to_string.py} +21 -20
- waldiez/exporting/chats/__init__.py +5 -12
- waldiez/exporting/chats/chats_exporter.py +240 -0
- waldiez/exporting/chats/utils/__init__.py +15 -0
- waldiez/exporting/chats/utils/common.py +81 -0
- waldiez/exporting/chats/{nested.py → utils/nested.py} +125 -86
- waldiez/exporting/chats/utils/sequential.py +244 -0
- waldiez/exporting/chats/utils/single_chat.py +313 -0
- waldiez/exporting/chats/utils/swarm.py +207 -0
- waldiez/exporting/flow/__init__.py +5 -3
- waldiez/exporting/flow/flow_exporter.py +503 -0
- waldiez/exporting/flow/utils/__init__.py +47 -0
- waldiez/exporting/flow/utils/agent_utils.py +204 -0
- waldiez/exporting/flow/utils/chat_utils.py +71 -0
- waldiez/exporting/flow/utils/def_main.py +62 -0
- waldiez/exporting/flow/utils/flow_content.py +112 -0
- waldiez/exporting/flow/utils/flow_names.py +115 -0
- waldiez/exporting/flow/utils/importing_utils.py +182 -0
- waldiez/exporting/{utils → flow/utils}/logging_utils.py +34 -31
- waldiez/exporting/models/__init__.py +7 -242
- waldiez/exporting/models/models_exporter.py +192 -0
- waldiez/exporting/models/utils.py +166 -0
- waldiez/exporting/skills/__init__.py +7 -161
- waldiez/exporting/skills/skills_exporter.py +169 -0
- waldiez/exporting/skills/utils.py +281 -0
- waldiez/models/__init__.py +25 -7
- waldiez/models/agents/__init__.py +70 -0
- waldiez/models/agents/agent/__init__.py +11 -1
- waldiez/models/agents/agent/agent.py +9 -4
- waldiez/models/agents/agent/agent_data.py +3 -1
- waldiez/models/agents/agent/code_execution.py +2 -0
- waldiez/models/agents/agent/linked_skill.py +2 -0
- waldiez/models/agents/agent/nested_chat.py +2 -0
- waldiez/models/agents/agent/teachability.py +2 -0
- waldiez/models/agents/agent/termination_message.py +49 -13
- waldiez/models/agents/agents.py +15 -3
- waldiez/models/agents/assistant/__init__.py +2 -0
- waldiez/models/agents/assistant/assistant.py +2 -0
- waldiez/models/agents/assistant/assistant_data.py +2 -0
- waldiez/models/agents/group_manager/__init__.py +9 -1
- waldiez/models/agents/group_manager/group_manager.py +2 -0
- waldiez/models/agents/group_manager/group_manager_data.py +2 -0
- waldiez/models/agents/group_manager/speakers.py +49 -13
- waldiez/models/agents/rag_user/__init__.py +21 -4
- waldiez/models/agents/rag_user/rag_user.py +3 -1
- waldiez/models/agents/rag_user/rag_user_data.py +2 -0
- waldiez/models/agents/rag_user/retrieve_config.py +268 -17
- waldiez/models/agents/rag_user/vector_db_config.py +5 -3
- waldiez/models/agents/swarm_agent/__init__.py +49 -0
- waldiez/models/agents/swarm_agent/after_work.py +178 -0
- waldiez/models/agents/swarm_agent/on_condition.py +103 -0
- waldiez/models/agents/swarm_agent/on_condition_available.py +140 -0
- waldiez/models/agents/swarm_agent/on_condition_target.py +40 -0
- waldiez/models/agents/swarm_agent/swarm_agent.py +107 -0
- waldiez/models/agents/swarm_agent/swarm_agent_data.py +125 -0
- waldiez/models/agents/swarm_agent/update_system_message.py +144 -0
- waldiez/models/agents/user_proxy/__init__.py +2 -0
- waldiez/models/agents/user_proxy/user_proxy.py +2 -0
- waldiez/models/agents/user_proxy/user_proxy_data.py +2 -0
- waldiez/models/chat/__init__.py +21 -3
- waldiez/models/chat/chat.py +241 -7
- waldiez/models/chat/chat_data.py +192 -48
- waldiez/models/chat/chat_message.py +153 -144
- waldiez/models/chat/chat_nested.py +33 -53
- waldiez/models/chat/chat_summary.py +2 -0
- waldiez/models/common/__init__.py +6 -6
- waldiez/models/common/base.py +4 -1
- waldiez/models/common/method_utils.py +163 -83
- waldiez/models/flow/__init__.py +2 -0
- waldiez/models/flow/flow.py +176 -40
- waldiez/models/flow/flow_data.py +63 -2
- waldiez/models/flow/utils.py +172 -0
- waldiez/models/model/__init__.py +2 -0
- waldiez/models/model/model.py +30 -9
- waldiez/models/model/model_data.py +3 -1
- waldiez/models/skill/__init__.py +4 -1
- waldiez/models/skill/skill.py +30 -2
- waldiez/models/skill/skill_data.py +2 -0
- waldiez/models/waldiez.py +28 -4
- waldiez/runner.py +142 -228
- waldiez/running/__init__.py +33 -0
- waldiez/running/environment.py +83 -0
- waldiez/running/gen_seq_diagram.py +185 -0
- waldiez/running/running.py +300 -0
- {waldiez-0.2.2.dist-info → waldiez-0.3.1.dist-info}/METADATA +35 -28
- waldiez-0.3.1.dist-info/RECORD +125 -0
- waldiez-0.3.1.dist-info/licenses/LICENSE +201 -0
- waldiez/exporting/agents/__init__.py +0 -5
- waldiez/exporting/agents/agent.py +0 -236
- waldiez/exporting/agents/agent_skills.py +0 -67
- waldiez/exporting/agents/llm_config.py +0 -53
- waldiez/exporting/chats/chats.py +0 -46
- waldiez/exporting/chats/helpers.py +0 -420
- waldiez/exporting/flow/def_main.py +0 -32
- waldiez/exporting/flow/flow.py +0 -189
- waldiez/exporting/utils/__init__.py +0 -36
- waldiez/exporting/utils/importing.py +0 -265
- waldiez/exporting/utils/method_utils.py +0 -35
- waldiez/exporting/utils/path_check.py +0 -51
- waldiez-0.2.2.dist-info/RECORD +0 -92
- waldiez-0.2.2.dist-info/licenses/LICENSE +0 -21
- {waldiez-0.2.2.dist-info → waldiez-0.3.1.dist-info}/WHEEL +0 -0
- {waldiez-0.2.2.dist-info → waldiez-0.3.1.dist-info}/entry_points.txt +0 -0
waldiez/models/chat/chat_data.py
CHANGED
|
@@ -1,13 +1,22 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
1
3
|
"""Chat data model."""
|
|
2
4
|
|
|
3
|
-
from typing import Any, Dict, Optional, Union
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
6
|
|
|
5
|
-
from pydantic import
|
|
6
|
-
from pydantic.alias_generators import to_camel
|
|
7
|
+
from pydantic import Field, field_validator, model_validator
|
|
7
8
|
from typing_extensions import Annotated, Self
|
|
8
9
|
|
|
9
|
-
from ..
|
|
10
|
-
|
|
10
|
+
from ..agents.swarm_agent import (
|
|
11
|
+
WaldiezSwarmAfterWork,
|
|
12
|
+
WaldiezSwarmOnConditionAvailable,
|
|
13
|
+
)
|
|
14
|
+
from ..common import WaldiezBase, check_function
|
|
15
|
+
from .chat_message import (
|
|
16
|
+
CALLABLE_MESSAGE,
|
|
17
|
+
CALLABLE_MESSAGE_ARGS,
|
|
18
|
+
WaldiezChatMessage,
|
|
19
|
+
)
|
|
11
20
|
from .chat_nested import WaldiezChatNested
|
|
12
21
|
from .chat_summary import WaldiezChatSummary
|
|
13
22
|
|
|
@@ -47,6 +56,10 @@ class WaldiezChatData(WaldiezBase):
|
|
|
47
56
|
The real source of the chat (overrides the source).
|
|
48
57
|
real_target : Optional[str]
|
|
49
58
|
The real target of the chat (overrides the target).
|
|
59
|
+
max_rounds : int
|
|
60
|
+
Maximum number of conversation rounds (swarm).
|
|
61
|
+
after_work : Optional[WaldiezSwarmAfterWork]
|
|
62
|
+
The work to do after the chat (swarm).
|
|
50
63
|
|
|
51
64
|
Functions
|
|
52
65
|
---------
|
|
@@ -60,13 +73,6 @@ class WaldiezChatData(WaldiezBase):
|
|
|
60
73
|
Get the chat arguments to use in autogen.
|
|
61
74
|
"""
|
|
62
75
|
|
|
63
|
-
model_config = ConfigDict(
|
|
64
|
-
extra="forbid",
|
|
65
|
-
alias_generator=to_camel,
|
|
66
|
-
populate_by_name=True,
|
|
67
|
-
frozen=False,
|
|
68
|
-
)
|
|
69
|
-
|
|
70
76
|
name: Annotated[
|
|
71
77
|
str, Field(..., title="Name", description="The name of the chat.")
|
|
72
78
|
]
|
|
@@ -153,6 +159,14 @@ class WaldiezChatData(WaldiezBase):
|
|
|
153
159
|
description="The maximum number of turns for the chat.",
|
|
154
160
|
),
|
|
155
161
|
]
|
|
162
|
+
prerequisites: Annotated[
|
|
163
|
+
List[str],
|
|
164
|
+
Field(
|
|
165
|
+
title="Prerequisites",
|
|
166
|
+
description="The prerequisites (chat ids) for the chat (if async).",
|
|
167
|
+
default_factory=list,
|
|
168
|
+
),
|
|
169
|
+
]
|
|
156
170
|
silent: Annotated[
|
|
157
171
|
Optional[bool],
|
|
158
172
|
Field(
|
|
@@ -179,14 +193,90 @@ class WaldiezChatData(WaldiezBase):
|
|
|
179
193
|
description="The real target of the chat (overrides the target).",
|
|
180
194
|
),
|
|
181
195
|
]
|
|
196
|
+
max_rounds: Annotated[
|
|
197
|
+
int,
|
|
198
|
+
Field(
|
|
199
|
+
20,
|
|
200
|
+
title="Max Rounds",
|
|
201
|
+
description="Maximum number of conversation rounds.(swarm)",
|
|
202
|
+
),
|
|
203
|
+
] = 20
|
|
204
|
+
after_work: Annotated[
|
|
205
|
+
Optional[WaldiezSwarmAfterWork],
|
|
206
|
+
Field(
|
|
207
|
+
None,
|
|
208
|
+
alias="afterWork",
|
|
209
|
+
title="After Work",
|
|
210
|
+
description="The work to do after the chat (swarm).",
|
|
211
|
+
),
|
|
212
|
+
] = None
|
|
213
|
+
context_variables: Annotated[
|
|
214
|
+
Optional[Dict[str, Any]],
|
|
215
|
+
Field(
|
|
216
|
+
None,
|
|
217
|
+
alias="contextVariables",
|
|
218
|
+
title="Context Variables",
|
|
219
|
+
description="The context variables to use in the chat.",
|
|
220
|
+
),
|
|
221
|
+
] = None
|
|
222
|
+
available: Annotated[
|
|
223
|
+
WaldiezSwarmOnConditionAvailable,
|
|
224
|
+
Field(
|
|
225
|
+
default_factory=WaldiezSwarmOnConditionAvailable,
|
|
226
|
+
title="Available",
|
|
227
|
+
description="The available condition for the chat.",
|
|
228
|
+
),
|
|
229
|
+
]
|
|
182
230
|
|
|
183
231
|
_message_content: Optional[str] = None
|
|
232
|
+
_chat_id: int = 0
|
|
233
|
+
_prerequisites: List[int] = []
|
|
184
234
|
|
|
185
235
|
@property
|
|
186
236
|
def message_content(self) -> Optional[str]:
|
|
187
237
|
"""Get the message content."""
|
|
188
238
|
return self._message_content
|
|
189
239
|
|
|
240
|
+
def get_chat_id(self) -> int:
|
|
241
|
+
"""Get the chat id.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
int
|
|
246
|
+
The chat id.
|
|
247
|
+
"""
|
|
248
|
+
return self._chat_id
|
|
249
|
+
|
|
250
|
+
def set_chat_id(self, value: int) -> None:
|
|
251
|
+
"""Set the chat id.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
value : int
|
|
256
|
+
The chat id.
|
|
257
|
+
"""
|
|
258
|
+
self._chat_id = value
|
|
259
|
+
|
|
260
|
+
def get_prerequisites(self) -> List[int]:
|
|
261
|
+
"""Get the chat prerequisites.
|
|
262
|
+
|
|
263
|
+
Returns
|
|
264
|
+
-------
|
|
265
|
+
List[int]
|
|
266
|
+
The chat prerequisites (if async).
|
|
267
|
+
"""
|
|
268
|
+
return self._prerequisites
|
|
269
|
+
|
|
270
|
+
def set_prerequisites(self, value: List[int]) -> None:
|
|
271
|
+
"""Set the chat prerequisites.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
value : List[int]
|
|
276
|
+
The chat prerequisites to set.
|
|
277
|
+
"""
|
|
278
|
+
self._prerequisites = value
|
|
279
|
+
|
|
190
280
|
@model_validator(mode="after")
|
|
191
281
|
def validate_chat_data(self) -> Self:
|
|
192
282
|
"""Validate the chat data.
|
|
@@ -206,16 +296,17 @@ class WaldiezChatData(WaldiezBase):
|
|
|
206
296
|
self._message_content = None
|
|
207
297
|
elif self.message.type == "string":
|
|
208
298
|
self._message_content = self.message.content
|
|
299
|
+
elif self.message.type == "method":
|
|
300
|
+
valid, error_or_body = check_function(
|
|
301
|
+
self.message.content or "",
|
|
302
|
+
CALLABLE_MESSAGE,
|
|
303
|
+
CALLABLE_MESSAGE_ARGS,
|
|
304
|
+
)
|
|
305
|
+
if not valid:
|
|
306
|
+
raise ValueError(error_or_body)
|
|
307
|
+
self._message_content = error_or_body
|
|
209
308
|
else:
|
|
210
|
-
self._message_content =
|
|
211
|
-
value={
|
|
212
|
-
"type": self.message.type,
|
|
213
|
-
"content": self.message.content,
|
|
214
|
-
"use_carryover": self.message.use_carryover,
|
|
215
|
-
},
|
|
216
|
-
function_name="callable_message",
|
|
217
|
-
skip_definition=True,
|
|
218
|
-
).content
|
|
309
|
+
self._message_content = self.message.content
|
|
219
310
|
return self
|
|
220
311
|
|
|
221
312
|
@field_validator("message", mode="before")
|
|
@@ -242,28 +333,47 @@ class WaldiezChatData(WaldiezBase):
|
|
|
242
333
|
return WaldiezChatMessage(
|
|
243
334
|
type="none", use_carryover=False, content=None, context={}
|
|
244
335
|
)
|
|
245
|
-
if isinstance(value, str):
|
|
336
|
+
if isinstance(value, (str, int, float, bool)):
|
|
246
337
|
return WaldiezChatMessage(
|
|
247
|
-
type="string",
|
|
338
|
+
type="string",
|
|
339
|
+
use_carryover=False,
|
|
340
|
+
content=str(value),
|
|
341
|
+
context={},
|
|
248
342
|
)
|
|
249
343
|
if isinstance(value, dict):
|
|
250
|
-
return
|
|
251
|
-
value, function_name="callable_message"
|
|
252
|
-
)
|
|
344
|
+
return WaldiezChatMessage.model_validate(value)
|
|
253
345
|
if isinstance(value, WaldiezChatMessage):
|
|
254
|
-
return
|
|
255
|
-
value={
|
|
256
|
-
"type": value.type,
|
|
257
|
-
"use_carryover": value.use_carryover,
|
|
258
|
-
"content": value.content,
|
|
259
|
-
"context": value.context,
|
|
260
|
-
},
|
|
261
|
-
function_name="callable_message",
|
|
262
|
-
)
|
|
346
|
+
return value
|
|
263
347
|
return WaldiezChatMessage(
|
|
264
348
|
type="none", use_carryover=False, content=None, context={}
|
|
265
349
|
)
|
|
266
350
|
|
|
351
|
+
@field_validator("context_variables", mode="after")
|
|
352
|
+
@classmethod
|
|
353
|
+
def validate_context_variables(cls, value: Any) -> Optional[Dict[str, Any]]:
|
|
354
|
+
"""Validate the context variables.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
value : Any
|
|
359
|
+
The context variables value.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
Optional[Dict[str, Any]]
|
|
364
|
+
The validated context variables value.
|
|
365
|
+
|
|
366
|
+
Raises
|
|
367
|
+
------
|
|
368
|
+
ValueError
|
|
369
|
+
If the validation fails.
|
|
370
|
+
"""
|
|
371
|
+
if value is None:
|
|
372
|
+
return None
|
|
373
|
+
if not isinstance(value, dict):
|
|
374
|
+
raise ValueError("Context variables must be a dictionary.")
|
|
375
|
+
return get_context_dict(value)
|
|
376
|
+
|
|
267
377
|
@property
|
|
268
378
|
def summary_args(self) -> Optional[Dict[str, Any]]:
|
|
269
379
|
"""Get the summary args."""
|
|
@@ -289,23 +399,19 @@ class WaldiezChatData(WaldiezBase):
|
|
|
289
399
|
"""
|
|
290
400
|
extra_args: Dict[str, Any] = {}
|
|
291
401
|
if isinstance(self.message, WaldiezChatMessage):
|
|
292
|
-
|
|
293
|
-
if str(value).lower() in ("none", "null"):
|
|
294
|
-
extra_args[key] = None
|
|
295
|
-
elif str(value).isdigit():
|
|
296
|
-
extra_args[key] = int(value)
|
|
297
|
-
elif str(value).replace(".", "").isdigit():
|
|
298
|
-
try:
|
|
299
|
-
extra_args[key] = float(value)
|
|
300
|
-
except ValueError: # pragma: no cover
|
|
301
|
-
extra_args[key] = value
|
|
302
|
-
else:
|
|
303
|
-
extra_args[key] = value
|
|
402
|
+
extra_args.update(get_context_dict(self.message.context))
|
|
304
403
|
return extra_args
|
|
305
404
|
|
|
306
|
-
def get_chat_args(self) -> Dict[str, Any]:
|
|
405
|
+
def get_chat_args(self, for_queue: bool) -> Dict[str, Any]:
|
|
307
406
|
"""Get the chat arguments to use in autogen.
|
|
308
407
|
|
|
408
|
+
Without the 'message' key.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
for_queue : bool
|
|
413
|
+
Whether to get the arguments for a chat queue.
|
|
414
|
+
|
|
309
415
|
Returns
|
|
310
416
|
-------
|
|
311
417
|
Dict[str, Any]
|
|
@@ -323,4 +429,42 @@ class WaldiezChatData(WaldiezBase):
|
|
|
323
429
|
if isinstance(self.silent, bool):
|
|
324
430
|
args["silent"] = self.silent
|
|
325
431
|
args.update(self._get_context_args())
|
|
432
|
+
if for_queue:
|
|
433
|
+
args["chat_id"] = self._chat_id
|
|
434
|
+
if self._prerequisites:
|
|
435
|
+
args["prerequisites"] = self._prerequisites
|
|
326
436
|
return args
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def get_context_dict(context: Dict[str, Any]) -> Dict[str, Any]:
|
|
440
|
+
"""Get the context dictionary.
|
|
441
|
+
|
|
442
|
+
Try to determine the type of the context variables.
|
|
443
|
+
|
|
444
|
+
Parameters
|
|
445
|
+
----------
|
|
446
|
+
context : Dict[str, Any]
|
|
447
|
+
The context variables.
|
|
448
|
+
|
|
449
|
+
Returns
|
|
450
|
+
-------
|
|
451
|
+
Dict[str, Any]
|
|
452
|
+
The context variables with the detected types.
|
|
453
|
+
"""
|
|
454
|
+
new_dict: Dict[str, Any] = {}
|
|
455
|
+
for key, value in context.items():
|
|
456
|
+
value_lower = str(value).lower()
|
|
457
|
+
if value_lower in ("none", "null"):
|
|
458
|
+
new_dict[key] = None
|
|
459
|
+
elif value_lower in ("true", "false"):
|
|
460
|
+
new_dict[key] = value.lower() == "true"
|
|
461
|
+
elif str(value).isdigit():
|
|
462
|
+
new_dict[key] = int(value)
|
|
463
|
+
elif str(value).replace(".", "").isdigit():
|
|
464
|
+
try:
|
|
465
|
+
new_dict[key] = float(value)
|
|
466
|
+
except ValueError: # pragma: no cover
|
|
467
|
+
new_dict[key] = value
|
|
468
|
+
else:
|
|
469
|
+
new_dict[key] = value
|
|
470
|
+
return new_dict
|
|
@@ -1,16 +1,29 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
1
3
|
"""Waldiez Message Model."""
|
|
2
4
|
|
|
3
|
-
from typing import Any, Dict,
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
4
6
|
|
|
5
|
-
from pydantic import Field
|
|
6
|
-
from typing_extensions import Annotated, Literal
|
|
7
|
+
from pydantic import Field, model_validator
|
|
8
|
+
from typing_extensions import Annotated, Literal, Self
|
|
7
9
|
|
|
8
|
-
from ..common import WaldiezBase,
|
|
10
|
+
from ..common import WaldiezBase, check_function
|
|
9
11
|
|
|
10
12
|
WaldiezChatMessageType = Literal[
|
|
11
13
|
"string", "method", "rag_message_generator", "none"
|
|
12
14
|
]
|
|
13
15
|
|
|
16
|
+
CALLABLE_MESSAGE = "callable_message"
|
|
17
|
+
CALLABLE_MESSAGE_ARGS = ["sender", "recipient", "context"]
|
|
18
|
+
CALLABLE_MESSAGE_TYPES = (
|
|
19
|
+
["ConversableAgent", "ConversableAgent", "Dict[str, Any]"],
|
|
20
|
+
"Union[Dict[str, Any], str]",
|
|
21
|
+
)
|
|
22
|
+
CALLABLE_MESSAGE_RAG_WITH_CARRYOVER_TYPES = (
|
|
23
|
+
["RetrieveUserProxyAgent", "ConversableAgent", "Dict[str, Any]"],
|
|
24
|
+
"Union[Dict[str, Any], str]",
|
|
25
|
+
)
|
|
26
|
+
|
|
14
27
|
|
|
15
28
|
class WaldiezChatMessage(WaldiezBase):
|
|
16
29
|
"""
|
|
@@ -84,148 +97,130 @@ class WaldiezChatMessage(WaldiezBase):
|
|
|
84
97
|
),
|
|
85
98
|
]
|
|
86
99
|
|
|
100
|
+
_content_body: Optional[str] = None
|
|
87
101
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
],
|
|
93
|
-
function_name: WaldiezMethodName,
|
|
94
|
-
skip_definition: bool = False,
|
|
95
|
-
) -> WaldiezChatMessage:
|
|
96
|
-
"""Validate a message dict.
|
|
102
|
+
@property
|
|
103
|
+
def content_body(self) -> Optional[str]:
|
|
104
|
+
"""Get the content body."""
|
|
105
|
+
return self._content_body
|
|
97
106
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
107
|
+
@model_validator(mode="after")
|
|
108
|
+
def validate_context_vars(self) -> Self:
|
|
109
|
+
"""Try to detect bools nulls and numbers from the context values.
|
|
101
110
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
WaldiezChatMessage
|
|
114
|
+
The validated instance.
|
|
115
|
+
"""
|
|
116
|
+
for key, value in self.context.items():
|
|
117
|
+
if isinstance(value, str):
|
|
118
|
+
if value.lower() == "true":
|
|
119
|
+
self.context[key] = True
|
|
120
|
+
elif value.lower() == "false":
|
|
121
|
+
self.context[key] = False
|
|
122
|
+
elif value.lower() in ["null", "none"]:
|
|
123
|
+
self.context[key] = None
|
|
124
|
+
else:
|
|
125
|
+
self.context[key] = self._number_or_string(value)
|
|
126
|
+
return self
|
|
110
127
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _number_or_string(value: Any) -> Any:
|
|
130
|
+
try:
|
|
131
|
+
int_value = int(value)
|
|
132
|
+
if str(int_value) == value:
|
|
133
|
+
return int_value
|
|
134
|
+
except ValueError:
|
|
135
|
+
try:
|
|
136
|
+
float_value = float(value)
|
|
137
|
+
if str(float_value) == value:
|
|
138
|
+
return float_value
|
|
139
|
+
except ValueError:
|
|
140
|
+
pass
|
|
141
|
+
return value
|
|
115
142
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
content=content
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
143
|
+
@model_validator(mode="after")
|
|
144
|
+
def validate_content(self) -> Self:
|
|
145
|
+
"""Validate the content (if not a method).
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
WaldiezChatMessage
|
|
150
|
+
The validated instance.
|
|
151
|
+
|
|
152
|
+
Raises
|
|
153
|
+
------
|
|
154
|
+
ValueError
|
|
155
|
+
If the content is invalid.
|
|
156
|
+
"""
|
|
157
|
+
content: Optional[str] = None
|
|
158
|
+
if self.type == "none":
|
|
159
|
+
content = "None"
|
|
160
|
+
if self.type == "method":
|
|
161
|
+
if not self.content:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"The message content is required for the method type"
|
|
164
|
+
)
|
|
165
|
+
content = self.content
|
|
166
|
+
if self.type == "string":
|
|
167
|
+
if not self.content:
|
|
168
|
+
self.content = ""
|
|
169
|
+
if self.use_carryover:
|
|
170
|
+
content = get_last_carryover_method_content(
|
|
171
|
+
text_content=self.content,
|
|
172
|
+
)
|
|
173
|
+
content = self.content
|
|
174
|
+
if self.type == "rag_message_generator":
|
|
175
|
+
if self.use_carryover:
|
|
176
|
+
content = get_last_carryover_method_content(
|
|
177
|
+
text_content=self.content or "",
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
content = RAG_METHOD_WITH_CARRYOVER_BODY
|
|
181
|
+
self.content = RAG_METHOD_WITH_CARRYOVER
|
|
182
|
+
self._content_body = content
|
|
183
|
+
return self
|
|
184
|
+
|
|
185
|
+
def validate_method(
|
|
186
|
+
self,
|
|
187
|
+
function_name: str,
|
|
188
|
+
function_args: List[str],
|
|
189
|
+
) -> str:
|
|
190
|
+
"""Validate a method.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
function_name : str
|
|
195
|
+
The method name.
|
|
196
|
+
function_args : List[str]
|
|
197
|
+
The expected method arguments.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
str
|
|
202
|
+
The validated method body.
|
|
203
|
+
|
|
204
|
+
Raises
|
|
205
|
+
------
|
|
206
|
+
ValueError
|
|
207
|
+
If the validation fails.
|
|
208
|
+
"""
|
|
209
|
+
if not self.content:
|
|
150
210
|
raise ValueError(
|
|
151
211
|
"The message content is required for the method type"
|
|
152
212
|
)
|
|
153
|
-
|
|
154
|
-
content,
|
|
213
|
+
is_valid, error_or_body = check_function(
|
|
214
|
+
code_string=self.content,
|
|
215
|
+
function_name=function_name,
|
|
216
|
+
function_args=function_args,
|
|
155
217
|
)
|
|
156
|
-
if not
|
|
157
|
-
raise ValueError(
|
|
158
|
-
|
|
159
|
-
return WaldiezChatMessage(
|
|
160
|
-
type="method",
|
|
161
|
-
use_carryover=use_carryover,
|
|
162
|
-
content=message_content,
|
|
163
|
-
context=context,
|
|
164
|
-
)
|
|
165
|
-
if message_type == "rag_message_generator":
|
|
166
|
-
if use_carryover:
|
|
167
|
-
return WaldiezChatMessage(
|
|
168
|
-
type="method",
|
|
169
|
-
use_carryover=True,
|
|
170
|
-
content=RAG_METHOD_WITH_CARRYOVER,
|
|
171
|
-
context=context,
|
|
172
|
-
)
|
|
173
|
-
return WaldiezChatMessage(
|
|
174
|
-
type="rag_message_generator",
|
|
175
|
-
use_carryover=use_carryover,
|
|
176
|
-
content=None,
|
|
177
|
-
context=context,
|
|
178
|
-
)
|
|
179
|
-
raise ValueError("Invalid message type") # pragma: no cover
|
|
180
|
-
|
|
218
|
+
if not is_valid:
|
|
219
|
+
raise ValueError(error_or_body)
|
|
220
|
+
return error_or_body
|
|
181
221
|
|
|
182
|
-
def _get_message_args_from_dict(
|
|
183
|
-
value: Dict[
|
|
184
|
-
Literal["type", "use_carryover", "content", "context"],
|
|
185
|
-
Union[Optional[str], Optional[bool], Optional[Dict[str, Any]]],
|
|
186
|
-
],
|
|
187
|
-
) -> Tuple[str, bool, Optional[str], Dict[str, Any]]:
|
|
188
|
-
"""Get the message args from a dict.
|
|
189
222
|
|
|
190
|
-
|
|
191
|
-
----------
|
|
192
|
-
value : dict
|
|
193
|
-
The message dict.
|
|
194
|
-
|
|
195
|
-
Returns
|
|
196
|
-
-------
|
|
197
|
-
tuple
|
|
198
|
-
The message type, content, and context.
|
|
199
|
-
|
|
200
|
-
Raises
|
|
201
|
-
------
|
|
202
|
-
ValueError
|
|
203
|
-
If the message type is invalid.
|
|
204
|
-
"""
|
|
205
|
-
message_type = value.get("type")
|
|
206
|
-
if not isinstance(message_type, str) or message_type not in (
|
|
207
|
-
"string",
|
|
208
|
-
"method",
|
|
209
|
-
"rag_message_generator",
|
|
210
|
-
"none",
|
|
211
|
-
):
|
|
212
|
-
raise ValueError("Invalid message type")
|
|
213
|
-
use_carryover = value.get("use_carryover", False)
|
|
214
|
-
if not isinstance(use_carryover, bool):
|
|
215
|
-
use_carryover = False
|
|
216
|
-
content = value.get("content", "")
|
|
217
|
-
if not isinstance(content, str):
|
|
218
|
-
content = ""
|
|
219
|
-
context: Dict[str, Any] = {}
|
|
220
|
-
context_value = value.get("context")
|
|
221
|
-
if isinstance(context_value, dict):
|
|
222
|
-
context = context_value
|
|
223
|
-
if not isinstance(context, dict): # pragma: no cover
|
|
224
|
-
context = {}
|
|
225
|
-
return message_type, use_carryover, content, context
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
def _get_last_carryover_method_content(text_content: str) -> str:
|
|
223
|
+
def get_last_carryover_method_content(text_content: str) -> str:
|
|
229
224
|
"""Get the last carryover method content.
|
|
230
225
|
|
|
231
226
|
Parameters
|
|
@@ -238,8 +233,6 @@ def _get_last_carryover_method_content(text_content: str) -> str:
|
|
|
238
233
|
The last carryover method content.
|
|
239
234
|
"""
|
|
240
235
|
method_content = '''
|
|
241
|
-
def callable_message(sender, recipient, context):
|
|
242
|
-
# type: (ConversableAgent, ConversableAgent, dict) -> Union[dict, str]
|
|
243
236
|
"""Get the message to send using the last carryover.
|
|
244
237
|
|
|
245
238
|
Parameters
|
|
@@ -248,17 +241,22 @@ def callable_message(sender, recipient, context):
|
|
|
248
241
|
The source agent.
|
|
249
242
|
recipient : ConversableAgent
|
|
250
243
|
The target agent.
|
|
251
|
-
context :
|
|
244
|
+
context : Dict[str, Any]
|
|
252
245
|
The context.
|
|
253
246
|
|
|
254
247
|
Returns
|
|
255
248
|
-------
|
|
256
|
-
Union[
|
|
249
|
+
Union[Dict[str, Any], str]
|
|
257
250
|
The message to send using the last carryover.
|
|
258
251
|
"""
|
|
259
252
|
carryover = context.get("carryover", "")
|
|
260
253
|
if isinstance(carryover, list):
|
|
261
254
|
carryover = carryover[-1]
|
|
255
|
+
if not isinstance(carryover, str):
|
|
256
|
+
if isinstance(carryover, list):
|
|
257
|
+
carryover = carryover[-1]
|
|
258
|
+
elif isinstance(carryover, dict):
|
|
259
|
+
carryover = carryover.get("content", "")
|
|
262
260
|
if not isinstance(carryover, str):
|
|
263
261
|
carryover = ""'''
|
|
264
262
|
if text_content:
|
|
@@ -273,9 +271,7 @@ def callable_message(sender, recipient, context):
|
|
|
273
271
|
return method_content
|
|
274
272
|
|
|
275
273
|
|
|
276
|
-
|
|
277
|
-
def callable_message(sender, recipient, context):
|
|
278
|
-
# type: (RetrieveUserProxyAgent, ConversableAgent, dict) -> Union[dict, str]
|
|
274
|
+
RAG_METHOD_WITH_CARRYOVER_BODY = '''
|
|
279
275
|
"""Get the message using the RAG message generator method.
|
|
280
276
|
|
|
281
277
|
Parameters
|
|
@@ -284,17 +280,22 @@ def callable_message(sender, recipient, context):
|
|
|
284
280
|
The source agent.
|
|
285
281
|
recipient : ConversableAgent
|
|
286
282
|
The target agent.
|
|
287
|
-
context :
|
|
283
|
+
context : Dict[str, Any]
|
|
288
284
|
The context.
|
|
289
285
|
|
|
290
286
|
Returns
|
|
291
287
|
-------
|
|
292
|
-
Union[
|
|
288
|
+
Union[Dict[str, Any], str]
|
|
293
289
|
The message to send using the last carryover.
|
|
294
290
|
"""
|
|
295
291
|
carryover = context.get("carryover", "")
|
|
296
292
|
if isinstance(carryover, list):
|
|
297
293
|
carryover = carryover[-1]
|
|
294
|
+
if not isinstance(carryover, str):
|
|
295
|
+
if isinstance(carryover, list):
|
|
296
|
+
carryover = carryover[-1]
|
|
297
|
+
elif isinstance(carryover, dict):
|
|
298
|
+
carryover = carryover.get("content", "")
|
|
298
299
|
if not isinstance(carryover, str):
|
|
299
300
|
carryover = ""
|
|
300
301
|
message = sender.message_generator(sender, recipient, context)
|
|
@@ -302,3 +303,11 @@ def callable_message(sender, recipient, context):
|
|
|
302
303
|
message += carryover
|
|
303
304
|
return message
|
|
304
305
|
'''
|
|
306
|
+
RAG_METHOD_WITH_CARRYOVER = (
|
|
307
|
+
"def callable_message(\n"
|
|
308
|
+
" sender: RetrieveUserProxyAgent,\n"
|
|
309
|
+
" recipient: ConversableAgent,\n"
|
|
310
|
+
" context: Dict[str, Any],\n"
|
|
311
|
+
") -> Union[Dict[str, Any], str]:"
|
|
312
|
+
f"{RAG_METHOD_WITH_CARRYOVER_BODY}"
|
|
313
|
+
)
|