puda-comms 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,13 @@ from typing import Dict, Any, Optional
16
16
  import nats
17
17
  from nats.js.client import JetStreamContext
18
18
  from nats.aio.msg import Msg
19
- from puda_comms.models import CommandRequest, CommandResponseStatus, NATSMessage, MessageHeader, MessageType
19
+ from puda_comms.models import (
20
+ CommandRequest,
21
+ CommandResponseStatus,
22
+ NATSMessage,
23
+ MessageHeader,
24
+ MessageType,
25
+ )
20
26
 
21
27
  logger = logging.getLogger(__name__)
22
28
 
@@ -266,9 +272,9 @@ class CommandService:
266
272
  max_attempts = 3
267
273
  connect_timeout = 3 # 3 seconds timeout per connection attempt
268
274
 
269
- for attempt in range(1, max_attempts + 1):
275
+ for attempt in range(max_attempts):
270
276
  try:
271
- logger.info("Connection attempt %d/%d to NATS servers: %s", attempt, max_attempts, self.servers)
277
+ logger.info("Connection attempt %d/%d to NATS servers: %s", attempt + 1, max_attempts, self.servers)
272
278
  self.nc = await asyncio.wait_for(
273
279
  nats.connect(
274
280
  servers=self.servers,
@@ -285,14 +291,14 @@ class CommandService:
285
291
  return True
286
292
 
287
293
  except asyncio.TimeoutError:
288
- logger.warning("Connection attempt %d/%d timed out after %d seconds", attempt, max_attempts, connect_timeout)
289
- if attempt < max_attempts:
294
+ logger.warning("Connection attempt %d/%d timed out after %d seconds", attempt + 1, max_attempts, connect_timeout)
295
+ if attempt < max_attempts - 1:
290
296
  logger.info("Retrying connection...")
291
297
  else:
292
298
  logger.error("Failed to connect after %d attempts. Giving up.", max_attempts)
293
299
  except Exception as e:
294
- logger.warning("Connection attempt %d/%d failed: %s", attempt, max_attempts, e)
295
- if attempt < max_attempts:
300
+ logger.warning("Connection attempt %d/%d failed: %s", attempt + 1, max_attempts, e)
301
+ if attempt < max_attempts - 1:
296
302
  logger.info("Retrying connection...")
297
303
  else:
298
304
  logger.error("Failed to connect after %d attempts. Giving up.", max_attempts)
@@ -408,6 +414,76 @@ class CommandService:
408
414
  response_handler.remove_pending(run_id, request.step_number)
409
415
  return None
410
416
 
417
+ async def start_run(
418
+ self,
419
+ machine_id: str,
420
+ run_id: str,
421
+ user_id: str,
422
+ username: str,
423
+ timeout: int = 120
424
+ ) -> Optional[NATSMessage]:
425
+ """
426
+ Send START immediate command to begin a run.
427
+
428
+ Args:
429
+ machine_id: Machine ID to send the command to
430
+ run_id: Run ID for the command
431
+ user_id: User ID who initiated the command
432
+ username: Username who initiated the command
433
+ timeout: Maximum time to wait for response in seconds
434
+
435
+ Returns:
436
+ NATSMessage if successful, None if failed or timeout
437
+ """
438
+ request = CommandRequest(
439
+ name="start",
440
+ params={},
441
+ step_number=0
442
+ )
443
+ return await self.send_immediate_command(
444
+ request=request,
445
+ machine_id=machine_id,
446
+ run_id=run_id,
447
+ user_id=user_id,
448
+ username=username,
449
+ timeout=timeout
450
+ )
451
+
452
+ async def complete_run(
453
+ self,
454
+ machine_id: str,
455
+ run_id: str,
456
+ user_id: str,
457
+ username: str,
458
+ timeout: int = 120
459
+ ) -> Optional[NATSMessage]:
460
+ """
461
+ Send COMPLETE immediate command to end a run.
462
+
463
+ Args:
464
+ machine_id: Machine ID to send the command to
465
+ run_id: Run ID for the command
466
+ user_id: User ID who initiated the command
467
+ username: Username who initiated the command
468
+ timeout: Maximum time to wait for response in seconds
469
+
470
+ Returns:
471
+ NATSMessage if successful, None if failed or timeout
472
+ """
473
+ request = CommandRequest(
474
+ name="complete",
475
+ params={},
476
+ step_number=0
477
+ )
478
+ return await self.send_immediate_command(
479
+ request=request,
480
+ machine_id=machine_id,
481
+ run_id=run_id,
482
+ user_id=user_id,
483
+ username=username,
484
+ timeout=timeout
485
+ )
486
+
411
487
  async def send_queue_commands(
412
488
  self,
413
489
  *,
@@ -421,9 +497,10 @@ class CommandService:
421
497
  """
422
498
  Send multiple queue commands sequentially and wait for responses.
423
499
 
424
- Sends commands one by one, waiting for each response before sending the next.
425
- If any command fails or times out, stops immediately and returns the error response.
426
- If all commands succeed, returns the last command's response.
500
+ Automatically sends START command before the sequence and COMPLETE command after
501
+ successful completion. Sends commands one by one, waiting for each response before
502
+ sending the next. If any command fails or times out, stops immediately and returns
503
+ the error response. If all commands succeed, returns the last command's response.
427
504
 
428
505
  Args:
429
506
  requests: List of CommandRequest models to send sequentially
@@ -451,76 +528,131 @@ class CommandService:
451
528
  run_id
452
529
  )
453
530
 
531
+ # Always send START command before sequence
532
+ logger.info("Sending START command before sequence")
533
+ start_response = await self.start_run(
534
+ machine_id=machine_id,
535
+ run_id=run_id,
536
+ user_id=user_id,
537
+ username=username,
538
+ timeout=timeout
539
+ )
540
+ if start_response is None:
541
+ logger.error("START command timed out")
542
+ return None
543
+ if start_response.response and start_response.response.status == CommandResponseStatus.ERROR:
544
+ logger.error("START command failed: %s", start_response.response.message)
545
+ return start_response
546
+
454
547
  last_response: Optional[NATSMessage] = None
455
548
 
456
- for idx, request in enumerate(requests, start=1):
457
- logger.info(
458
- "Sending command %d/%d: %s (step %s)",
459
- idx,
460
- len(requests),
461
- request.name,
462
- request.step_number
463
- )
464
-
465
- response = await self.send_queue_command(
466
- request=request,
467
- machine_id=machine_id,
468
- run_id=run_id,
469
- user_id=user_id,
470
- username=username,
471
- timeout=timeout
472
- )
473
-
474
- # Check if command failed (None means timeout or exception)
475
- if response is None:
476
- logger.error(
477
- "Command %d/%d failed or timed out: %s (step %s)",
549
+ try:
550
+ for idx, request in enumerate(requests, start=1):
551
+ # Validate request - convert dict to CommandRequest if needed
552
+ if isinstance(request, dict):
553
+ request = CommandRequest.model_validate(request)
554
+ elif not isinstance(request, CommandRequest):
555
+ raise ValueError(f"Request {idx} must be a CommandRequest or dict, got {type(request)}")
556
+
557
+ logger.info(
558
+ "Sending command %d/%d: %s (step %s)",
478
559
  idx,
479
560
  len(requests),
480
561
  request.name,
481
562
  request.step_number
482
563
  )
483
- return None
484
564
 
485
- # Check if command returned an error status
486
- if response.response is not None:
487
- if response.response.status == CommandResponseStatus.ERROR:
565
+ response = await self.send_queue_command(
566
+ request=request,
567
+ machine_id=machine_id,
568
+ run_id=run_id,
569
+ user_id=user_id,
570
+ username=username,
571
+ timeout=timeout
572
+ )
573
+
574
+ # Check if command failed (None means timeout or exception)
575
+ if response is None:
488
576
  logger.error(
489
- "Command %d/%d failed with error: %s (step %s) - code: %s, message: %s",
577
+ "Command %d/%d failed or timed out: %s (step %s)",
578
+ idx,
579
+ len(requests),
580
+ request.name,
581
+ request.step_number
582
+ )
583
+ return None
584
+
585
+ # Check if command returned an error status
586
+ if response.response is not None:
587
+ if response.response.status == CommandResponseStatus.ERROR:
588
+ logger.error(
589
+ "Command %d/%d failed with error: %s (step %s) - code: %s, message: %s",
590
+ idx,
591
+ len(requests),
592
+ request.name,
593
+ request.step_number,
594
+ response.response.code,
595
+ response.response.message
596
+ )
597
+ return response
598
+
599
+ # Command succeeded, store as last response
600
+ last_response = response
601
+ logger.info(
602
+ "Command %d/%d succeeded: %s (step %s)",
603
+ idx,
604
+ len(requests),
605
+ request.name,
606
+ request.step_number
607
+ )
608
+ else:
609
+ # Response exists but has no response data (shouldn't happen, but handle it)
610
+ logger.warning(
611
+ "Command %d/%d returned response with no response data: %s (step %s)",
490
612
  idx,
491
613
  len(requests),
492
614
  request.name,
493
- request.step_number,
494
- response.response.code,
495
- response.response.message
615
+ request.step_number
496
616
  )
497
617
  return response
498
-
499
- # Command succeeded, store as last response
500
- last_response = response
501
- logger.info(
502
- "Command %d/%d succeeded: %s (step %s)",
503
- idx,
504
- len(requests),
505
- request.name,
506
- request.step_number
507
- )
508
- else:
509
- # Response exists but has no response data (shouldn't happen, but handle it)
510
- logger.warning(
511
- "Command %d/%d returned response with no response data: %s (step %s)",
512
- idx,
513
- len(requests),
514
- request.name,
515
- request.step_number
618
+
619
+ logger.info(
620
+ "All %d commands completed successfully",
621
+ len(requests)
622
+ )
623
+
624
+ # Always send COMPLETE command after successful sequence
625
+ logger.info("Sending COMPLETE command after successful sequence")
626
+ complete_response = await self.complete_run(
627
+ machine_id=machine_id,
628
+ run_id=run_id,
629
+ user_id=user_id,
630
+ username=username,
631
+ timeout=timeout
632
+ )
633
+ if complete_response is None:
634
+ logger.error("COMPLETE command timed out")
635
+ return None
636
+ if complete_response.response and complete_response.response.status == CommandResponseStatus.ERROR:
637
+ logger.error("COMPLETE command failed: %s", complete_response.response.message)
638
+ return complete_response
639
+ # Return the last command response, not the COMPLETE response
640
+ return last_response
641
+ except Exception as e:
642
+ # If any error occurs during command execution, try to complete the run
643
+ # to clean up state (but don't fail if this also fails)
644
+ logger.warning("Error during command sequence, attempting to complete run: %s", e)
645
+ try:
646
+ await self.complete_run(
647
+ machine_id=machine_id,
648
+ run_id=run_id,
649
+ user_id=user_id,
650
+ username=username,
651
+ timeout=timeout
516
652
  )
517
- return response
518
-
519
- logger.info(
520
- "All %d commands completed successfully",
521
- len(requests)
522
- )
523
- return last_response
653
+ except Exception as cleanup_error:
654
+ logger.error("Failed to complete run during error cleanup: %s", cleanup_error)
655
+ raise
524
656
 
525
657
  async def send_immediate_command(
526
658
  self,
@@ -651,13 +783,16 @@ class CommandService:
651
783
  Args:
652
784
  command_request: CommandRequest model containing command details
653
785
  machine_id: Machine ID for the command
654
- run_id: Run ID for the command
786
+ run_id: Run ID for the command (empty string will be converted to None)
655
787
  user_id: User ID who initiated the command
656
788
  username: Username who initiated the command
657
789
 
658
790
  Returns:
659
791
  NATSMessage object ready for NATS transmission
660
792
  """
793
+ # Convert empty string to None for run_id
794
+ run_id_value = run_id if run_id else None
795
+
661
796
  header = MessageHeader(
662
797
  message_type=MessageType.COMMAND,
663
798
  version="1.0",
@@ -665,7 +800,7 @@ class CommandService:
665
800
  user_id=user_id,
666
801
  username=username,
667
802
  machine_id=machine_id,
668
- run_id=run_id
803
+ run_id=run_id_value
669
804
  )
670
805
 
671
806
  return NATSMessage(
@@ -19,6 +19,7 @@ from puda_comms.models import (
19
19
  MessageType,
20
20
  ImmediateCommand,
21
21
  )
22
+ from puda_comms.run_manager import RunManager
22
23
  from nats.js.client import JetStreamContext
23
24
  from nats.js.api import StreamConfig, ConsumerConfig
24
25
  from nats.js.errors import NotFoundError
@@ -80,7 +81,9 @@ class MachineClient:
80
81
  # Queue control state
81
82
  self._pause_lock = asyncio.Lock()
82
83
  self._is_paused = False
83
- self._cancelled_run_ids = set()
84
+
85
+ # Run state management
86
+ self.run_manager = RunManager(machine_id=machine_id)
84
87
 
85
88
  def _init_subjects(self):
86
89
  """Initialize all subject and stream names."""
@@ -423,7 +426,7 @@ class MachineClient:
423
426
  logger.error("Error publishing command response: %s", e)
424
427
 
425
428
  async def process_queue_cmd(
426
- self,
429
+ self,
427
430
  msg: Msg,
428
431
  handler: Callable[[NATSMessage], Awaitable[CommandResponse]]
429
432
  ) -> None:
@@ -432,32 +435,26 @@ class MachineClient:
432
435
 
433
436
  Args:
434
437
  msg: NATS message
435
- handler: Handler function that processes the message and returns CommandResponse
438
+ handler: Handler function that processes the message and returns a CommandResponse object
436
439
  """
440
+ # Initialize variables for exception handlers
441
+ run_id = None
442
+ step_number = None
443
+ command = None
444
+
437
445
  try:
438
446
  # Parse message
439
447
  message = NATSMessage.model_validate_json(msg.data)
440
448
  run_id = message.header.run_id
441
- step_number = message.command.step_number
442
- command = message.command.name
449
+ step_number = message.command.step_number if message.command else None
450
+ command = message.command.name if message.command else None
443
451
 
444
- # Check if cancelled
445
- if run_id and run_id in self._cancelled_run_ids:
446
- logger.info("Skipping cancelled command: run_id=%s, step_number=%s, command=%s", run_id, step_number, command)
447
- await msg.ack()
448
- await self._publish_command_response(
449
- msg=msg,
450
- response=CommandResponse(
451
- status=CommandResponseStatus.ERROR,
452
- code=CommandResponseCode.COMMAND_CANCELLED,
453
- message='Command cancelled'
454
- ),
455
- subject=self.response_queue
456
- )
457
- # Note: Final state update should be published by the handler with machine-specific data
458
- return
452
+ # For all commands, continue with normal processing:
453
+ # 1. Check if paused
454
+ # 2. Validate run_id matches active run
455
+ # 3. Execute handler
459
456
 
460
- # Check if paused (for queue messages)
457
+ # If machine is paused, publish error response and return
461
458
  async with self._pause_lock:
462
459
  if self._is_paused:
463
460
  await self._publish_command_response(
@@ -470,24 +467,42 @@ class MachineClient:
470
467
  subject=self.response_queue
471
468
  )
472
469
  return
473
- while self._is_paused:
474
- await msg.in_progress()
475
- await asyncio.sleep(1)
476
- # Re-check cancelled state in case it was cancelled while paused
477
- if run_id and run_id in self._cancelled_run_ids:
478
- logger.info("Command cancelled while paused: run_id=%s, step_number=%s, command=%s", run_id, step_number, command)
479
- await msg.ack()
480
- await self._publish_command_response(
481
- msg=msg,
482
- response=CommandResponse(
483
- status=CommandResponseStatus.ERROR,
484
- code=CommandResponseCode.COMMAND_CANCELLED,
485
- message='Command cancelled'
486
- ),
487
- subject=self.response_queue
488
- )
489
- # Note: Final state update should be published by the handler with machine-specific data
490
- return
470
+
471
+ # Wait while paused (release lock during wait so RESUME can acquire it)
472
+ while True:
473
+ async with self._pause_lock:
474
+ if not self._is_paused:
475
+ break
476
+ # Release lock before sleeping so RESUME can set _is_paused = False
477
+ await msg.in_progress()
478
+ await asyncio.sleep(1)
479
+
480
+ # Validate run_id matches active run (run_id is required)
481
+ if run_id is None:
482
+ await msg.ack()
483
+ await self._publish_command_response(
484
+ msg=msg,
485
+ response=CommandResponse(
486
+ status=CommandResponseStatus.ERROR,
487
+ code=CommandResponseCode.EXECUTION_ERROR,
488
+ message='Command requires run_id'
489
+ ),
490
+ subject=self.response_queue
491
+ )
492
+ return
493
+
494
+ if not await self.run_manager.validate_run_id(run_id):
495
+ await msg.ack()
496
+ await self._publish_command_response(
497
+ msg=msg,
498
+ response=CommandResponse(
499
+ status=CommandResponseStatus.ERROR,
500
+ code=CommandResponseCode.RUN_ID_MISMATCH,
501
+ message=f'Run ID mismatch: expected active run, got {run_id}'
502
+ ),
503
+ subject=self.response_queue
504
+ )
505
+ return
491
506
 
492
507
  # Execute handler with auto-heartbeat (task might take a while for machine to complete)
493
508
  # The handler should be defined in the machine-specific edge module.
@@ -539,34 +554,19 @@ class MachineClient:
539
554
  # This is a rare case - consider if handler should be called with None payload
540
555
 
541
556
  except Exception as e:
542
- # Check if cancelled before sending error response
543
- if run_id and run_id in self._cancelled_run_ids:
544
- logger.info("Command cancelled during execution (exception occurred): run_id=%s, step_number=%s, command=%s", run_id, step_number, command)
545
- await msg.ack()
546
- await self._publish_command_response(
547
- msg=msg,
548
- response=CommandResponse(
549
- status=CommandResponseStatus.ERROR,
550
- code=CommandResponseCode.COMMAND_CANCELLED,
551
- message='Command cancelled'
552
- ),
553
- subject=self.response_queue
554
- )
555
- # Note: Final state update should be published by the handler with machine-specific data
556
- else:
557
- # Terminate all errors to prevent infinite redelivery loops
558
- logger.error("Handler failed (terminating message): %s", e)
559
- await msg.term()
560
- await self._publish_command_response(
561
- msg=msg,
562
- response=CommandResponse(
563
- status=CommandResponseStatus.ERROR,
564
- code=CommandResponseCode.EXECUTION_ERROR,
565
- message=str(e)
566
- ),
567
- subject=self.response_queue
568
- )
569
- # Note: Final state update should be published by the handler with machine-specific data
557
+ # Terminate all errors to prevent infinite redelivery loops
558
+ logger.error("Handler failed (terminating message): %s", e)
559
+ await msg.term()
560
+ await self._publish_command_response(
561
+ msg=msg,
562
+ response=CommandResponse(
563
+ status=CommandResponseStatus.ERROR,
564
+ code=CommandResponseCode.EXECUTION_ERROR,
565
+ message=str(e)
566
+ ),
567
+ subject=self.response_queue
568
+ )
569
+ # Note: Final state update should be published by the handler with machine-specific data
570
570
 
571
571
  async def process_immediate_cmd(self, msg: Msg, handler: Callable[[CommandRequest], Awaitable[CommandResponse]]) -> None:
572
572
  """Process immediate commands (pause, cancel, resume, etc.)."""
@@ -581,8 +581,49 @@ class MachineClient:
581
581
  return
582
582
 
583
583
  command_name = message.command.name.lower()
584
+ run_id = message.header.run_id
585
+ response: CommandResponse
584
586
 
585
587
  match command_name:
588
+ case ImmediateCommand.START:
589
+ if run_id:
590
+ success = await self.run_manager.start_run(run_id)
591
+ if not success:
592
+ # Run already active
593
+ response = CommandResponse(
594
+ status=CommandResponseStatus.ERROR,
595
+ code=CommandResponseCode.RUN_ID_MISMATCH,
596
+ message='cannot start, another run is currently running'
597
+ )
598
+ else:
599
+ await self.publish_state({'state': 'active', 'run_id': run_id})
600
+ response = CommandResponse(status=CommandResponseStatus.SUCCESS)
601
+ else:
602
+ response = CommandResponse(
603
+ status=CommandResponseStatus.ERROR,
604
+ code=CommandResponseCode.MISSING_RUN_ID,
605
+ message='START command requires RUN_ID'
606
+ )
607
+
608
+ case ImmediateCommand.COMPLETE:
609
+ if not run_id:
610
+ response = CommandResponse(
611
+ status=CommandResponseStatus.ERROR,
612
+ code=CommandResponseCode.MISSING_RUN_ID,
613
+ message='COMPLETE command requires RUN_ID'
614
+ )
615
+ else:
616
+ success = await self.run_manager.complete_run(run_id)
617
+ if success:
618
+ await self.publish_state({'state': 'idle', 'run_id': None})
619
+ response = CommandResponse(status=CommandResponseStatus.SUCCESS)
620
+ else:
621
+ response = CommandResponse(
622
+ status=CommandResponseStatus.ERROR,
623
+ code=CommandResponseCode.RUN_ID_MISMATCH,
624
+ message=f'Run {run_id} not active'
625
+ )
626
+
586
627
  case ImmediateCommand.PAUSE:
587
628
  async with self._pause_lock:
588
629
  if not self._is_paused:
@@ -590,7 +631,7 @@ class MachineClient:
590
631
  logger.info("Queue paused")
591
632
  await self.publish_state({'state': 'paused', 'run_id': message.header.run_id})
592
633
  # Call handler and use its response
593
- response: CommandResponse = await handler(message)
634
+ response = await handler(message)
594
635
 
595
636
  case ImmediateCommand.RESUME:
596
637
  async with self._pause_lock:
@@ -599,19 +640,30 @@ class MachineClient:
599
640
  logger.info("Queue resumed")
600
641
  await self.publish_state({'state': 'idle', 'run_id': None})
601
642
  # Call handler and use its response
602
- response: CommandResponse = await handler(message)
643
+ response = await handler(message)
603
644
 
604
645
  case ImmediateCommand.CANCEL:
605
- if message.header.run_id:
606
- self._cancelled_run_ids.add(message.header.run_id)
607
- logger.info("Cancelling all commands with run_id: %s", message.header.run_id)
646
+ if not run_id:
647
+ response = CommandResponse(
648
+ status=CommandResponseStatus.ERROR,
649
+ code=CommandResponseCode.MISSING_RUN_ID,
650
+ message='CANCEL command requires RUN_ID'
651
+ )
652
+ else:
653
+ logger.info("Cancelling all commands with run_id: %s", run_id)
654
+ # Clear the active run_id when cancelling (try to complete, but clear anyway)
655
+ await self.run_manager.complete_run(run_id)
608
656
  await self.publish_state({'state': 'idle', 'run_id': None})
609
- # Call handler and use its response
610
- response: CommandResponse = await handler(message)
657
+ # Call handler and use its response
658
+ response = await handler(message)
611
659
 
612
660
  case _:
613
- # For other immediate commands, call the user-provided handler
614
- response: CommandResponse = await handler(message)
661
+ # Unknown immediate command
662
+ response = CommandResponse(
663
+ status=CommandResponseStatus.ERROR,
664
+ code=CommandResponseCode.UNKNOWN_COMMAND,
665
+ message=f'Unknown immediate command: {command_name}'
666
+ )
615
667
 
616
668
  await self._publish_command_response(
617
669
  msg=msg,
@@ -702,6 +754,9 @@ class MachineClient:
702
754
  if not self.js:
703
755
  logger.error("JetStream not available for queue subscription")
704
756
  return
757
+
758
+ # Store handler for reconnection
759
+ self._queue_handler = handler
705
760
 
706
761
  # Ensure stream exists before attempting to subscribe
707
762
  await self._ensure_all_streams()
@@ -744,12 +799,11 @@ class MachineClient:
744
799
  try:
745
800
  while True:
746
801
  try:
747
- # Fetch messages (batch of 1, timeout 1 second)
802
+ # Fetch one message (timeout 1 second)
748
803
  msgs = await self._cmd_queue_sub.fetch(batch=1, timeout=1.0)
749
804
  if msgs:
750
- logger.debug("Pulled %d message(s) from queue", len(msgs))
751
- for msg in msgs:
752
- await self.process_queue_cmd(msg, handler)
805
+ logger.debug("Pulled message from queue")
806
+ await self.process_queue_cmd(msgs[0], handler)
753
807
  except asyncio.TimeoutError:
754
808
  # Timeout is expected when no messages are available
755
809
  continue
@@ -780,8 +834,6 @@ class MachineClient:
780
834
  logger.error(" Stream verification failed: %s", stream_check_error)
781
835
  raise
782
836
 
783
- # Store handler for reconnection
784
- self._queue_handler = handler
785
837
  logger.info("Subscribed to queue commands: %s (durable: cmd_queue_%s, stream: %s, pull consumer)",
786
838
  self.cmd_queue, self.machine_id, self.STREAM_COMMAND_QUEUE)
787
839
 
puda_comms/models.py CHANGED
@@ -25,6 +25,7 @@ class CommandResponseCode(str, Enum):
25
25
  RESUME_ERROR = 'RESUME_ERROR'
26
26
  NO_EXECUTION = 'NO_EXECUTION'
27
27
  RUN_ID_MISMATCH = 'RUN_ID_MISMATCH'
28
+ MISSING_RUN_ID = 'MISSING_RUN_ID'
28
29
  CANCEL_ERROR = 'CANCEL_ERROR'
29
30
  MACHINE_PAUSED = 'MACHINE_PAUSED'
30
31
 
@@ -40,6 +41,8 @@ class MessageType(str, Enum):
40
41
 
41
42
  class ImmediateCommand(str, Enum):
42
43
  """Command names for immediate commands."""
44
+ START = 'start'
45
+ COMPLETE = 'complete'
43
46
  PAUSE = 'pause'
44
47
  RESUME = 'resume'
45
48
  CANCEL = 'cancel'
@@ -0,0 +1,112 @@
1
+ """
2
+ Run State Management
3
+ Provides thread-safe run state tracking and validation for machine commands.
4
+ """
5
+ import asyncio
6
+ import logging
7
+ from typing import Optional
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class RunManager:
13
+ """
14
+ Manages run state for a machine.
15
+
16
+ Tracks the active run_id and validates that commands match the active run.
17
+ Provides thread-safe operations for run lifecycle management.
18
+ """
19
+
20
+ def __init__(self, machine_id: str):
21
+ """
22
+ Initialize RunManager for a machine.
23
+
24
+ Args:
25
+ machine_id: Machine identifier
26
+ """
27
+ self.machine_id = machine_id
28
+ self._active_run_id: Optional[str] = None
29
+ self._lock = asyncio.Lock()
30
+
31
+ async def start_run(self, run_id: str) -> bool:
32
+ """
33
+ Set active run_id. Returns True if successful, False if run already active.
34
+
35
+ Args:
36
+ run_id: Run ID to set as active
37
+
38
+ Returns:
39
+ True if run was started successfully, False if another run is already active
40
+ """
41
+ async with self._lock:
42
+ if self._active_run_id is not None:
43
+ logger.warning(
44
+ "Cannot start run %s: run %s is already active on machine %s",
45
+ run_id, self._active_run_id, self.machine_id
46
+ )
47
+ return False
48
+
49
+ self._active_run_id = run_id
50
+ logger.info("Started run %s on machine %s", run_id, self.machine_id)
51
+ return True
52
+
53
+ async def complete_run(self, run_id: str) -> bool:
54
+ """
55
+ Clear run_id if it matches. Returns True if successful.
56
+
57
+ Args:
58
+ run_id: Run ID to complete
59
+
60
+ Returns:
61
+ True if run was completed successfully, False if run_id doesn't match active run
62
+ """
63
+ async with self._lock:
64
+ if self._active_run_id != run_id:
65
+ logger.warning(
66
+ "Cannot complete run %s: active run is %s on machine %s",
67
+ run_id, self._active_run_id, self.machine_id
68
+ )
69
+ return False
70
+
71
+ self._active_run_id = None
72
+ logger.info("Completed run %s on machine %s", run_id, self.machine_id)
73
+ return True
74
+
75
+ async def validate_run_id(self, run_id: str) -> bool:
76
+ """
77
+ Check if run_id matches active run. Returns True if valid.
78
+
79
+ Args:
80
+ run_id: Run ID to validate (required)
81
+
82
+ Returns:
83
+ True if run_id matches active run, False otherwise
84
+ """
85
+ async with self._lock:
86
+ # If no active run, any run_id is invalid
87
+ if self._active_run_id is None:
88
+ logger.warning(
89
+ "Run ID validation failed: no active run, got %s on machine %s",
90
+ run_id, self.machine_id
91
+ )
92
+ return False
93
+
94
+ # Run_id must match active run
95
+ if self._active_run_id != run_id:
96
+ logger.warning(
97
+ "Run ID validation failed: expected %s, got %s on machine %s",
98
+ self._active_run_id, run_id, self.machine_id
99
+ )
100
+ return False
101
+
102
+ return True
103
+
104
+ def get_active_run_id(self) -> Optional[str]:
105
+ """
106
+ Get current active run_id.
107
+
108
+ Returns:
109
+ Active run_id if one exists, None otherwise
110
+ """
111
+ return self._active_run_id
112
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: puda-comms
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Communication library for the PUDA platform.
5
5
  Author: zhao
6
6
  Author-email: zhao <20024592+agentzhao@users.noreply.github.com>
@@ -0,0 +1,9 @@
1
+ puda_comms/__init__.py,sha256=lntvVFJJez_rv5lZy5mYj4_43B9Y3NRNzxWfBuSAQ1M,194
2
+ puda_comms/command_service.py,sha256=Lxk-CUan_DwftBZlSYO3VnddxaM9fYKxxhWF8VCqABY,30423
3
+ puda_comms/execution_state.py,sha256=aTaejCnJgg1y_FP-ymIC1GQzqC81FIWo0RZ18XzAQnA,2881
4
+ puda_comms/machine_client.py,sha256=OnA8we1c62n1aEFr0NfiapklHWXR-WFzq5FXQrvuUM8,39378
5
+ puda_comms/models.py,sha256=CfXq_Wxqk5OQo5VknXR-BdLIT2SM69s8cGxGYr9T8WI,3701
6
+ puda_comms/run_manager.py,sha256=_s4VYVGwtRMcduz95_DPIObso4uWRS24n5NH7AiGgjI,3591
7
+ puda_comms-0.0.5.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
8
+ puda_comms-0.0.5.dist-info/METADATA,sha256=REBvcpJsUCxiFCKihVVReP0lh6IkJcBl4I8XohjhSHE,11512
9
+ puda_comms-0.0.5.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- puda_comms/__init__.py,sha256=lntvVFJJez_rv5lZy5mYj4_43B9Y3NRNzxWfBuSAQ1M,194
2
- puda_comms/command_service.py,sha256=KFremcEGfsTeUVQMIhyk1knYmUCvRYQ12vS_jy_14wA,25193
3
- puda_comms/execution_state.py,sha256=aTaejCnJgg1y_FP-ymIC1GQzqC81FIWo0RZ18XzAQnA,2881
4
- puda_comms/machine_client.py,sha256=wj6t_QHGs7l1Oc8JQ6hq2hqBd5C14TCPA_dTU9qOLzw,37430
5
- puda_comms/models.py,sha256=9ZGX0PR7SgMBOL5zVLrPuSUhZqutQU96PubyjyQLhf8,3617
6
- puda_comms-0.0.4.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
7
- puda_comms-0.0.4.dist-info/METADATA,sha256=0cMHDub_3NZt7Cj5U1jzrQXI8atQqpMM-i3vSMrT5lo,11512
8
- puda_comms-0.0.4.dist-info/RECORD,,