waldiez 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waldiez might be problematic. Click here for more details.
- waldiez/__init__.py +1 -1
- waldiez/_version.py +1 -1
- waldiez/cli.py +19 -7
- waldiez/cli_extras/jupyter.py +3 -0
- waldiez/cli_extras/runner.py +3 -1
- waldiez/cli_extras/studio.py +3 -1
- waldiez/exporter.py +9 -3
- waldiez/exporting/agent/exporter.py +15 -16
- waldiez/exporting/agent/extras/captain_agent_extras.py +6 -6
- waldiez/exporting/agent/extras/doc_agent_extras.py +6 -6
- waldiez/exporting/agent/extras/group_manager_agent_extas.py +40 -24
- waldiez/exporting/agent/extras/group_member_extras.py +6 -5
- waldiez/exporting/agent/extras/handoffs/after_work.py +2 -1
- waldiez/exporting/agent/extras/handoffs/available.py +2 -1
- waldiez/exporting/agent/extras/handoffs/condition.py +3 -2
- waldiez/exporting/agent/extras/handoffs/handoff.py +2 -1
- waldiez/exporting/agent/extras/handoffs/target.py +7 -4
- waldiez/exporting/agent/extras/rag/chroma_extras.py +27 -19
- waldiez/exporting/agent/extras/rag/mongo_extras.py +8 -8
- waldiez/exporting/agent/extras/rag/pgvector_extras.py +5 -5
- waldiez/exporting/agent/extras/rag/qdrant_extras.py +5 -4
- waldiez/exporting/agent/extras/rag/vector_db_extras.py +1 -1
- waldiez/exporting/agent/extras/rag_user_proxy_agent_extras.py +5 -7
- waldiez/exporting/agent/extras/reasoning_agent_extras.py +3 -5
- waldiez/exporting/agent/termination.py +1 -0
- waldiez/exporting/chats/exporter.py +4 -4
- waldiez/exporting/chats/processor.py +1 -2
- waldiez/exporting/chats/utils/common.py +89 -48
- waldiez/exporting/chats/utils/group.py +9 -9
- waldiez/exporting/chats/utils/nested.py +7 -7
- waldiez/exporting/chats/utils/sequential.py +1 -1
- waldiez/exporting/chats/utils/single.py +2 -2
- waldiez/exporting/core/constants.py +3 -1
- waldiez/exporting/core/content.py +7 -7
- waldiez/exporting/core/context.py +5 -3
- waldiez/exporting/core/exporter.py +5 -3
- waldiez/exporting/core/exporters.py +2 -2
- waldiez/exporting/core/extras/agent_extras/captain_extras.py +2 -2
- waldiez/exporting/core/extras/agent_extras/group_manager_extras.py +2 -2
- waldiez/exporting/core/extras/agent_extras/rag_user_extras.py +2 -2
- waldiez/exporting/core/extras/agent_extras/standard_extras.py +3 -8
- waldiez/exporting/core/extras/base.py +7 -5
- waldiez/exporting/core/extras/flow_extras.py +4 -5
- waldiez/exporting/core/extras/model_extras.py +2 -2
- waldiez/exporting/core/extras/path_resolver.py +1 -2
- waldiez/exporting/core/extras/serializer.py +13 -11
- waldiez/exporting/core/protocols.py +6 -5
- waldiez/exporting/core/result.py +25 -28
- waldiez/exporting/core/types.py +11 -10
- waldiez/exporting/core/utils/llm_config.py +4 -4
- waldiez/exporting/core/validation.py +10 -11
- waldiez/exporting/flow/execution_generator.py +99 -10
- waldiez/exporting/flow/exporter.py +2 -2
- waldiez/exporting/flow/factory.py +2 -2
- waldiez/exporting/flow/file_generator.py +4 -2
- waldiez/exporting/flow/merger.py +5 -3
- waldiez/exporting/flow/orchestrator.py +72 -2
- waldiez/exporting/flow/utils/common.py +6 -6
- waldiez/exporting/flow/utils/importing.py +7 -8
- waldiez/exporting/flow/utils/linting.py +25 -9
- waldiez/exporting/flow/utils/logging.py +5 -77
- waldiez/exporting/models/exporter.py +8 -8
- waldiez/exporting/models/processor.py +5 -5
- waldiez/exporting/tools/exporter.py +2 -2
- waldiez/exporting/tools/processor.py +7 -4
- waldiez/io/__init__.py +11 -5
- waldiez/io/_ws.py +12 -6
- waldiez/io/models/constants.py +10 -10
- waldiez/io/models/content/audio.py +1 -0
- waldiez/io/models/content/base.py +20 -18
- waldiez/io/models/content/file.py +1 -0
- waldiez/io/models/content/image.py +1 -0
- waldiez/io/models/content/text.py +1 -0
- waldiez/io/models/content/video.py +1 -0
- waldiez/io/models/user_input.py +10 -5
- waldiez/io/models/user_response.py +17 -16
- waldiez/io/mqtt.py +18 -31
- waldiez/io/redis.py +18 -22
- waldiez/io/structured.py +122 -70
- waldiez/io/utils.py +19 -10
- waldiez/io/ws.py +7 -3
- waldiez/logger.py +16 -3
- waldiez/models/agents/__init__.py +3 -0
- waldiez/models/agents/agent/agent.py +25 -17
- waldiez/models/agents/agent/agent_data.py +25 -22
- waldiez/models/agents/agent/code_execution.py +9 -11
- waldiez/models/agents/agent/termination_message.py +10 -12
- waldiez/models/agents/agent/update_system_message.py +2 -4
- waldiez/models/agents/agents.py +8 -8
- waldiez/models/agents/assistant/assistant.py +6 -3
- waldiez/models/agents/assistant/assistant_data.py +2 -2
- waldiez/models/agents/captain/captain_agent.py +7 -4
- waldiez/models/agents/captain/captain_agent_data.py +5 -7
- waldiez/models/agents/doc_agent/doc_agent.py +7 -4
- waldiez/models/agents/doc_agent/doc_agent_data.py +9 -10
- waldiez/models/agents/doc_agent/rag_query_engine.py +10 -12
- waldiez/models/agents/extra_requirements.py +3 -3
- waldiez/models/agents/group_manager/group_manager.py +12 -7
- waldiez/models/agents/group_manager/group_manager_data.py +13 -12
- waldiez/models/agents/group_manager/speakers.py +17 -19
- waldiez/models/agents/rag_user_proxy/rag_user_proxy.py +7 -4
- waldiez/models/agents/rag_user_proxy/rag_user_proxy_data.py +4 -1
- waldiez/models/agents/rag_user_proxy/retrieve_config.py +69 -63
- waldiez/models/agents/rag_user_proxy/vector_db_config.py +19 -19
- waldiez/models/agents/reasoning/reasoning_agent.py +7 -4
- waldiez/models/agents/reasoning/reasoning_agent_data.py +3 -2
- waldiez/models/agents/reasoning/reasoning_agent_reason_config.py +8 -8
- waldiez/models/agents/user_proxy/user_proxy.py +6 -3
- waldiez/models/agents/user_proxy/user_proxy_data.py +1 -1
- waldiez/models/chat/chat.py +28 -20
- waldiez/models/chat/chat_data.py +22 -21
- waldiez/models/chat/chat_message.py +9 -9
- waldiez/models/chat/chat_nested.py +9 -9
- waldiez/models/chat/chat_summary.py +6 -6
- waldiez/models/common/__init__.py +2 -0
- waldiez/models/common/ag2_version.py +2 -0
- waldiez/models/common/base.py +2 -0
- waldiez/models/common/dict_utils.py +8 -6
- waldiez/models/common/handoff.py +20 -17
- waldiez/models/common/method_utils.py +9 -7
- waldiez/models/common/naming.py +49 -0
- waldiez/models/flow/flow.py +11 -6
- waldiez/models/flow/flow_data.py +23 -17
- waldiez/models/flow/info.py +3 -3
- waldiez/models/flow/naming.py +2 -1
- waldiez/models/model/_aws.py +11 -13
- waldiez/models/model/_llm.py +8 -0
- waldiez/models/model/_price.py +2 -4
- waldiez/models/model/extra_requirements.py +1 -3
- waldiez/models/model/model.py +2 -2
- waldiez/models/model/model_data.py +21 -21
- waldiez/models/tool/extra_requirements.py +2 -4
- waldiez/models/tool/predefined/_duckduckgo.py +1 -0
- waldiez/models/tool/predefined/_email.py +4 -0
- waldiez/models/tool/predefined/_google.py +1 -0
- waldiez/models/tool/predefined/_perplexity.py +2 -1
- waldiez/models/tool/predefined/_searxng.py +2 -1
- waldiez/models/tool/predefined/_tavily.py +1 -0
- waldiez/models/tool/predefined/_wikipedia.py +2 -1
- waldiez/models/tool/predefined/_youtube.py +1 -0
- waldiez/models/tool/tool.py +8 -5
- waldiez/models/tool/tool_data.py +2 -2
- waldiez/models/waldiez.py +152 -4
- waldiez/runner.py +11 -5
- waldiez/running/async_utils.py +192 -0
- waldiez/running/base_runner.py +155 -241
- waldiez/running/dir_utils.py +52 -0
- waldiez/running/environment.py +10 -44
- waldiez/running/events_mixin.py +252 -0
- waldiez/running/exceptions.py +20 -0
- waldiez/running/gen_seq_diagram.py +18 -15
- waldiez/running/io_utils.py +216 -0
- waldiez/running/protocol.py +11 -5
- waldiez/running/requirements_mixin.py +65 -0
- waldiez/running/results_mixin.py +926 -0
- waldiez/running/standard_runner.py +24 -27
- waldiez/running/step_by_step/breakpoints_mixin.py +503 -47
- waldiez/running/step_by_step/command_handler.py +154 -0
- waldiez/running/step_by_step/events_processor.py +379 -0
- waldiez/running/step_by_step/step_by_step_models.py +425 -41
- waldiez/running/step_by_step/step_by_step_runner.py +437 -382
- waldiez/running/subprocess_runner/__base__.py +13 -8
- waldiez/running/subprocess_runner/_async_runner.py +6 -4
- waldiez/running/subprocess_runner/_sync_runner.py +11 -6
- waldiez/running/subprocess_runner/runner.py +48 -23
- waldiez/running/timeline_processor.py +1 -1
- waldiez/utils/__init__.py +2 -0
- waldiez/utils/conflict_checker.py +4 -4
- waldiez/utils/python_manager.py +415 -0
- waldiez/ws/__init__.py +8 -7
- waldiez/ws/_file_handler.py +18 -20
- waldiez/ws/_mock.py +75 -0
- waldiez/ws/cli.py +58 -10
- waldiez/ws/client_manager.py +77 -53
- waldiez/ws/errors.py +3 -0
- waldiez/ws/models.py +61 -53
- waldiez/ws/reloader.py +33 -4
- waldiez/ws/server.py +121 -52
- waldiez/ws/session_manager.py +8 -9
- waldiez/ws/session_stats.py +1 -1
- waldiez/ws/utils.py +33 -5
- {waldiez-0.5.10.dist-info → waldiez-0.6.1.dist-info}/METADATA +107 -109
- waldiez-0.6.1.dist-info/RECORD +254 -0
- waldiez/running/post_run.py +0 -180
- waldiez/running/pre_run.py +0 -159
- waldiez/running/run_results.py +0 -14
- waldiez/running/utils.py +0 -511
- waldiez-0.5.10.dist-info/RECORD +0 -248
- {waldiez-0.5.10.dist-info → waldiez-0.6.1.dist-info}/WHEEL +0 -0
- {waldiez-0.5.10.dist-info → waldiez-0.6.1.dist-info}/entry_points.txt +0 -0
- {waldiez-0.5.10.dist-info → waldiez-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {waldiez-0.5.10.dist-info → waldiez-0.6.1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0.
|
|
2
2
|
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
|
|
4
|
+
# pyright: reportUnnecessaryIsInstance=false,reportUnknownVariableType=false
|
|
5
|
+
# pyright: reportUnknownArgumentType=false,reportArgumentType=false
|
|
6
|
+
|
|
3
7
|
"""User response model and validation."""
|
|
4
8
|
|
|
5
9
|
import json
|
|
6
10
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Callable
|
|
11
|
+
from typing import Any, Callable
|
|
8
12
|
|
|
9
13
|
from pydantic import ValidationError, field_validator
|
|
10
14
|
|
|
@@ -26,7 +30,7 @@ class UserResponse(StructuredBase):
|
|
|
26
30
|
@classmethod
|
|
27
31
|
def validate_data(
|
|
28
32
|
cls, value: Any
|
|
29
|
-
) ->
|
|
33
|
+
) -> str | UserInputData | list[UserInputData] | None:
|
|
30
34
|
"""Validate the data field in UserResponse.
|
|
31
35
|
|
|
32
36
|
Parameters
|
|
@@ -51,16 +55,16 @@ class UserResponse(StructuredBase):
|
|
|
51
55
|
|
|
52
56
|
handlers: dict[
|
|
53
57
|
type,
|
|
54
|
-
Callable[[Any],
|
|
58
|
+
Callable[[Any], str | UserInputData | list[UserInputData]],
|
|
55
59
|
] = {
|
|
56
60
|
str: cls._handle_string,
|
|
57
61
|
dict: cls._handle_dict,
|
|
58
62
|
list: cls._handle_list,
|
|
59
63
|
}
|
|
60
64
|
|
|
61
|
-
value_type = type(value)
|
|
65
|
+
value_type = type(value)
|
|
62
66
|
handler = handlers.get(
|
|
63
|
-
value_type,
|
|
67
|
+
value_type,
|
|
64
68
|
cls._handle_default,
|
|
65
69
|
)
|
|
66
70
|
result = handler(value)
|
|
@@ -71,16 +75,13 @@ class UserResponse(StructuredBase):
|
|
|
71
75
|
"""Check if value is already a valid type."""
|
|
72
76
|
return isinstance(value, UserInputData) or (
|
|
73
77
|
isinstance(value, list)
|
|
74
|
-
and all(
|
|
75
|
-
isinstance(item, UserInputData)
|
|
76
|
-
for item in value # pyright: ignore
|
|
77
|
-
)
|
|
78
|
+
and all(isinstance(item, UserInputData) for item in value)
|
|
78
79
|
)
|
|
79
80
|
|
|
80
81
|
@classmethod
|
|
81
82
|
def _handle_string(
|
|
82
83
|
cls, value: str
|
|
83
|
-
) ->
|
|
84
|
+
) -> str | UserInputData | list[UserInputData]:
|
|
84
85
|
"""Handle string input.
|
|
85
86
|
|
|
86
87
|
Parameters
|
|
@@ -97,9 +98,9 @@ class UserResponse(StructuredBase):
|
|
|
97
98
|
try:
|
|
98
99
|
parsed_value = json.loads(value)
|
|
99
100
|
if isinstance(parsed_value, dict):
|
|
100
|
-
return cls._handle_dict(parsed_value)
|
|
101
|
+
return cls._handle_dict(parsed_value)
|
|
101
102
|
if isinstance(parsed_value, list):
|
|
102
|
-
return cls._handle_list(parsed_value)
|
|
103
|
+
return cls._handle_list(parsed_value)
|
|
103
104
|
return cls._create_text_input(str(parsed_value))
|
|
104
105
|
except json.JSONDecodeError:
|
|
105
106
|
return cls._create_text_input(value)
|
|
@@ -127,18 +128,18 @@ class UserResponse(StructuredBase):
|
|
|
127
128
|
@classmethod
|
|
128
129
|
def _handle_list(
|
|
129
130
|
cls, value: list[Any]
|
|
130
|
-
) ->
|
|
131
|
+
) -> UserInputData | list[UserInputData]:
|
|
131
132
|
result: list[UserInputData] = []
|
|
132
133
|
|
|
133
134
|
for item in value:
|
|
134
135
|
if isinstance(item, UserInputData):
|
|
135
136
|
result.append(item)
|
|
136
137
|
elif isinstance(item, dict):
|
|
137
|
-
result.append(cls._handle_dict(item))
|
|
138
|
+
result.append(cls._handle_dict(item))
|
|
138
139
|
elif isinstance(item, str):
|
|
139
140
|
result.append(cls._create_text_input(item))
|
|
140
141
|
elif isinstance(item, list):
|
|
141
|
-
nested_result = cls._handle_list(item)
|
|
142
|
+
nested_result = cls._handle_list(item)
|
|
142
143
|
if isinstance(nested_result, list):
|
|
143
144
|
result.extend(nested_result)
|
|
144
145
|
else:
|
|
@@ -207,7 +208,7 @@ class UserResponse(StructuredBase):
|
|
|
207
208
|
uploads_root=uploads_root, base_name=base_name
|
|
208
209
|
).strip()
|
|
209
210
|
# we have probably returned sth till here
|
|
210
|
-
if isinstance(self.data, str): #
|
|
211
|
+
if isinstance(self.data, str): # pragma: no cover
|
|
211
212
|
return self.data
|
|
212
213
|
# noinspection PyUnreachableCode
|
|
213
214
|
return ( # pragma: no cover
|
waldiez/io/mqtt.py
CHANGED
|
@@ -7,6 +7,9 @@
|
|
|
7
7
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
8
8
|
# pylint: disable=too-many-locals,too-many-instance-attributes
|
|
9
9
|
|
|
10
|
+
# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
|
|
11
|
+
# pyright: reportUnusedParameter=false
|
|
12
|
+
|
|
10
13
|
"""An MQTT I/O stream for handling print and input messages."""
|
|
11
14
|
|
|
12
15
|
import json
|
|
@@ -17,12 +20,7 @@ import uuid
|
|
|
17
20
|
from pathlib import Path
|
|
18
21
|
from threading import Event, Lock
|
|
19
22
|
from types import TracebackType
|
|
20
|
-
from typing import
|
|
21
|
-
Any,
|
|
22
|
-
Callable,
|
|
23
|
-
Optional,
|
|
24
|
-
Type,
|
|
25
|
-
)
|
|
23
|
+
from typing import Any, Callable
|
|
26
24
|
|
|
27
25
|
try:
|
|
28
26
|
from paho.mqtt import client as mqtt
|
|
@@ -44,7 +42,7 @@ from .models import (
|
|
|
44
42
|
UserInputRequest,
|
|
45
43
|
UserResponse,
|
|
46
44
|
)
|
|
47
|
-
from .utils import gen_id, now
|
|
45
|
+
from .utils import gen_id, get_message_dump, now
|
|
48
46
|
|
|
49
47
|
LOG = logging.getLogger(__name__)
|
|
50
48
|
|
|
@@ -62,8 +60,8 @@ class MqttIOStream(IOStream):
|
|
|
62
60
|
client: mqtt.Client
|
|
63
61
|
task_id: str
|
|
64
62
|
input_timeout: int
|
|
65
|
-
on_input_request:
|
|
66
|
-
|
|
63
|
+
on_input_request: Callable[[str, str, str], None] | None
|
|
64
|
+
on_input_response: Callable[[str, str], None] | None
|
|
67
65
|
max_retain_messages: int
|
|
68
66
|
output_topic: str
|
|
69
67
|
input_request_topic: str
|
|
@@ -87,8 +85,8 @@ class MqttIOStream(IOStream):
|
|
|
87
85
|
input_timeout: int = 120,
|
|
88
86
|
connect_timeout: int = 10,
|
|
89
87
|
max_retain_messages: int = 1000,
|
|
90
|
-
on_input_request:
|
|
91
|
-
on_input_response:
|
|
88
|
+
on_input_request: Callable[[str, str, str], None] | None = None,
|
|
89
|
+
on_input_response: Callable[[str, str], None] | None = None,
|
|
92
90
|
mqtt_client_kwargs: dict[str, Any] | None = None,
|
|
93
91
|
uploads_root: Path | str | None = None,
|
|
94
92
|
username: str | None = None,
|
|
@@ -173,9 +171,9 @@ class MqttIOStream(IOStream):
|
|
|
173
171
|
# Set up TLS
|
|
174
172
|
if use_tls:
|
|
175
173
|
if ca_cert_path:
|
|
176
|
-
self.client.tls_set(ca_cert_path)
|
|
174
|
+
self.client.tls_set(ca_cert_path)
|
|
177
175
|
else: # pragma: no cover
|
|
178
|
-
self.client.tls_set()
|
|
176
|
+
self.client.tls_set()
|
|
179
177
|
|
|
180
178
|
# Set up callbacks
|
|
181
179
|
self.client.on_connect = self._on_connect
|
|
@@ -320,11 +318,13 @@ class MqttIOStream(IOStream):
|
|
|
320
318
|
LOG.debug(
|
|
321
319
|
"Received message on topic %s: %s",
|
|
322
320
|
msg.topic,
|
|
323
|
-
msg.payload.decode(),
|
|
321
|
+
msg.payload.decode("utf-8", errors="replace"),
|
|
324
322
|
)
|
|
325
323
|
|
|
326
324
|
if msg.topic == self.input_response_topic: # pragma: no branch
|
|
327
|
-
self._handle_input_response(
|
|
325
|
+
self._handle_input_response(
|
|
326
|
+
msg.payload.decode("utf-8", errors="replace")
|
|
327
|
+
)
|
|
328
328
|
|
|
329
329
|
except Exception as e: # pragma: no cover
|
|
330
330
|
LOG.error("Error handling message: %s", e)
|
|
@@ -396,7 +396,7 @@ class MqttIOStream(IOStream):
|
|
|
396
396
|
|
|
397
397
|
def __exit__(
|
|
398
398
|
self,
|
|
399
|
-
exc_type:
|
|
399
|
+
exc_type: type[Exception] | None,
|
|
400
400
|
exc_value: Exception | None,
|
|
401
401
|
traceback: TracebackType | None,
|
|
402
402
|
) -> None:
|
|
@@ -487,23 +487,10 @@ class MqttIOStream(IOStream):
|
|
|
487
487
|
message : BaseEvent | BaseMessage
|
|
488
488
|
The message or event to send.
|
|
489
489
|
"""
|
|
490
|
-
|
|
491
|
-
message_dump = message.model_dump(mode="json")
|
|
492
|
-
except Exception:
|
|
493
|
-
try:
|
|
494
|
-
message_dump = message.model_dump(
|
|
495
|
-
serialize_as_any=True, mode="json", fallback=str
|
|
496
|
-
)
|
|
497
|
-
except Exception as e:
|
|
498
|
-
message_dump = {
|
|
499
|
-
"error": str(e),
|
|
500
|
-
"type": message.__class__.__name__,
|
|
501
|
-
}
|
|
502
|
-
|
|
490
|
+
message_dump = get_message_dump(message)
|
|
503
491
|
message_type = message_dump.get("type", None)
|
|
504
492
|
if not message_type: # pragma: no cover
|
|
505
493
|
message_type = message.__class__.__name__
|
|
506
|
-
|
|
507
494
|
self._print(
|
|
508
495
|
{
|
|
509
496
|
"type": message_type,
|
|
@@ -652,7 +639,7 @@ class MqttIOStream(IOStream):
|
|
|
652
639
|
@staticmethod
|
|
653
640
|
def _create_user_response(
|
|
654
641
|
message_data: dict[str, Any],
|
|
655
|
-
) ->
|
|
642
|
+
) -> UserResponse | None:
|
|
656
643
|
"""Create UserResponse object from validated data."""
|
|
657
644
|
try:
|
|
658
645
|
# Handle nested JSON in 'data' field
|
waldiez/io/redis.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# flake8: noqa: E501
|
|
5
5
|
# pylint: disable=too-many-try-statements,broad-exception-caught
|
|
6
6
|
# pylint: disable=line-too-long,duplicate-code
|
|
7
|
+
# pyright: reportMissingTypeStubs=false,reportUnknownArgumentType=false
|
|
8
|
+
# pyright: reportUnknownMemberType=false
|
|
7
9
|
|
|
8
10
|
"""A Redis I/O stream for handling print and input messages."""
|
|
9
11
|
|
|
@@ -12,16 +14,10 @@ import logging
|
|
|
12
14
|
import time
|
|
13
15
|
import traceback as tb
|
|
14
16
|
import uuid
|
|
17
|
+
from collections.abc import Awaitable
|
|
15
18
|
from pathlib import Path
|
|
16
19
|
from types import TracebackType
|
|
17
|
-
from typing import
|
|
18
|
-
TYPE_CHECKING,
|
|
19
|
-
Any,
|
|
20
|
-
Awaitable,
|
|
21
|
-
Callable,
|
|
22
|
-
Optional,
|
|
23
|
-
Type,
|
|
24
|
-
)
|
|
20
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
25
21
|
|
|
26
22
|
try:
|
|
27
23
|
import redis
|
|
@@ -60,10 +56,10 @@ class RedisIOStream(IOStream):
|
|
|
60
56
|
redis: Redis
|
|
61
57
|
task_id: str
|
|
62
58
|
input_timeout: int
|
|
63
|
-
on_input_request:
|
|
64
|
-
|
|
59
|
+
on_input_request: Callable[[str, str, str], None] | None
|
|
60
|
+
on_input_response: Callable[[str, str], None] | None
|
|
65
61
|
max_stream_size: int
|
|
66
|
-
|
|
62
|
+
task_output_stream: str
|
|
67
63
|
input_request_channel: str
|
|
68
64
|
input_response_channel: str
|
|
69
65
|
|
|
@@ -73,8 +69,8 @@ class RedisIOStream(IOStream):
|
|
|
73
69
|
task_id: str | None = None,
|
|
74
70
|
input_timeout: int = 120,
|
|
75
71
|
max_stream_size: int = 1000,
|
|
76
|
-
on_input_request:
|
|
77
|
-
on_input_response:
|
|
72
|
+
on_input_request: Callable[[str, str, str], None] | None = None,
|
|
73
|
+
on_input_response: Callable[[str, str], None] | None = None,
|
|
78
74
|
redis_connection_kwargs: dict[str, Any] | None = None,
|
|
79
75
|
uploads_root: Path | str | None = None,
|
|
80
76
|
) -> None:
|
|
@@ -127,7 +123,7 @@ class RedisIOStream(IOStream):
|
|
|
127
123
|
|
|
128
124
|
def __exit__(
|
|
129
125
|
self,
|
|
130
|
-
exc_type:
|
|
126
|
+
exc_type: type[Exception] | None,
|
|
131
127
|
exc_value: Exception | None,
|
|
132
128
|
traceback: TracebackType | None,
|
|
133
129
|
) -> None:
|
|
@@ -165,7 +161,7 @@ class RedisIOStream(IOStream):
|
|
|
165
161
|
"""
|
|
166
162
|
LOG.debug("Sending print message: %s", payload)
|
|
167
163
|
RedisIOStream.try_do(
|
|
168
|
-
self.redis.xadd,
|
|
164
|
+
self.redis.xadd,
|
|
169
165
|
self.task_output_stream,
|
|
170
166
|
payload,
|
|
171
167
|
maxlen=self.max_stream_size,
|
|
@@ -182,7 +178,7 @@ class RedisIOStream(IOStream):
|
|
|
182
178
|
"""
|
|
183
179
|
LOG.debug("Sending print message: %s", payload)
|
|
184
180
|
RedisIOStream.try_do(
|
|
185
|
-
self.redis.xadd,
|
|
181
|
+
self.redis.xadd,
|
|
186
182
|
self.common_output_stream,
|
|
187
183
|
payload,
|
|
188
184
|
maxlen=self.max_stream_size,
|
|
@@ -414,7 +410,7 @@ class RedisIOStream(IOStream):
|
|
|
414
410
|
)
|
|
415
411
|
|
|
416
412
|
@staticmethod
|
|
417
|
-
def _extract_message_data(data: Any) ->
|
|
413
|
+
def _extract_message_data(data: Any) -> dict[str, Any] | None:
|
|
418
414
|
"""Extract and parse the message data field."""
|
|
419
415
|
message_data = data
|
|
420
416
|
|
|
@@ -431,7 +427,7 @@ class RedisIOStream(IOStream):
|
|
|
431
427
|
LOG.error("Invalid message data format: %s", message_data)
|
|
432
428
|
return None
|
|
433
429
|
|
|
434
|
-
return message_data # pyright: ignore
|
|
430
|
+
return message_data # pyright: ignore[reportUnknownVariableType]
|
|
435
431
|
|
|
436
432
|
@staticmethod
|
|
437
433
|
def _message_has_required_fields(message_data: dict[str, Any]) -> bool:
|
|
@@ -445,7 +441,7 @@ class RedisIOStream(IOStream):
|
|
|
445
441
|
@staticmethod
|
|
446
442
|
def _process_nested_data(
|
|
447
443
|
message_data: dict[str, Any],
|
|
448
|
-
) ->
|
|
444
|
+
) -> dict[str, Any] | None:
|
|
449
445
|
"""Process nested JSON data if present."""
|
|
450
446
|
# Create a copy to avoid modifying the original
|
|
451
447
|
processed_data = message_data.copy()
|
|
@@ -467,7 +463,7 @@ class RedisIOStream(IOStream):
|
|
|
467
463
|
@staticmethod
|
|
468
464
|
def _create_user_response(
|
|
469
465
|
message_data: dict[str, Any],
|
|
470
|
-
) ->
|
|
466
|
+
) -> UserResponse | None:
|
|
471
467
|
"""Create UserResponse object from validated data."""
|
|
472
468
|
try:
|
|
473
469
|
return UserResponse.model_validate(message_data)
|
|
@@ -684,7 +680,7 @@ class RedisIOStream(IOStream):
|
|
|
684
680
|
"""
|
|
685
681
|
for key in redis_client.scan_iter("task:*:output", count=100):
|
|
686
682
|
RedisIOStream.try_do(
|
|
687
|
-
redis_client.xtrim,
|
|
683
|
+
redis_client.xtrim,
|
|
688
684
|
key,
|
|
689
685
|
maxlen=maxlen,
|
|
690
686
|
approximate=approximate,
|
|
@@ -760,7 +756,7 @@ class RedisIOStream(IOStream):
|
|
|
760
756
|
): # pragma: no branch
|
|
761
757
|
before = await redis_client.xlen(key)
|
|
762
758
|
await RedisIOStream.a_try_do(
|
|
763
|
-
redis_client.xtrim,
|
|
759
|
+
redis_client.xtrim,
|
|
764
760
|
key,
|
|
765
761
|
maxlen=maxlen,
|
|
766
762
|
approximate=approximate,
|