puda-comms 0.0.5__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,10 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: puda-comms
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Communication library for the PUDA platform.
5
5
  Author: zhao
6
6
  Author-email: zhao <20024592+agentzhao@users.noreply.github.com>
7
7
  Requires-Dist: nats-py>=2.12.0
8
- Requires-Dist: puda-drivers
9
8
  Requires-Dist: pydantic>=2.12.5
10
9
  Requires-Python: >=3.14
11
10
  Description-Content-Type: text/markdown
@@ -73,6 +72,7 @@ Represents a command to be sent to a machine.
73
72
 
74
73
  **Fields:**
75
74
  - `name` (str): The command name to execute
75
+ - `machine_id` (str): Machine ID to send the command to (required)
76
76
  - `params` (Dict[str, Any]): Command parameters (default: empty dict)
77
77
  - `step_number` (int): Execution step number for tracking progress
78
78
  - `version` (str): Command version (default: "1.0")
@@ -81,7 +81,8 @@ Represents a command to be sent to a machine.
81
81
  ```python
82
82
  command = CommandRequest(
83
83
  name="attach_tip",
84
- params={"slot": "A3", "well": "G8"},
84
+ machine_id="first",
85
+ params={"deck_slot": "A3", "well_name": "G8"},
85
86
  step_number=2,
86
87
  version="1.0"
87
88
  )
@@ -109,7 +110,7 @@ response = CommandResponse(
109
110
  error_response = CommandResponse(
110
111
  status=CommandResponseStatus.ERROR,
111
112
  code="EXECUTION_ERROR",
112
- message="Failed to attach tip: slot A3 not found",
113
+ message="Failed to attach tip: deck_slot A3 not found",
113
114
  completed_at="2026-01-20T02:00:46Z"
114
115
  )
115
116
  ```
@@ -166,8 +167,8 @@ Complete NATS message structure combining header with optional command or respon
166
167
  "command": {
167
168
  "name": "attach_tip",
168
169
  "params": {
169
- "slot": "A3",
170
- "well": "G8"
170
+ "deck_slot": "A3",
171
+ "well_name": "G8"
171
172
  },
172
173
  "step_number": 2,
173
174
  "version": "1.0"
@@ -230,10 +231,9 @@ Queue commands are regular commands that are executed in sequence. Use `send_que
230
231
  Both `send_queue_command()`, `send_queue_commands()`, and `send_immediate_command()` accept an optional `timeout` parameter (default: 120 seconds):
231
232
 
232
233
  ```python
233
- # Single command
234
+ # Single command (machine_id must be in CommandRequest)
234
235
  reply = await service.send_queue_command(
235
- request=request,
236
- machine_id="first",
236
+ request=request, # request.machine_id must be set
237
237
  run_id=run_id,
238
238
  user_id="user123",
239
239
  username="John Doe",
@@ -241,9 +241,9 @@ reply = await service.send_queue_command(
241
241
  )
242
242
 
243
243
  # Multiple commands (timeout applies to each command)
244
+ # Each command in the list must have machine_id set
244
245
  reply = await service.send_queue_commands(
245
- requests=commands,
246
- machine_id="first",
246
+ requests=commands, # Each CommandRequest must have machine_id
247
247
  run_id=run_id,
248
248
  user_id="user123",
249
249
  username="John Doe",
@@ -282,8 +282,7 @@ Always check the response status and handle errors appropriately:
282
282
 
283
283
  ```python
284
284
  reply: NATSMessage = await service.send_queue_command(
285
- request=request,
286
- machine_id="first",
285
+ request=request, # request.machine_id must be set
287
286
  run_id=run_id,
288
287
  user_id="user123",
289
288
  username="John Doe"
@@ -61,6 +61,7 @@ Represents a command to be sent to a machine.
61
61
 
62
62
  **Fields:**
63
63
  - `name` (str): The command name to execute
64
+ - `machine_id` (str): Machine ID to send the command to (required)
64
65
  - `params` (Dict[str, Any]): Command parameters (default: empty dict)
65
66
  - `step_number` (int): Execution step number for tracking progress
66
67
  - `version` (str): Command version (default: "1.0")
@@ -69,7 +70,8 @@ Represents a command to be sent to a machine.
69
70
  ```python
70
71
  command = CommandRequest(
71
72
  name="attach_tip",
72
- params={"slot": "A3", "well": "G8"},
73
+ machine_id="first",
74
+ params={"deck_slot": "A3", "well_name": "G8"},
73
75
  step_number=2,
74
76
  version="1.0"
75
77
  )
@@ -97,7 +99,7 @@ response = CommandResponse(
97
99
  error_response = CommandResponse(
98
100
  status=CommandResponseStatus.ERROR,
99
101
  code="EXECUTION_ERROR",
100
- message="Failed to attach tip: slot A3 not found",
102
+ message="Failed to attach tip: deck_slot A3 not found",
101
103
  completed_at="2026-01-20T02:00:46Z"
102
104
  )
103
105
  ```
@@ -154,8 +156,8 @@ Complete NATS message structure combining header with optional command or respon
154
156
  "command": {
155
157
  "name": "attach_tip",
156
158
  "params": {
157
- "slot": "A3",
158
- "well": "G8"
159
+ "deck_slot": "A3",
160
+ "well_name": "G8"
159
161
  },
160
162
  "step_number": 2,
161
163
  "version": "1.0"
@@ -218,10 +220,9 @@ Queue commands are regular commands that are executed in sequence. Use `send_que
218
220
  Both `send_queue_command()`, `send_queue_commands()`, and `send_immediate_command()` accept an optional `timeout` parameter (default: 120 seconds):
219
221
 
220
222
  ```python
221
- # Single command
223
+ # Single command (machine_id must be in CommandRequest)
222
224
  reply = await service.send_queue_command(
223
- request=request,
224
- machine_id="first",
225
+ request=request, # request.machine_id must be set
225
226
  run_id=run_id,
226
227
  user_id="user123",
227
228
  username="John Doe",
@@ -229,9 +230,9 @@ reply = await service.send_queue_command(
229
230
  )
230
231
 
231
232
  # Multiple commands (timeout applies to each command)
233
+ # Each command in the list must have machine_id set
232
234
  reply = await service.send_queue_commands(
233
- requests=commands,
234
- machine_id="first",
235
+ requests=commands, # Each CommandRequest must have machine_id
235
236
  run_id=run_id,
236
237
  user_id="user123",
237
238
  username="John Doe",
@@ -270,8 +271,7 @@ Always check the response status and handle errors appropriately:
270
271
 
271
272
  ```python
272
273
  reply: NATSMessage = await service.send_queue_command(
273
- request=request,
274
- machine_id="first",
274
+ request=request, # request.machine_id must be set
275
275
  run_id=run_id,
276
276
  user_id="user123",
277
277
  username="John Doe"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "puda-comms"
3
- version = "0.0.5"
3
+ version = "0.0.6"
4
4
  description = "Communication library for the PUDA platform."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -9,14 +9,9 @@ authors = [
9
9
  requires-python = ">=3.14"
10
10
  dependencies = [
11
11
  "nats-py>=2.12.0",
12
- # "puda-drivers>=0.0.16",
13
- "puda-drivers",
14
12
  "pydantic>=2.12.5",
15
13
  ]
16
14
 
17
- [tool.uv.sources]
18
- puda-drivers = {workspace = true}
19
-
20
15
  [tool.ruff]
21
16
  line-length = 100
22
17
 
@@ -0,0 +1,9 @@
1
+ # Import models first to ensure they're initialized before other modules that depend on them
2
+ from . import models
3
+
4
+ from .machine_client import MachineClient
5
+ from .execution_state import ExecutionState
6
+ from .command_service import CommandService
7
+ from .stream_subscriber import StreamSubscriber
8
+
9
+ __all__ = ["MachineClient", "ExecutionState", "CommandService", "StreamSubscriber", "models"]
@@ -16,7 +16,7 @@ from typing import Dict, Any, Optional
16
16
  import nats
17
17
  from nats.js.client import JetStreamContext
18
18
  from nats.aio.msg import Msg
19
- from puda_comms.models import (
19
+ from .models import (
20
20
  CommandRequest,
21
21
  CommandResponseStatus,
22
22
  NATSMessage,
@@ -104,7 +104,7 @@ class ResponseHandler:
104
104
  command, step_number, run_id, message.response.status
105
105
  )
106
106
  if message.response.status == CommandResponseStatus.ERROR:
107
- logger.warning("Command failed: %s", message.response.message)
107
+ logger.error("Error Code: %s, Message: %s", message.response.code.name, message.response.message)
108
108
 
109
109
  # Get the pending response
110
110
  pending = self._pending_responses[key]
@@ -347,7 +347,6 @@ class CommandService:
347
347
  self,
348
348
  *,
349
349
  request: CommandRequest,
350
- machine_id: str,
351
350
  run_id: str,
352
351
  user_id: str,
353
352
  username: str,
@@ -357,8 +356,7 @@ class CommandService:
357
356
  Send a queue command to the machine and wait for response.
358
357
 
359
358
  Args:
360
- request: CommandRequest model containing command details
361
- machine_id: Machine ID to send the command to
359
+ request: CommandRequest model containing command details (must include machine_id)
362
360
  run_id: Run ID for the command
363
361
  user_id: User ID who initiated the command
364
362
  username: Username who initiated the command
@@ -370,8 +368,8 @@ class CommandService:
370
368
  if not self._connected or not self.js:
371
369
  raise RuntimeError("Not connected to NATS. Call connect() first.")
372
370
 
373
- # Determine subject
374
- subject = f"{NAMESPACE}.{machine_id}.cmd.queue"
371
+ # Determine subject using machine_id from request
372
+ subject = f"{NAMESPACE}.{request.machine_id}.cmd.queue"
375
373
 
376
374
  logger.info(
377
375
  "Sending queue command: subject=%s, command=%s, run_id=%s, step_number=%s",
@@ -379,12 +377,12 @@ class CommandService:
379
377
  )
380
378
 
381
379
  # Get or create response handler for this machine
382
- response_handler = await self._get_response_handler(machine_id)
380
+ response_handler = await self._get_response_handler(request.machine_id)
383
381
  # Register pending response
384
382
  response_event = response_handler.register_pending(run_id, request.step_number)
385
383
 
386
384
  # Build payload
387
- payload = self._build_command_payload(request, machine_id, run_id, user_id, username)
385
+ payload = self._build_command_payload(request, request.machine_id, run_id, user_id, username)
388
386
 
389
387
  try:
390
388
  # Publish to JetStream
@@ -437,12 +435,12 @@ class CommandService:
437
435
  """
438
436
  request = CommandRequest(
439
437
  name="start",
438
+ machine_id=machine_id,
440
439
  params={},
441
440
  step_number=0
442
441
  )
443
442
  return await self.send_immediate_command(
444
443
  request=request,
445
- machine_id=machine_id,
446
444
  run_id=run_id,
447
445
  user_id=user_id,
448
446
  username=username,
@@ -472,12 +470,12 @@ class CommandService:
472
470
  """
473
471
  request = CommandRequest(
474
472
  name="complete",
473
+ machine_id=machine_id,
475
474
  params={},
476
475
  step_number=0
477
476
  )
478
477
  return await self.send_immediate_command(
479
478
  request=request,
480
- machine_id=machine_id,
481
479
  run_id=run_id,
482
480
  user_id=user_id,
483
481
  username=username,
@@ -488,7 +486,6 @@ class CommandService:
488
486
  self,
489
487
  *,
490
488
  requests: list[CommandRequest],
491
- machine_id: str,
492
489
  run_id: str,
493
490
  user_id: str,
494
491
  username: str,
@@ -497,14 +494,18 @@ class CommandService:
497
494
  """
498
495
  Send multiple queue commands sequentially and wait for responses.
499
496
 
500
- Automatically sends START command before the sequence and COMPLETE command after
501
- successful completion. Sends commands one by one, waiting for each response before
502
- sending the next. If any command fails or times out, stops immediately and returns
503
- the error response. If all commands succeed, returns the last command's response.
497
+ Automatically sends START commands to all unique machine_ids before the sequence
498
+ and COMPLETE commands to all unique machine_ids after successful completion.
499
+ Sends commands one by one, waiting for each response before sending the next.
500
+ If any command fails or times out, stops immediately, completes all started runs,
501
+ and returns the error response. If all commands succeed, returns the last command's response.
502
+
503
+ Each command must specify its own machine_id. Commands with different machine_ids
504
+ will be sent to their respective machines. All machines involved will receive
505
+ START commands at the beginning and COMPLETE commands at the end.
504
506
 
505
507
  Args:
506
- requests: List of CommandRequest models to send sequentially
507
- machine_id: Machine ID to send the commands to
508
+ requests: List of CommandRequest models to send sequentially (each must include machine_id)
508
509
  run_id: Run ID for all commands
509
510
  user_id: User ID who initiated the commands
510
511
  username: Username who initiated the commands
@@ -521,28 +522,42 @@ class CommandService:
521
522
  logger.warning("No commands to send")
522
523
  return None
523
524
 
525
+ # Collect all unique machine_ids from requests
526
+ machine_ids = set()
527
+ for req in requests:
528
+ if isinstance(req, dict):
529
+ req = CommandRequest.model_validate(req)
530
+ elif not isinstance(req, CommandRequest):
531
+ raise ValueError(f"Request must be a CommandRequest or dict, got {type(req)}")
532
+ machine_ids.add(req.machine_id)
533
+
534
+ machine_ids_list = sorted(list(machine_ids)) # Sort for consistent logging
535
+
524
536
  logger.info(
525
- "Sending %d queue commands sequentially: machine_id=%s, run_id=%s",
537
+ "Sending %d queue commands sequentially to machines: %s, run_id=%s",
526
538
  len(requests),
527
- machine_id,
539
+ machine_ids_list,
528
540
  run_id
529
541
  )
530
542
 
531
- # Always send START command before sequence
532
- logger.info("Sending START command before sequence")
533
- start_response = await self.start_run(
534
- machine_id=machine_id,
535
- run_id=run_id,
536
- user_id=user_id,
537
- username=username,
538
- timeout=timeout
539
- )
540
- if start_response is None:
541
- logger.error("START command timed out")
542
- return None
543
- if start_response.response and start_response.response.status == CommandResponseStatus.ERROR:
544
- logger.error("START command failed: %s", start_response.response.message)
545
- return start_response
543
+ # Send START commands to all unique machine_ids before sequence
544
+ logger.info("Sending START commands to all machines: %s", machine_ids_list)
545
+ started_machines = set()
546
+ for machine_id in machine_ids_list:
547
+ start_response = await self.start_run(
548
+ machine_id=machine_id,
549
+ run_id=run_id,
550
+ user_id=user_id,
551
+ username=username,
552
+ timeout=timeout
553
+ )
554
+ if start_response is None:
555
+ logger.error("START command timed out for machine: %s, aborting", machine_id)
556
+ return None
557
+ if start_response.response and start_response.response.status == CommandResponseStatus.ERROR:
558
+ logger.error("START command failed for machine %s: %s, aborting", machine_id, start_response.response.message)
559
+ return start_response
560
+ started_machines.add(machine_id)
546
561
 
547
562
  last_response: Optional[NATSMessage] = None
548
563
 
@@ -555,16 +570,16 @@ class CommandService:
555
570
  raise ValueError(f"Request {idx} must be a CommandRequest or dict, got {type(request)}")
556
571
 
557
572
  logger.info(
558
- "Sending command %d/%d: %s (step %s)",
573
+ "Sending command %d/%d: %s (step %s) to machine %s",
559
574
  idx,
560
575
  len(requests),
561
576
  request.name,
562
- request.step_number
577
+ request.step_number,
578
+ request.machine_id
563
579
  )
564
580
 
565
581
  response = await self.send_queue_command(
566
582
  request=request,
567
- machine_id=machine_id,
568
583
  run_id=run_id,
569
584
  user_id=user_id,
570
585
  username=username,
@@ -591,9 +606,22 @@ class CommandService:
591
606
  len(requests),
592
607
  request.name,
593
608
  request.step_number,
594
- response.response.code,
609
+ response.response.code.name,
595
610
  response.response.message
596
611
  )
612
+ # Complete the run on all machines that were started
613
+ logger.info("Completing runs on all machines due to error")
614
+ for machine_id_to_complete in started_machines:
615
+ try:
616
+ await self.complete_run(
617
+ machine_id=machine_id_to_complete,
618
+ run_id=run_id,
619
+ user_id=user_id,
620
+ username=username,
621
+ timeout=timeout
622
+ )
623
+ except Exception as e:
624
+ logger.error("Failed to complete run for machine %s during error cleanup: %s", machine_id_to_complete, e)
597
625
  return response
598
626
 
599
627
  # Command succeeded, store as last response
@@ -614,6 +642,19 @@ class CommandService:
614
642
  request.name,
615
643
  request.step_number
616
644
  )
645
+ # Complete the run on all machines that were started
646
+ logger.info("Completing runs on all machines due to error")
647
+ for machine_id_to_complete in started_machines:
648
+ try:
649
+ await self.complete_run(
650
+ machine_id=machine_id_to_complete,
651
+ run_id=run_id,
652
+ user_id=user_id,
653
+ username=username,
654
+ timeout=timeout
655
+ )
656
+ except Exception as e:
657
+ logger.error("Failed to complete run for machine %s during error cleanup: %s", machine_id_to_complete, e)
617
658
  return response
618
659
 
619
660
  logger.info(
@@ -621,44 +662,46 @@ class CommandService:
621
662
  len(requests)
622
663
  )
623
664
 
624
- # Always send COMPLETE command after successful sequence
625
- logger.info("Sending COMPLETE command after successful sequence")
626
- complete_response = await self.complete_run(
627
- machine_id=machine_id,
628
- run_id=run_id,
629
- user_id=user_id,
630
- username=username,
631
- timeout=timeout
632
- )
633
- if complete_response is None:
634
- logger.error("COMPLETE command timed out")
635
- return None
636
- if complete_response.response and complete_response.response.status == CommandResponseStatus.ERROR:
637
- logger.error("COMPLETE command failed: %s", complete_response.response.message)
638
- return complete_response
639
- # Return the last command response, not the COMPLETE response
640
- return last_response
641
- except Exception as e:
642
- # If any error occurs during command execution, try to complete the run
643
- # to clean up state (but don't fail if this also fails)
644
- logger.warning("Error during command sequence, attempting to complete run: %s", e)
645
- try:
646
- await self.complete_run(
647
- machine_id=machine_id,
665
+ # Always send COMPLETE commands to all machines after successful sequence
666
+ logger.info("Sending COMPLETE commands to all machines: %s", machine_ids_list)
667
+ for machine_id_to_complete in machine_ids_list:
668
+ complete_response = await self.complete_run(
669
+ machine_id=machine_id_to_complete,
648
670
  run_id=run_id,
649
671
  user_id=user_id,
650
672
  username=username,
651
673
  timeout=timeout
652
674
  )
653
- except Exception as cleanup_error:
654
- logger.error("Failed to complete run during error cleanup: %s", cleanup_error)
675
+ if complete_response is None:
676
+ logger.error("COMPLETE command timed out for machine: %s, aborting", machine_id_to_complete)
677
+ return None
678
+ if complete_response.response and complete_response.response.status == CommandResponseStatus.ERROR:
679
+ logger.error("COMPLETE command failed for machine %s: %s, aborting", machine_id_to_complete, complete_response.response.message)
680
+ return complete_response
681
+
682
+ # Return the last command response, not the COMPLETE response
683
+ return last_response
684
+ except Exception as e:
685
+ # If any error occurs during command execution, try to complete the run
686
+ # on all machines that were started to clean up state
687
+ logger.warning("Error during command sequence, attempting to complete runs on all machines: %s", e)
688
+ for machine_id_to_complete in started_machines:
689
+ try:
690
+ await self.complete_run(
691
+ machine_id=machine_id_to_complete,
692
+ run_id=run_id,
693
+ user_id=user_id,
694
+ username=username,
695
+ timeout=timeout
696
+ )
697
+ except Exception as cleanup_error:
698
+ logger.error("Failed to complete run for machine %s during error cleanup: %s", machine_id_to_complete, cleanup_error)
655
699
  raise
656
700
 
657
701
  async def send_immediate_command(
658
702
  self,
659
703
  *,
660
704
  request: CommandRequest,
661
- machine_id: str,
662
705
  run_id: str,
663
706
  user_id: str,
664
707
  username: str,
@@ -668,8 +711,7 @@ class CommandService:
668
711
  Send an immediate command (pause, resume, cancel) to the machine.
669
712
 
670
713
  Args:
671
- request: CommandRequest model containing command details
672
- machine_id: Machine ID to send the command to
714
+ request: CommandRequest model containing command details (must include machine_id)
673
715
  run_id: Run ID for the command
674
716
  user_id: User ID who initiated the command
675
717
  username: Username who initiated the command
@@ -681,23 +723,22 @@ class CommandService:
681
723
  if not self._connected or not self.js:
682
724
  raise RuntimeError("Not connected to NATS. Call connect() first.")
683
725
 
684
-
685
- # Determine subject
686
- subject = f"{NAMESPACE}.{machine_id}.cmd.immediate"
726
+ # Determine subject using machine_id from request
727
+ subject = f"{NAMESPACE}.{request.machine_id}.cmd.immediate"
687
728
 
688
729
  logger.info(
689
730
  "Sending immediate command: machine_id=%s, command=%s, run_id=%s, step_number=%s",
690
- machine_id, request.name, run_id, request.step_number
731
+ request.machine_id, request.name, run_id, request.step_number
691
732
  )
692
733
 
693
734
  # Get or create response handler for this machine
694
- response_handler = await self._get_response_handler(machine_id)
735
+ response_handler = await self._get_response_handler(request.machine_id)
695
736
 
696
737
  # Register pending response
697
738
  response_received = response_handler.register_pending(run_id, request.step_number)
698
739
 
699
740
  # Build payload
700
- payload = self._build_command_payload(request, machine_id, run_id, user_id, username)
741
+ payload = self._build_command_payload(request, request.machine_id, run_id, user_id, username)
701
742
 
702
743
  try:
703
744
  # Publish to JetStream
@@ -10,7 +10,7 @@ import logging
10
10
  from typing import Dict, Any, Optional, Callable, Awaitable
11
11
  from datetime import datetime, timezone
12
12
  import nats
13
- from puda_comms.models import (
13
+ from .models import (
14
14
  CommandResponseStatus,
15
15
  CommandResponse,
16
16
  CommandResponseCode,
@@ -19,10 +19,10 @@ from puda_comms.models import (
19
19
  MessageType,
20
20
  ImmediateCommand,
21
21
  )
22
- from puda_comms.run_manager import RunManager
22
+ from .run_manager import RunManager
23
23
  from nats.js.client import JetStreamContext
24
24
  from nats.js.api import StreamConfig, ConsumerConfig
25
- from nats.js.errors import NotFoundError
25
+ from nats.js.errors import NotFoundError, Error as NATSError
26
26
  from nats.aio.msg import Msg
27
27
 
28
28
  logger = logging.getLogger(__name__)
@@ -491,6 +491,21 @@ class MachineClient:
491
491
  )
492
492
  return
493
493
 
494
+ # If active run_id is None, return error response
495
+ if self.run_manager.get_active_run_id() is None:
496
+ await msg.ack()
497
+ await self._publish_command_response(
498
+ msg=msg,
499
+ response=CommandResponse(
500
+ status=CommandResponseStatus.ERROR,
501
+ code=CommandResponseCode.RUN_ID_MISMATCH,
502
+ message='Send START command to start a run before sending commands'
503
+ ),
504
+ subject=self.response_queue
505
+ )
506
+ return
507
+
508
+ # If run_id does not match active run_id, return error response
494
509
  if not await self.run_manager.validate_run_id(run_id):
495
510
  await msg.ack()
496
511
  await self._publish_command_response(
@@ -512,7 +527,9 @@ class MachineClient:
512
527
  # Finalize message state based on response
513
528
  if response.status == CommandResponseStatus.SUCCESS:
514
529
  await msg.ack()
515
- else:
530
+ elif response.status == CommandResponseStatus.ERROR:
531
+ # just complete the run if the command failed
532
+ await self.run_manager.complete_run(run_id)
516
533
  await msg.term()
517
534
 
518
535
  await self._publish_command_response(
@@ -526,6 +543,7 @@ class MachineClient:
526
543
  # Handler was cancelled (e.g., via task cancellation)
527
544
  logger.info("Handler execution cancelled: run_id=%s, step_number=%s, command=%s", run_id, step_number, command)
528
545
  await msg.ack()
546
+ await self.run_manager.complete_run(run_id)
529
547
  await self._publish_command_response(
530
548
  msg=msg,
531
549
  response=CommandResponse(
@@ -540,6 +558,7 @@ class MachineClient:
540
558
  except json.JSONDecodeError as e:
541
559
  logger.error("JSON Decode Error. Terminating message.")
542
560
  await msg.term()
561
+ await self.run_manager.complete_run(run_id)
543
562
  await self._publish_command_response(
544
563
  msg=msg,
545
564
  response=CommandResponse(
@@ -557,6 +576,7 @@ class MachineClient:
557
576
  # Terminate all errors to prevent infinite redelivery loops
558
577
  logger.error("Handler failed (terminating message): %s", e)
559
578
  await msg.term()
579
+ await self.run_manager.complete_run(run_id)
560
580
  await self._publish_command_response(
561
581
  msg=msg,
562
582
  response=CommandResponse(
@@ -593,7 +613,7 @@ class MachineClient:
593
613
  response = CommandResponse(
594
614
  status=CommandResponseStatus.ERROR,
595
615
  code=CommandResponseCode.RUN_ID_MISMATCH,
596
- message='cannot start, another run is currently running'
616
+ message=f'cannot start, {self.run_manager.get_active_run_id()} is currently running'
597
617
  )
598
618
  else:
599
619
  await self.publish_state({'state': 'active', 'run_id': run_id})
@@ -862,21 +882,76 @@ class MachineClient:
862
882
  retention='workqueue'
863
883
  )
864
884
 
885
+ durable_name = f"cmd_immed_{self.machine_id}"
886
+
887
+ # Try to unsubscribe from existing subscription if it exists
888
+ if self._cmd_immediate_sub:
889
+ try:
890
+ await self._cmd_immediate_sub.unsubscribe()
891
+ logger.info("Unsubscribed from existing immediate command subscription")
892
+ except Exception as e:
893
+ logger.debug("Error unsubscribing from existing subscription: %s", e)
894
+ self._cmd_immediate_sub = None
895
+
896
+ # Try to delete existing consumer if it's bound (from previous run)
897
+ try:
898
+ await self.js.delete_consumer(self.STREAM_COMMAND_IMMEDIATE, durable_name)
899
+ logger.info("Deleted existing immediate consumer: %s", durable_name)
900
+ except NotFoundError:
901
+ # Consumer doesn't exist, which is fine
902
+ logger.debug("Consumer %s does not exist, will be created", durable_name)
903
+ except Exception as e:
904
+ error_msg = str(e).lower()
905
+ if "bound" in error_msg or "in use" in error_msg:
906
+ # Consumer is bound but we can't delete it - try to unsubscribe first
907
+ logger.warning("Consumer %s is bound to a subscription. Attempting to force delete...", durable_name)
908
+ # Wait a moment for any pending operations to complete
909
+ await asyncio.sleep(0.5)
910
+ try:
911
+ await self.js.delete_consumer(self.STREAM_COMMAND_IMMEDIATE, durable_name)
912
+ logger.info("Successfully deleted bound consumer: %s", durable_name)
913
+ except Exception as delete_error:
914
+ logger.warning("Could not delete bound consumer %s: %s. Will attempt to subscribe anyway.",
915
+ durable_name, delete_error)
916
+ else:
917
+ logger.warning("Error checking/deleting consumer %s: %s", durable_name, e)
918
+
865
919
  try:
866
920
  self._cmd_immediate_sub = await self.js.subscribe(
867
921
  subject=self.cmd_immediate,
868
922
  stream=self.STREAM_COMMAND_IMMEDIATE,
869
- durable=f"cmd_immed_{self.machine_id}",
923
+ durable=durable_name,
870
924
  cb=message_handler # required for push consumer to handle messages
871
925
  )
926
+ except NATSError as e:
927
+ error_msg = str(e).lower()
928
+ if "bound" in error_msg or "already bound" in error_msg:
929
+ # Consumer is still bound - try to delete it and retry
930
+ logger.warning("Consumer %s is still bound. Attempting to delete and retry...", durable_name)
931
+ try:
932
+ await self.js.delete_consumer(self.STREAM_COMMAND_IMMEDIATE, durable_name)
933
+ await asyncio.sleep(0.5) # Brief wait for cleanup
934
+ # Retry subscription
935
+ self._cmd_immediate_sub = await self.js.subscribe(
936
+ subject=self.cmd_immediate,
937
+ stream=self.STREAM_COMMAND_IMMEDIATE,
938
+ durable=durable_name,
939
+ cb=message_handler
940
+ )
941
+ logger.info("Successfully subscribed after deleting bound consumer")
942
+ except Exception as retry_error:
943
+ logger.error("Failed to subscribe after deleting bound consumer: %s", retry_error)
944
+ raise
945
+ else:
946
+ raise
872
947
  except NotFoundError:
873
948
  # Stream still not found after ensuring it exists - this shouldn't happen
874
949
  # but handle it gracefully
875
- logger.error("Stream %s not found even after creation attempt. Check NATS server configuration.",
950
+ logger.error("Stream %s not found even after creation attempt. Check NATS server configuration.",
876
951
  self.STREAM_COMMAND_IMMEDIATE)
877
952
  raise
878
953
 
879
- logger.info("Subscribed to immediate commands: %s (durable: cmd_immed_%s, stream: %s)",
954
+ logger.info("Subscribed to immediate commands: %s (durable: cmd_immed_%s, stream: %s)",
880
955
  self.cmd_immediate, self.machine_id, self.STREAM_COMMAND_IMMEDIATE)
881
956
 
882
957
 
@@ -57,8 +57,10 @@ class CommandRequest(BaseModel):
57
57
  """Command request data for NATS messages."""
58
58
  name: str = Field(description="The command name (string) to send to the machine.")
59
59
  params: Dict[str, Any] = Field(default_factory=dict, description="The parameters to send to the machine.")
60
+ kwargs: Dict[str, Any] = Field(default_factory=dict, description="Additional keyword arguments (e.g., channels in Biologic).")
60
61
  step_number: int = Field(description="Execution step number (integer). Used to track the progress of a command.")
61
62
  version: str = Field(default="1.0", description="Command version.")
63
+ machine_id: str = Field(description="Machine ID to send the command to.")
62
64
 
63
65
 
64
66
  class CommandResponse(BaseModel):
@@ -67,7 +69,7 @@ class CommandResponse(BaseModel):
67
69
  completed_at: str = Field(default_factory=_get_current_timestamp, description="ISO format timestamp (auto-set on creation)")
68
70
  code: Optional[CommandResponseCode] = Field(default=None, description="Error code")
69
71
  message: Optional[str] = Field(default=None, description="Error message (human-readable description)")
70
- data: Optional[Dict[str, Any]] = Field(default=None, description="Optional output data from the command handler")
72
+ data: Optional[Dict[Any, Any]] = Field(default=None, description="Optional output data from the command handler")
71
73
 
72
74
  class MessageHeader(BaseModel):
73
75
  """Header for NATS messages."""
@@ -78,6 +80,7 @@ class MessageHeader(BaseModel):
78
80
  machine_id: str = Field(description="Machine ID")
79
81
  run_id: Optional[str] = Field(default=None, description="Unique identifier (uuid) for the run/workflow")
80
82
  timestamp: str = Field(default_factory=_get_current_timestamp, description="ISO format timestamp (auto-set on creation)")
83
+
81
84
  class NATSMessage(BaseModel):
82
85
  """
83
86
  Complete NATS message structure.
@@ -0,0 +1,388 @@
1
+ """
2
+ Reusable NATS JetStream subscriber for services that need to consume messages.
3
+
4
+ Provides a base class for subscribing to NATS streams with durable consumers,
5
+ automatic reconnection, and message handling callbacks.
6
+
7
+ This implements a push consumer pattern where NATS JetStream automatically
8
+ delivers messages to registered callbacks as they arrive, rather than requiring
9
+ the client to explicitly fetch/pull messages.
10
+ """
11
+ import asyncio
12
+ import logging
13
+ from typing import Optional, Callable, Awaitable, List, Any
14
+ from abc import abstractmethod
15
+ import nats
16
+ from nats.js.client import JetStreamContext
17
+ from nats.aio.msg import Msg
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class StreamSubscriber:
23
+ """
24
+ Base class for subscribing to NATS JetStream streams using push consumer pattern.
25
+
26
+ This class implements a push consumer where NATS JetStream automatically delivers
27
+ messages to registered callbacks as they arrive. The server pushes messages to
28
+ the client rather than requiring the client to pull/fetch them.
29
+
30
+ Handles connection management, durable subscriptions, and message routing.
31
+ Services can extend this class and implement message handling logic.
32
+
33
+ Example:
34
+ ```python
35
+ class MyService(StreamSubscriber):
36
+ async def handle_message(self, msg: Msg, stream: str, subject: str):
37
+ # Process message
38
+ data = json.loads(msg.data.decode())
39
+ # ... your logic ...
40
+ await msg.ack()
41
+
42
+ service = MyService(servers=["nats://localhost:4222"])
43
+ await service.subscribe("STREAM_NAME", "puda.*.cmd.response.queue", "my_consumer")
44
+ await service.run()
45
+ ```
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ servers: List[str],
51
+ connect_timeout: int = 10,
52
+ reconnect_time_wait: int = 2,
53
+ max_reconnect_attempts: int = -1
54
+ ):
55
+ """
56
+ Initialize the stream subscriber.
57
+
58
+ Args:
59
+ servers: List of NATS server URLs (e.g., ["nats://localhost:4222"])
60
+ connect_timeout: Timeout for initial connection in seconds
61
+ reconnect_time_wait: Wait time between reconnection attempts in seconds
62
+ max_reconnect_attempts: Maximum reconnection attempts (-1 for unlimited)
63
+ """
64
+ if not servers:
65
+ raise ValueError("servers must be a non-empty list")
66
+
67
+ self.servers = servers
68
+ self.connect_timeout = connect_timeout
69
+ self.reconnect_time_wait = reconnect_time_wait
70
+ self.max_reconnect_attempts = max_reconnect_attempts
71
+
72
+ self.nc: Optional[nats.NATS] = None
73
+ self.js: Optional[JetStreamContext] = None
74
+ self._subscriptions: List[Any] = []
75
+ self._is_connected = False
76
+ self._should_run = True
77
+
78
+ async def connect(self) -> bool:
79
+ """
80
+ Connect to NATS servers.
81
+
82
+ Returns:
83
+ True if connected successfully, False otherwise
84
+ """
85
+ if self._is_connected:
86
+ return True
87
+
88
+ try:
89
+ self.nc = await nats.connect(
90
+ servers=self.servers,
91
+ connect_timeout=self.connect_timeout,
92
+ reconnect_time_wait=self.reconnect_time_wait,
93
+ max_reconnect_attempts=self.max_reconnect_attempts,
94
+ error_cb=self._error_callback,
95
+ disconnected_cb=self._disconnected_callback,
96
+ reconnected_cb=self._reconnected_callback,
97
+ closed_cb=self._closed_callback
98
+ )
99
+ self.js = self.nc.jetstream()
100
+ self._is_connected = True
101
+ logger.info("Connected to NATS servers: %s", self.servers)
102
+ return True
103
+ except Exception as e:
104
+ logger.error("Failed to connect to NATS: %s", e)
105
+ self._is_connected = False
106
+ return False
107
+
108
+ async def disconnect(self):
109
+ """Disconnect from NATS and cleanup subscriptions."""
110
+ self._should_run = False
111
+
112
+ # Unsubscribe from all streams
113
+ for sub in self._subscriptions:
114
+ try:
115
+ await sub.unsubscribe()
116
+ except Exception as e:
117
+ logger.debug("Error unsubscribing: %s", e)
118
+ self._subscriptions.clear()
119
+
120
+ # Close NATS connection
121
+ if self.nc:
122
+ await self.nc.close()
123
+ self.nc = None
124
+ self.js = None
125
+
126
+ self._is_connected = False
127
+ logger.info("Disconnected from NATS")
128
+
129
+ async def subscribe(
130
+ self,
131
+ stream: str,
132
+ subject: str,
133
+ durable: Optional[str] = None,
134
+ callback: Optional[Callable[[Msg, str, str], Awaitable[None]]] = None
135
+ ):
136
+ """
137
+ Subscribe to a NATS JetStream stream using push consumer pattern.
138
+
139
+ This creates a push subscription where NATS JetStream automatically delivers
140
+ messages to the callback as they arrive. Messages are pushed to the client
141
+ rather than requiring explicit fetch/pull operations.
142
+
143
+ Args:
144
+ stream: Name of the JetStream stream
145
+ subject: Subject pattern to subscribe to (supports wildcards)
146
+ durable: Optional durable consumer name (for persistent subscriptions)
147
+ callback: Optional async callback function(msg, stream, subject) -> None
148
+ If not provided, calls handle_message() method
149
+
150
+ Raises:
151
+ RuntimeError: If not connected to NATS
152
+ """
153
+ if not self._is_connected or not self.js:
154
+ raise RuntimeError("Not connected to NATS. Call connect() first.")
155
+
156
+ # Use provided callback or default to handle_message method
157
+ if callback is None:
158
+ callback = self.handle_message
159
+
160
+ # Create callback wrapper
161
+ async def message_wrapper(msg: Msg):
162
+ try:
163
+ await callback(msg, stream, subject)
164
+ except Exception as e:
165
+ logger.error(
166
+ "Error in message callback for stream=%s, subject=%s: %s",
167
+ stream, subject, e, exc_info=True
168
+ )
169
+ # Don't ack on error - let the caller decide
170
+ # This allows for retry logic in the handler
171
+
172
+ try:
173
+ # Subscribe with durable consumer if specified
174
+ if durable:
175
+ sub = await self.js.subscribe(
176
+ subject,
177
+ stream=stream,
178
+ durable=durable,
179
+ cb=lambda msg: asyncio.create_task(message_wrapper(msg))
180
+ )
181
+ else:
182
+ # Ephemeral subscription
183
+ sub = await self.js.subscribe(
184
+ subject,
185
+ stream=stream,
186
+ cb=lambda msg: asyncio.create_task(message_wrapper(msg))
187
+ )
188
+
189
+ self._subscriptions.append(sub)
190
+ logger.info(
191
+ "Subscribed to stream=%s, subject=%s, durable=%s",
192
+ stream, subject, durable or "ephemeral"
193
+ )
194
+ except Exception as e:
195
+ error_msg = str(e)
196
+ # Handle the specific case where consumer is already bound
197
+ if "consumer is already bound" in error_msg.lower():
198
+ logger.warning(
199
+ "Consumer '%s' for stream '%s' is already bound. "
200
+ "This usually happens when the service didn't shut down cleanly. "
201
+ "Attempting to delete the consumer and retry...",
202
+ durable, stream
203
+ )
204
+ if durable:
205
+ try:
206
+ # Try to delete the consumer (may fail if actively bound)
207
+ await self.js.delete_consumer(stream, durable)
208
+ logger.info("Deleted consumer '%s' for stream '%s'", durable, stream)
209
+ # Retry subscription after deletion
210
+ sub = await self.js.subscribe(
211
+ subject,
212
+ stream=stream,
213
+ durable=durable,
214
+ cb=lambda msg: asyncio.create_task(message_wrapper(msg))
215
+ )
216
+ self._subscriptions.append(sub)
217
+ logger.info(
218
+ "Successfully subscribed after consumer cleanup: stream=%s, subject=%s, durable=%s",
219
+ stream, subject, durable
220
+ )
221
+ except Exception as retry_error:
222
+ retry_error_msg = str(retry_error)
223
+ if "bound" in retry_error_msg.lower() or "in use" in retry_error_msg.lower():
224
+ logger.error(
225
+ "Consumer '%s' for stream '%s' cannot be deleted because it's still bound. "
226
+ "This typically means the previous service instance is still running or "
227
+ "the subscription hasn't timed out yet. Solutions:\n"
228
+ " 1. Wait a few seconds and restart the service\n"
229
+ " 2. Manually delete the consumer: nats consumer rm %s %s\n"
230
+ " 3. Restart the NATS server\n"
231
+ " 4. Use a different durable consumer name",
232
+ durable, stream, stream, durable
233
+ )
234
+ else:
235
+ logger.error(
236
+ "Failed to delete consumer '%s' for stream '%s': %s",
237
+ durable, stream, retry_error
238
+ )
239
+ raise
240
+ else:
241
+ raise
242
+ else:
243
+ logger.error(
244
+ "Failed to subscribe to stream=%s, subject=%s: %s",
245
+ stream, subject, e
246
+ )
247
+ raise
248
+
249
+ @abstractmethod
250
+ async def handle_message(self, msg: Msg, stream: str, subject: str):
251
+ """
252
+ Handle an incoming message pushed by NATS JetStream. Override this method in subclasses.
253
+
254
+ This method is called automatically when NATS JetStream pushes a message
255
+ to this subscriber. The push consumer pattern means messages arrive
256
+ asynchronously via callbacks rather than being explicitly fetched.
257
+
258
+ Default implementation logs and acks the message.
259
+ Subclasses should implement their own message processing logic.
260
+
261
+ Args:
262
+ msg: NATS message object
263
+ stream: Name of the stream the message came from
264
+ subject: Subject pattern that matched this message
265
+ """
266
+ logger.debug(
267
+ "Received message from stream=%s, subject=%s, data_size=%d",
268
+ stream, subject, len(msg.data)
269
+ )
270
+ # Default: ack the message
271
+ await msg.ack()
272
+
273
+ async def _error_callback(self, error: Exception):
274
+ """Callback for NATS errors."""
275
+ if error:
276
+ logger.error("NATS error: %s", error, exc_info=True)
277
+ else:
278
+ logger.error("NATS error: Unknown error (error object is None)")
279
+
280
+ async def _disconnected_callback(self):
281
+ """Callback when disconnected from NATS."""
282
+ logger.warning("Disconnected from NATS servers")
283
+ self._is_connected = False
284
+
285
+ async def _reconnected_callback(self):
286
+ """Callback when reconnected to NATS."""
287
+ logger.info("Reconnected to NATS servers")
288
+ self._is_connected = True
289
+ if self.nc:
290
+ self.js = self.nc.jetstream()
291
+ # Clear old subscriptions as they're no longer valid after reconnection
292
+ self._subscriptions.clear()
293
+ # Re-subscribe to all streams
294
+ await self._resubscribe_all()
295
+
296
+ async def _closed_callback(self):
297
+ """Callback when connection is closed."""
298
+ logger.info("NATS connection closed")
299
+ self._is_connected = False
300
+
301
+ async def _resubscribe_all(self):
302
+ """
303
+ Re-subscribe to all streams after reconnection.
304
+
305
+ Override this method in subclasses to restore subscriptions.
306
+ The default implementation does nothing - subclasses should track
307
+ their subscriptions and re-subscribe here.
308
+ """
309
+ logger.debug("Reconnection detected, but no subscriptions to restore")
310
+
311
+ async def run(self, health_check_interval: float = 1.0):
312
+ """
313
+ Run the subscriber service with connection health monitoring.
314
+
315
+ This method will:
316
+ 1. Connect to NATS (with retry logic)
317
+ 2. Call on_start() hook for subclasses to set up subscriptions
318
+ 3. Monitor connection health and reconnect if needed
319
+ 4. Call on_stop() hook on shutdown
320
+
321
+ Args:
322
+ health_check_interval: Interval in seconds to check connection health
323
+ """
324
+ # Connect to NATS with retry logic
325
+ while self._should_run:
326
+ if await self.connect():
327
+ break
328
+ logger.warning("Failed to connect to NATS, retrying in 5 seconds...")
329
+ await asyncio.sleep(5)
330
+
331
+ # Call on_start hook for subclasses to set up subscriptions
332
+ await self.on_start()
333
+
334
+ logger.info("Stream subscriber service started")
335
+
336
+ # Main loop with health monitoring
337
+ try:
338
+ while self._should_run:
339
+ await asyncio.sleep(health_check_interval)
340
+
341
+ # Check connection health
342
+ if not self._is_connected:
343
+ logger.warning("Connection lost, attempting to reconnect...")
344
+ if await self.connect():
345
+ # Clear old subscriptions as they're no longer valid after reconnection
346
+ self._subscriptions.clear()
347
+ await self._resubscribe_all()
348
+ except KeyboardInterrupt:
349
+ logger.info("Received KeyboardInterrupt, shutting down...")
350
+ except Exception as e:
351
+ logger.error("Unexpected error in main loop: %s", e, exc_info=True)
352
+ finally:
353
+ await self.on_stop()
354
+ await self.disconnect()
355
+
356
+ @abstractmethod
357
+ async def on_start(self):
358
+ """
359
+ Hook called when the service starts. Override in subclasses to set up subscriptions.
360
+
361
+ Example:
362
+ ```python
363
+ async def on_start(self):
364
+ await self.subscribe("STREAM_NAME", "puda.*.cmd.response.queue", "my_consumer")
365
+ await self.subscribe("STREAM_NAME", "puda.*.cmd.response.immediate", "my_consumer2")
366
+ ```
367
+ """
368
+ pass
369
+
370
+ @abstractmethod
371
+ async def on_stop(self):
372
+ """
373
+ Hook called when the service stops. Override in subclasses for cleanup.
374
+ """
375
+ pass
376
+
377
+ # ==================== Context Manager ====================
378
+
379
+ async def __aenter__(self):
380
+ """Async context manager entry."""
381
+ await self.connect()
382
+ return self
383
+
384
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
385
+ """Async context manager exit."""
386
+ await self.disconnect()
387
+ return False # Don't suppress exceptions
388
+
@@ -1,5 +0,0 @@
1
- from .machine_client import MachineClient
2
- from .execution_state import ExecutionState
3
- from .command_service import CommandService
4
-
5
- __all__ = ["MachineClient", "ExecutionState", "CommandService"]