puda-comms 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- puda_comms/__init__.py +5 -0
- puda_comms/command_service.py +635 -0
- puda_comms/execution_state.py +89 -0
- puda_comms/machine_client.py +771 -0
- puda_comms/models.py +88 -0
- puda_comms-0.0.2.dist-info/METADATA +310 -0
- puda_comms-0.0.2.dist-info/RECORD +8 -0
- puda_comms-0.0.2.dist-info/WHEEL +4 -0
puda_comms/__init__.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Service for sending commands to machines via NATS. Should take in AI generated commands as CommandRequest models.
|
|
3
|
+
|
|
4
|
+
This service handles:
|
|
5
|
+
- Connecting to NATS servers
|
|
6
|
+
- Parsing and sending commands to the correct topics (queue/immediate)
|
|
7
|
+
- Waiting for and handling responses
|
|
8
|
+
- Managing command lifecycle (run_id, step_number, etc.)
|
|
9
|
+
"""
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import signal
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
from typing import Dict, Any, Optional, Tuple
|
|
17
|
+
import nats
|
|
18
|
+
from nats.js.client import JetStreamContext
|
|
19
|
+
from nats.aio.msg import Msg
|
|
20
|
+
from puda_comms.models import CommandRequest, CommandResponse, CommandResponseStatus, NATSMessage, MessageHeader, MessageType
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Constants
|
|
25
|
+
NAMESPACE = "puda"
|
|
26
|
+
STREAM_COMMAND_QUEUE = "COMMAND_QUEUE"
|
|
27
|
+
STREAM_COMMAND_IMMEDIATE = "COMMAND_IMMEDIATE"
|
|
28
|
+
STREAM_RESPONSE_QUEUE = "RESPONSE_QUEUE"
|
|
29
|
+
STREAM_RESPONSE_IMMEDIATE = "RESPONSE_IMMEDIATE"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ResponseHandler:
|
|
33
|
+
"""
|
|
34
|
+
Handles response messages from a specific machine.
|
|
35
|
+
Routes responses to waiting commands based on run_id and step_number.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, js: JetStreamContext, machine_id: str):
|
|
39
|
+
self.js = js
|
|
40
|
+
self.machine_id = machine_id
|
|
41
|
+
self._pending_responses: Dict[str, Tuple[asyncio.Event, CommandResponse]] = {}
|
|
42
|
+
self._queue_consumer = None
|
|
43
|
+
self._immediate_consumer = None
|
|
44
|
+
self._initialized = False
|
|
45
|
+
|
|
46
|
+
async def initialize(self):
|
|
47
|
+
"""Initialize the response handler by subscribing to response streams."""
|
|
48
|
+
if self._initialized:
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
# push consumers with ephemeral subscriptions
|
|
52
|
+
queue_subject = f"{NAMESPACE}.{self.machine_id}.cmd.response.queue"
|
|
53
|
+
immediate_subject = f"{NAMESPACE}.{self.machine_id}.cmd.response.immediate"
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
# Create ephemeral consumers for response streams
|
|
57
|
+
self._queue_consumer = await self.js.subscribe(
|
|
58
|
+
queue_subject,
|
|
59
|
+
stream=STREAM_RESPONSE_QUEUE,
|
|
60
|
+
cb=lambda msg: asyncio.create_task(self._handle_message(msg))
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self._immediate_consumer = await self.js.subscribe(
|
|
64
|
+
immediate_subject,
|
|
65
|
+
stream=STREAM_RESPONSE_IMMEDIATE,
|
|
66
|
+
cb=lambda msg: asyncio.create_task(self._handle_message(msg))
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
logger.info("Response handler initialized for machine: %s", self.machine_id)
|
|
70
|
+
self._initialized = True
|
|
71
|
+
|
|
72
|
+
except Exception as e:
|
|
73
|
+
logger.error("Failed to initialize response handler: %s", e)
|
|
74
|
+
raise
|
|
75
|
+
|
|
76
|
+
async def _handle_message(self, msg: Msg):
|
|
77
|
+
"""Handle incoming response messages."""
|
|
78
|
+
try:
|
|
79
|
+
message = NATSMessage.model_validate_json(msg.data)
|
|
80
|
+
command = message.command.name
|
|
81
|
+
run_id = message.header.run_id
|
|
82
|
+
step_number = message.command.step_number
|
|
83
|
+
|
|
84
|
+
# Check if we have required fields for matching
|
|
85
|
+
if run_id is None or step_number is None:
|
|
86
|
+
logger.error(
|
|
87
|
+
"Response missing required fields: command=%s, step_number=%s, run_id=%s - putting back in queue",
|
|
88
|
+
command, step_number, run_id
|
|
89
|
+
)
|
|
90
|
+
await msg.nak()
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
# Look up pending response
|
|
94
|
+
key = f"{run_id}:{step_number}"
|
|
95
|
+
if key in self._pending_responses:
|
|
96
|
+
|
|
97
|
+
logger.info(
|
|
98
|
+
"Response received: command=%s, step_number=%s, run_id=%s, status=%s",
|
|
99
|
+
command, step_number, run_id, message.response.status
|
|
100
|
+
)
|
|
101
|
+
if message.response.status == CommandResponseStatus.ERROR:
|
|
102
|
+
logger.warning("Command failed: %s", message.response.message)
|
|
103
|
+
|
|
104
|
+
# Get the pending response
|
|
105
|
+
pending = self._pending_responses[key]
|
|
106
|
+
# Store the full NATSMessage JSON structure
|
|
107
|
+
pending['response'] = message.model_dump()
|
|
108
|
+
# Signal that response was received
|
|
109
|
+
# Don't delete here - let get_response() delete it after retrieval
|
|
110
|
+
pending['event'].set()
|
|
111
|
+
|
|
112
|
+
# Acknowledge the message since we matched it
|
|
113
|
+
await msg.ack()
|
|
114
|
+
else:
|
|
115
|
+
# No matching pending command - acknowledge to remove from queue
|
|
116
|
+
# This response is likely from a previous run or different session
|
|
117
|
+
logger.debug(
|
|
118
|
+
"Unmatched response (acknowledging): command=%s, step_number=%s, run_id=%s",
|
|
119
|
+
command, step_number, run_id
|
|
120
|
+
)
|
|
121
|
+
await msg.ack()
|
|
122
|
+
|
|
123
|
+
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
|
124
|
+
logger.error("Error processing response message: %s", e)
|
|
125
|
+
try:
|
|
126
|
+
await msg.ack()
|
|
127
|
+
except Exception:
|
|
128
|
+
pass
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error("Unexpected error processing response message: %s", e)
|
|
131
|
+
try:
|
|
132
|
+
await msg.ack()
|
|
133
|
+
except Exception:
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
def register_pending(self, run_id: str, step_number: int) -> asyncio.Event:
|
|
137
|
+
"""
|
|
138
|
+
Register a pending response and return event.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
run_id: Run ID for the command
|
|
142
|
+
step_number: Step number for the command
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Event that will be set when the response is received
|
|
146
|
+
"""
|
|
147
|
+
key = f"{run_id}:{str(step_number)}"
|
|
148
|
+
event = asyncio.Event()
|
|
149
|
+
# Store None initially, will be updated with the response
|
|
150
|
+
self._pending_responses[key] = {
|
|
151
|
+
'event': event,
|
|
152
|
+
'response': None
|
|
153
|
+
}
|
|
154
|
+
return event
|
|
155
|
+
|
|
156
|
+
def get_response(self, run_id: str, step_number: int) -> Optional[Dict[str, Any]]:
|
|
157
|
+
"""
|
|
158
|
+
Get the response for a pending command.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
run_id: Run ID for the command
|
|
162
|
+
step_number: Step number for the command
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
The NATSMessage dict structure if available, None otherwise
|
|
166
|
+
"""
|
|
167
|
+
key = f"{run_id}:{str(step_number)}"
|
|
168
|
+
if key in self._pending_responses:
|
|
169
|
+
response = self._pending_responses[key].get('response')
|
|
170
|
+
# Delete after retrieval to clean up
|
|
171
|
+
del self._pending_responses[key]
|
|
172
|
+
return response
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
def remove_pending(self, run_id: str, step_number: int):
|
|
176
|
+
"""Remove a pending response registration."""
|
|
177
|
+
key = f"{run_id}:{str(step_number)}"
|
|
178
|
+
if key in self._pending_responses:
|
|
179
|
+
del self._pending_responses[key]
|
|
180
|
+
|
|
181
|
+
async def cleanup(self):
|
|
182
|
+
"""Clean up subscriptions."""
|
|
183
|
+
if self._queue_consumer:
|
|
184
|
+
try:
|
|
185
|
+
await self._queue_consumer.unsubscribe()
|
|
186
|
+
except Exception:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
if self._immediate_consumer:
|
|
190
|
+
try:
|
|
191
|
+
await self._immediate_consumer.unsubscribe()
|
|
192
|
+
except Exception:
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class CommandService:
|
|
197
|
+
"""
|
|
198
|
+
Service for sending commands to machines via NATS.
|
|
199
|
+
|
|
200
|
+
Handles connection management, command parsing, and response handling.
|
|
201
|
+
Can send commands to multiple machines.
|
|
202
|
+
|
|
203
|
+
Supports async context manager usage for automatic cleanup:
|
|
204
|
+
async with CommandService() as service:
|
|
205
|
+
await service.send_queue_command(...)
|
|
206
|
+
# Automatically disconnects on exit
|
|
207
|
+
|
|
208
|
+
Automatically registers signal handlers (SIGTERM, SIGINT) for graceful shutdown.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
# ==================== Initialization ====================
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
servers: Optional[list[str]] = None
|
|
216
|
+
):
|
|
217
|
+
"""
|
|
218
|
+
Initialize NATS service.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
servers: List of NATS server URLs. If None, reads from NATS_SERVERS env var.
|
|
222
|
+
"""
|
|
223
|
+
if servers is None:
|
|
224
|
+
nats_servers_env = os.getenv(
|
|
225
|
+
"NATS_SERVERS",
|
|
226
|
+
"nats://192.168.50.201:4222,nats://192.168.50.201:4223,nats://192.168.50.201:4224"
|
|
227
|
+
)
|
|
228
|
+
servers = [s.strip() for s in nats_servers_env.split(",")]
|
|
229
|
+
|
|
230
|
+
self.servers = servers
|
|
231
|
+
self.nc: Optional[nats.NATS] = None
|
|
232
|
+
self.js: Optional[JetStreamContext] = None
|
|
233
|
+
self._response_handlers: Dict[str, ResponseHandler] = {} # stores response handlers for each machine
|
|
234
|
+
self._connected = False
|
|
235
|
+
|
|
236
|
+
# Always register signal handlers for graceful shutdown
|
|
237
|
+
self._register_signal_handlers()
|
|
238
|
+
|
|
239
|
+
# ==================== Context Manager ====================
|
|
240
|
+
|
|
241
|
+
async def __aenter__(self):
|
|
242
|
+
"""Async context manager entry."""
|
|
243
|
+
await self.connect()
|
|
244
|
+
return self
|
|
245
|
+
|
|
246
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
247
|
+
"""Async context manager exit."""
|
|
248
|
+
await self.disconnect()
|
|
249
|
+
return False # Don't suppress exceptions
|
|
250
|
+
|
|
251
|
+
# ==================== Connection Management ====================
|
|
252
|
+
|
|
253
|
+
async def connect(self) -> bool:
|
|
254
|
+
"""
|
|
255
|
+
Connect to NATS servers.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
True if connected successfully, False otherwise
|
|
259
|
+
"""
|
|
260
|
+
if self._connected:
|
|
261
|
+
return True
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
self.nc = await nats.connect(servers=self.servers)
|
|
265
|
+
self.js = self.nc.jetstream()
|
|
266
|
+
|
|
267
|
+
self._connected = True
|
|
268
|
+
logger.info("Connected to NATS servers: %s", self.servers)
|
|
269
|
+
return True
|
|
270
|
+
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error("Failed to connect to NATS: %s", e)
|
|
273
|
+
self._connected = False
|
|
274
|
+
return False
|
|
275
|
+
|
|
276
|
+
async def _get_response_handler(self, machine_id: str) -> ResponseHandler:
|
|
277
|
+
"""
|
|
278
|
+
Get or create a response handler for the specified machine.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
machine_id: Machine identifier
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
ResponseHandler instance for the machine
|
|
285
|
+
"""
|
|
286
|
+
if machine_id not in self._response_handlers:
|
|
287
|
+
handler = ResponseHandler(self.js, machine_id)
|
|
288
|
+
await handler.initialize()
|
|
289
|
+
self._response_handlers[machine_id] = handler
|
|
290
|
+
|
|
291
|
+
return self._response_handlers[machine_id]
|
|
292
|
+
|
|
293
|
+
async def disconnect(self):
|
|
294
|
+
"""Disconnect from NATS servers and cleanup."""
|
|
295
|
+
if not self._connected:
|
|
296
|
+
return
|
|
297
|
+
|
|
298
|
+
# Cleanup all response handlers
|
|
299
|
+
for handler in self._response_handlers.values():
|
|
300
|
+
await handler.cleanup()
|
|
301
|
+
self._response_handlers.clear()
|
|
302
|
+
|
|
303
|
+
if self.nc:
|
|
304
|
+
await self.nc.close()
|
|
305
|
+
self.nc = None
|
|
306
|
+
self.js = None
|
|
307
|
+
|
|
308
|
+
self._connected = False
|
|
309
|
+
logger.info("Disconnected from NATS")
|
|
310
|
+
|
|
311
|
+
# ==================== Public Command Methods ====================
|
|
312
|
+
|
|
313
|
+
async def send_queue_command(
|
|
314
|
+
self,
|
|
315
|
+
*,
|
|
316
|
+
request: CommandRequest,
|
|
317
|
+
machine_id: str,
|
|
318
|
+
run_id: str,
|
|
319
|
+
timeout: int = 120
|
|
320
|
+
) -> Optional[NATSMessage]:
|
|
321
|
+
"""
|
|
322
|
+
Send a queue command to the machine and wait for response.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
request: CommandRequest model containing command details
|
|
326
|
+
machine_id: Machine ID to send the command to
|
|
327
|
+
run_id: Run ID for the command
|
|
328
|
+
timeout: Maximum time to wait for response in seconds
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
CommandResponse if successful, None if failed or timeout
|
|
332
|
+
"""
|
|
333
|
+
if not self._connected or not self.js:
|
|
334
|
+
raise RuntimeError("Not connected to NATS. Call connect() first.")
|
|
335
|
+
|
|
336
|
+
# Determine subject
|
|
337
|
+
subject = f"{NAMESPACE}.{machine_id}.cmd.queue"
|
|
338
|
+
|
|
339
|
+
logger.info(
|
|
340
|
+
"Sending queue command: machine_id=%s, command=%s, run_id=%s, step_number=%s",
|
|
341
|
+
machine_id, request.name, run_id, request.step_number
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Get or create response handler for this machine
|
|
345
|
+
response_handler = await self._get_response_handler(machine_id)
|
|
346
|
+
# Register pending response
|
|
347
|
+
response_event = response_handler.register_pending(run_id, request.step_number)
|
|
348
|
+
|
|
349
|
+
# Build payload
|
|
350
|
+
payload = self._build_command_payload(request, machine_id, run_id)
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
# Publish to JetStream
|
|
354
|
+
pub_ack = await self.js.publish(
|
|
355
|
+
subject,
|
|
356
|
+
payload.model_dump_json().encode()
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
logger.info("Command published (step_number: %s), waiting for response...", request.step_number)
|
|
360
|
+
|
|
361
|
+
# Wait for response with timeout
|
|
362
|
+
try:
|
|
363
|
+
await asyncio.wait_for(response_event.wait(), timeout=timeout)
|
|
364
|
+
except asyncio.TimeoutError:
|
|
365
|
+
logger.error("Timeout waiting for response after %s seconds", timeout)
|
|
366
|
+
response_handler.remove_pending(run_id, request.step_number)
|
|
367
|
+
return None
|
|
368
|
+
|
|
369
|
+
# Give a small delay to ensure any pending messages are processed
|
|
370
|
+
await asyncio.sleep(0.1)
|
|
371
|
+
|
|
372
|
+
# Get the response
|
|
373
|
+
response_data = response_handler.get_response(run_id, request.step_number)
|
|
374
|
+
if response_data is None:
|
|
375
|
+
return None
|
|
376
|
+
|
|
377
|
+
return NATSMessage.model_validate(response_data)
|
|
378
|
+
|
|
379
|
+
except Exception as e:
|
|
380
|
+
logger.error("Error sending queue command: %s", e)
|
|
381
|
+
response_handler.remove_pending(run_id, request.step_number)
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
async def send_queue_commands(
|
|
385
|
+
self,
|
|
386
|
+
*,
|
|
387
|
+
requests: list[CommandRequest],
|
|
388
|
+
machine_id: str,
|
|
389
|
+
run_id: str,
|
|
390
|
+
timeout: int = 120
|
|
391
|
+
) -> Optional[NATSMessage]:
|
|
392
|
+
"""
|
|
393
|
+
Send multiple queue commands sequentially and wait for responses.
|
|
394
|
+
|
|
395
|
+
Sends commands one by one, waiting for each response before sending the next.
|
|
396
|
+
If any command fails or times out, stops immediately and returns the error response.
|
|
397
|
+
If all commands succeed, returns the last command's response.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
requests: List of CommandRequest models to send sequentially
|
|
401
|
+
machine_id: Machine ID to send the commands to
|
|
402
|
+
run_id: Run ID for all commands
|
|
403
|
+
timeout: Maximum time to wait for each response in seconds
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
NATSMessage of the failed command if any command fails, or the last
|
|
407
|
+
command's response if all succeed. Returns None if a command times out.
|
|
408
|
+
"""
|
|
409
|
+
if not self._connected or not self.js:
|
|
410
|
+
raise RuntimeError("Not connected to NATS. Call connect() first.")
|
|
411
|
+
|
|
412
|
+
if not requests:
|
|
413
|
+
logger.warning("No commands to send")
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
logger.info(
|
|
417
|
+
"Sending %d queue commands sequentially: machine_id=%s, run_id=%s",
|
|
418
|
+
len(requests),
|
|
419
|
+
machine_id,
|
|
420
|
+
run_id
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
last_response: Optional[NATSMessage] = None
|
|
424
|
+
|
|
425
|
+
for idx, request in enumerate(requests, start=1):
|
|
426
|
+
logger.info(
|
|
427
|
+
"Sending command %d/%d: %s (step %s)",
|
|
428
|
+
idx,
|
|
429
|
+
len(requests),
|
|
430
|
+
request.name,
|
|
431
|
+
request.step_number
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
response = await self.send_queue_command(
|
|
435
|
+
request=request,
|
|
436
|
+
machine_id=machine_id,
|
|
437
|
+
run_id=run_id,
|
|
438
|
+
timeout=timeout
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Check if command failed (None means timeout or exception)
|
|
442
|
+
if response is None:
|
|
443
|
+
logger.error(
|
|
444
|
+
"Command %d/%d failed or timed out: %s (step %s)",
|
|
445
|
+
idx,
|
|
446
|
+
len(requests),
|
|
447
|
+
request.name,
|
|
448
|
+
request.step_number
|
|
449
|
+
)
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
# Check if command returned an error status
|
|
453
|
+
if response.response is not None:
|
|
454
|
+
if response.response.status == CommandResponseStatus.ERROR:
|
|
455
|
+
logger.error(
|
|
456
|
+
"Command %d/%d failed with error: %s (step %s) - code: %s, message: %s",
|
|
457
|
+
idx,
|
|
458
|
+
len(requests),
|
|
459
|
+
request.name,
|
|
460
|
+
request.step_number,
|
|
461
|
+
response.response.code,
|
|
462
|
+
response.response.message
|
|
463
|
+
)
|
|
464
|
+
return response
|
|
465
|
+
|
|
466
|
+
# Command succeeded, store as last response
|
|
467
|
+
last_response = response
|
|
468
|
+
logger.info(
|
|
469
|
+
"Command %d/%d succeeded: %s (step %s)",
|
|
470
|
+
idx,
|
|
471
|
+
len(requests),
|
|
472
|
+
request.name,
|
|
473
|
+
request.step_number
|
|
474
|
+
)
|
|
475
|
+
else:
|
|
476
|
+
# Response exists but has no response data (shouldn't happen, but handle it)
|
|
477
|
+
logger.warning(
|
|
478
|
+
"Command %d/%d returned response with no response data: %s (step %s)",
|
|
479
|
+
idx,
|
|
480
|
+
len(requests),
|
|
481
|
+
request.name,
|
|
482
|
+
request.step_number
|
|
483
|
+
)
|
|
484
|
+
return response
|
|
485
|
+
|
|
486
|
+
logger.info(
|
|
487
|
+
"All %d commands completed successfully",
|
|
488
|
+
len(requests)
|
|
489
|
+
)
|
|
490
|
+
return last_response
|
|
491
|
+
|
|
492
|
+
async def send_immediate_command(
|
|
493
|
+
self,
|
|
494
|
+
*,
|
|
495
|
+
request: CommandRequest,
|
|
496
|
+
machine_id: str,
|
|
497
|
+
run_id: str,
|
|
498
|
+
timeout: int = 120
|
|
499
|
+
) -> Optional[NATSMessage]:
|
|
500
|
+
"""
|
|
501
|
+
Send an immediate command (pause, resume, cancel) to the machine.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
request: CommandRequest model containing command details
|
|
505
|
+
machine_id: Machine ID to send the command to
|
|
506
|
+
run_id: Run ID for the command
|
|
507
|
+
timeout: Maximum time to wait for response in seconds
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
NATSMessage if successful, None if failed or timeout
|
|
511
|
+
"""
|
|
512
|
+
if not self._connected or not self.js:
|
|
513
|
+
raise RuntimeError("Not connected to NATS. Call connect() first.")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
# Determine subject
|
|
517
|
+
subject = f"{NAMESPACE}.{machine_id}.cmd.immediate"
|
|
518
|
+
|
|
519
|
+
logger.info(
|
|
520
|
+
"Sending immediate command: machine_id=%s, command=%s, run_id=%s, step_number=%s",
|
|
521
|
+
machine_id, request.name, run_id, request.step_number
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Get or create response handler for this machine
|
|
525
|
+
response_handler = await self._get_response_handler(machine_id)
|
|
526
|
+
|
|
527
|
+
# Register pending response
|
|
528
|
+
response_received = response_handler.register_pending(run_id, request.step_number)
|
|
529
|
+
|
|
530
|
+
# Build payload
|
|
531
|
+
payload = self._build_command_payload(request, machine_id, run_id)
|
|
532
|
+
|
|
533
|
+
try:
|
|
534
|
+
# Publish to JetStream
|
|
535
|
+
pub_ack = await self.js.publish(
|
|
536
|
+
subject,
|
|
537
|
+
payload.model_dump_json().encode()
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
logger.info("Command published, waiting for response...")
|
|
541
|
+
|
|
542
|
+
# Wait for response with timeout
|
|
543
|
+
try:
|
|
544
|
+
await asyncio.wait_for(response_received.wait(), timeout=timeout)
|
|
545
|
+
except asyncio.TimeoutError:
|
|
546
|
+
logger.error("Timeout waiting for response after %s seconds", timeout)
|
|
547
|
+
response_handler.remove_pending(run_id, request.step_number)
|
|
548
|
+
return None
|
|
549
|
+
|
|
550
|
+
# Give a small delay to ensure any pending messages are processed
|
|
551
|
+
await asyncio.sleep(0.1)
|
|
552
|
+
|
|
553
|
+
# Get the response
|
|
554
|
+
response_data = response_handler.get_response(run_id, request.step_number)
|
|
555
|
+
if response_data is None:
|
|
556
|
+
return None
|
|
557
|
+
|
|
558
|
+
return NATSMessage.model_validate(response_data)
|
|
559
|
+
|
|
560
|
+
except Exception as e:
|
|
561
|
+
logger.error("Error sending immediate command: %s", e)
|
|
562
|
+
response_handler.remove_pending(run_id, request.step_number)
|
|
563
|
+
return None
|
|
564
|
+
|
|
565
|
+
# ==================== Private Helper Methods ====================
|
|
566
|
+
|
|
567
|
+
def _register_signal_handlers(self):
|
|
568
|
+
"""Register signal handlers for graceful shutdown."""
|
|
569
|
+
def signal_handler(signum, _frame):
|
|
570
|
+
"""Handle shutdown signals by scheduling disconnect."""
|
|
571
|
+
logger.info("Received signal %s, initiating graceful shutdown...", signum)
|
|
572
|
+
try:
|
|
573
|
+
# Try to get the running event loop
|
|
574
|
+
loop = asyncio.get_running_loop()
|
|
575
|
+
# Schedule disconnect as a task in the running loop
|
|
576
|
+
def schedule_disconnect():
|
|
577
|
+
asyncio.create_task(self.disconnect())
|
|
578
|
+
loop.call_soon_threadsafe(schedule_disconnect)
|
|
579
|
+
except RuntimeError:
|
|
580
|
+
# No running loop, create a new one and run disconnect
|
|
581
|
+
asyncio.run(self.disconnect())
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.error("Error during signal handler disconnect: %s", e)
|
|
584
|
+
|
|
585
|
+
# Register handlers for SIGTERM and SIGINT
|
|
586
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
587
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
588
|
+
logger.debug("Signal handlers registered for SIGTERM and SIGINT")
|
|
589
|
+
|
|
590
|
+
async def _get_response_handler(self, machine_id: str) -> ResponseHandler:
|
|
591
|
+
"""
|
|
592
|
+
Get or create a response handler for the specified machine.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
machine_id: Machine identifier
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
ResponseHandler instance for the machine
|
|
599
|
+
"""
|
|
600
|
+
if machine_id not in self._response_handlers:
|
|
601
|
+
handler = ResponseHandler(self.js, machine_id)
|
|
602
|
+
await handler.initialize()
|
|
603
|
+
self._response_handlers[machine_id] = handler
|
|
604
|
+
|
|
605
|
+
return self._response_handlers[machine_id]
|
|
606
|
+
|
|
607
|
+
def _build_command_payload(
|
|
608
|
+
self,
|
|
609
|
+
command_request: CommandRequest,
|
|
610
|
+
machine_id: str,
|
|
611
|
+
run_id: str
|
|
612
|
+
) -> NATSMessage:
|
|
613
|
+
"""
|
|
614
|
+
Build a command payload in the expected format.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
command_request: CommandRequest model containing command details
|
|
618
|
+
machine_id: Machine ID for the command
|
|
619
|
+
run_id: Run ID for the command
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
NATSMessage object ready for NATS transmission
|
|
623
|
+
"""
|
|
624
|
+
header = MessageHeader(
|
|
625
|
+
message_type=MessageType.COMMAND,
|
|
626
|
+
version="1.0",
|
|
627
|
+
timestamp=datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'),
|
|
628
|
+
machine_id=machine_id,
|
|
629
|
+
run_id=run_id
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
return NATSMessage(
|
|
633
|
+
header=header,
|
|
634
|
+
command=command_request
|
|
635
|
+
)
|