puda-comms 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
puda_comms/__init__.py CHANGED
@@ -1,5 +1,9 @@
1
+ # Import models first to ensure they're initialized before other modules that depend on them
2
+ from . import models
3
+
1
4
  from .machine_client import MachineClient
2
5
  from .execution_state import ExecutionState
3
6
  from .command_service import CommandService
7
+ from .stream_subscriber import StreamSubscriber
4
8
 
5
- __all__ = ["MachineClient", "ExecutionState", "CommandService"]
9
+ __all__ = ["MachineClient", "ExecutionState", "CommandService", "StreamSubscriber", "models"]
@@ -16,7 +16,13 @@ from typing import Dict, Any, Optional
16
16
  import nats
17
17
  from nats.js.client import JetStreamContext
18
18
  from nats.aio.msg import Msg
19
- from puda_comms.models import CommandRequest, CommandResponseStatus, NATSMessage, MessageHeader, MessageType
19
+ from .models import (
20
+ CommandRequest,
21
+ CommandResponseStatus,
22
+ NATSMessage,
23
+ MessageHeader,
24
+ MessageType,
25
+ )
20
26
 
21
27
  logger = logging.getLogger(__name__)
22
28
 
@@ -98,7 +104,7 @@ class ResponseHandler:
98
104
  command, step_number, run_id, message.response.status
99
105
  )
100
106
  if message.response.status == CommandResponseStatus.ERROR:
101
- logger.warning("Command failed: %s", message.response.message)
107
+ logger.error("Error Code: %s, Message: %s", message.response.code.name, message.response.message)
102
108
 
103
109
  # Get the pending response
104
110
  pending = self._pending_responses[key]
@@ -266,9 +272,9 @@ class CommandService:
266
272
  max_attempts = 3
267
273
  connect_timeout = 3 # 3 seconds timeout per connection attempt
268
274
 
269
- for attempt in range(1, max_attempts + 1):
275
+ for attempt in range(max_attempts):
270
276
  try:
271
- logger.info("Connection attempt %d/%d to NATS servers: %s", attempt, max_attempts, self.servers)
277
+ logger.info("Connection attempt %d/%d to NATS servers: %s", attempt + 1, max_attempts, self.servers)
272
278
  self.nc = await asyncio.wait_for(
273
279
  nats.connect(
274
280
  servers=self.servers,
@@ -285,14 +291,14 @@ class CommandService:
285
291
  return True
286
292
 
287
293
  except asyncio.TimeoutError:
288
- logger.warning("Connection attempt %d/%d timed out after %d seconds", attempt, max_attempts, connect_timeout)
289
- if attempt < max_attempts:
294
+ logger.warning("Connection attempt %d/%d timed out after %d seconds", attempt + 1, max_attempts, connect_timeout)
295
+ if attempt < max_attempts - 1:
290
296
  logger.info("Retrying connection...")
291
297
  else:
292
298
  logger.error("Failed to connect after %d attempts. Giving up.", max_attempts)
293
299
  except Exception as e:
294
- logger.warning("Connection attempt %d/%d failed: %s", attempt, max_attempts, e)
295
- if attempt < max_attempts:
300
+ logger.warning("Connection attempt %d/%d failed: %s", attempt + 1, max_attempts, e)
301
+ if attempt < max_attempts - 1:
296
302
  logger.info("Retrying connection...")
297
303
  else:
298
304
  logger.error("Failed to connect after %d attempts. Giving up.", max_attempts)
@@ -341,7 +347,6 @@ class CommandService:
341
347
  self,
342
348
  *,
343
349
  request: CommandRequest,
344
- machine_id: str,
345
350
  run_id: str,
346
351
  user_id: str,
347
352
  username: str,
@@ -351,8 +356,7 @@ class CommandService:
351
356
  Send a queue command to the machine and wait for response.
352
357
 
353
358
  Args:
354
- request: CommandRequest model containing command details
355
- machine_id: Machine ID to send the command to
359
+ request: CommandRequest model containing command details (must include machine_id)
356
360
  run_id: Run ID for the command
357
361
  user_id: User ID who initiated the command
358
362
  username: Username who initiated the command
@@ -364,8 +368,8 @@ class CommandService:
364
368
  if not self._connected or not self.js:
365
369
  raise RuntimeError("Not connected to NATS. Call connect() first.")
366
370
 
367
- # Determine subject
368
- subject = f"{NAMESPACE}.{machine_id}.cmd.queue"
371
+ # Determine subject using machine_id from request
372
+ subject = f"{NAMESPACE}.{request.machine_id}.cmd.queue"
369
373
 
370
374
  logger.info(
371
375
  "Sending queue command: subject=%s, command=%s, run_id=%s, step_number=%s",
@@ -373,12 +377,12 @@ class CommandService:
373
377
  )
374
378
 
375
379
  # Get or create response handler for this machine
376
- response_handler = await self._get_response_handler(machine_id)
380
+ response_handler = await self._get_response_handler(request.machine_id)
377
381
  # Register pending response
378
382
  response_event = response_handler.register_pending(run_id, request.step_number)
379
383
 
380
384
  # Build payload
381
- payload = self._build_command_payload(request, machine_id, run_id, user_id, username)
385
+ payload = self._build_command_payload(request, request.machine_id, run_id, user_id, username)
382
386
 
383
387
  try:
384
388
  # Publish to JetStream
@@ -408,11 +412,80 @@ class CommandService:
408
412
  response_handler.remove_pending(run_id, request.step_number)
409
413
  return None
410
414
 
415
+ async def start_run(
416
+ self,
417
+ machine_id: str,
418
+ run_id: str,
419
+ user_id: str,
420
+ username: str,
421
+ timeout: int = 120
422
+ ) -> Optional[NATSMessage]:
423
+ """
424
+ Send START immediate command to begin a run.
425
+
426
+ Args:
427
+ machine_id: Machine ID to send the command to
428
+ run_id: Run ID for the command
429
+ user_id: User ID who initiated the command
430
+ username: Username who initiated the command
431
+ timeout: Maximum time to wait for response in seconds
432
+
433
+ Returns:
434
+ NATSMessage if successful, None if failed or timeout
435
+ """
436
+ request = CommandRequest(
437
+ name="start",
438
+ machine_id=machine_id,
439
+ params={},
440
+ step_number=0
441
+ )
442
+ return await self.send_immediate_command(
443
+ request=request,
444
+ run_id=run_id,
445
+ user_id=user_id,
446
+ username=username,
447
+ timeout=timeout
448
+ )
449
+
450
+ async def complete_run(
451
+ self,
452
+ machine_id: str,
453
+ run_id: str,
454
+ user_id: str,
455
+ username: str,
456
+ timeout: int = 120
457
+ ) -> Optional[NATSMessage]:
458
+ """
459
+ Send COMPLETE immediate command to end a run.
460
+
461
+ Args:
462
+ machine_id: Machine ID to send the command to
463
+ run_id: Run ID for the command
464
+ user_id: User ID who initiated the command
465
+ username: Username who initiated the command
466
+ timeout: Maximum time to wait for response in seconds
467
+
468
+ Returns:
469
+ NATSMessage if successful, None if failed or timeout
470
+ """
471
+ request = CommandRequest(
472
+ name="complete",
473
+ machine_id=machine_id,
474
+ params={},
475
+ step_number=0
476
+ )
477
+ return await self.send_immediate_command(
478
+ request=request,
479
+ run_id=run_id,
480
+ user_id=user_id,
481
+ username=username,
482
+ timeout=timeout
483
+ )
484
+
411
485
  async def send_queue_commands(
412
486
  self,
413
487
  *,
414
488
  requests: list[CommandRequest],
415
- machine_id: str,
416
489
  run_id: str,
417
490
  user_id: str,
418
491
  username: str,
@@ -421,13 +494,18 @@ class CommandService:
421
494
  """
422
495
  Send multiple queue commands sequentially and wait for responses.
423
496
 
497
+ Automatically sends START commands to all unique machine_ids before the sequence
498
+ and COMPLETE commands to all unique machine_ids after successful completion.
424
499
  Sends commands one by one, waiting for each response before sending the next.
425
- If any command fails or times out, stops immediately and returns the error response.
426
- If all commands succeed, returns the last command's response.
500
+ If any command fails or times out, stops immediately, completes all started runs,
501
+ and returns the error response. If all commands succeed, returns the last command's response.
502
+
503
+ Each command must specify its own machine_id. Commands with different machine_ids
504
+ will be sent to their respective machines. All machines involved will receive
505
+ START commands at the beginning and COMPLETE commands at the end.
427
506
 
428
507
  Args:
429
- requests: List of CommandRequest models to send sequentially
430
- machine_id: Machine ID to send the commands to
508
+ requests: List of CommandRequest models to send sequentially (each must include machine_id)
431
509
  run_id: Run ID for all commands
432
510
  user_id: User ID who initiated the commands
433
511
  username: Username who initiated the commands
@@ -444,89 +522,186 @@ class CommandService:
444
522
  logger.warning("No commands to send")
445
523
  return None
446
524
 
525
+ # Collect all unique machine_ids from requests
526
+ machine_ids = set()
527
+ for req in requests:
528
+ if isinstance(req, dict):
529
+ req = CommandRequest.model_validate(req)
530
+ elif not isinstance(req, CommandRequest):
531
+ raise ValueError(f"Request must be a CommandRequest or dict, got {type(req)}")
532
+ machine_ids.add(req.machine_id)
533
+
534
+ machine_ids_list = sorted(list(machine_ids)) # Sort for consistent logging
535
+
447
536
  logger.info(
448
- "Sending %d queue commands sequentially: machine_id=%s, run_id=%s",
537
+ "Sending %d queue commands sequentially to machines: %s, run_id=%s",
449
538
  len(requests),
450
- machine_id,
539
+ machine_ids_list,
451
540
  run_id
452
541
  )
453
542
 
454
- last_response: Optional[NATSMessage] = None
455
-
456
- for idx, request in enumerate(requests, start=1):
457
- logger.info(
458
- "Sending command %d/%d: %s (step %s)",
459
- idx,
460
- len(requests),
461
- request.name,
462
- request.step_number
463
- )
464
-
465
- response = await self.send_queue_command(
466
- request=request,
543
+ # Send START commands to all unique machine_ids before sequence
544
+ logger.info("Sending START commands to all machines: %s", machine_ids_list)
545
+ started_machines = set()
546
+ for machine_id in machine_ids_list:
547
+ start_response = await self.start_run(
467
548
  machine_id=machine_id,
468
549
  run_id=run_id,
469
550
  user_id=user_id,
470
551
  username=username,
471
552
  timeout=timeout
472
553
  )
473
-
474
- # Check if command failed (None means timeout or exception)
475
- if response is None:
476
- logger.error(
477
- "Command %d/%d failed or timed out: %s (step %s)",
554
+ if start_response is None:
555
+ logger.error("START command timed out for machine: %s, aborting", machine_id)
556
+ return None
557
+ if start_response.response and start_response.response.status == CommandResponseStatus.ERROR:
558
+ logger.error("START command failed for machine %s: %s, aborting", machine_id, start_response.response.message)
559
+ return start_response
560
+ started_machines.add(machine_id)
561
+
562
+ last_response: Optional[NATSMessage] = None
563
+
564
+ try:
565
+ for idx, request in enumerate(requests, start=1):
566
+ # Validate request - convert dict to CommandRequest if needed
567
+ if isinstance(request, dict):
568
+ request = CommandRequest.model_validate(request)
569
+ elif not isinstance(request, CommandRequest):
570
+ raise ValueError(f"Request {idx} must be a CommandRequest or dict, got {type(request)}")
571
+
572
+ logger.info(
573
+ "Sending command %d/%d: %s (step %s) to machine %s",
478
574
  idx,
479
575
  len(requests),
480
576
  request.name,
481
- request.step_number
577
+ request.step_number,
578
+ request.machine_id
579
+ )
580
+
581
+ response = await self.send_queue_command(
582
+ request=request,
583
+ run_id=run_id,
584
+ user_id=user_id,
585
+ username=username,
586
+ timeout=timeout
482
587
  )
483
- return None
484
588
 
485
- # Check if command returned an error status
486
- if response.response is not None:
487
- if response.response.status == CommandResponseStatus.ERROR:
589
+ # Check if command failed (None means timeout or exception)
590
+ if response is None:
488
591
  logger.error(
489
- "Command %d/%d failed with error: %s (step %s) - code: %s, message: %s",
592
+ "Command %d/%d failed or timed out: %s (step %s)",
593
+ idx,
594
+ len(requests),
595
+ request.name,
596
+ request.step_number
597
+ )
598
+ return None
599
+
600
+ # Check if command returned an error status
601
+ if response.response is not None:
602
+ if response.response.status == CommandResponseStatus.ERROR:
603
+ logger.error(
604
+ "Command %d/%d failed with error: %s (step %s) - code: %s, message: %s",
605
+ idx,
606
+ len(requests),
607
+ request.name,
608
+ request.step_number,
609
+ response.response.code.name,
610
+ response.response.message
611
+ )
612
+ # Complete the run on all machines that were started
613
+ logger.info("Completing runs on all machines due to error")
614
+ for machine_id_to_complete in started_machines:
615
+ try:
616
+ await self.complete_run(
617
+ machine_id=machine_id_to_complete,
618
+ run_id=run_id,
619
+ user_id=user_id,
620
+ username=username,
621
+ timeout=timeout
622
+ )
623
+ except Exception as e:
624
+ logger.error("Failed to complete run for machine %s during error cleanup: %s", machine_id_to_complete, e)
625
+ return response
626
+
627
+ # Command succeeded, store as last response
628
+ last_response = response
629
+ logger.info(
630
+ "Command %d/%d succeeded: %s (step %s)",
631
+ idx,
632
+ len(requests),
633
+ request.name,
634
+ request.step_number
635
+ )
636
+ else:
637
+ # Response exists but has no response data (shouldn't happen, but handle it)
638
+ logger.warning(
639
+ "Command %d/%d returned response with no response data: %s (step %s)",
490
640
  idx,
491
641
  len(requests),
492
642
  request.name,
493
- request.step_number,
494
- response.response.code,
495
- response.response.message
643
+ request.step_number
496
644
  )
645
+ # Complete the run on all machines that were started
646
+ logger.info("Completing runs on all machines due to error")
647
+ for machine_id_to_complete in started_machines:
648
+ try:
649
+ await self.complete_run(
650
+ machine_id=machine_id_to_complete,
651
+ run_id=run_id,
652
+ user_id=user_id,
653
+ username=username,
654
+ timeout=timeout
655
+ )
656
+ except Exception as e:
657
+ logger.error("Failed to complete run for machine %s during error cleanup: %s", machine_id_to_complete, e)
497
658
  return response
498
-
499
- # Command succeeded, store as last response
500
- last_response = response
501
- logger.info(
502
- "Command %d/%d succeeded: %s (step %s)",
503
- idx,
504
- len(requests),
505
- request.name,
506
- request.step_number
507
- )
508
- else:
509
- # Response exists but has no response data (shouldn't happen, but handle it)
510
- logger.warning(
511
- "Command %d/%d returned response with no response data: %s (step %s)",
512
- idx,
513
- len(requests),
514
- request.name,
515
- request.step_number
659
+
660
+ logger.info(
661
+ "All %d commands completed successfully",
662
+ len(requests)
663
+ )
664
+
665
+ # Always send COMPLETE commands to all machines after successful sequence
666
+ logger.info("Sending COMPLETE commands to all machines: %s", machine_ids_list)
667
+ for machine_id_to_complete in machine_ids_list:
668
+ complete_response = await self.complete_run(
669
+ machine_id=machine_id_to_complete,
670
+ run_id=run_id,
671
+ user_id=user_id,
672
+ username=username,
673
+ timeout=timeout
516
674
  )
517
- return response
518
-
519
- logger.info(
520
- "All %d commands completed successfully",
521
- len(requests)
522
- )
523
- return last_response
675
+ if complete_response is None:
676
+ logger.error("COMPLETE command timed out for machine: %s, aborting", machine_id_to_complete)
677
+ return None
678
+ if complete_response.response and complete_response.response.status == CommandResponseStatus.ERROR:
679
+ logger.error("COMPLETE command failed for machine %s: %s, aborting", machine_id_to_complete, complete_response.response.message)
680
+ return complete_response
681
+
682
+ # Return the last command response, not the COMPLETE response
683
+ return last_response
684
+ except Exception as e:
685
+ # If any error occurs during command execution, try to complete the run
686
+ # on all machines that were started to clean up state
687
+ logger.warning("Error during command sequence, attempting to complete runs on all machines: %s", e)
688
+ for machine_id_to_complete in started_machines:
689
+ try:
690
+ await self.complete_run(
691
+ machine_id=machine_id_to_complete,
692
+ run_id=run_id,
693
+ user_id=user_id,
694
+ username=username,
695
+ timeout=timeout
696
+ )
697
+ except Exception as cleanup_error:
698
+ logger.error("Failed to complete run for machine %s during error cleanup: %s", machine_id_to_complete, cleanup_error)
699
+ raise
524
700
 
525
701
  async def send_immediate_command(
526
702
  self,
527
703
  *,
528
704
  request: CommandRequest,
529
- machine_id: str,
530
705
  run_id: str,
531
706
  user_id: str,
532
707
  username: str,
@@ -536,8 +711,7 @@ class CommandService:
536
711
  Send an immediate command (pause, resume, cancel) to the machine.
537
712
 
538
713
  Args:
539
- request: CommandRequest model containing command details
540
- machine_id: Machine ID to send the command to
714
+ request: CommandRequest model containing command details (must include machine_id)
541
715
  run_id: Run ID for the command
542
716
  user_id: User ID who initiated the command
543
717
  username: Username who initiated the command
@@ -549,23 +723,22 @@ class CommandService:
549
723
  if not self._connected or not self.js:
550
724
  raise RuntimeError("Not connected to NATS. Call connect() first.")
551
725
 
552
-
553
- # Determine subject
554
- subject = f"{NAMESPACE}.{machine_id}.cmd.immediate"
726
+ # Determine subject using machine_id from request
727
+ subject = f"{NAMESPACE}.{request.machine_id}.cmd.immediate"
555
728
 
556
729
  logger.info(
557
730
  "Sending immediate command: machine_id=%s, command=%s, run_id=%s, step_number=%s",
558
- machine_id, request.name, run_id, request.step_number
731
+ request.machine_id, request.name, run_id, request.step_number
559
732
  )
560
733
 
561
734
  # Get or create response handler for this machine
562
- response_handler = await self._get_response_handler(machine_id)
735
+ response_handler = await self._get_response_handler(request.machine_id)
563
736
 
564
737
  # Register pending response
565
738
  response_received = response_handler.register_pending(run_id, request.step_number)
566
739
 
567
740
  # Build payload
568
- payload = self._build_command_payload(request, machine_id, run_id, user_id, username)
741
+ payload = self._build_command_payload(request, request.machine_id, run_id, user_id, username)
569
742
 
570
743
  try:
571
744
  # Publish to JetStream
@@ -651,13 +824,16 @@ class CommandService:
651
824
  Args:
652
825
  command_request: CommandRequest model containing command details
653
826
  machine_id: Machine ID for the command
654
- run_id: Run ID for the command
827
+ run_id: Run ID for the command (empty string will be converted to None)
655
828
  user_id: User ID who initiated the command
656
829
  username: Username who initiated the command
657
830
 
658
831
  Returns:
659
832
  NATSMessage object ready for NATS transmission
660
833
  """
834
+ # Convert empty string to None for run_id
835
+ run_id_value = run_id if run_id else None
836
+
661
837
  header = MessageHeader(
662
838
  message_type=MessageType.COMMAND,
663
839
  version="1.0",
@@ -665,7 +841,7 @@ class CommandService:
665
841
  user_id=user_id,
666
842
  username=username,
667
843
  machine_id=machine_id,
668
- run_id=run_id
844
+ run_id=run_id_value
669
845
  )
670
846
 
671
847
  return NATSMessage(