puda-comms 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
puda_comms/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .machine_client import MachineClient
2
+ from .execution_state import ExecutionState
3
+ from .command_service import CommandService
4
+
5
+ __all__ = ["MachineClient", "ExecutionState", "CommandService"]
@@ -0,0 +1,635 @@
1
+ """
2
+ Service for sending commands to machines via NATS. Should take in AI generated commands as CommandRequest models.
3
+
4
+ This service handles:
5
+ - Connecting to NATS servers
6
+ - Parsing and sending commands to the correct topics (queue/immediate)
7
+ - Waiting for and handling responses
8
+ - Managing command lifecycle (run_id, step_number, etc.)
9
+ """
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ import os
14
+ import signal
15
+ from datetime import datetime, timezone
16
+ from typing import Dict, Any, Optional, Tuple
17
+ import nats
18
+ from nats.js.client import JetStreamContext
19
+ from nats.aio.msg import Msg
20
+ from puda_comms.models import CommandRequest, CommandResponse, CommandResponseStatus, NATSMessage, MessageHeader, MessageType
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Constants
25
+ NAMESPACE = "puda"
26
+ STREAM_COMMAND_QUEUE = "COMMAND_QUEUE"
27
+ STREAM_COMMAND_IMMEDIATE = "COMMAND_IMMEDIATE"
28
+ STREAM_RESPONSE_QUEUE = "RESPONSE_QUEUE"
29
+ STREAM_RESPONSE_IMMEDIATE = "RESPONSE_IMMEDIATE"
30
+
31
+
32
+ class ResponseHandler:
33
+ """
34
+ Handles response messages from a specific machine.
35
+ Routes responses to waiting commands based on run_id and step_number.
36
+ """
37
+
38
+ def __init__(self, js: JetStreamContext, machine_id: str):
39
+ self.js = js
40
+ self.machine_id = machine_id
41
+ self._pending_responses: Dict[str, Tuple[asyncio.Event, CommandResponse]] = {}
42
+ self._queue_consumer = None
43
+ self._immediate_consumer = None
44
+ self._initialized = False
45
+
46
+ async def initialize(self):
47
+ """Initialize the response handler by subscribing to response streams."""
48
+ if self._initialized:
49
+ return
50
+
51
+ # push consumers with ephemeral subscriptions
52
+ queue_subject = f"{NAMESPACE}.{self.machine_id}.cmd.response.queue"
53
+ immediate_subject = f"{NAMESPACE}.{self.machine_id}.cmd.response.immediate"
54
+
55
+ try:
56
+ # Create ephemeral consumers for response streams
57
+ self._queue_consumer = await self.js.subscribe(
58
+ queue_subject,
59
+ stream=STREAM_RESPONSE_QUEUE,
60
+ cb=lambda msg: asyncio.create_task(self._handle_message(msg))
61
+ )
62
+
63
+ self._immediate_consumer = await self.js.subscribe(
64
+ immediate_subject,
65
+ stream=STREAM_RESPONSE_IMMEDIATE,
66
+ cb=lambda msg: asyncio.create_task(self._handle_message(msg))
67
+ )
68
+
69
+ logger.info("Response handler initialized for machine: %s", self.machine_id)
70
+ self._initialized = True
71
+
72
+ except Exception as e:
73
+ logger.error("Failed to initialize response handler: %s", e)
74
+ raise
75
+
76
+ async def _handle_message(self, msg: Msg):
77
+ """Handle incoming response messages."""
78
+ try:
79
+ message = NATSMessage.model_validate_json(msg.data)
80
+ command = message.command.name
81
+ run_id = message.header.run_id
82
+ step_number = message.command.step_number
83
+
84
+ # Check if we have required fields for matching
85
+ if run_id is None or step_number is None:
86
+ logger.error(
87
+ "Response missing required fields: command=%s, step_number=%s, run_id=%s - putting back in queue",
88
+ command, step_number, run_id
89
+ )
90
+ await msg.nak()
91
+ return
92
+
93
+ # Look up pending response
94
+ key = f"{run_id}:{step_number}"
95
+ if key in self._pending_responses:
96
+
97
+ logger.info(
98
+ "Response received: command=%s, step_number=%s, run_id=%s, status=%s",
99
+ command, step_number, run_id, message.response.status
100
+ )
101
+ if message.response.status == CommandResponseStatus.ERROR:
102
+ logger.warning("Command failed: %s", message.response.message)
103
+
104
+ # Get the pending response
105
+ pending = self._pending_responses[key]
106
+ # Store the full NATSMessage JSON structure
107
+ pending['response'] = message.model_dump()
108
+ # Signal that response was received
109
+ # Don't delete here - let get_response() delete it after retrieval
110
+ pending['event'].set()
111
+
112
+ # Acknowledge the message since we matched it
113
+ await msg.ack()
114
+ else:
115
+ # No matching pending command - acknowledge to remove from queue
116
+ # This response is likely from a previous run or different session
117
+ logger.debug(
118
+ "Unmatched response (acknowledging): command=%s, step_number=%s, run_id=%s",
119
+ command, step_number, run_id
120
+ )
121
+ await msg.ack()
122
+
123
+ except (json.JSONDecodeError, KeyError, AttributeError) as e:
124
+ logger.error("Error processing response message: %s", e)
125
+ try:
126
+ await msg.ack()
127
+ except Exception:
128
+ pass
129
+ except Exception as e:
130
+ logger.error("Unexpected error processing response message: %s", e)
131
+ try:
132
+ await msg.ack()
133
+ except Exception:
134
+ pass
135
+
136
+ def register_pending(self, run_id: str, step_number: int) -> asyncio.Event:
137
+ """
138
+ Register a pending response and return event.
139
+
140
+ Args:
141
+ run_id: Run ID for the command
142
+ step_number: Step number for the command
143
+
144
+ Returns:
145
+ Event that will be set when the response is received
146
+ """
147
+ key = f"{run_id}:{str(step_number)}"
148
+ event = asyncio.Event()
149
+ # Store None initially, will be updated with the response
150
+ self._pending_responses[key] = {
151
+ 'event': event,
152
+ 'response': None
153
+ }
154
+ return event
155
+
156
+ def get_response(self, run_id: str, step_number: int) -> Optional[Dict[str, Any]]:
157
+ """
158
+ Get the response for a pending command.
159
+
160
+ Args:
161
+ run_id: Run ID for the command
162
+ step_number: Step number for the command
163
+
164
+ Returns:
165
+ The NATSMessage dict structure if available, None otherwise
166
+ """
167
+ key = f"{run_id}:{str(step_number)}"
168
+ if key in self._pending_responses:
169
+ response = self._pending_responses[key].get('response')
170
+ # Delete after retrieval to clean up
171
+ del self._pending_responses[key]
172
+ return response
173
+ return None
174
+
175
+ def remove_pending(self, run_id: str, step_number: int):
176
+ """Remove a pending response registration."""
177
+ key = f"{run_id}:{str(step_number)}"
178
+ if key in self._pending_responses:
179
+ del self._pending_responses[key]
180
+
181
+ async def cleanup(self):
182
+ """Clean up subscriptions."""
183
+ if self._queue_consumer:
184
+ try:
185
+ await self._queue_consumer.unsubscribe()
186
+ except Exception:
187
+ pass
188
+
189
+ if self._immediate_consumer:
190
+ try:
191
+ await self._immediate_consumer.unsubscribe()
192
+ except Exception:
193
+ pass
194
+
195
+
196
+ class CommandService:
197
+ """
198
+ Service for sending commands to machines via NATS.
199
+
200
+ Handles connection management, command parsing, and response handling.
201
+ Can send commands to multiple machines.
202
+
203
+ Supports async context manager usage for automatic cleanup:
204
+ async with CommandService() as service:
205
+ await service.send_queue_command(...)
206
+ # Automatically disconnects on exit
207
+
208
+ Automatically registers signal handlers (SIGTERM, SIGINT) for graceful shutdown.
209
+ """
210
+
211
+ # ==================== Initialization ====================
212
+
213
+ def __init__(
214
+ self,
215
+ servers: Optional[list[str]] = None
216
+ ):
217
+ """
218
+ Initialize NATS service.
219
+
220
+ Args:
221
+ servers: List of NATS server URLs. If None, reads from NATS_SERVERS env var.
222
+ """
223
+ if servers is None:
224
+ nats_servers_env = os.getenv(
225
+ "NATS_SERVERS",
226
+ "nats://192.168.50.201:4222,nats://192.168.50.201:4223,nats://192.168.50.201:4224"
227
+ )
228
+ servers = [s.strip() for s in nats_servers_env.split(",")]
229
+
230
+ self.servers = servers
231
+ self.nc: Optional[nats.NATS] = None
232
+ self.js: Optional[JetStreamContext] = None
233
+ self._response_handlers: Dict[str, ResponseHandler] = {} # stores response handlers for each machine
234
+ self._connected = False
235
+
236
+ # Always register signal handlers for graceful shutdown
237
+ self._register_signal_handlers()
238
+
239
+ # ==================== Context Manager ====================
240
+
241
+ async def __aenter__(self):
242
+ """Async context manager entry."""
243
+ await self.connect()
244
+ return self
245
+
246
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
247
+ """Async context manager exit."""
248
+ await self.disconnect()
249
+ return False # Don't suppress exceptions
250
+
251
+ # ==================== Connection Management ====================
252
+
253
+ async def connect(self) -> bool:
254
+ """
255
+ Connect to NATS servers.
256
+
257
+ Returns:
258
+ True if connected successfully, False otherwise
259
+ """
260
+ if self._connected:
261
+ return True
262
+
263
+ try:
264
+ self.nc = await nats.connect(servers=self.servers)
265
+ self.js = self.nc.jetstream()
266
+
267
+ self._connected = True
268
+ logger.info("Connected to NATS servers: %s", self.servers)
269
+ return True
270
+
271
+ except Exception as e:
272
+ logger.error("Failed to connect to NATS: %s", e)
273
+ self._connected = False
274
+ return False
275
+
276
+ async def _get_response_handler(self, machine_id: str) -> ResponseHandler:
277
+ """
278
+ Get or create a response handler for the specified machine.
279
+
280
+ Args:
281
+ machine_id: Machine identifier
282
+
283
+ Returns:
284
+ ResponseHandler instance for the machine
285
+ """
286
+ if machine_id not in self._response_handlers:
287
+ handler = ResponseHandler(self.js, machine_id)
288
+ await handler.initialize()
289
+ self._response_handlers[machine_id] = handler
290
+
291
+ return self._response_handlers[machine_id]
292
+
293
+ async def disconnect(self):
294
+ """Disconnect from NATS servers and cleanup."""
295
+ if not self._connected:
296
+ return
297
+
298
+ # Cleanup all response handlers
299
+ for handler in self._response_handlers.values():
300
+ await handler.cleanup()
301
+ self._response_handlers.clear()
302
+
303
+ if self.nc:
304
+ await self.nc.close()
305
+ self.nc = None
306
+ self.js = None
307
+
308
+ self._connected = False
309
+ logger.info("Disconnected from NATS")
310
+
311
+ # ==================== Public Command Methods ====================
312
+
313
+ async def send_queue_command(
314
+ self,
315
+ *,
316
+ request: CommandRequest,
317
+ machine_id: str,
318
+ run_id: str,
319
+ timeout: int = 120
320
+ ) -> Optional[NATSMessage]:
321
+ """
322
+ Send a queue command to the machine and wait for response.
323
+
324
+ Args:
325
+ request: CommandRequest model containing command details
326
+ machine_id: Machine ID to send the command to
327
+ run_id: Run ID for the command
328
+ timeout: Maximum time to wait for response in seconds
329
+
330
+ Returns:
331
+ CommandResponse if successful, None if failed or timeout
332
+ """
333
+ if not self._connected or not self.js:
334
+ raise RuntimeError("Not connected to NATS. Call connect() first.")
335
+
336
+ # Determine subject
337
+ subject = f"{NAMESPACE}.{machine_id}.cmd.queue"
338
+
339
+ logger.info(
340
+ "Sending queue command: machine_id=%s, command=%s, run_id=%s, step_number=%s",
341
+ machine_id, request.name, run_id, request.step_number
342
+ )
343
+
344
+ # Get or create response handler for this machine
345
+ response_handler = await self._get_response_handler(machine_id)
346
+ # Register pending response
347
+ response_event = response_handler.register_pending(run_id, request.step_number)
348
+
349
+ # Build payload
350
+ payload = self._build_command_payload(request, machine_id, run_id)
351
+
352
+ try:
353
+ # Publish to JetStream
354
+ pub_ack = await self.js.publish(
355
+ subject,
356
+ payload.model_dump_json().encode()
357
+ )
358
+
359
+ logger.info("Command published (step_number: %s), waiting for response...", request.step_number)
360
+
361
+ # Wait for response with timeout
362
+ try:
363
+ await asyncio.wait_for(response_event.wait(), timeout=timeout)
364
+ except asyncio.TimeoutError:
365
+ logger.error("Timeout waiting for response after %s seconds", timeout)
366
+ response_handler.remove_pending(run_id, request.step_number)
367
+ return None
368
+
369
+ # Give a small delay to ensure any pending messages are processed
370
+ await asyncio.sleep(0.1)
371
+
372
+ # Get the response
373
+ response_data = response_handler.get_response(run_id, request.step_number)
374
+ if response_data is None:
375
+ return None
376
+
377
+ return NATSMessage.model_validate(response_data)
378
+
379
+ except Exception as e:
380
+ logger.error("Error sending queue command: %s", e)
381
+ response_handler.remove_pending(run_id, request.step_number)
382
+ return None
383
+
384
+ async def send_queue_commands(
385
+ self,
386
+ *,
387
+ requests: list[CommandRequest],
388
+ machine_id: str,
389
+ run_id: str,
390
+ timeout: int = 120
391
+ ) -> Optional[NATSMessage]:
392
+ """
393
+ Send multiple queue commands sequentially and wait for responses.
394
+
395
+ Sends commands one by one, waiting for each response before sending the next.
396
+ If any command fails or times out, stops immediately and returns the error response.
397
+ If all commands succeed, returns the last command's response.
398
+
399
+ Args:
400
+ requests: List of CommandRequest models to send sequentially
401
+ machine_id: Machine ID to send the commands to
402
+ run_id: Run ID for all commands
403
+ timeout: Maximum time to wait for each response in seconds
404
+
405
+ Returns:
406
+ NATSMessage of the failed command if any command fails, or the last
407
+ command's response if all succeed. Returns None if a command times out.
408
+ """
409
+ if not self._connected or not self.js:
410
+ raise RuntimeError("Not connected to NATS. Call connect() first.")
411
+
412
+ if not requests:
413
+ logger.warning("No commands to send")
414
+ return None
415
+
416
+ logger.info(
417
+ "Sending %d queue commands sequentially: machine_id=%s, run_id=%s",
418
+ len(requests),
419
+ machine_id,
420
+ run_id
421
+ )
422
+
423
+ last_response: Optional[NATSMessage] = None
424
+
425
+ for idx, request in enumerate(requests, start=1):
426
+ logger.info(
427
+ "Sending command %d/%d: %s (step %s)",
428
+ idx,
429
+ len(requests),
430
+ request.name,
431
+ request.step_number
432
+ )
433
+
434
+ response = await self.send_queue_command(
435
+ request=request,
436
+ machine_id=machine_id,
437
+ run_id=run_id,
438
+ timeout=timeout
439
+ )
440
+
441
+ # Check if command failed (None means timeout or exception)
442
+ if response is None:
443
+ logger.error(
444
+ "Command %d/%d failed or timed out: %s (step %s)",
445
+ idx,
446
+ len(requests),
447
+ request.name,
448
+ request.step_number
449
+ )
450
+ return None
451
+
452
+ # Check if command returned an error status
453
+ if response.response is not None:
454
+ if response.response.status == CommandResponseStatus.ERROR:
455
+ logger.error(
456
+ "Command %d/%d failed with error: %s (step %s) - code: %s, message: %s",
457
+ idx,
458
+ len(requests),
459
+ request.name,
460
+ request.step_number,
461
+ response.response.code,
462
+ response.response.message
463
+ )
464
+ return response
465
+
466
+ # Command succeeded, store as last response
467
+ last_response = response
468
+ logger.info(
469
+ "Command %d/%d succeeded: %s (step %s)",
470
+ idx,
471
+ len(requests),
472
+ request.name,
473
+ request.step_number
474
+ )
475
+ else:
476
+ # Response exists but has no response data (shouldn't happen, but handle it)
477
+ logger.warning(
478
+ "Command %d/%d returned response with no response data: %s (step %s)",
479
+ idx,
480
+ len(requests),
481
+ request.name,
482
+ request.step_number
483
+ )
484
+ return response
485
+
486
+ logger.info(
487
+ "All %d commands completed successfully",
488
+ len(requests)
489
+ )
490
+ return last_response
491
+
492
+ async def send_immediate_command(
493
+ self,
494
+ *,
495
+ request: CommandRequest,
496
+ machine_id: str,
497
+ run_id: str,
498
+ timeout: int = 120
499
+ ) -> Optional[NATSMessage]:
500
+ """
501
+ Send an immediate command (pause, resume, cancel) to the machine.
502
+
503
+ Args:
504
+ request: CommandRequest model containing command details
505
+ machine_id: Machine ID to send the command to
506
+ run_id: Run ID for the command
507
+ timeout: Maximum time to wait for response in seconds
508
+
509
+ Returns:
510
+ NATSMessage if successful, None if failed or timeout
511
+ """
512
+ if not self._connected or not self.js:
513
+ raise RuntimeError("Not connected to NATS. Call connect() first.")
514
+
515
+
516
+ # Determine subject
517
+ subject = f"{NAMESPACE}.{machine_id}.cmd.immediate"
518
+
519
+ logger.info(
520
+ "Sending immediate command: machine_id=%s, command=%s, run_id=%s, step_number=%s",
521
+ machine_id, request.name, run_id, request.step_number
522
+ )
523
+
524
+ # Get or create response handler for this machine
525
+ response_handler = await self._get_response_handler(machine_id)
526
+
527
+ # Register pending response
528
+ response_received = response_handler.register_pending(run_id, request.step_number)
529
+
530
+ # Build payload
531
+ payload = self._build_command_payload(request, machine_id, run_id)
532
+
533
+ try:
534
+ # Publish to JetStream
535
+ pub_ack = await self.js.publish(
536
+ subject,
537
+ payload.model_dump_json().encode()
538
+ )
539
+
540
+ logger.info("Command published, waiting for response...")
541
+
542
+ # Wait for response with timeout
543
+ try:
544
+ await asyncio.wait_for(response_received.wait(), timeout=timeout)
545
+ except asyncio.TimeoutError:
546
+ logger.error("Timeout waiting for response after %s seconds", timeout)
547
+ response_handler.remove_pending(run_id, request.step_number)
548
+ return None
549
+
550
+ # Give a small delay to ensure any pending messages are processed
551
+ await asyncio.sleep(0.1)
552
+
553
+ # Get the response
554
+ response_data = response_handler.get_response(run_id, request.step_number)
555
+ if response_data is None:
556
+ return None
557
+
558
+ return NATSMessage.model_validate(response_data)
559
+
560
+ except Exception as e:
561
+ logger.error("Error sending immediate command: %s", e)
562
+ response_handler.remove_pending(run_id, request.step_number)
563
+ return None
564
+
565
+ # ==================== Private Helper Methods ====================
566
+
567
+ def _register_signal_handlers(self):
568
+ """Register signal handlers for graceful shutdown."""
569
+ def signal_handler(signum, _frame):
570
+ """Handle shutdown signals by scheduling disconnect."""
571
+ logger.info("Received signal %s, initiating graceful shutdown...", signum)
572
+ try:
573
+ # Try to get the running event loop
574
+ loop = asyncio.get_running_loop()
575
+ # Schedule disconnect as a task in the running loop
576
+ def schedule_disconnect():
577
+ asyncio.create_task(self.disconnect())
578
+ loop.call_soon_threadsafe(schedule_disconnect)
579
+ except RuntimeError:
580
+ # No running loop, create a new one and run disconnect
581
+ asyncio.run(self.disconnect())
582
+ except Exception as e:
583
+ logger.error("Error during signal handler disconnect: %s", e)
584
+
585
+ # Register handlers for SIGTERM and SIGINT
586
+ signal.signal(signal.SIGTERM, signal_handler)
587
+ signal.signal(signal.SIGINT, signal_handler)
588
+ logger.debug("Signal handlers registered for SIGTERM and SIGINT")
589
+
590
+ async def _get_response_handler(self, machine_id: str) -> ResponseHandler:
591
+ """
592
+ Get or create a response handler for the specified machine.
593
+
594
+ Args:
595
+ machine_id: Machine identifier
596
+
597
+ Returns:
598
+ ResponseHandler instance for the machine
599
+ """
600
+ if machine_id not in self._response_handlers:
601
+ handler = ResponseHandler(self.js, machine_id)
602
+ await handler.initialize()
603
+ self._response_handlers[machine_id] = handler
604
+
605
+ return self._response_handlers[machine_id]
606
+
607
+ def _build_command_payload(
608
+ self,
609
+ command_request: CommandRequest,
610
+ machine_id: str,
611
+ run_id: str
612
+ ) -> NATSMessage:
613
+ """
614
+ Build a command payload in the expected format.
615
+
616
+ Args:
617
+ command_request: CommandRequest model containing command details
618
+ machine_id: Machine ID for the command
619
+ run_id: Run ID for the command
620
+
621
+ Returns:
622
+ NATSMessage object ready for NATS transmission
623
+ """
624
+ header = MessageHeader(
625
+ message_type=MessageType.COMMAND,
626
+ version="1.0",
627
+ timestamp=datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'),
628
+ machine_id=machine_id,
629
+ run_id=run_id
630
+ )
631
+
632
+ return NATSMessage(
633
+ header=header,
634
+ command=command_request
635
+ )