flo-python 0.1.0.dev2__py3-none-any.whl → 0.1.0.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flo/worker.py CHANGED
@@ -1,8 +1,10 @@
1
1
  """Flo High-Level Worker API
2
2
 
3
- Provides an easy-to-use Worker class for executing actions.
3
+ Provides ActionWorker for executing actions and StreamWorker for
4
+ processing stream records via consumer groups.
5
+
6
+ Example::
4
7
 
5
- Example:
6
8
  from flo import FloClient, ActionContext
7
9
 
8
10
  async def process_order(ctx: ActionContext) -> bytes:
@@ -12,12 +14,14 @@ Example:
12
14
 
13
15
  async def main():
14
16
  async with FloClient("localhost:3000", namespace="myapp") as client:
15
- worker = client.new_worker(concurrency=5)
16
- worker.action("process-order")(process_order)
17
- await worker.start()
17
+ worker = client.new_action_worker(concurrency=5)
18
+ worker.register_action("process-order", process_order)
19
+ async with worker:
20
+ await worker.start()
18
21
  """
19
22
 
20
23
  import asyncio
24
+ import contextlib
21
25
  import json
22
26
  import logging
23
27
  import secrets
@@ -27,18 +31,46 @@ from dataclasses import dataclass, field
27
31
  from typing import Any
28
32
 
29
33
  from .client import FloClient
30
- from .types import ActionType, TaskAssignment, WorkerAwaitOptions, WorkerTouchOptions
34
+ from .exceptions import NonRetryableError, is_connection_error
35
+ from .types import (
36
+ ActionType,
37
+ StreamGroupAckOptions,
38
+ StreamGroupNackOptions,
39
+ StreamGroupReadOptions,
40
+ StreamID,
41
+ StreamRecord,
42
+ TaskAssignment,
43
+ WorkerAwaitOptions,
44
+ WorkerTouchOptions,
45
+ )
31
46
 
32
47
  logger = logging.getLogger("flo.worker")
33
48
 
34
49
 
35
- # Type alias for action handlers
36
- ActionHandler = Callable[["ActionContext"], Awaitable[bytes]]
50
+ class ActionResult:
51
+ """Represents the result of an action with a named outcome.
52
+
53
+ Use ``ActionContext.result()`` to create instances.
54
+
55
+ Attributes:
56
+ outcome: Named outcome (e.g. "approved", "rejected").
57
+ data: Result data as bytes.
58
+ """
59
+
60
+ __slots__ = ("outcome", "data")
61
+
62
+ def __init__(self, outcome: str, data: bytes):
63
+ self.outcome = outcome
64
+ self.data = data
65
+
66
+
67
+ # Type alias for action handlers — can return bytes, dict, or ActionResult
68
+ ActionHandler = Callable[["ActionContext"], Awaitable[bytes | dict[str, Any] | ActionResult]]
37
69
 
38
70
 
39
71
  @dataclass
40
- class WorkerOptions:
41
- """Configuration for a Flo worker.
72
+ class ActionWorkerOptions:
73
+ """Configuration for a Flo action worker.
42
74
 
43
75
  Endpoint and namespace are inherited from the parent FloClient.
44
76
  """
@@ -63,7 +95,7 @@ class ActionContext:
63
95
  attempt: int
64
96
  created_at: int
65
97
  namespace: str
66
- _worker: "Worker" = field(repr=False)
98
+ _worker: "ActionWorker" = field(repr=False)
67
99
  _cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
68
100
 
69
101
  def input(self) -> bytes:
@@ -103,7 +135,7 @@ class ActionContext:
103
135
  Args:
104
136
  extend_ms: How long to extend the lease in milliseconds.
105
137
  """
106
- await self._worker._touch_task(self.task_id, extend_ms)
138
+ await self._worker._touch_task(self.action_name, self.task_id, extend_ms)
107
139
 
108
140
  @property
109
141
  def cancelled(self) -> bool:
@@ -115,15 +147,33 @@ class ActionContext:
115
147
  if self._cancel_event.is_set():
116
148
  raise asyncio.CancelledError("Task was cancelled")
117
149
 
150
+ def result(self, outcome: str, value: Any = None) -> ActionResult:
151
+ """Create an ActionResult with a named outcome.
152
+
153
+ Args:
154
+ outcome: Named outcome (e.g. "approved", "rejected").
155
+ value: Result value — dict/list is JSON-encoded, bytes passed through.
156
+
157
+ Returns:
158
+ ActionResult to return from the handler.
159
+ """
160
+ if isinstance(value, bytes):
161
+ data = value
162
+ elif value is not None:
163
+ data = json.dumps(value).encode("utf-8")
164
+ else:
165
+ data = b""
166
+ return ActionResult(outcome=outcome, data=data)
118
167
 
119
- class Worker:
168
+
169
+ class ActionWorker:
120
170
  """High-level Flo worker for executing actions.
121
171
 
122
- Created from a FloClient via ``client.new_worker()``.
172
+ Created from a FloClient via ``client.new_action_worker()``.
123
173
 
124
174
  Example:
125
175
  async with FloClient("localhost:3000", namespace="myapp") as client:
126
- worker = client.new_worker(concurrency=5)
176
+ worker = client.new_action_worker(concurrency=5)
127
177
 
128
178
  @worker.action("process-order")
129
179
  async def process_order(ctx: ActionContext) -> bytes:
@@ -152,7 +202,7 @@ class Worker:
152
202
  block_ms: Timeout for blocking dequeue in milliseconds.
153
203
  """
154
204
  self._parent_client = parent_client
155
- self.config = WorkerOptions(
205
+ self.config = ActionWorkerOptions(
156
206
  worker_id=worker_id or self._generate_worker_id(),
157
207
  concurrency=concurrency,
158
208
  action_timeout=action_timeout,
@@ -160,11 +210,13 @@ class Worker:
160
210
  )
161
211
 
162
212
  self._client: FloClient | None = None
213
+ self._result_client: FloClient | None = None
163
214
  self._handlers: dict[str, ActionHandler] = {}
164
215
  self._running = False
165
216
  self._stop_event = asyncio.Event()
166
217
  self._tasks: set[asyncio.Task[None]] = set()
167
218
  self._semaphore: asyncio.Semaphore | None = None
219
+ self._heartbeat_task: asyncio.Task[None] | None = None
168
220
 
169
221
  @staticmethod
170
222
  def _generate_worker_id() -> str:
@@ -228,14 +280,32 @@ class Worker:
228
280
  f"namespace={self._parent_client.namespace}, concurrency={self.config.concurrency})"
229
281
  )
230
282
 
231
- # Create a dedicated connection using the parent client's endpoint and namespace
283
+ # Create a dedicated connection using the parent client's endpoint and namespace.
284
+ # Timeout must accommodate block_ms + action_timeout so blocking reads
285
+ # (ACTION_AWAIT with block_ms) don't get killed by socket-level timeout.
286
+ worker_timeout_ms = max(
287
+ self.config.block_ms + 5000,
288
+ int(self.config.action_timeout * 1000),
289
+ )
232
290
  self._client = FloClient(
233
291
  self._parent_client._endpoint,
234
292
  namespace=self._parent_client.namespace,
235
293
  debug=self._parent_client._debug,
294
+ timeout_ms=worker_timeout_ms,
236
295
  )
237
296
  await self._client.connect()
238
297
 
298
+ # Create a second connection for sending Complete/Fail results.
299
+ # The polling connection holds its lock during blocking Await calls
300
+ # (up to block_ms), so a separate connection prevents contention.
301
+ self._result_client = FloClient(
302
+ self._parent_client._endpoint,
303
+ namespace=self._parent_client.namespace,
304
+ debug=self._parent_client._debug,
305
+ timeout_ms=worker_timeout_ms,
306
+ )
307
+ await self._result_client.connect()
308
+
239
309
  try:
240
310
  # Register actions with the server
241
311
  action_names = list(self._handlers.keys())
@@ -244,9 +314,15 @@ class Worker:
244
314
  logger.debug(f"Registered action with server: {action_name}")
245
315
 
246
316
  # Register worker
317
+ from .types import WorkerRegisterOptions
318
+
247
319
  await self._client.worker.register(
248
320
  self.config.worker_id,
249
321
  action_names,
322
+ WorkerRegisterOptions(
323
+ concurrency=self.config.concurrency,
324
+ machine_id=socket.gethostname(),
325
+ ),
250
326
  )
251
327
  logger.info(f"Worker registered with {len(action_names)} actions")
252
328
 
@@ -255,20 +331,52 @@ class Worker:
255
331
  self._running = True
256
332
  self._stop_event.clear()
257
333
 
334
+ # Start heartbeat loop
335
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
336
+
258
337
  # Main polling loop
259
338
  await self._poll_loop(action_names)
260
339
 
261
340
  finally:
341
+ # Cancel heartbeat
342
+ if self._heartbeat_task is not None:
343
+ self._heartbeat_task.cancel()
344
+ with contextlib.suppress(asyncio.CancelledError):
345
+ await self._heartbeat_task
346
+ self._heartbeat_task = None
347
+
262
348
  # Wait for running tasks
263
349
  if self._tasks:
264
350
  logger.info(f"Waiting for {len(self._tasks)} tasks to complete...")
265
351
  await asyncio.gather(*self._tasks, return_exceptions=True)
266
352
 
353
+ if self._result_client:
354
+ await self._result_client.close()
355
+ self._result_client = None
267
356
  await self._client.close()
268
357
  self._client = None
269
358
  self._running = False
270
359
  logger.info("Worker stopped")
271
360
 
361
+ async def _heartbeat_loop(self) -> None:
362
+ """Send periodic heartbeats to keep the worker registration alive."""
363
+ assert self._client is not None
364
+ while self._running and not self._stop_event.is_set():
365
+ try:
366
+ await asyncio.sleep(30)
367
+ if not self._running:
368
+ break
369
+ current_load = len(self._tasks)
370
+ await self._client.worker.heartbeat(
371
+ self.config.worker_id,
372
+ current_load=current_load,
373
+ )
374
+ logger.debug(f"Heartbeat sent (load={current_load})")
375
+ except asyncio.CancelledError:
376
+ break
377
+ except Exception as e:
378
+ logger.warning(f"Heartbeat failed: {e}")
379
+
272
380
  async def _poll_loop(self, action_names: list[str]) -> None:
273
381
  """Main polling loop for tasks."""
274
382
  assert self._client is not None
@@ -304,13 +412,61 @@ class Worker:
304
412
  break
305
413
  except Exception as e:
306
414
  self._semaphore.release()
307
- logger.error(f"Await error: {e}, retrying...")
308
- await asyncio.sleep(1)
415
+ if is_connection_error(e):
416
+ logger.warning("Connection lost, reconnecting...")
417
+ try:
418
+ await self._client.reconnect()
419
+ # Also reconnect result client
420
+ if self._result_client is not None:
421
+ try:
422
+ await self._result_client.reconnect()
423
+ except Exception as rc_err:
424
+ logger.warning(f"Failed to reconnect result client: {rc_err}")
425
+ # Re-register worker after reconnect
426
+ try:
427
+ from .types import WorkerRegisterOptions
428
+
429
+ await self._client.worker.register(
430
+ self.config.worker_id,
431
+ action_names,
432
+ WorkerRegisterOptions(
433
+ concurrency=self.config.concurrency,
434
+ machine_id=socket.gethostname(),
435
+ ),
436
+ )
437
+ except Exception as reg_err:
438
+ logger.warning(f"Failed to re-register worker: {reg_err}")
439
+ logger.info("Reconnected, resuming work")
440
+ except Exception as recon_err:
441
+ logger.error(f"Reconnect failed: {recon_err}, retrying...")
442
+ await asyncio.sleep(1)
443
+ else:
444
+ logger.error(f"Await error: {e}, retrying...")
445
+ await asyncio.sleep(1)
446
+
447
+ async def _send_with_retry(self, op: str, fn: Callable[[], Awaitable[None]]) -> None:
448
+ """Attempt to send a result (Complete/Fail), reconnecting on connection error."""
449
+ max_attempts = 3
450
+ for attempt in range(1, max_attempts + 1):
451
+ try:
452
+ await fn()
453
+ return
454
+ except Exception as e:
455
+ if not is_connection_error(e) or not self._running:
456
+ raise
457
+ logger.warning(
458
+ f"Connection lost while sending {op} result "
459
+ f"(attempt {attempt}/{max_attempts}), reconnecting..."
460
+ )
461
+ if self._result_client is not None:
462
+ await self._result_client.reconnect()
463
+ raise RuntimeError(f"Failed to send {op} result after {max_attempts} attempts")
309
464
 
310
465
  async def _execute_task(self, task: TaskAssignment) -> None:
311
466
  """Execute a task with error handling."""
312
467
  assert self._client is not None
313
468
  assert self._semaphore is not None
469
+ rc = self._result_client or self._client
314
470
  try:
315
471
  logger.info(
316
472
  f"Executing action: {task.task_type} (task={task.task_id}, attempt={task.attempt})"
@@ -320,10 +476,14 @@ class Worker:
320
476
  handler = self._handlers.get(task.task_type)
321
477
  if handler is None:
322
478
  logger.error(f"No handler registered for action: {task.task_type}")
323
- await self._client.worker.fail(
324
- self.config.worker_id,
325
- task.task_id,
326
- f"No handler for: {task.task_type}",
479
+ await self._send_with_retry(
480
+ "fail",
481
+ lambda: rc.worker.fail(
482
+ self.config.worker_id,
483
+ task.task_type,
484
+ task.task_id,
485
+ f"No handler for: {task.task_type}",
486
+ ),
327
487
  )
328
488
  return
329
489
 
@@ -345,36 +505,85 @@ class Worker:
345
505
  timeout=self.config.action_timeout,
346
506
  )
347
507
 
348
- # Success - complete the task
349
- await self._client.worker.complete(
350
- self.config.worker_id,
351
- task.task_id,
352
- result,
353
- )
508
+ # 3-way dispatch based on result type
509
+ from .types import WorkerCompleteOptions
510
+
511
+ if isinstance(result, ActionResult):
512
+ # Named outcome
513
+ await self._send_with_retry(
514
+ "complete",
515
+ lambda: rc.worker.complete(
516
+ self.config.worker_id,
517
+ task.task_type,
518
+ task.task_id,
519
+ result.data,
520
+ WorkerCompleteOptions(outcome=result.outcome),
521
+ ),
522
+ )
523
+ elif isinstance(result, dict):
524
+ # Plain dict → JSON serialize
525
+ result_bytes = json.dumps(result).encode("utf-8")
526
+ await self._send_with_retry(
527
+ "complete",
528
+ lambda: rc.worker.complete(
529
+ self.config.worker_id,
530
+ task.task_type,
531
+ task.task_id,
532
+ result_bytes,
533
+ ),
534
+ )
535
+ else:
536
+ # bytes or other → pass through
537
+ await self._send_with_retry(
538
+ "complete",
539
+ lambda: rc.worker.complete(
540
+ self.config.worker_id,
541
+ task.task_type,
542
+ task.task_id,
543
+ result if isinstance(result, bytes) else b"",
544
+ ),
545
+ )
354
546
  logger.info(f"Action completed: {task.task_type}")
355
547
 
356
548
  except asyncio.TimeoutError:
357
549
  logger.error(f"Action timed out: {task.task_type}")
358
- await self._client.worker.fail(
359
- self.config.worker_id,
360
- task.task_id,
361
- "Action timed out",
550
+ await self._send_with_retry(
551
+ "fail",
552
+ lambda: rc.worker.fail(
553
+ self.config.worker_id,
554
+ task.task_type,
555
+ task.task_id,
556
+ "Action timed out",
557
+ ),
362
558
  )
363
559
 
364
560
  except asyncio.CancelledError:
365
561
  logger.warning(f"Action cancelled: {task.task_type}")
366
- await self._client.worker.fail(
367
- self.config.worker_id,
368
- task.task_id,
369
- "Action cancelled",
562
+ await self._send_with_retry(
563
+ "fail",
564
+ lambda: rc.worker.fail(
565
+ self.config.worker_id,
566
+ task.task_type,
567
+ task.task_id,
568
+ "Action cancelled",
569
+ ),
370
570
  )
371
571
 
372
- except Exception as e:
373
- logger.error(f"Action failed: {task.task_type} - {e}")
374
- await self._client.worker.fail(
375
- self.config.worker_id,
376
- task.task_id,
377
- str(e),
572
+ except Exception as exc:
573
+ logger.error(f"Action failed: {task.task_type} - {exc}")
574
+ from .types import WorkerFailOptions
575
+
576
+ retry = not isinstance(exc, NonRetryableError)
577
+ err_msg = str(exc)
578
+ await self._send_with_retry(
579
+ "fail",
580
+ lambda: rc.worker.fail(
581
+ self.config.worker_id,
582
+ task.task_type,
583
+ task.task_id,
584
+ err_msg,
585
+ WorkerFailOptions(retry=retry),
586
+ ),
378
587
  )
379
588
 
380
589
  except Exception as e:
@@ -384,12 +593,13 @@ class Worker:
384
593
  if self._semaphore is not None:
385
594
  self._semaphore.release()
386
595
 
387
- async def _touch_task(self, task_id: str, extend_ms: int) -> None:
596
+ async def _touch_task(self, action_name: str, task_id: str, extend_ms: int) -> None:
388
597
  """Extend lease on a task (internal method)."""
389
598
  if self._client is None:
390
599
  raise RuntimeError("Worker not connected")
391
600
  await self._client.worker.touch(
392
601
  self.config.worker_id,
602
+ action_name,
393
603
  task_id,
394
604
  WorkerTouchOptions(extend_ms=extend_ms),
395
605
  )
@@ -398,14 +608,400 @@ class Worker:
398
608
  """Signal the worker to stop.
399
609
 
400
610
  This sets a flag that will cause the polling loop to exit
401
- after the current iteration completes.
611
+ after the current iteration completes. Also interrupts in-flight
612
+ connections to unblock any blocking Await call immediately.
402
613
  """
403
614
  logger.info("Stopping worker...")
404
615
  self._running = False
405
616
  self._stop_event.set()
617
+ # Interrupt connections to unblock any blocking Await
618
+ if self._client:
619
+ self._client.interrupt()
620
+ if self._result_client:
621
+ self._result_client.interrupt()
406
622
 
407
623
  async def close(self) -> None:
408
624
  """Stop and close the worker."""
409
625
  self.stop()
626
+ if self._result_client:
627
+ await self._result_client.close()
628
+ if self._client:
629
+ await self._client.close()
630
+
631
+ async def __aenter__(self) -> "ActionWorker":
632
+ """Async context manager entry."""
633
+ return self
634
+
635
+ async def __aexit__(
636
+ self,
637
+ exc_type: type[BaseException] | None,
638
+ exc_val: BaseException | None,
639
+ exc_tb: object,
640
+ ) -> None:
641
+ """Async context manager exit — closes the worker."""
642
+ await self.close()
643
+
644
+
645
+ # =============================================================================
646
+ # Stream Worker
647
+ # =============================================================================
648
+
649
+ # Type alias for stream record handlers
650
+ # Return normally to auto-ack, raise to auto-nack.
651
+ StreamRecordHandler = Callable[["StreamContext"], Awaitable[None]]
652
+
653
+
654
+ @dataclass
655
+ class StreamWorkerOptions:
656
+ """Configuration for a Flo stream worker.
657
+
658
+ Endpoint and namespace are inherited from the parent FloClient.
659
+ """
660
+
661
+ stream: str
662
+ group: str = ""
663
+ consumer: str = ""
664
+ worker_id: str = ""
665
+ concurrency: int = 10
666
+ batch_size: int = 10
667
+ block_ms: int = 30000
668
+ message_timeout: float = 300.0 # 5 minutes
669
+
670
+
671
+ @dataclass
672
+ class StreamContext:
673
+ """Context passed to stream record handlers.
674
+
675
+ Provides access to record data and helper methods.
676
+ """
677
+
678
+ record: StreamRecord
679
+ namespace: str
680
+ stream: str
681
+ group: str
682
+ consumer: str
683
+
684
+ @property
685
+ def stream_id(self) -> StreamID:
686
+ """Get the record's StreamID."""
687
+ return self.record.id
688
+
689
+ @property
690
+ def payload(self) -> bytes:
691
+ """Get the raw record payload."""
692
+ return self.record.payload
693
+
694
+ def json(self) -> Any:
695
+ """Parse record payload as JSON."""
696
+ if not self.record.payload:
697
+ raise ValueError("No payload data")
698
+ return json.loads(self.record.payload.decode("utf-8"))
699
+
700
+ def into(self, cls: type) -> Any:
701
+ """Parse payload as JSON and instantiate the given class."""
702
+ data = self.json()
703
+ if isinstance(data, dict):
704
+ return cls(**data)
705
+ return cls(data)
706
+
707
+ @property
708
+ def headers(self) -> dict[str, str]:
709
+ """Get record headers."""
710
+ return self.record.headers if self.record.headers else {}
711
+
712
+
713
+ class StreamWorker:
714
+ """High-level Flo worker for processing stream records via consumer groups.
715
+
716
+ Polls a consumer group with ``group_read()``, dispatches records to the
717
+ handler, and auto-acks on success or auto-nacks on error.
718
+
719
+ Created from a FloClient via ``client.new_stream_worker()``.
720
+
721
+ Example:
722
+ async with FloClient("localhost:3000", namespace="myapp") as client:
723
+ async def process_event(ctx: StreamContext) -> None:
724
+ event = ctx.json()
725
+ print(f"Got event: {event}")
726
+ # Return normally → auto-ack
727
+ # Raise → auto-nack
728
+
729
+ worker = client.new_stream_worker(
730
+ stream="events",
731
+ group="processors",
732
+ consumer="worker-1",
733
+ handler=process_event,
734
+ concurrency=5,
735
+ )
736
+ await worker.start()
737
+ """
738
+
739
+ def __init__(
740
+ self,
741
+ parent_client: "FloClient",
742
+ handler: StreamRecordHandler,
743
+ *,
744
+ stream: str,
745
+ group: str,
746
+ consumer: str | None = None,
747
+ worker_id: str | None = None,
748
+ concurrency: int = 10,
749
+ batch_size: int = 10,
750
+ block_ms: int = 30000,
751
+ message_timeout: float = 300.0,
752
+ ):
753
+ self._parent_client = parent_client
754
+ self._handler = handler
755
+ self.config = StreamWorkerOptions(
756
+ stream=stream,
757
+ group=group,
758
+ consumer=consumer or self._generate_consumer_id(),
759
+ worker_id=worker_id or self._generate_consumer_id(),
760
+ concurrency=concurrency,
761
+ batch_size=batch_size,
762
+ block_ms=block_ms,
763
+ message_timeout=message_timeout,
764
+ )
765
+
766
+ self._client: FloClient | None = None
767
+ self._running = False
768
+ self._stop_event = asyncio.Event()
769
+ self._tasks: set[asyncio.Task[None]] = set()
770
+ self._semaphore: asyncio.Semaphore | None = None
771
+
772
+ @staticmethod
773
+ def _generate_consumer_id() -> str:
774
+ """Generate a unique consumer ID."""
775
+ try:
776
+ hostname = socket.gethostname()
777
+ except Exception:
778
+ hostname = "unknown"
779
+ return f"{hostname}-{secrets.token_hex(4)}"
780
+
781
+ async def start(self) -> None:
782
+ """Start the stream worker and begin processing records.
783
+
784
+ Joins the consumer group, then polls for records. Blocks until
785
+ ``stop()`` is called.
786
+ """
787
+ logger.info(
788
+ f"Starting stream worker (stream={self.config.stream}, "
789
+ f"group={self.config.group}, consumer={self.config.consumer}, "
790
+ f"concurrency={self.config.concurrency})"
791
+ )
792
+
793
+ # Timeout must accommodate block_ms + message_timeout so blocking reads
794
+ # (group_read with block_ms) don't get killed by socket-level timeout.
795
+ worker_timeout_ms = max(
796
+ self.config.block_ms + 5000,
797
+ int(self.config.message_timeout * 1000),
798
+ )
799
+ self._client = FloClient(
800
+ self._parent_client._endpoint,
801
+ namespace=self._parent_client.namespace,
802
+ debug=self._parent_client._debug,
803
+ timeout_ms=worker_timeout_ms,
804
+ )
805
+ await self._client.connect()
806
+
807
+ try:
808
+ # Join consumer group
809
+ await self._client.stream.group_join(
810
+ self.config.stream,
811
+ self.config.group,
812
+ self.config.consumer,
813
+ )
814
+ logger.info(f"Joined consumer group {self.config.group} on stream {self.config.stream}")
815
+
816
+ self._semaphore = asyncio.Semaphore(self.config.concurrency)
817
+ self._running = True
818
+ self._stop_event.clear()
819
+
820
+ await self._poll_loop()
821
+
822
+ finally:
823
+ # Wait for in-flight tasks
824
+ if self._tasks:
825
+ logger.info(f"Waiting for {len(self._tasks)} tasks to complete...")
826
+ await asyncio.gather(*self._tasks, return_exceptions=True)
827
+
828
+ # Leave consumer group
829
+ if self._client:
830
+ try:
831
+ await self._client.stream.group_leave(
832
+ self.config.stream,
833
+ self.config.group,
834
+ self.config.consumer,
835
+ )
836
+ logger.info(f"Left consumer group {self.config.group}")
837
+ except Exception as e:
838
+ logger.warning(f"Failed to leave consumer group: {e}")
839
+
840
+ await self._client.close()
841
+ self._client = None
842
+
843
+ self._running = False
844
+ logger.info("Stream worker stopped")
845
+
846
+ async def _poll_loop(self) -> None:
847
+ """Main polling loop for stream records."""
848
+ assert self._client is not None
849
+ assert self._semaphore is not None
850
+
851
+ while self._running and not self._stop_event.is_set():
852
+ try:
853
+ # Wait for a concurrency slot
854
+ await self._semaphore.acquire()
855
+
856
+ if self._stop_event.is_set():
857
+ self._semaphore.release()
858
+ break
859
+
860
+ result = await self._client.stream.group_read(
861
+ self.config.stream,
862
+ self.config.group,
863
+ self.config.consumer,
864
+ StreamGroupReadOptions(
865
+ count=self.config.batch_size,
866
+ block_ms=self.config.block_ms,
867
+ ),
868
+ )
869
+
870
+ if not result.records:
871
+ self._semaphore.release()
872
+ continue
873
+
874
+ # Release semaphore before dispatching — each task will
875
+ # acquire its own slot via _process_record.
876
+ self._semaphore.release()
877
+
878
+ for record in result.records:
879
+ await self._semaphore.acquire()
880
+ if self._stop_event.is_set():
881
+ self._semaphore.release()
882
+ return
883
+ task = asyncio.create_task(self._process_record(record))
884
+ self._tasks.add(task)
885
+ task.add_done_callback(self._tasks.discard)
886
+
887
+ except asyncio.CancelledError:
888
+ break
889
+ except Exception as e:
890
+ self._semaphore.release()
891
+ if is_connection_error(e):
892
+ logger.warning("Stream worker lost connection, reconnecting...")
893
+ try:
894
+ await self._handle_reconnect()
895
+ logger.info("Stream worker reconnected, resuming")
896
+ except Exception as recon_err:
897
+ logger.error(f"Stream worker reconnect failed: {recon_err}, retrying...")
898
+ await asyncio.sleep(1)
899
+ else:
900
+ logger.error(f"Stream read error: {e}, retrying...")
901
+ await asyncio.sleep(1)
902
+
903
+ async def _handle_reconnect(self) -> None:
904
+ """Reconnect and re-join the consumer group."""
905
+ assert self._client is not None
906
+ await self._client.reconnect()
907
+ await self._client.stream.group_join(
908
+ self.config.stream,
909
+ self.config.group,
910
+ self.config.consumer,
911
+ )
912
+
913
+ async def _ack_with_retry(
914
+ self, record_ids: list[StreamID], options: StreamGroupAckOptions
915
+ ) -> None:
916
+ """Ack with retry on connection error."""
917
+ max_attempts = 3
918
+ for attempt in range(1, max_attempts + 1):
919
+ try:
920
+ assert self._client is not None
921
+ await self._client.stream.group_ack(
922
+ self.config.stream,
923
+ self.config.group,
924
+ record_ids,
925
+ options,
926
+ )
927
+ return
928
+ except Exception as e:
929
+ if not is_connection_error(e) or not self._running:
930
+ raise
931
+ logger.warning(
932
+ "Connection lost while acking "
933
+ f"(attempt {attempt}/{max_attempts}), reconnecting..."
934
+ )
935
+ await self._handle_reconnect()
936
+
937
+ async def _process_record(self, record: StreamRecord) -> None:
938
+ """Process a single record: call handler, then ack or nack."""
939
+ assert self._client is not None
940
+ assert self._semaphore is not None
941
+ try:
942
+ ctx = StreamContext(
943
+ record=record,
944
+ namespace=self._parent_client.namespace,
945
+ stream=self.config.stream,
946
+ group=self.config.group,
947
+ consumer=self.config.consumer,
948
+ )
949
+
950
+ try:
951
+ await asyncio.wait_for(
952
+ self._handler(ctx),
953
+ timeout=self.config.message_timeout,
954
+ )
955
+
956
+ # Success — ack with retry
957
+ await self._ack_with_retry(
958
+ [record.id],
959
+ StreamGroupAckOptions(consumer=self.config.consumer),
960
+ )
961
+
962
+ except Exception as e:
963
+ logger.error(
964
+ f"Record processing failed (stream={self.config.stream}, id={record.id}): {e}"
965
+ )
966
+ try:
967
+ await self._client.stream.group_nack(
968
+ self.config.stream,
969
+ self.config.group,
970
+ [record.id],
971
+ StreamGroupNackOptions(consumer=self.config.consumer),
972
+ )
973
+ except Exception as nack_err:
974
+ logger.error(f"Failed to nack record: {nack_err}")
975
+
976
+ except Exception as e:
977
+ logger.error(f"Failed to process record: {e}")
978
+
979
+ finally:
980
+ self._semaphore.release()
981
+
982
+ def stop(self) -> None:
983
+ """Signal the stream worker to stop."""
984
+ logger.info("Stopping stream worker...")
985
+ self._running = False
986
+ self._stop_event.set()
987
+ if self._client:
988
+ self._client.interrupt()
989
+
990
+ async def close(self) -> None:
991
+ """Stop and close the stream worker."""
992
+ self.stop()
410
993
  if self._client:
411
994
  await self._client.close()
995
+
996
+ async def __aenter__(self) -> "StreamWorker":
997
+ """Async context manager entry."""
998
+ return self
999
+
1000
+ async def __aexit__(
1001
+ self,
1002
+ exc_type: type[BaseException] | None,
1003
+ exc_val: BaseException | None,
1004
+ exc_tb: object,
1005
+ ) -> None:
1006
+ """Async context manager exit — closes the stream worker."""
1007
+ await self.close()