waldiez 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waldiez might be problematic. Click here for more details.
- waldiez/__init__.py +5 -5
- waldiez/_version.py +1 -1
- waldiez/cli.py +112 -73
- waldiez/exporter.py +61 -19
- waldiez/exporting/__init__.py +25 -6
- waldiez/exporting/agent/__init__.py +7 -3
- waldiez/exporting/agent/code_execution.py +114 -0
- waldiez/exporting/agent/exporter.py +354 -0
- waldiez/exporting/agent/extras/__init__.py +15 -0
- waldiez/exporting/agent/extras/captain_agent_extras.py +315 -0
- waldiez/exporting/agent/extras/group/target.py +178 -0
- waldiez/exporting/agent/extras/group_manager_agent_extas.py +500 -0
- waldiez/exporting/agent/extras/group_member_extras.py +181 -0
- waldiez/exporting/agent/extras/handoffs/__init__.py +19 -0
- waldiez/exporting/agent/extras/handoffs/after_work.py +78 -0
- waldiez/exporting/agent/extras/handoffs/available.py +74 -0
- waldiez/exporting/agent/extras/handoffs/condition.py +158 -0
- waldiez/exporting/agent/extras/handoffs/handoff.py +171 -0
- waldiez/exporting/agent/extras/handoffs/target.py +189 -0
- waldiez/exporting/agent/extras/rag/__init__.py +10 -0
- waldiez/exporting/agent/{utils/rag_user/chroma_utils.py → extras/rag/chroma_extras.py} +16 -15
- waldiez/exporting/agent/{utils/rag_user/mongo_utils.py → extras/rag/mongo_extras.py} +10 -10
- waldiez/exporting/agent/{utils/rag_user/pgvector_utils.py → extras/rag/pgvector_extras.py} +13 -13
- waldiez/exporting/agent/{utils/rag_user/qdrant_utils.py → extras/rag/qdrant_extras.py} +13 -13
- waldiez/exporting/agent/{utils/rag_user/vector_db.py → extras/rag/vector_db_extras.py} +59 -46
- waldiez/exporting/agent/extras/rag_user_proxy_agent_extras.py +245 -0
- waldiez/exporting/agent/extras/reasoning_agent_extras.py +88 -0
- waldiez/exporting/agent/factory.py +95 -0
- waldiez/exporting/agent/processor.py +150 -0
- waldiez/exporting/agent/system_message.py +36 -0
- waldiez/exporting/agent/termination.py +50 -0
- waldiez/exporting/chats/__init__.py +7 -3
- waldiez/exporting/chats/exporter.py +97 -0
- waldiez/exporting/chats/factory.py +65 -0
- waldiez/exporting/chats/processor.py +226 -0
- waldiez/exporting/chats/utils/__init__.py +6 -5
- waldiez/exporting/chats/utils/common.py +11 -45
- waldiez/exporting/chats/utils/group.py +55 -0
- waldiez/exporting/chats/utils/nested.py +37 -52
- waldiez/exporting/chats/utils/sequential.py +72 -61
- waldiez/exporting/chats/utils/{single_chat.py → single.py} +48 -50
- waldiez/exporting/core/__init__.py +196 -0
- waldiez/exporting/core/constants.py +17 -0
- waldiez/exporting/core/content.py +69 -0
- waldiez/exporting/core/context.py +244 -0
- waldiez/exporting/core/enums.py +89 -0
- waldiez/exporting/core/errors.py +19 -0
- waldiez/exporting/core/exporter.py +390 -0
- waldiez/exporting/core/exporters.py +67 -0
- waldiez/exporting/core/extras/__init__.py +39 -0
- waldiez/exporting/core/extras/agent_extras/__init__.py +27 -0
- waldiez/exporting/core/extras/agent_extras/captain_extras.py +57 -0
- waldiez/exporting/core/extras/agent_extras/group_manager_extras.py +102 -0
- waldiez/exporting/core/extras/agent_extras/rag_user_extras.py +53 -0
- waldiez/exporting/core/extras/agent_extras/reasoning_extras.py +68 -0
- waldiez/exporting/core/extras/agent_extras/standard_extras.py +263 -0
- waldiez/exporting/core/extras/base.py +241 -0
- waldiez/exporting/core/extras/chat_extras.py +118 -0
- waldiez/exporting/core/extras/flow_extras.py +70 -0
- waldiez/exporting/core/extras/model_extras.py +73 -0
- waldiez/exporting/core/extras/path_resolver.py +93 -0
- waldiez/exporting/core/extras/serializer.py +138 -0
- waldiez/exporting/core/extras/tool_extras.py +82 -0
- waldiez/exporting/core/protocols.py +259 -0
- waldiez/exporting/core/result.py +705 -0
- waldiez/exporting/core/types.py +329 -0
- waldiez/exporting/core/utils/__init__.py +11 -0
- waldiez/exporting/core/utils/comment.py +33 -0
- waldiez/exporting/core/utils/llm_config.py +117 -0
- waldiez/exporting/core/validation.py +96 -0
- waldiez/exporting/flow/__init__.py +6 -2
- waldiez/exporting/flow/execution_generator.py +193 -0
- waldiez/exporting/flow/exporter.py +107 -0
- waldiez/exporting/flow/factory.py +94 -0
- waldiez/exporting/flow/file_generator.py +214 -0
- waldiez/exporting/flow/merger.py +387 -0
- waldiez/exporting/flow/orchestrator.py +411 -0
- waldiez/exporting/flow/utils/__init__.py +9 -36
- waldiez/exporting/flow/utils/common.py +206 -0
- waldiez/exporting/flow/utils/importing.py +373 -0
- waldiez/exporting/flow/utils/linting.py +200 -0
- waldiez/exporting/flow/utils/{logging_utils.py → logging.py} +23 -9
- waldiez/exporting/models/__init__.py +3 -1
- waldiez/exporting/models/exporter.py +233 -0
- waldiez/exporting/models/factory.py +66 -0
- waldiez/exporting/models/processor.py +139 -0
- waldiez/exporting/tools/__init__.py +11 -0
- waldiez/exporting/tools/exporter.py +207 -0
- waldiez/exporting/tools/factory.py +57 -0
- waldiez/exporting/tools/processor.py +248 -0
- waldiez/exporting/tools/registration.py +133 -0
- waldiez/io/__init__.py +128 -0
- waldiez/io/_ws.py +199 -0
- waldiez/io/models/__init__.py +60 -0
- waldiez/io/models/base.py +66 -0
- waldiez/io/models/constants.py +78 -0
- waldiez/io/models/content/__init__.py +23 -0
- waldiez/io/models/content/audio.py +43 -0
- waldiez/io/models/content/base.py +45 -0
- waldiez/io/models/content/file.py +43 -0
- waldiez/io/models/content/image.py +96 -0
- waldiez/io/models/content/text.py +37 -0
- waldiez/io/models/content/video.py +43 -0
- waldiez/io/models/user_input.py +269 -0
- waldiez/io/models/user_response.py +215 -0
- waldiez/io/mqtt.py +681 -0
- waldiez/io/redis.py +782 -0
- waldiez/io/structured.py +419 -0
- waldiez/io/utils.py +184 -0
- waldiez/io/ws.py +298 -0
- waldiez/logger.py +481 -0
- waldiez/models/__init__.py +108 -51
- waldiez/models/agents/__init__.py +34 -70
- waldiez/models/agents/agent/__init__.py +10 -4
- waldiez/models/agents/agent/agent.py +466 -65
- waldiez/models/agents/agent/agent_data.py +119 -47
- waldiez/models/agents/agent/agent_type.py +13 -2
- waldiez/models/agents/agent/code_execution.py +12 -12
- waldiez/models/agents/agent/human_input_mode.py +8 -0
- waldiez/models/agents/agent/{linked_skill.py → linked_tool.py} +7 -7
- waldiez/models/agents/agent/nested_chat.py +35 -7
- waldiez/models/agents/agent/termination_message.py +30 -22
- waldiez/models/agents/{swarm_agent → agent}/update_system_message.py +22 -22
- waldiez/models/agents/agents.py +58 -63
- waldiez/models/agents/assistant/assistant.py +4 -4
- waldiez/models/agents/assistant/assistant_data.py +13 -1
- waldiez/models/agents/{captain_agent → captain}/captain_agent.py +5 -5
- waldiez/models/agents/{captain_agent → captain}/captain_agent_data.py +5 -5
- waldiez/models/agents/extra_requirements.py +11 -16
- waldiez/models/agents/group_manager/group_manager.py +103 -13
- waldiez/models/agents/group_manager/group_manager_data.py +36 -14
- waldiez/models/agents/group_manager/speakers.py +77 -24
- waldiez/models/agents/{rag_user → rag_user_proxy}/__init__.py +16 -16
- waldiez/models/agents/rag_user_proxy/rag_user_proxy.py +64 -0
- waldiez/models/agents/{rag_user/rag_user_data.py → rag_user_proxy/rag_user_proxy_data.py} +6 -5
- waldiez/models/agents/{rag_user → rag_user_proxy}/retrieve_config.py +182 -114
- waldiez/models/agents/{rag_user → rag_user_proxy}/vector_db_config.py +13 -13
- waldiez/models/agents/reasoning/reasoning_agent.py +6 -6
- waldiez/models/agents/reasoning/reasoning_agent_data.py +110 -63
- waldiez/models/agents/reasoning/reasoning_agent_reason_config.py +38 -10
- waldiez/models/agents/user_proxy/user_proxy.py +11 -7
- waldiez/models/agents/user_proxy/user_proxy_data.py +2 -2
- waldiez/models/chat/__init__.py +2 -1
- waldiez/models/chat/chat.py +166 -87
- waldiez/models/chat/chat_data.py +99 -136
- waldiez/models/chat/chat_message.py +33 -23
- waldiez/models/chat/chat_nested.py +31 -30
- waldiez/models/chat/chat_summary.py +10 -8
- waldiez/models/common/__init__.py +52 -2
- waldiez/models/common/ag2_version.py +1 -1
- waldiez/models/common/base.py +38 -7
- waldiez/models/common/dict_utils.py +42 -17
- waldiez/models/common/handoff.py +459 -0
- waldiez/models/common/id_generator.py +19 -0
- waldiez/models/common/method_utils.py +130 -68
- waldiez/{exporting/base/utils → models/common}/naming.py +38 -61
- waldiez/models/common/waldiez_version.py +37 -0
- waldiez/models/flow/__init__.py +9 -2
- waldiez/models/flow/connection.py +18 -0
- waldiez/models/flow/flow.py +311 -215
- waldiez/models/flow/flow_data.py +207 -40
- waldiez/models/flow/info.py +85 -0
- waldiez/models/flow/naming.py +131 -0
- waldiez/models/model/__init__.py +7 -1
- waldiez/models/model/extra_requirements.py +3 -12
- waldiez/models/model/model.py +76 -21
- waldiez/models/model/model_data.py +108 -20
- waldiez/models/tool/__init__.py +16 -0
- waldiez/models/tool/extra_requirements.py +36 -0
- waldiez/models/{skill/skill.py → tool/tool.py} +88 -88
- waldiez/models/tool/tool_data.py +51 -0
- waldiez/models/tool/tool_type.py +8 -0
- waldiez/models/waldiez.py +97 -80
- waldiez/runner.py +114 -49
- waldiez/running/__init__.py +1 -1
- waldiez/running/environment.py +49 -68
- waldiez/running/gen_seq_diagram.py +16 -14
- waldiez/running/running.py +53 -34
- waldiez/utils/__init__.py +0 -4
- waldiez/utils/cli_extras/jupyter.py +5 -3
- waldiez/utils/cli_extras/runner.py +6 -4
- waldiez/utils/cli_extras/studio.py +6 -4
- waldiez/utils/conflict_checker.py +15 -9
- waldiez/utils/flaml_warnings.py +5 -5
- {waldiez-0.4.6.dist-info → waldiez-0.4.8.dist-info}/METADATA +235 -91
- waldiez-0.4.8.dist-info/RECORD +200 -0
- waldiez/exporting/agent/agent_exporter.py +0 -297
- waldiez/exporting/agent/utils/__init__.py +0 -23
- waldiez/exporting/agent/utils/captain_agent.py +0 -263
- waldiez/exporting/agent/utils/code_execution.py +0 -65
- waldiez/exporting/agent/utils/group_manager.py +0 -220
- waldiez/exporting/agent/utils/rag_user/__init__.py +0 -7
- waldiez/exporting/agent/utils/rag_user/rag_user.py +0 -209
- waldiez/exporting/agent/utils/reasoning.py +0 -36
- waldiez/exporting/agent/utils/swarm_agent.py +0 -469
- waldiez/exporting/agent/utils/teachability.py +0 -41
- waldiez/exporting/agent/utils/termination_message.py +0 -44
- waldiez/exporting/base/__init__.py +0 -25
- waldiez/exporting/base/agent_position.py +0 -75
- waldiez/exporting/base/base_exporter.py +0 -118
- waldiez/exporting/base/export_position.py +0 -48
- waldiez/exporting/base/import_position.py +0 -23
- waldiez/exporting/base/mixin.py +0 -137
- waldiez/exporting/base/utils/__init__.py +0 -18
- waldiez/exporting/base/utils/comments.py +0 -96
- waldiez/exporting/base/utils/path_check.py +0 -68
- waldiez/exporting/base/utils/to_string.py +0 -84
- waldiez/exporting/chats/chats_exporter.py +0 -240
- waldiez/exporting/chats/utils/swarm.py +0 -210
- waldiez/exporting/flow/flow_exporter.py +0 -528
- waldiez/exporting/flow/utils/agent_utils.py +0 -204
- waldiez/exporting/flow/utils/chat_utils.py +0 -71
- waldiez/exporting/flow/utils/def_main.py +0 -77
- waldiez/exporting/flow/utils/flow_content.py +0 -202
- waldiez/exporting/flow/utils/flow_names.py +0 -116
- waldiez/exporting/flow/utils/importing_utils.py +0 -227
- waldiez/exporting/models/models_exporter.py +0 -199
- waldiez/exporting/models/utils.py +0 -174
- waldiez/exporting/skills/__init__.py +0 -9
- waldiez/exporting/skills/skills_exporter.py +0 -176
- waldiez/exporting/skills/utils.py +0 -369
- waldiez/models/agents/agent/teachability.py +0 -70
- waldiez/models/agents/rag_user/rag_user.py +0 -60
- waldiez/models/agents/swarm_agent/__init__.py +0 -50
- waldiez/models/agents/swarm_agent/after_work.py +0 -179
- waldiez/models/agents/swarm_agent/on_condition.py +0 -105
- waldiez/models/agents/swarm_agent/on_condition_available.py +0 -142
- waldiez/models/agents/swarm_agent/on_condition_target.py +0 -40
- waldiez/models/agents/swarm_agent/swarm_agent.py +0 -107
- waldiez/models/agents/swarm_agent/swarm_agent_data.py +0 -124
- waldiez/models/flow/utils.py +0 -232
- waldiez/models/skill/__init__.py +0 -16
- waldiez/models/skill/extra_requirements.py +0 -36
- waldiez/models/skill/skill_data.py +0 -53
- waldiez/models/skill/skill_type.py +0 -8
- waldiez/utils/pysqlite3_checker.py +0 -308
- waldiez/utils/rdps_checker.py +0 -122
- waldiez-0.4.6.dist-info/RECORD +0 -149
- /waldiez/models/agents/{captain_agent → captain}/__init__.py +0 -0
- /waldiez/models/agents/{captain_agent → captain}/captain_agent_lib_entry.py +0 -0
- {waldiez-0.4.6.dist-info → waldiez-0.4.8.dist-info}/WHEEL +0 -0
- {waldiez-0.4.6.dist-info → waldiez-0.4.8.dist-info}/entry_points.txt +0 -0
- {waldiez-0.4.6.dist-info → waldiez-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {waldiez-0.4.6.dist-info → waldiez-0.4.8.dist-info}/licenses/NOTICE.md +0 -0
waldiez/io/mqtt.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0.
|
|
2
|
+
# Copyright (c) 2024 - 2025 Waldiez and contributors.
|
|
3
|
+
|
|
4
|
+
# flake8: noqa: E501
|
|
5
|
+
# pylint: disable=too-many-try-statements,broad-exception-caught,
|
|
6
|
+
# pylint: disable=line-too-long,unused-argument,too-many-instance-attributes
|
|
7
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
8
|
+
|
|
9
|
+
"""An MQTT I/O stream for handling print and input messages."""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
import traceback as tb
|
|
15
|
+
import uuid
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from threading import Event, Lock
|
|
18
|
+
from types import TracebackType
|
|
19
|
+
from typing import (
|
|
20
|
+
Any,
|
|
21
|
+
Callable,
|
|
22
|
+
Optional,
|
|
23
|
+
Type,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
from paho.mqtt import client as mqtt
|
|
28
|
+
from paho.mqtt.enums import CallbackAPIVersion
|
|
29
|
+
from paho.mqtt.reasoncodes import ReasonCode
|
|
30
|
+
except ImportError as error: # pragma: no cover
|
|
31
|
+
raise ImportError(
|
|
32
|
+
"MQTT client not installed. Please install paho-mqtt with `pip install paho-mqtt`."
|
|
33
|
+
) from error
|
|
34
|
+
|
|
35
|
+
from autogen.io import IOStream # type: ignore
|
|
36
|
+
from autogen.messages import BaseMessage # type: ignore
|
|
37
|
+
|
|
38
|
+
from .models import (
|
|
39
|
+
PrintMessage,
|
|
40
|
+
TextMediaContent,
|
|
41
|
+
UserInputData,
|
|
42
|
+
UserInputRequest,
|
|
43
|
+
UserResponse,
|
|
44
|
+
)
|
|
45
|
+
from .utils import gen_id, now
|
|
46
|
+
|
|
47
|
+
LOG = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
MQTT_FIRST_RECONNECT_DELAY = 1
|
|
51
|
+
MQTT_RECONNECT_RATE = 2
|
|
52
|
+
MQTT_MAX_RECONNECT_COUNT = 12
|
|
53
|
+
MQTT_MAX_RECONNECT_DELAY = 60
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MqttIOStream(IOStream):
|
|
57
|
+
"""MQTT I/O stream."""
|
|
58
|
+
|
|
59
|
+
client: mqtt.Client
|
|
60
|
+
task_id: str
|
|
61
|
+
input_timeout: int
|
|
62
|
+
on_input_request: Optional[Callable[[str, str, str], None]]
|
|
63
|
+
on_input_received: Optional[Callable[[str, str], None]]
|
|
64
|
+
max_retain_messages: int
|
|
65
|
+
output_topic: str
|
|
66
|
+
input_request_topic: str
|
|
67
|
+
input_response_topic: str
|
|
68
|
+
common_output_topic: str
|
|
69
|
+
broker_host: str
|
|
70
|
+
broker_port: int
|
|
71
|
+
|
|
72
|
+
# Thread safety and input handling
|
|
73
|
+
_input_responses: dict[str, str]
|
|
74
|
+
_input_lock: Lock
|
|
75
|
+
_input_events: dict[str, Event]
|
|
76
|
+
_processed_requests: set[str]
|
|
77
|
+
_connected: bool
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
broker_host: str = "localhost",
|
|
82
|
+
broker_port: int = 1883,
|
|
83
|
+
task_id: str | None = None,
|
|
84
|
+
input_timeout: int = 120,
|
|
85
|
+
max_retain_messages: int = 1000,
|
|
86
|
+
on_input_request: Optional[Callable[[str, str, str], None]] = None,
|
|
87
|
+
on_input_response: Optional[Callable[[str, str], None]] = None,
|
|
88
|
+
mqtt_client_kwargs: dict[str, Any] | None = None,
|
|
89
|
+
uploads_root: Path | str | None = None,
|
|
90
|
+
username: str | None = None,
|
|
91
|
+
password: str | None = None,
|
|
92
|
+
use_tls: bool = False,
|
|
93
|
+
ca_cert_path: str | None = None,
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Initialize the MQTT I/O stream.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
broker_host : str, optional
|
|
100
|
+
The MQTT broker host, by default "localhost".
|
|
101
|
+
broker_port : int, optional
|
|
102
|
+
The MQTT broker port, by default 1883.
|
|
103
|
+
task_id : str, optional
|
|
104
|
+
An ID to use for the topics. If not provided, a random UUID will be generated.
|
|
105
|
+
input_timeout : int, optional
|
|
106
|
+
The time to wait for user input in seconds, by default 120.
|
|
107
|
+
on_input_request : Optional[Callable[[str, str, str], None]], optional
|
|
108
|
+
Callback for input request, by default None
|
|
109
|
+
parameters: prompt, request_id, task_id
|
|
110
|
+
on_input_response : Optional[Callable[[str, str], None]], optional
|
|
111
|
+
Callback for input response, by default None.
|
|
112
|
+
parameters: user_input, task_id
|
|
113
|
+
mqtt_client_kwargs : dict[str, Any] | None, optional
|
|
114
|
+
Additional MQTT client kwargs, by default None.
|
|
115
|
+
max_retain_messages : int, optional
|
|
116
|
+
Maximum number of retained messages per topic, by default 1000.
|
|
117
|
+
uploads_root : Path | str | None, optional
|
|
118
|
+
The root directory for uploads, by default None.
|
|
119
|
+
username : str | None, optional
|
|
120
|
+
MQTT broker username, by default None.
|
|
121
|
+
password : str | None, optional
|
|
122
|
+
MQTT broker password, by default None.
|
|
123
|
+
use_tls : bool, optional
|
|
124
|
+
Whether to use TLS connection, by default False.
|
|
125
|
+
ca_cert_path : str | None, optional
|
|
126
|
+
Path to CA certificate file for TLS, by default None.
|
|
127
|
+
"""
|
|
128
|
+
self.broker_host = broker_host
|
|
129
|
+
self.broker_port = broker_port
|
|
130
|
+
self.task_id = task_id or uuid.uuid4().hex
|
|
131
|
+
self.input_timeout = input_timeout
|
|
132
|
+
self.on_input_request = on_input_request
|
|
133
|
+
self.on_input_response = on_input_response
|
|
134
|
+
self.max_retain_messages = max_retain_messages
|
|
135
|
+
|
|
136
|
+
# Topic structure
|
|
137
|
+
self.output_topic = f"task/{self.task_id}/output"
|
|
138
|
+
self.input_request_topic = f"task/{self.task_id}/input_request"
|
|
139
|
+
self.input_response_topic = f"task/{self.task_id}/input_response"
|
|
140
|
+
self.common_output_topic = "task/output"
|
|
141
|
+
|
|
142
|
+
# Thread safety
|
|
143
|
+
self._input_responses = {}
|
|
144
|
+
self._input_lock = Lock()
|
|
145
|
+
self._input_events = {}
|
|
146
|
+
self._processed_requests = set()
|
|
147
|
+
self._connected = False
|
|
148
|
+
|
|
149
|
+
# Uploads
|
|
150
|
+
self.uploads_root = (
|
|
151
|
+
Path(uploads_root).resolve() if uploads_root else None
|
|
152
|
+
)
|
|
153
|
+
if self.uploads_root and not self.uploads_root.exists():
|
|
154
|
+
self.uploads_root.mkdir(parents=True, exist_ok=True)
|
|
155
|
+
|
|
156
|
+
# Initialize MQTT client
|
|
157
|
+
client_kwargs = mqtt_client_kwargs or {}
|
|
158
|
+
if "callback_api_version" not in client_kwargs: # pragma: no branch
|
|
159
|
+
client_kwargs["callback_api_version"] = CallbackAPIVersion.VERSION2
|
|
160
|
+
self.client = mqtt.Client(**client_kwargs)
|
|
161
|
+
|
|
162
|
+
# Set up authentication
|
|
163
|
+
if username and password:
|
|
164
|
+
self.client.username_pw_set(username, password)
|
|
165
|
+
|
|
166
|
+
# Set up TLS
|
|
167
|
+
if use_tls:
|
|
168
|
+
if ca_cert_path:
|
|
169
|
+
self.client.tls_set(ca_cert_path) # pyright: ignore
|
|
170
|
+
else: # pragma: no cover
|
|
171
|
+
self.client.tls_set() # pyright: ignore
|
|
172
|
+
|
|
173
|
+
# Set up callbacks
|
|
174
|
+
self.client.on_connect = self._on_connect
|
|
175
|
+
self.client.on_disconnect = self._on_disconnect
|
|
176
|
+
self.client.on_message = self._on_message
|
|
177
|
+
self.client.on_log = self._on_log
|
|
178
|
+
|
|
179
|
+
# Connect to broker
|
|
180
|
+
self._connect()
|
|
181
|
+
|
|
182
|
+
def _connect(self) -> None:
|
|
183
|
+
"""Connect to MQTT broker."""
|
|
184
|
+
try:
|
|
185
|
+
LOG.debug(
|
|
186
|
+
"Connecting to MQTT broker at %s:%d",
|
|
187
|
+
self.broker_host,
|
|
188
|
+
self.broker_port,
|
|
189
|
+
)
|
|
190
|
+
self.client.connect(self.broker_host, self.broker_port, 60)
|
|
191
|
+
self.client.loop_start()
|
|
192
|
+
|
|
193
|
+
# Wait for connection
|
|
194
|
+
timeout = 10 # seconds
|
|
195
|
+
start_time = time.time()
|
|
196
|
+
while (
|
|
197
|
+
not self.client.is_connected()
|
|
198
|
+
and (time.time() - start_time) < timeout
|
|
199
|
+
):
|
|
200
|
+
time.sleep(0.1)
|
|
201
|
+
|
|
202
|
+
if not self.client.is_connected():
|
|
203
|
+
raise ConnectionError(
|
|
204
|
+
"Failed to connect to MQTT broker within timeout"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
except Exception as e:
|
|
208
|
+
LOG.error("Failed to connect to MQTT broker: %s", e)
|
|
209
|
+
raise
|
|
210
|
+
|
|
211
|
+
def _on_connect(
|
|
212
|
+
self,
|
|
213
|
+
client: mqtt.Client,
|
|
214
|
+
userdata: Any,
|
|
215
|
+
flags: dict[str, Any],
|
|
216
|
+
reason_code: ReasonCode | int,
|
|
217
|
+
) -> None:
|
|
218
|
+
"""Handle MQTT connection event.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
client : mqtt.Client
|
|
223
|
+
The MQTT client instance.
|
|
224
|
+
userdata : Any
|
|
225
|
+
User-defined data of any type (not used here).
|
|
226
|
+
flags : dict[str, Any]
|
|
227
|
+
Response flags from the broker.
|
|
228
|
+
reason_code : ReasonCode | int
|
|
229
|
+
The connection reason code.
|
|
230
|
+
"""
|
|
231
|
+
if isinstance(reason_code, ReasonCode): # pragma: no cover
|
|
232
|
+
failed = reason_code.is_failure
|
|
233
|
+
else:
|
|
234
|
+
failed = reason_code != mqtt.MQTT_ERR_SUCCESS
|
|
235
|
+
if not failed and client.is_connected():
|
|
236
|
+
LOG.debug("Connected to MQTT broker successfully")
|
|
237
|
+
self._connected = True
|
|
238
|
+
|
|
239
|
+
# Subscribe to input response topic
|
|
240
|
+
client.subscribe(self.input_response_topic, qos=1)
|
|
241
|
+
LOG.debug(
|
|
242
|
+
"Subscribed to input response topic: %s",
|
|
243
|
+
self.input_response_topic,
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
LOG.error(
|
|
247
|
+
"Failed to connect to MQTT broker: %s (code %s)",
|
|
248
|
+
userdata,
|
|
249
|
+
reason_code,
|
|
250
|
+
)
|
|
251
|
+
self._connected = False
|
|
252
|
+
raise ConnectionError(
|
|
253
|
+
f"MQTT connection failed with reason code {reason_code}"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def _on_disconnect(
|
|
257
|
+
self, client: mqtt.Client, userdata: Any, reason_code: ReasonCode | int
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Handle MQTT disconnection event.
|
|
260
|
+
|
|
261
|
+
Parameters
|
|
262
|
+
----------
|
|
263
|
+
client : mqtt.Client
|
|
264
|
+
The MQTT client instance.
|
|
265
|
+
userdata : Any
|
|
266
|
+
User-defined data of any type (not used here).
|
|
267
|
+
reason_code : ReasonCode | int
|
|
268
|
+
The disconnection reason code.
|
|
269
|
+
"""
|
|
270
|
+
self._connected = False
|
|
271
|
+
if isinstance(reason_code, ReasonCode): # pragma: no cover
|
|
272
|
+
is_normal_disconnect = reason_code.value == mqtt.MQTT_ERR_SUCCESS
|
|
273
|
+
else:
|
|
274
|
+
is_normal_disconnect = reason_code == mqtt.MQTT_ERR_SUCCESS
|
|
275
|
+
if is_normal_disconnect: # pragma: no cover
|
|
276
|
+
LOG.debug("Disconnected from MQTT broker normally")
|
|
277
|
+
else:
|
|
278
|
+
LOG.warning("Disconnected with reason: %s", str(reason_code))
|
|
279
|
+
reconnect_count, reconnect_delay = 0, MQTT_FIRST_RECONNECT_DELAY
|
|
280
|
+
while reconnect_count < MQTT_MAX_RECONNECT_COUNT:
|
|
281
|
+
LOG.info("Reconnecting in %d seconds...", reconnect_delay)
|
|
282
|
+
time.sleep(reconnect_delay)
|
|
283
|
+
# pylint: disable=broad-except
|
|
284
|
+
try:
|
|
285
|
+
client.reconnect()
|
|
286
|
+
except Exception as err:
|
|
287
|
+
LOG.error("%s. Reconnect failed. Retrying...", err)
|
|
288
|
+
else: # pragma: no cover
|
|
289
|
+
LOG.info("Reconnected successfully!")
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
reconnect_delay *= MQTT_RECONNECT_RATE
|
|
293
|
+
reconnect_delay = min(reconnect_delay, MQTT_MAX_RECONNECT_DELAY)
|
|
294
|
+
reconnect_count += 1
|
|
295
|
+
LOG.info("Reconnect failed after %s attempts.", reconnect_count)
|
|
296
|
+
|
|
297
|
+
def _on_message(
|
|
298
|
+
self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage
|
|
299
|
+
) -> None:
|
|
300
|
+
"""Handle incoming MQTT messages.
|
|
301
|
+
|
|
302
|
+
Parameters
|
|
303
|
+
----------
|
|
304
|
+
client : mqtt.Client
|
|
305
|
+
The MQTT client instance.
|
|
306
|
+
userdata : Any
|
|
307
|
+
User-defined data of any type (not used here).
|
|
308
|
+
msg : mqtt.MQTTMessage
|
|
309
|
+
The received MQTT message.
|
|
310
|
+
"""
|
|
311
|
+
try:
|
|
312
|
+
LOG.debug(
|
|
313
|
+
"Received message on topic %s: %s",
|
|
314
|
+
msg.topic,
|
|
315
|
+
msg.payload.decode(),
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if msg.topic == self.input_response_topic: # pragma: no branch
|
|
319
|
+
self._handle_input_response(msg.payload.decode())
|
|
320
|
+
|
|
321
|
+
except Exception as e: # pragma: no cover
|
|
322
|
+
LOG.error("Error handling message: %s", e)
|
|
323
|
+
|
|
324
|
+
def _on_log(
|
|
325
|
+
self,
|
|
326
|
+
client: mqtt.Client,
|
|
327
|
+
userdata: Any,
|
|
328
|
+
level: int,
|
|
329
|
+
buf: str,
|
|
330
|
+
) -> None: # pragma: no cover
|
|
331
|
+
"""Handle MQTT log messages.
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
client : mqtt.Client
|
|
336
|
+
The MQTT client instance.
|
|
337
|
+
userdata : Any
|
|
338
|
+
User-defined data of any type (not used here).
|
|
339
|
+
level : int
|
|
340
|
+
The log level.
|
|
341
|
+
buf : str
|
|
342
|
+
The log message.
|
|
343
|
+
"""
|
|
344
|
+
payload: dict[str, Any] = {
|
|
345
|
+
"level": level,
|
|
346
|
+
"message": buf,
|
|
347
|
+
}
|
|
348
|
+
LOG.debug("MQTT log: %s", payload)
|
|
349
|
+
print_message = PrintMessage(data=buf)
|
|
350
|
+
self._print_to_common_output(
|
|
351
|
+
payload=print_message.model_dump(mode="json")
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def _handle_input_response(self, payload: str) -> None:
|
|
355
|
+
"""Handle input response message."""
|
|
356
|
+
try:
|
|
357
|
+
message_data = json.loads(payload)
|
|
358
|
+
response = self._create_user_response(message_data)
|
|
359
|
+
|
|
360
|
+
if not response or not response.request_id:
|
|
361
|
+
return
|
|
362
|
+
|
|
363
|
+
# Check if already processed
|
|
364
|
+
if response.request_id in self._processed_requests:
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
with self._input_lock:
|
|
368
|
+
self._processed_requests.add(response.request_id)
|
|
369
|
+
user_input = self._get_user_input(response)
|
|
370
|
+
self._input_responses[response.request_id] = user_input
|
|
371
|
+
|
|
372
|
+
# Signal waiting thread
|
|
373
|
+
if (
|
|
374
|
+
response.request_id in self._input_events
|
|
375
|
+
): # pragma: no branch
|
|
376
|
+
self._input_events[response.request_id].set()
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
LOG.error("Error handling input response: %s", e)
|
|
380
|
+
|
|
381
|
+
def __enter__(self) -> "MqttIOStream":
|
|
382
|
+
"""Enable context manager usage."""
|
|
383
|
+
return self
|
|
384
|
+
|
|
385
|
+
def __exit__(
|
|
386
|
+
self,
|
|
387
|
+
exc_type: Type[Exception] | None,
|
|
388
|
+
exc_value: Exception | None,
|
|
389
|
+
traceback: TracebackType | None,
|
|
390
|
+
) -> None:
|
|
391
|
+
"""Exit the context manager."""
|
|
392
|
+
self.close()
|
|
393
|
+
|
|
394
|
+
def close(self) -> None:
|
|
395
|
+
"""Close the MQTT client."""
|
|
396
|
+
if hasattr(self, "client"): # pragma: no branch
|
|
397
|
+
try:
|
|
398
|
+
self.client.loop_stop()
|
|
399
|
+
self.client.disconnect()
|
|
400
|
+
except Exception as e:
|
|
401
|
+
LOG.error("Error closing MQTT client: %s", e)
|
|
402
|
+
|
|
403
|
+
def _publish_message(
|
|
404
|
+
self, topic: str, payload: dict[str, Any], retain: bool = False
|
|
405
|
+
) -> None:
|
|
406
|
+
"""Publish message to MQTT topic.
|
|
407
|
+
|
|
408
|
+
Parameters
|
|
409
|
+
----------
|
|
410
|
+
topic : str
|
|
411
|
+
The MQTT topic.
|
|
412
|
+
payload : dict[str, Any]
|
|
413
|
+
The message payload.
|
|
414
|
+
retain : bool, optional
|
|
415
|
+
Whether to retain the message, by default False.
|
|
416
|
+
"""
|
|
417
|
+
try:
|
|
418
|
+
json_payload = json.dumps(payload)
|
|
419
|
+
LOG.debug("Publishing to %s: %s", topic, json_payload)
|
|
420
|
+
|
|
421
|
+
result = self.client.publish(
|
|
422
|
+
topic, json_payload, qos=1, retain=retain
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
if result.rc != mqtt.MQTT_ERR_SUCCESS:
|
|
426
|
+
LOG.error(
|
|
427
|
+
"Failed to publish message to %s: %s", topic, result.rc
|
|
428
|
+
)
|
|
429
|
+
except Exception as e:
|
|
430
|
+
LOG.error("Error publishing message: %s", e)
|
|
431
|
+
|
|
432
|
+
def _print_to_task_output(self, payload: dict[str, Any]) -> None:
|
|
433
|
+
"""Print message to the task output topic."""
|
|
434
|
+
self._publish_message(self.output_topic, payload, retain=True)
|
|
435
|
+
|
|
436
|
+
def _print_to_common_output(self, payload: dict[str, Any]) -> None:
|
|
437
|
+
"""Print message to the common output topic."""
|
|
438
|
+
self._publish_message(self.common_output_topic, payload, retain=False)
|
|
439
|
+
|
|
440
|
+
def _print(self, payload: dict[str, Any]) -> None:
|
|
441
|
+
"""Print message to MQTT topics."""
|
|
442
|
+
if "id" not in payload:
|
|
443
|
+
payload["id"] = gen_id()
|
|
444
|
+
payload["task_id"] = self.task_id
|
|
445
|
+
if "timestamp" not in payload:
|
|
446
|
+
payload["timestamp"] = now()
|
|
447
|
+
|
|
448
|
+
self._print_to_task_output(payload)
|
|
449
|
+
self._print_to_common_output(payload)
|
|
450
|
+
|
|
451
|
+
def print(self, *args: Any, **kwargs: Any) -> None:
|
|
452
|
+
"""Print message to MQTT topics.
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
args : Any
|
|
457
|
+
The message to print.
|
|
458
|
+
kwargs : Any
|
|
459
|
+
Additional keyword arguments.
|
|
460
|
+
"""
|
|
461
|
+
print_message = PrintMessage.create(*args, **kwargs)
|
|
462
|
+
payload = print_message.model_dump(mode="json")
|
|
463
|
+
self._print(payload)
|
|
464
|
+
|
|
465
|
+
def send(self, message: BaseMessage) -> None:
|
|
466
|
+
"""Send a structured message to MQTT.
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
message : BaseMessage
|
|
471
|
+
The message to send.
|
|
472
|
+
"""
|
|
473
|
+
try:
|
|
474
|
+
message_dump = message.model_dump(mode="json")
|
|
475
|
+
except Exception as e: # pragma: no cover
|
|
476
|
+
message_dump = {
|
|
477
|
+
"error": str(e),
|
|
478
|
+
"type": message.__class__.__name__,
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
message_type = message_dump.get("type", None)
|
|
482
|
+
if not message_type: # pragma: no cover
|
|
483
|
+
message_type = message.__class__.__name__
|
|
484
|
+
|
|
485
|
+
self._print(
|
|
486
|
+
{
|
|
487
|
+
"type": message_type,
|
|
488
|
+
"data": json.dumps(message_dump),
|
|
489
|
+
}
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
def input(
|
|
493
|
+
self,
|
|
494
|
+
prompt: str = "",
|
|
495
|
+
*,
|
|
496
|
+
password: bool = False,
|
|
497
|
+
request_id: str | None = None,
|
|
498
|
+
) -> str:
|
|
499
|
+
"""Request input via MQTT and wait for response.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
prompt : str, optional
|
|
504
|
+
The prompt message, by default "".
|
|
505
|
+
password : bool, optional
|
|
506
|
+
Whether input is masked, by default False.
|
|
507
|
+
request_id : str, optional
|
|
508
|
+
The request ID (for testing), by default None.
|
|
509
|
+
|
|
510
|
+
Returns
|
|
511
|
+
-------
|
|
512
|
+
str
|
|
513
|
+
The received user input, or empty string if timeout occurs.
|
|
514
|
+
"""
|
|
515
|
+
request_id = request_id or gen_id()
|
|
516
|
+
|
|
517
|
+
input_request = UserInputRequest(
|
|
518
|
+
request_id=request_id,
|
|
519
|
+
prompt=prompt,
|
|
520
|
+
password=password,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
payload = input_request.model_dump(mode="json")
|
|
524
|
+
payload["task_id"] = self.task_id
|
|
525
|
+
payload["password"] = str(password).lower()
|
|
526
|
+
|
|
527
|
+
LOG.debug("Requesting input via MQTT: %s", payload)
|
|
528
|
+
|
|
529
|
+
# Create event for this request
|
|
530
|
+
with self._input_lock:
|
|
531
|
+
self._input_events[request_id] = Event()
|
|
532
|
+
|
|
533
|
+
# Publish input request
|
|
534
|
+
self._print(payload)
|
|
535
|
+
self._publish_message(self.input_request_topic, payload)
|
|
536
|
+
|
|
537
|
+
if self.on_input_request:
|
|
538
|
+
self.on_input_request(prompt, request_id, self.task_id)
|
|
539
|
+
|
|
540
|
+
user_input = self._wait_for_input(request_id)
|
|
541
|
+
|
|
542
|
+
if self.on_input_response:
|
|
543
|
+
self.on_input_response(user_input, self.task_id)
|
|
544
|
+
|
|
545
|
+
# Send response confirmation
|
|
546
|
+
text_response = UserInputData(
|
|
547
|
+
content=TextMediaContent(text=user_input),
|
|
548
|
+
)
|
|
549
|
+
user_response = UserResponse(
|
|
550
|
+
request_id=request_id,
|
|
551
|
+
type="input_response",
|
|
552
|
+
data=text_response,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
payload = user_response.model_dump(mode="json")
|
|
556
|
+
payload["task_id"] = self.task_id
|
|
557
|
+
payload["data"] = json.dumps(payload["data"])
|
|
558
|
+
|
|
559
|
+
LOG.debug("Sending input response: %s", payload)
|
|
560
|
+
self._print(payload)
|
|
561
|
+
|
|
562
|
+
return user_input
|
|
563
|
+
|
|
564
|
+
def _wait_for_input(self, request_id: str) -> str:
|
|
565
|
+
"""Wait for user input.
|
|
566
|
+
|
|
567
|
+
Parameters
|
|
568
|
+
----------
|
|
569
|
+
request_id : str
|
|
570
|
+
The request ID.
|
|
571
|
+
|
|
572
|
+
Returns
|
|
573
|
+
-------
|
|
574
|
+
str
|
|
575
|
+
The user input.
|
|
576
|
+
"""
|
|
577
|
+
try:
|
|
578
|
+
# Wait for response
|
|
579
|
+
event = self._input_events.get(request_id)
|
|
580
|
+
if not event:
|
|
581
|
+
LOG.error("No event found for request %s", request_id)
|
|
582
|
+
return ""
|
|
583
|
+
|
|
584
|
+
if event.wait(timeout=self.input_timeout):
|
|
585
|
+
# Got response
|
|
586
|
+
with self._input_lock:
|
|
587
|
+
user_input = self._input_responses.pop(request_id, "")
|
|
588
|
+
self._input_events.pop(request_id, None)
|
|
589
|
+
return user_input
|
|
590
|
+
# Timeout
|
|
591
|
+
LOG.warning(
|
|
592
|
+
"No input received for %ds on task %s, assuming empty string",
|
|
593
|
+
self.input_timeout,
|
|
594
|
+
self.task_id,
|
|
595
|
+
)
|
|
596
|
+
with self._input_lock:
|
|
597
|
+
self._input_events.pop(request_id, None)
|
|
598
|
+
return ""
|
|
599
|
+
|
|
600
|
+
except Exception as e:
|
|
601
|
+
LOG.error("Error in _wait_for_input: %s", e)
|
|
602
|
+
return ""
|
|
603
|
+
|
|
604
|
+
def _get_user_input(self, response: UserResponse) -> str:
|
|
605
|
+
"""Get user input from the response.
|
|
606
|
+
|
|
607
|
+
Parameters
|
|
608
|
+
----------
|
|
609
|
+
response : UserResponse
|
|
610
|
+
The user response.
|
|
611
|
+
|
|
612
|
+
Returns
|
|
613
|
+
-------
|
|
614
|
+
str
|
|
615
|
+
The user input.
|
|
616
|
+
"""
|
|
617
|
+
if not response.data:
|
|
618
|
+
return ""
|
|
619
|
+
if isinstance(response.data, str): # pragma: no cover
|
|
620
|
+
return response.data
|
|
621
|
+
return response.to_string(
|
|
622
|
+
uploads_root=self.uploads_root,
|
|
623
|
+
base_name=response.request_id,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
@staticmethod
|
|
627
|
+
def _create_user_response(
|
|
628
|
+
message_data: dict[str, Any],
|
|
629
|
+
) -> Optional["UserResponse"]:
|
|
630
|
+
"""Create UserResponse object from validated data."""
|
|
631
|
+
try:
|
|
632
|
+
# Handle nested JSON in 'data' field
|
|
633
|
+
if "data" in message_data and isinstance(message_data["data"], str):
|
|
634
|
+
try:
|
|
635
|
+
message_data["data"] = json.loads(message_data["data"])
|
|
636
|
+
except json.JSONDecodeError:
|
|
637
|
+
LOG.error(
|
|
638
|
+
"Invalid JSON in nested data field: %s", message_data
|
|
639
|
+
)
|
|
640
|
+
return None
|
|
641
|
+
|
|
642
|
+
return UserResponse.model_validate(message_data)
|
|
643
|
+
except Exception as e:
|
|
644
|
+
LOG.error(
|
|
645
|
+
"Error parsing user input response: %s - %s",
|
|
646
|
+
message_data,
|
|
647
|
+
str(e),
|
|
648
|
+
)
|
|
649
|
+
return None
|
|
650
|
+
|
|
651
|
+
@staticmethod
|
|
652
|
+
def try_do(func: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
|
|
653
|
+
"""Try to execute a function without raising exceptions.
|
|
654
|
+
|
|
655
|
+
Parameters
|
|
656
|
+
----------
|
|
657
|
+
func : Callable[..., Any]
|
|
658
|
+
The function to call.
|
|
659
|
+
args : Any
|
|
660
|
+
The function's positional arguments.
|
|
661
|
+
kwargs : Any
|
|
662
|
+
The function's keyword arguments.
|
|
663
|
+
"""
|
|
664
|
+
try:
|
|
665
|
+
func(*args, **kwargs)
|
|
666
|
+
except BaseException: # pragma: no cover
|
|
667
|
+
LOG.error("Error on try_do:")
|
|
668
|
+
LOG.error(tb.format_exc())
|
|
669
|
+
|
|
670
|
+
def cleanup_task_data(self) -> None:
|
|
671
|
+
"""Clean up task-specific data.
|
|
672
|
+
|
|
673
|
+
Note: MQTT doesn't have built-in cleanup like Redis streams.
|
|
674
|
+
This method clears local state and can be extended for broker-specific cleanup.
|
|
675
|
+
"""
|
|
676
|
+
with self._input_lock:
|
|
677
|
+
self._input_responses.clear()
|
|
678
|
+
self._input_events.clear()
|
|
679
|
+
self._processed_requests.clear()
|
|
680
|
+
|
|
681
|
+
LOG.debug("Cleaned up task data for %s", self.task_id)
|