waldiez 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waldiez might be problematic. Click here for more details.
- waldiez/_version.py +1 -1
- waldiez/cli.py +1 -0
- waldiez/exporting/agent/exporter.py +6 -6
- waldiez/exporting/agent/extras/group_manager_agent_extas.py +6 -1
- waldiez/exporting/agent/extras/handoffs/after_work.py +1 -0
- waldiez/exporting/agent/extras/handoffs/available.py +1 -0
- waldiez/exporting/agent/extras/handoffs/handoff.py +1 -0
- waldiez/exporting/agent/extras/handoffs/target.py +1 -0
- waldiez/exporting/agent/termination.py +1 -0
- waldiez/exporting/core/constants.py +3 -1
- waldiez/exporting/core/extras/serializer.py +12 -10
- waldiez/exporting/core/types.py +1 -0
- waldiez/exporting/core/utils/llm_config.py +2 -2
- waldiez/exporting/flow/execution_generator.py +1 -0
- waldiez/exporting/flow/utils/common.py +1 -1
- waldiez/exporting/flow/utils/importing.py +1 -1
- waldiez/exporting/flow/utils/logging.py +3 -75
- waldiez/io/__init__.py +3 -1
- waldiez/io/_ws.py +2 -0
- waldiez/io/structured.py +81 -28
- waldiez/io/utils.py +16 -10
- waldiez/io/ws.py +2 -2
- waldiez/models/agents/agent/agent.py +2 -1
- waldiez/models/chat/chat.py +1 -0
- waldiez/models/chat/chat_data.py +0 -2
- waldiez/models/common/base.py +2 -0
- waldiez/models/common/handoff.py +2 -0
- waldiez/models/common/method_utils.py +2 -0
- waldiez/models/model/_llm.py +3 -0
- waldiez/models/tool/predefined/_email.py +3 -0
- waldiez/models/tool/predefined/_perplexity.py +1 -1
- waldiez/models/tool/predefined/_searxng.py +1 -1
- waldiez/models/tool/predefined/_wikipedia.py +1 -1
- waldiez/running/base_runner.py +81 -20
- waldiez/running/post_run.py +6 -0
- waldiez/running/pre_run.py +167 -45
- waldiez/running/standard_runner.py +5 -5
- waldiez/running/step_by_step/breakpoints_mixin.py +368 -44
- waldiez/running/step_by_step/command_handler.py +151 -0
- waldiez/running/step_by_step/events_processor.py +199 -0
- waldiez/running/step_by_step/step_by_step_models.py +358 -41
- waldiez/running/step_by_step/step_by_step_runner.py +358 -353
- waldiez/running/subprocess_runner/__base__.py +4 -7
- waldiez/running/subprocess_runner/_async_runner.py +1 -1
- waldiez/running/subprocess_runner/_sync_runner.py +5 -4
- waldiez/running/subprocess_runner/runner.py +9 -0
- waldiez/running/utils.py +116 -2
- waldiez/ws/__init__.py +8 -7
- waldiez/ws/_file_handler.py +0 -2
- waldiez/ws/_mock.py +74 -0
- waldiez/ws/cli.py +27 -3
- waldiez/ws/client_manager.py +45 -29
- waldiez/ws/models.py +18 -1
- waldiez/ws/reloader.py +23 -2
- waldiez/ws/server.py +47 -8
- waldiez/ws/utils.py +29 -4
- {waldiez-0.5.10.dist-info → waldiez-0.6.0.dist-info}/METADATA +53 -44
- {waldiez-0.5.10.dist-info → waldiez-0.6.0.dist-info}/RECORD +62 -59
- {waldiez-0.5.10.dist-info → waldiez-0.6.0.dist-info}/WHEEL +0 -0
- {waldiez-0.5.10.dist-info → waldiez-0.6.0.dist-info}/entry_points.txt +0 -0
- {waldiez-0.5.10.dist-info → waldiez-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {waldiez-0.5.10.dist-info → waldiez-0.6.0.dist-info}/licenses/NOTICE.md +0 -0
waldiez/_version.py
CHANGED
waldiez/cli.py
CHANGED
|
@@ -73,17 +73,17 @@ class AgentExporter(Exporter[StandardExtras]):
|
|
|
73
73
|
Whether the flow is async, by default False
|
|
74
74
|
for_notebook : bool, optional
|
|
75
75
|
Whether exporting for notebook, by default False
|
|
76
|
-
cache_seed :
|
|
76
|
+
cache_seed : int, optional
|
|
77
77
|
Cache seed if any, by default None
|
|
78
|
-
initial_chats :
|
|
78
|
+
initial_chats : list[WaldiezAgentConnection], optional
|
|
79
79
|
Initial chats for group managers, by default None
|
|
80
|
-
group_chat_members :
|
|
80
|
+
group_chat_members : list[WaldiezAgent], optional
|
|
81
81
|
Group chat members if group manager, by default None
|
|
82
|
-
arguments_resolver :
|
|
82
|
+
arguments_resolver : Callable, optional
|
|
83
83
|
Function to resolve additional arguments, by default None
|
|
84
|
-
output_dir :
|
|
84
|
+
output_dir : str | Path, optional
|
|
85
85
|
Output directory for generated files, by default None
|
|
86
|
-
context :
|
|
86
|
+
context : ExporterContext, optional
|
|
87
87
|
Exporter context with dependencies, by default None
|
|
88
88
|
**kwargs : Any
|
|
89
89
|
Additional keyword arguments.
|
|
@@ -252,8 +252,12 @@ class GroupManagerProcessor:
|
|
|
252
252
|
cache_seed=self.cache_seed,
|
|
253
253
|
as_dict=True,
|
|
254
254
|
)
|
|
255
|
+
manager_name = self.agent_names[self.agent.id]
|
|
255
256
|
pattern_lines.append(
|
|
256
|
-
|
|
257
|
+
" group_manager_args={\n"
|
|
258
|
+
f"{llm_config_arg}"
|
|
259
|
+
f' "name": "{manager_name}",\n'
|
|
260
|
+
" },"
|
|
257
261
|
)
|
|
258
262
|
|
|
259
263
|
# Add context variables if present
|
|
@@ -445,6 +449,7 @@ class GroupManagerProcessor:
|
|
|
445
449
|
|
|
446
450
|
return self._get_transition_target(self.agent.data.after_work)
|
|
447
451
|
|
|
452
|
+
# noinspection PyTypeHints
|
|
448
453
|
def _get_transition_target(
|
|
449
454
|
self, target: WaldiezTransitionTarget
|
|
450
455
|
) -> tuple[str, str]:
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
3
|
"""Constants for Waldiez exporting core."""
|
|
4
4
|
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
5
7
|
from .enums import (
|
|
6
8
|
AgentPosition,
|
|
7
9
|
ExportPosition,
|
|
@@ -10,7 +12,7 @@ from .enums import (
|
|
|
10
12
|
|
|
11
13
|
FILE_HEADER = (
|
|
12
14
|
"# SPDX-License-Identifier: Apache-2.0.\n"
|
|
13
|
-
"# Copyright (c) 2024 -
|
|
15
|
+
f"# Copyright (c) 2024 - {datetime.now().year} Waldiez and contributors."
|
|
14
16
|
)
|
|
15
17
|
DEFAULT_IMPORT_POSITION = ImportPosition.THIRD_PARTY
|
|
16
18
|
DEFAULT_EXPORT_POSITION = ExportPosition.AGENTS
|
|
@@ -10,7 +10,7 @@ from typing import Any, Optional
|
|
|
10
10
|
from ..protocols import Serializer
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
# pylint: disable=too-few-public-methods
|
|
13
|
+
# pylint: disable=too-few-public-methods,no-self-use
|
|
14
14
|
class DefaultSerializer(Serializer):
|
|
15
15
|
"""Default serializer for Waldiez items."""
|
|
16
16
|
|
|
@@ -37,7 +37,7 @@ class DefaultSerializer(Serializer):
|
|
|
37
37
|
def serialize_item(
|
|
38
38
|
item: Any,
|
|
39
39
|
tabs: int = 1,
|
|
40
|
-
|
|
40
|
+
visited: Optional[set[int]] = None,
|
|
41
41
|
) -> str:
|
|
42
42
|
"""Convert an item to a formatted string with given indentation.
|
|
43
43
|
|
|
@@ -47,6 +47,8 @@ def serialize_item(
|
|
|
47
47
|
The item to convert.
|
|
48
48
|
tabs : int, optional
|
|
49
49
|
The number of tabs, by default 1.
|
|
50
|
+
visited : set[int], optional
|
|
51
|
+
A set of visited IDs, by default None
|
|
50
52
|
|
|
51
53
|
Returns
|
|
52
54
|
-------
|
|
@@ -81,8 +83,8 @@ def serialize_item(
|
|
|
81
83
|
}
|
|
82
84
|
```
|
|
83
85
|
"""
|
|
84
|
-
if
|
|
85
|
-
|
|
86
|
+
if visited is None:
|
|
87
|
+
visited = set()
|
|
86
88
|
|
|
87
89
|
if callable(item):
|
|
88
90
|
return item.__name__
|
|
@@ -92,37 +94,37 @@ def serialize_item(
|
|
|
92
94
|
return _format_primitive(item)
|
|
93
95
|
|
|
94
96
|
# Handle circular references in containers
|
|
95
|
-
if isinstance(item, (dict, list, tuple, set)) and id(item) in
|
|
97
|
+
if isinstance(item, (dict, list, tuple, set)) and id(item) in visited:
|
|
96
98
|
return '"<circular reference>"'
|
|
97
99
|
|
|
98
100
|
next_indent = " " * 4 * (tabs + 1)
|
|
99
|
-
|
|
101
|
+
visited.add(id(item))
|
|
100
102
|
|
|
101
103
|
if isinstance(item, dict):
|
|
102
104
|
items: list[str] = []
|
|
103
105
|
for key, value in item.items():
|
|
104
106
|
key_str = f'{next_indent}"{key}"'
|
|
105
|
-
value_str = serialize_item(value, tabs + 1,
|
|
107
|
+
value_str = serialize_item(value, tabs + 1, visited)
|
|
106
108
|
items.append(f"{key_str}: {value_str}")
|
|
107
109
|
return _format_container(items, "{", "}", tabs)
|
|
108
110
|
|
|
109
111
|
if isinstance(item, list):
|
|
110
112
|
items = [
|
|
111
|
-
f"{next_indent}{serialize_item(sub_item, tabs + 1,
|
|
113
|
+
f"{next_indent}{serialize_item(sub_item, tabs + 1, visited)}"
|
|
112
114
|
for sub_item in item
|
|
113
115
|
]
|
|
114
116
|
return _format_container(items, "[", "]", tabs)
|
|
115
117
|
|
|
116
118
|
if isinstance(item, tuple):
|
|
117
119
|
items = [
|
|
118
|
-
f"{next_indent}{serialize_item(sub_item, tabs + 1,
|
|
120
|
+
f"{next_indent}{serialize_item(sub_item, tabs + 1, visited)}"
|
|
119
121
|
for sub_item in item
|
|
120
122
|
]
|
|
121
123
|
return _format_container(items, "(", ")", tabs)
|
|
122
124
|
|
|
123
125
|
if isinstance(item, set):
|
|
124
126
|
items = [
|
|
125
|
-
f"{next_indent}{serialize_item(sub_item, tabs + 1,
|
|
127
|
+
f"{next_indent}{serialize_item(sub_item, tabs + 1, visited)}"
|
|
126
128
|
for sub_item in item
|
|
127
129
|
]
|
|
128
130
|
return _format_container(items, "{", "}", tabs)
|
waldiez/exporting/core/types.py
CHANGED
|
@@ -314,6 +314,7 @@ class ExportConfig:
|
|
|
314
314
|
if for_notebook:
|
|
315
315
|
output_extension = "ipynb"
|
|
316
316
|
cache_seed = kwargs.pop("cache_seed", None)
|
|
317
|
+
# noinspection PyUnreachableCode
|
|
317
318
|
if cache_seed is not None and not isinstance(cache_seed, int):
|
|
318
319
|
cache_seed = None
|
|
319
320
|
return cls(
|
|
@@ -97,7 +97,7 @@ def _get_agent_llm_config_arg_as_dict(
|
|
|
97
97
|
) -> str:
|
|
98
98
|
tab = " " * tab_leng * tabs if tabs > 0 else ""
|
|
99
99
|
if not agent.data.model_ids:
|
|
100
|
-
return f'{tab}"llm_config": False' + "\n"
|
|
100
|
+
return f'{tab}"llm_config": False,' + "\n"
|
|
101
101
|
content = f'{tab}"llm_config": autogen.LLMConfig(' + "\n"
|
|
102
102
|
content += f"{tab} config_list=["
|
|
103
103
|
got_at_least_one_model = False
|
|
@@ -110,7 +110,7 @@ def _get_agent_llm_config_arg_as_dict(
|
|
|
110
110
|
content += "\n" + f"{tab} {model_name}_llm_config,"
|
|
111
111
|
got_at_least_one_model = True
|
|
112
112
|
if not got_at_least_one_model: # pragma: no cover
|
|
113
|
-
return f'{tab}"llm_config": False' + "\n"
|
|
113
|
+
return f'{tab}"llm_config": False,' + "\n"
|
|
114
114
|
content += "\n" + f"{tab} ]," + "\n"
|
|
115
115
|
content += f"{tab} cache_seed={cache_seed}," + "\n"
|
|
116
116
|
if temperature is not None:
|
|
@@ -125,6 +125,7 @@ class ExecutionGenerator:
|
|
|
125
125
|
flow_content += " result_dicts: list[dict[str, Any]] = []\n"
|
|
126
126
|
space = " "
|
|
127
127
|
if cache_seed is not None:
|
|
128
|
+
# noinspection SqlDialectInspection
|
|
128
129
|
flow_content += (
|
|
129
130
|
f" with Cache.disk(cache_seed={cache_seed}"
|
|
130
131
|
") as cache: # pyright: ignore\n"
|
|
@@ -7,6 +7,7 @@ from typing import Optional
|
|
|
7
7
|
from waldiez.exporting.core import ImportPosition
|
|
8
8
|
|
|
9
9
|
BUILTIN_IMPORTS = [
|
|
10
|
+
"import asyncio",
|
|
10
11
|
"import csv",
|
|
11
12
|
"import importlib",
|
|
12
13
|
"import json",
|
|
@@ -208,7 +209,6 @@ def get_the_imports_string(
|
|
|
208
209
|
final_string += "\n"
|
|
209
210
|
|
|
210
211
|
if is_async:
|
|
211
|
-
builtin_imports.insert(0, "import asyncio")
|
|
212
212
|
final_string += (
|
|
213
213
|
"\nimport aiofiles"
|
|
214
214
|
"\nimport aiosqlite"
|
|
@@ -82,47 +82,13 @@ def get_start_logging(is_async: bool, for_notebook: bool) -> str:
|
|
|
82
82
|
'''
|
|
83
83
|
|
|
84
84
|
|
|
85
|
-
# pylint: disable=differing-param-doc,differing-type-doc
|
|
86
|
-
# noinspection PyUnresolvedReferences
|
|
87
85
|
def get_sync_sqlite_out() -> str:
|
|
88
|
-
|
|
86
|
+
"""Get the sqlite to csv and json conversion code string.
|
|
89
87
|
|
|
90
88
|
Returns
|
|
91
89
|
-------
|
|
92
90
|
str
|
|
93
91
|
The sqlite to csv and json conversion code string.
|
|
94
|
-
|
|
95
|
-
Example
|
|
96
|
-
-------
|
|
97
|
-
```python
|
|
98
|
-
>>> get_sqlite_outputs()
|
|
99
|
-
def get_sqlite_out(dbname: str, table: str, csv_file: str) -> None:
|
|
100
|
-
\"\"\"Convert a sqlite table to csv and json files.
|
|
101
|
-
|
|
102
|
-
Parameters
|
|
103
|
-
----------
|
|
104
|
-
dbname : str
|
|
105
|
-
The sqlite database name.
|
|
106
|
-
table : str
|
|
107
|
-
The table name.
|
|
108
|
-
csv_file : str
|
|
109
|
-
The csv file name.
|
|
110
|
-
\"\"\"
|
|
111
|
-
conn = sqlite3.connect(dbname)
|
|
112
|
-
query = f"SELECT * FROM {table}" # nosec
|
|
113
|
-
cursor = conn.execute(query)
|
|
114
|
-
rows = cursor.fetchall()
|
|
115
|
-
column_names = [description[0] for description in cursor.description]
|
|
116
|
-
data = [dict(zip(column_names, row, strict=True)) for row in rows]
|
|
117
|
-
conn.close()
|
|
118
|
-
with open(csv_file, "w", newline="", encoding="utf-8") as file:
|
|
119
|
-
csv_writer = csv.DictWriter(file, fieldnames=column_names)
|
|
120
|
-
csv_writer.writeheader()
|
|
121
|
-
csv_writer.writerows(data)
|
|
122
|
-
json_file = csv_file.replace(".csv", ".json")
|
|
123
|
-
with open(json_file, "w", encoding="utf-8") as file:
|
|
124
|
-
json.dump(data, file, indent=4, ensure_ascii=False)
|
|
125
|
-
```
|
|
126
92
|
"""
|
|
127
93
|
content = "\n\n"
|
|
128
94
|
content += (
|
|
@@ -166,52 +132,14 @@ def get_sync_sqlite_out() -> str:
|
|
|
166
132
|
return content
|
|
167
133
|
|
|
168
134
|
|
|
169
|
-
# pylint: disable=
|
|
170
|
-
# noinspection PyUnresolvedReferences
|
|
135
|
+
# pylint: disable=line-too-long
|
|
171
136
|
def get_async_sqlite_out() -> str:
|
|
172
|
-
|
|
137
|
+
"""Get the sqlite to csv and json conversion code string.
|
|
173
138
|
|
|
174
139
|
Returns
|
|
175
140
|
-------
|
|
176
141
|
str
|
|
177
142
|
The sqlite to csv and json conversion code string.
|
|
178
|
-
|
|
179
|
-
Example
|
|
180
|
-
-------
|
|
181
|
-
```python
|
|
182
|
-
>>> get_sqlite_outputs()
|
|
183
|
-
async def get_sqlite_out(dbname: str, table: str, csv_file: str) -> None:
|
|
184
|
-
\"\"\"Convert a sqlite table to csv and json files.
|
|
185
|
-
|
|
186
|
-
Parameters
|
|
187
|
-
----------
|
|
188
|
-
dbname : str
|
|
189
|
-
The sqlite database name.
|
|
190
|
-
table : str
|
|
191
|
-
The table name.
|
|
192
|
-
csv_file : str
|
|
193
|
-
The csv file name.
|
|
194
|
-
\"\"\"
|
|
195
|
-
conn = await aiosqlite.connect(dbname)
|
|
196
|
-
query = f"SELECT * FROM {table}" # nosec
|
|
197
|
-
try:
|
|
198
|
-
cursor = await conn.execute(query)
|
|
199
|
-
except BaseException: # pylint: disable=broad-exception-caught
|
|
200
|
-
await conn.close()
|
|
201
|
-
return
|
|
202
|
-
rows = await cursor.fetchall()
|
|
203
|
-
column_names = [description[0] for description in cursor.description]
|
|
204
|
-
data = [dict(zip(column_names, row, strict=True)) for row in rows]
|
|
205
|
-
await cursor.close()
|
|
206
|
-
await conn.close()
|
|
207
|
-
async with aiofiles.open(csv_file, "w", newline="", encoding="utf-8") as file:
|
|
208
|
-
csv_writer = csv.DictWriter(file, fieldnames=column_names)
|
|
209
|
-
csv_writer.writeheader()
|
|
210
|
-
csv_writer.writerows(data)
|
|
211
|
-
json_file = csv_file.replace(".csv", ".json")
|
|
212
|
-
async with aiofiles.open(json_file, "w", encoding="utf-8") as file:
|
|
213
|
-
await file.write(json.dumps(data, indent=4, ensure_ascii=False)
|
|
214
|
-
```
|
|
215
143
|
"""
|
|
216
144
|
# fmt: off
|
|
217
145
|
content = "\n\n"
|
waldiez/io/__init__.py
CHANGED
|
@@ -30,7 +30,7 @@ from .models import (
|
|
|
30
30
|
VideoMediaContent,
|
|
31
31
|
)
|
|
32
32
|
from .structured import StructuredIOStream
|
|
33
|
-
from .utils import MediaType, MessageType
|
|
33
|
+
from .utils import DEBUG_INPUT_PROMPT, START_CHAT_PROMPT, MediaType, MessageType
|
|
34
34
|
|
|
35
35
|
try:
|
|
36
36
|
from .redis import RedisIOStream # type: ignore[no-redef,unused-ignore]
|
|
@@ -125,4 +125,6 @@ __all__ = [
|
|
|
125
125
|
"AudioContent",
|
|
126
126
|
"VideoMediaContent",
|
|
127
127
|
"VideoContent",
|
|
128
|
+
"DEBUG_INPUT_PROMPT",
|
|
129
|
+
"START_CHAT_PROMPT",
|
|
128
130
|
]
|
waldiez/io/_ws.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
3
|
# flake8: noqa: E501
|
|
4
4
|
# pylint: disable=line-too-long
|
|
5
|
+
# pyright: reportUnknownMemberType=false,reportUnknownParameterType=false
|
|
6
|
+
# pyright: reportUnknownVariableType=false,reportUnknownArgumentType=false
|
|
5
7
|
"""WebSocket IOStream implementation for AsyncIO."""
|
|
6
8
|
|
|
7
9
|
import asyncio
|
waldiez/io/structured.py
CHANGED
|
@@ -10,7 +10,6 @@ import threading
|
|
|
10
10
|
from getpass import getpass
|
|
11
11
|
from pathlib import Path
|
|
12
12
|
from typing import Any
|
|
13
|
-
from uuid import uuid4
|
|
14
13
|
|
|
15
14
|
from autogen.events import BaseEvent # type: ignore
|
|
16
15
|
from autogen.io import IOStream # type: ignore
|
|
@@ -23,6 +22,9 @@ from .models import (
|
|
|
23
22
|
UserResponse,
|
|
24
23
|
)
|
|
25
24
|
from .utils import (
|
|
25
|
+
DEBUG_INPUT_PROMPT,
|
|
26
|
+
START_CHAT_PROMPT,
|
|
27
|
+
MessageType,
|
|
26
28
|
gen_id,
|
|
27
29
|
get_image,
|
|
28
30
|
get_message_dump,
|
|
@@ -63,29 +65,43 @@ class StructuredIOStream(IOStream):
|
|
|
63
65
|
"""
|
|
64
66
|
sep = kwargs.get("sep", " ")
|
|
65
67
|
end = kwargs.get("end", "\n")
|
|
68
|
+
flush = kwargs.get("flush", True)
|
|
69
|
+
payload_type = kwargs.get("type", "print")
|
|
66
70
|
message = sep.join(map(str, args))
|
|
67
|
-
|
|
68
|
-
|
|
71
|
+
if len(args) == 1 and isinstance(args[0], dict):
|
|
72
|
+
message = args[0] # pyright: ignore
|
|
73
|
+
payload_type = message.get("type", payload_type) # pyright: ignore
|
|
74
|
+
is_dumped = True
|
|
75
|
+
else:
|
|
76
|
+
is_dumped, message = is_json_dumped(message)
|
|
77
|
+
if is_dumped:
|
|
69
78
|
# If the message is already JSON-dumped,
|
|
70
79
|
# let's try not to double dump it
|
|
71
|
-
message = json.loads(message)
|
|
72
80
|
payload: dict[str, Any] = {
|
|
73
|
-
"
|
|
74
|
-
"id": uuid4().hex,
|
|
81
|
+
"id": gen_id(),
|
|
75
82
|
"timestamp": now(),
|
|
76
|
-
"data": message,
|
|
83
|
+
# "data": message,
|
|
77
84
|
}
|
|
85
|
+
if isinstance(message, dict):
|
|
86
|
+
payload.update(message) # pyright: ignore
|
|
87
|
+
else:
|
|
88
|
+
payload["data"] = message
|
|
89
|
+
if "type" not in payload:
|
|
90
|
+
payload["type"] = payload_type
|
|
78
91
|
else:
|
|
79
|
-
message
|
|
80
|
-
payload =
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
92
|
+
print_message = PrintMessage(data=message) # pyright: ignore
|
|
93
|
+
payload = print_message.model_dump(mode="json", fallback=str)
|
|
94
|
+
payload["type"] = payload_type
|
|
95
|
+
dumped = json.dumps(payload, default=str, ensure_ascii=False) + end
|
|
96
|
+
if kwargs.get("file") and kwargs["file"] in [
|
|
97
|
+
sys.stderr,
|
|
98
|
+
sys.__stderr__,
|
|
99
|
+
sys.stdout,
|
|
100
|
+
sys.__stdout__,
|
|
101
|
+
]:
|
|
102
|
+
print(dumped, file=kwargs["file"], flush=flush)
|
|
103
|
+
else:
|
|
104
|
+
print(dumped, flush=flush)
|
|
89
105
|
|
|
90
106
|
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
|
91
107
|
"""Structured input from stdin.
|
|
@@ -102,14 +118,21 @@ class StructuredIOStream(IOStream):
|
|
|
102
118
|
str
|
|
103
119
|
The line read from the input stream.
|
|
104
120
|
"""
|
|
105
|
-
request_id =
|
|
121
|
+
request_id = gen_id()
|
|
106
122
|
prompt = prompt or ">"
|
|
107
123
|
if not prompt or prompt in [">", "> "]: # pragma: no cover
|
|
108
124
|
# if the prompt is just ">" or "> ",
|
|
109
125
|
# let's use a more descriptive one
|
|
110
|
-
prompt =
|
|
111
|
-
|
|
112
|
-
|
|
126
|
+
prompt = START_CHAT_PROMPT
|
|
127
|
+
input_type = "chat"
|
|
128
|
+
if prompt.strip() == DEBUG_INPUT_PROMPT.strip():
|
|
129
|
+
input_type = "debug"
|
|
130
|
+
self._send_input_request(
|
|
131
|
+
prompt,
|
|
132
|
+
request_id,
|
|
133
|
+
password,
|
|
134
|
+
input_type=input_type,
|
|
135
|
+
)
|
|
113
136
|
user_input_raw = self._read_user_input(prompt, password, request_id)
|
|
114
137
|
response = self._handle_user_input(user_input_raw, request_id)
|
|
115
138
|
user_response = response.to_string(
|
|
@@ -152,14 +175,21 @@ class StructuredIOStream(IOStream):
|
|
|
152
175
|
self,
|
|
153
176
|
prompt: str,
|
|
154
177
|
request_id: str,
|
|
155
|
-
password: bool,
|
|
178
|
+
password: bool = False,
|
|
179
|
+
input_type: str = "chat",
|
|
156
180
|
) -> None:
|
|
181
|
+
if input_type not in ("chat", "debug"):
|
|
182
|
+
input_type = "chat"
|
|
183
|
+
request_type = (
|
|
184
|
+
"debug_input_request" if input_type == "debug" else "input_request"
|
|
185
|
+
)
|
|
157
186
|
payload = UserInputRequest(
|
|
187
|
+
type=request_type, # type: ignore
|
|
158
188
|
request_id=request_id,
|
|
159
189
|
prompt=prompt,
|
|
160
190
|
password=password,
|
|
161
191
|
).model_dump(mode="json")
|
|
162
|
-
print(json.dumps(payload), flush=True)
|
|
192
|
+
print(json.dumps(payload, default=str), flush=True)
|
|
163
193
|
|
|
164
194
|
def _read_user_input(
|
|
165
195
|
self,
|
|
@@ -294,7 +324,12 @@ class StructuredIOStream(IOStream):
|
|
|
294
324
|
UserResponse
|
|
295
325
|
The structured user response.
|
|
296
326
|
"""
|
|
297
|
-
|
|
327
|
+
response_type: MessageType
|
|
328
|
+
_response_type = user_input.get("type", "input_response")
|
|
329
|
+
if _response_type not in ("input_response", "debug_input_response"):
|
|
330
|
+
response_type = "input_response"
|
|
331
|
+
else:
|
|
332
|
+
response_type = _response_type
|
|
298
333
|
if user_input.get("request_id") == request_id:
|
|
299
334
|
# We have a valid response to our request
|
|
300
335
|
data = user_input.get("data")
|
|
@@ -302,6 +337,7 @@ class StructuredIOStream(IOStream):
|
|
|
302
337
|
# let's check if text|image keys are sent (outside data)
|
|
303
338
|
if "image" in user_input or "text" in user_input:
|
|
304
339
|
return UserResponse(
|
|
340
|
+
type=response_type,
|
|
305
341
|
request_id=request_id,
|
|
306
342
|
data=self._format_multimedia_response(
|
|
307
343
|
request_id=request_id, data=user_input
|
|
@@ -311,10 +347,12 @@ class StructuredIOStream(IOStream):
|
|
|
311
347
|
return self._handle_list_response(
|
|
312
348
|
data, # pyright: ignore
|
|
313
349
|
request_id=request_id,
|
|
350
|
+
response_type=response_type,
|
|
314
351
|
)
|
|
315
352
|
if not data or not isinstance(data, (str, dict)):
|
|
316
353
|
# No / invalid data provided in the response
|
|
317
354
|
return UserResponse(
|
|
355
|
+
type=response_type,
|
|
318
356
|
request_id=request_id,
|
|
319
357
|
data="",
|
|
320
358
|
)
|
|
@@ -324,6 +362,7 @@ class StructuredIOStream(IOStream):
|
|
|
324
362
|
data = self._load_user_input(data)
|
|
325
363
|
if isinstance(data, dict):
|
|
326
364
|
return UserResponse(
|
|
365
|
+
type=response_type,
|
|
327
366
|
data=self._format_multimedia_response(
|
|
328
367
|
request_id=request_id,
|
|
329
368
|
data=data, # pyright: ignore
|
|
@@ -333,11 +372,14 @@ class StructuredIOStream(IOStream):
|
|
|
333
372
|
# For other types (numbers, bools ,...),
|
|
334
373
|
# let's just convert to string
|
|
335
374
|
return UserResponse(
|
|
336
|
-
|
|
375
|
+
type=response_type,
|
|
376
|
+
data=str(data),
|
|
377
|
+
request_id=request_id,
|
|
337
378
|
) # pragma: no cover
|
|
338
379
|
# This response doesn't match our request_id, log and return empty
|
|
339
380
|
self._log_mismatched_response(request_id, user_input)
|
|
340
381
|
return UserResponse(
|
|
382
|
+
type=response_type,
|
|
341
383
|
request_id=request_id,
|
|
342
384
|
data="",
|
|
343
385
|
)
|
|
@@ -347,10 +389,12 @@ class StructuredIOStream(IOStream):
|
|
|
347
389
|
self,
|
|
348
390
|
data: list[dict[str, Any]],
|
|
349
391
|
request_id: str,
|
|
392
|
+
response_type: MessageType,
|
|
350
393
|
) -> UserResponse:
|
|
351
394
|
if len(data) == 0: # pyright: ignore
|
|
352
395
|
# Empty list, return empty response
|
|
353
396
|
return UserResponse(
|
|
397
|
+
type=response_type,
|
|
354
398
|
request_id=request_id,
|
|
355
399
|
data="",
|
|
356
400
|
)
|
|
@@ -367,10 +411,12 @@ class StructuredIOStream(IOStream):
|
|
|
367
411
|
if not input_data: # pragma: no cover
|
|
368
412
|
# No valid data in the list, return empty response
|
|
369
413
|
return UserResponse(
|
|
414
|
+
type=response_type,
|
|
370
415
|
request_id=request_id,
|
|
371
416
|
data="",
|
|
372
417
|
)
|
|
373
418
|
return UserResponse(
|
|
419
|
+
type=response_type,
|
|
374
420
|
request_id=request_id,
|
|
375
421
|
data=input_data,
|
|
376
422
|
)
|
|
@@ -387,9 +433,16 @@ class StructuredIOStream(IOStream):
|
|
|
387
433
|
The response received
|
|
388
434
|
"""
|
|
389
435
|
# Create a log message
|
|
436
|
+
got_id: str | None = None
|
|
437
|
+
if isinstance(response, dict):
|
|
438
|
+
got_id = response.get("request_id") # pyright: ignore
|
|
439
|
+
response_str = str(response) # pyright: ignore
|
|
440
|
+
message = response_str[:100] + (
|
|
441
|
+
"..." if len(response_str) > 100 else ""
|
|
442
|
+
)
|
|
390
443
|
log_payload: dict[str, Any] = {
|
|
391
444
|
"type": "warning",
|
|
392
|
-
"id":
|
|
445
|
+
"id": gen_id(),
|
|
393
446
|
"timestamp": now(),
|
|
394
447
|
"data": {
|
|
395
448
|
"message": (
|
|
@@ -398,8 +451,8 @@ class StructuredIOStream(IOStream):
|
|
|
398
451
|
),
|
|
399
452
|
"details": {
|
|
400
453
|
"expected_id": expected_id,
|
|
401
|
-
"
|
|
402
|
-
|
|
454
|
+
"received_id": got_id,
|
|
455
|
+
"message": message,
|
|
403
456
|
},
|
|
404
457
|
},
|
|
405
458
|
}
|