flo-python 0.1.0.dev2__py3-none-any.whl → 0.1.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flo/__init__.py +80 -8
- flo/actions.py +44 -15
- flo/client.py +141 -18
- flo/exceptions.py +21 -0
- flo/kv.py +6 -6
- flo/processing.py +341 -0
- flo/streams.py +17 -16
- flo/types.py +440 -190
- flo/wire.py +107 -49
- flo/worker.py +641 -45
- flo/workflows.py +463 -0
- {flo_python-0.1.0.dev2.dist-info → flo_python-0.1.0.dev4.dist-info}/METADATA +29 -4
- flo_python-0.1.0.dev4.dist-info/RECORD +16 -0
- {flo_python-0.1.0.dev2.dist-info → flo_python-0.1.0.dev4.dist-info}/WHEEL +1 -1
- flo_python-0.1.0.dev2.dist-info/RECORD +0 -14
- {flo_python-0.1.0.dev2.dist-info → flo_python-0.1.0.dev4.dist-info}/licenses/LICENSE +0 -0
flo/worker.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
"""Flo High-Level Worker API
|
|
2
2
|
|
|
3
|
-
Provides
|
|
3
|
+
Provides ActionWorker for executing actions and StreamWorker for
|
|
4
|
+
processing stream records via consumer groups.
|
|
5
|
+
|
|
6
|
+
Example::
|
|
4
7
|
|
|
5
|
-
Example:
|
|
6
8
|
from flo import FloClient, ActionContext
|
|
7
9
|
|
|
8
10
|
async def process_order(ctx: ActionContext) -> bytes:
|
|
@@ -12,12 +14,14 @@ Example:
|
|
|
12
14
|
|
|
13
15
|
async def main():
|
|
14
16
|
async with FloClient("localhost:3000", namespace="myapp") as client:
|
|
15
|
-
worker = client.
|
|
16
|
-
worker.
|
|
17
|
-
|
|
17
|
+
worker = client.new_action_worker(concurrency=5)
|
|
18
|
+
worker.register_action("process-order", process_order)
|
|
19
|
+
async with worker:
|
|
20
|
+
await worker.start()
|
|
18
21
|
"""
|
|
19
22
|
|
|
20
23
|
import asyncio
|
|
24
|
+
import contextlib
|
|
21
25
|
import json
|
|
22
26
|
import logging
|
|
23
27
|
import secrets
|
|
@@ -27,18 +31,46 @@ from dataclasses import dataclass, field
|
|
|
27
31
|
from typing import Any
|
|
28
32
|
|
|
29
33
|
from .client import FloClient
|
|
30
|
-
from .
|
|
34
|
+
from .exceptions import NonRetryableError, is_connection_error
|
|
35
|
+
from .types import (
|
|
36
|
+
ActionType,
|
|
37
|
+
StreamGroupAckOptions,
|
|
38
|
+
StreamGroupNackOptions,
|
|
39
|
+
StreamGroupReadOptions,
|
|
40
|
+
StreamID,
|
|
41
|
+
StreamRecord,
|
|
42
|
+
TaskAssignment,
|
|
43
|
+
WorkerAwaitOptions,
|
|
44
|
+
WorkerTouchOptions,
|
|
45
|
+
)
|
|
31
46
|
|
|
32
47
|
logger = logging.getLogger("flo.worker")
|
|
33
48
|
|
|
34
49
|
|
|
35
|
-
|
|
36
|
-
|
|
50
|
+
class ActionResult:
|
|
51
|
+
"""Represents the result of an action with a named outcome.
|
|
52
|
+
|
|
53
|
+
Use ``ActionContext.result()`` to create instances.
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
outcome: Named outcome (e.g. "approved", "rejected").
|
|
57
|
+
data: Result data as bytes.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
__slots__ = ("outcome", "data")
|
|
61
|
+
|
|
62
|
+
def __init__(self, outcome: str, data: bytes):
|
|
63
|
+
self.outcome = outcome
|
|
64
|
+
self.data = data
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Type alias for action handlers — can return bytes, dict, or ActionResult
|
|
68
|
+
ActionHandler = Callable[["ActionContext"], Awaitable[bytes | dict[str, Any] | ActionResult]]
|
|
37
69
|
|
|
38
70
|
|
|
39
71
|
@dataclass
|
|
40
|
-
class
|
|
41
|
-
"""Configuration for a Flo worker.
|
|
72
|
+
class ActionWorkerOptions:
|
|
73
|
+
"""Configuration for a Flo action worker.
|
|
42
74
|
|
|
43
75
|
Endpoint and namespace are inherited from the parent FloClient.
|
|
44
76
|
"""
|
|
@@ -63,7 +95,7 @@ class ActionContext:
|
|
|
63
95
|
attempt: int
|
|
64
96
|
created_at: int
|
|
65
97
|
namespace: str
|
|
66
|
-
_worker: "
|
|
98
|
+
_worker: "ActionWorker" = field(repr=False)
|
|
67
99
|
_cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
|
68
100
|
|
|
69
101
|
def input(self) -> bytes:
|
|
@@ -103,7 +135,7 @@ class ActionContext:
|
|
|
103
135
|
Args:
|
|
104
136
|
extend_ms: How long to extend the lease in milliseconds.
|
|
105
137
|
"""
|
|
106
|
-
await self._worker._touch_task(self.task_id, extend_ms)
|
|
138
|
+
await self._worker._touch_task(self.action_name, self.task_id, extend_ms)
|
|
107
139
|
|
|
108
140
|
@property
|
|
109
141
|
def cancelled(self) -> bool:
|
|
@@ -115,15 +147,33 @@ class ActionContext:
|
|
|
115
147
|
if self._cancel_event.is_set():
|
|
116
148
|
raise asyncio.CancelledError("Task was cancelled")
|
|
117
149
|
|
|
150
|
+
def result(self, outcome: str, value: Any = None) -> ActionResult:
|
|
151
|
+
"""Create an ActionResult with a named outcome.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
outcome: Named outcome (e.g. "approved", "rejected").
|
|
155
|
+
value: Result value — dict/list is JSON-encoded, bytes passed through.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
ActionResult to return from the handler.
|
|
159
|
+
"""
|
|
160
|
+
if isinstance(value, bytes):
|
|
161
|
+
data = value
|
|
162
|
+
elif value is not None:
|
|
163
|
+
data = json.dumps(value).encode("utf-8")
|
|
164
|
+
else:
|
|
165
|
+
data = b""
|
|
166
|
+
return ActionResult(outcome=outcome, data=data)
|
|
118
167
|
|
|
119
|
-
|
|
168
|
+
|
|
169
|
+
class ActionWorker:
|
|
120
170
|
"""High-level Flo worker for executing actions.
|
|
121
171
|
|
|
122
|
-
Created from a FloClient via ``client.
|
|
172
|
+
Created from a FloClient via ``client.new_action_worker()``.
|
|
123
173
|
|
|
124
174
|
Example:
|
|
125
175
|
async with FloClient("localhost:3000", namespace="myapp") as client:
|
|
126
|
-
worker = client.
|
|
176
|
+
worker = client.new_action_worker(concurrency=5)
|
|
127
177
|
|
|
128
178
|
@worker.action("process-order")
|
|
129
179
|
async def process_order(ctx: ActionContext) -> bytes:
|
|
@@ -152,7 +202,7 @@ class Worker:
|
|
|
152
202
|
block_ms: Timeout for blocking dequeue in milliseconds.
|
|
153
203
|
"""
|
|
154
204
|
self._parent_client = parent_client
|
|
155
|
-
self.config =
|
|
205
|
+
self.config = ActionWorkerOptions(
|
|
156
206
|
worker_id=worker_id or self._generate_worker_id(),
|
|
157
207
|
concurrency=concurrency,
|
|
158
208
|
action_timeout=action_timeout,
|
|
@@ -160,11 +210,13 @@ class Worker:
|
|
|
160
210
|
)
|
|
161
211
|
|
|
162
212
|
self._client: FloClient | None = None
|
|
213
|
+
self._result_client: FloClient | None = None
|
|
163
214
|
self._handlers: dict[str, ActionHandler] = {}
|
|
164
215
|
self._running = False
|
|
165
216
|
self._stop_event = asyncio.Event()
|
|
166
217
|
self._tasks: set[asyncio.Task[None]] = set()
|
|
167
218
|
self._semaphore: asyncio.Semaphore | None = None
|
|
219
|
+
self._heartbeat_task: asyncio.Task[None] | None = None
|
|
168
220
|
|
|
169
221
|
@staticmethod
|
|
170
222
|
def _generate_worker_id() -> str:
|
|
@@ -228,14 +280,32 @@ class Worker:
|
|
|
228
280
|
f"namespace={self._parent_client.namespace}, concurrency={self.config.concurrency})"
|
|
229
281
|
)
|
|
230
282
|
|
|
231
|
-
# Create a dedicated connection using the parent client's endpoint and namespace
|
|
283
|
+
# Create a dedicated connection using the parent client's endpoint and namespace.
|
|
284
|
+
# Timeout must accommodate block_ms + action_timeout so blocking reads
|
|
285
|
+
# (ACTION_AWAIT with block_ms) don't get killed by socket-level timeout.
|
|
286
|
+
worker_timeout_ms = max(
|
|
287
|
+
self.config.block_ms + 5000,
|
|
288
|
+
int(self.config.action_timeout * 1000),
|
|
289
|
+
)
|
|
232
290
|
self._client = FloClient(
|
|
233
291
|
self._parent_client._endpoint,
|
|
234
292
|
namespace=self._parent_client.namespace,
|
|
235
293
|
debug=self._parent_client._debug,
|
|
294
|
+
timeout_ms=worker_timeout_ms,
|
|
236
295
|
)
|
|
237
296
|
await self._client.connect()
|
|
238
297
|
|
|
298
|
+
# Create a second connection for sending Complete/Fail results.
|
|
299
|
+
# The polling connection holds its lock during blocking Await calls
|
|
300
|
+
# (up to block_ms), so a separate connection prevents contention.
|
|
301
|
+
self._result_client = FloClient(
|
|
302
|
+
self._parent_client._endpoint,
|
|
303
|
+
namespace=self._parent_client.namespace,
|
|
304
|
+
debug=self._parent_client._debug,
|
|
305
|
+
timeout_ms=worker_timeout_ms,
|
|
306
|
+
)
|
|
307
|
+
await self._result_client.connect()
|
|
308
|
+
|
|
239
309
|
try:
|
|
240
310
|
# Register actions with the server
|
|
241
311
|
action_names = list(self._handlers.keys())
|
|
@@ -244,9 +314,15 @@ class Worker:
|
|
|
244
314
|
logger.debug(f"Registered action with server: {action_name}")
|
|
245
315
|
|
|
246
316
|
# Register worker
|
|
317
|
+
from .types import WorkerRegisterOptions
|
|
318
|
+
|
|
247
319
|
await self._client.worker.register(
|
|
248
320
|
self.config.worker_id,
|
|
249
321
|
action_names,
|
|
322
|
+
WorkerRegisterOptions(
|
|
323
|
+
concurrency=self.config.concurrency,
|
|
324
|
+
machine_id=socket.gethostname(),
|
|
325
|
+
),
|
|
250
326
|
)
|
|
251
327
|
logger.info(f"Worker registered with {len(action_names)} actions")
|
|
252
328
|
|
|
@@ -255,20 +331,52 @@ class Worker:
|
|
|
255
331
|
self._running = True
|
|
256
332
|
self._stop_event.clear()
|
|
257
333
|
|
|
334
|
+
# Start heartbeat loop
|
|
335
|
+
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
336
|
+
|
|
258
337
|
# Main polling loop
|
|
259
338
|
await self._poll_loop(action_names)
|
|
260
339
|
|
|
261
340
|
finally:
|
|
341
|
+
# Cancel heartbeat
|
|
342
|
+
if self._heartbeat_task is not None:
|
|
343
|
+
self._heartbeat_task.cancel()
|
|
344
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
345
|
+
await self._heartbeat_task
|
|
346
|
+
self._heartbeat_task = None
|
|
347
|
+
|
|
262
348
|
# Wait for running tasks
|
|
263
349
|
if self._tasks:
|
|
264
350
|
logger.info(f"Waiting for {len(self._tasks)} tasks to complete...")
|
|
265
351
|
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
266
352
|
|
|
353
|
+
if self._result_client:
|
|
354
|
+
await self._result_client.close()
|
|
355
|
+
self._result_client = None
|
|
267
356
|
await self._client.close()
|
|
268
357
|
self._client = None
|
|
269
358
|
self._running = False
|
|
270
359
|
logger.info("Worker stopped")
|
|
271
360
|
|
|
361
|
+
async def _heartbeat_loop(self) -> None:
|
|
362
|
+
"""Send periodic heartbeats to keep the worker registration alive."""
|
|
363
|
+
assert self._client is not None
|
|
364
|
+
while self._running and not self._stop_event.is_set():
|
|
365
|
+
try:
|
|
366
|
+
await asyncio.sleep(30)
|
|
367
|
+
if not self._running:
|
|
368
|
+
break
|
|
369
|
+
current_load = len(self._tasks)
|
|
370
|
+
await self._client.worker.heartbeat(
|
|
371
|
+
self.config.worker_id,
|
|
372
|
+
current_load=current_load,
|
|
373
|
+
)
|
|
374
|
+
logger.debug(f"Heartbeat sent (load={current_load})")
|
|
375
|
+
except asyncio.CancelledError:
|
|
376
|
+
break
|
|
377
|
+
except Exception as e:
|
|
378
|
+
logger.warning(f"Heartbeat failed: {e}")
|
|
379
|
+
|
|
272
380
|
async def _poll_loop(self, action_names: list[str]) -> None:
|
|
273
381
|
"""Main polling loop for tasks."""
|
|
274
382
|
assert self._client is not None
|
|
@@ -304,13 +412,61 @@ class Worker:
|
|
|
304
412
|
break
|
|
305
413
|
except Exception as e:
|
|
306
414
|
self._semaphore.release()
|
|
307
|
-
|
|
308
|
-
|
|
415
|
+
if is_connection_error(e):
|
|
416
|
+
logger.warning("Connection lost, reconnecting...")
|
|
417
|
+
try:
|
|
418
|
+
await self._client.reconnect()
|
|
419
|
+
# Also reconnect result client
|
|
420
|
+
if self._result_client is not None:
|
|
421
|
+
try:
|
|
422
|
+
await self._result_client.reconnect()
|
|
423
|
+
except Exception as rc_err:
|
|
424
|
+
logger.warning(f"Failed to reconnect result client: {rc_err}")
|
|
425
|
+
# Re-register worker after reconnect
|
|
426
|
+
try:
|
|
427
|
+
from .types import WorkerRegisterOptions
|
|
428
|
+
|
|
429
|
+
await self._client.worker.register(
|
|
430
|
+
self.config.worker_id,
|
|
431
|
+
action_names,
|
|
432
|
+
WorkerRegisterOptions(
|
|
433
|
+
concurrency=self.config.concurrency,
|
|
434
|
+
machine_id=socket.gethostname(),
|
|
435
|
+
),
|
|
436
|
+
)
|
|
437
|
+
except Exception as reg_err:
|
|
438
|
+
logger.warning(f"Failed to re-register worker: {reg_err}")
|
|
439
|
+
logger.info("Reconnected, resuming work")
|
|
440
|
+
except Exception as recon_err:
|
|
441
|
+
logger.error(f"Reconnect failed: {recon_err}, retrying...")
|
|
442
|
+
await asyncio.sleep(1)
|
|
443
|
+
else:
|
|
444
|
+
logger.error(f"Await error: {e}, retrying...")
|
|
445
|
+
await asyncio.sleep(1)
|
|
446
|
+
|
|
447
|
+
async def _send_with_retry(self, op: str, fn: Callable[[], Awaitable[None]]) -> None:
|
|
448
|
+
"""Attempt to send a result (Complete/Fail), reconnecting on connection error."""
|
|
449
|
+
max_attempts = 3
|
|
450
|
+
for attempt in range(1, max_attempts + 1):
|
|
451
|
+
try:
|
|
452
|
+
await fn()
|
|
453
|
+
return
|
|
454
|
+
except Exception as e:
|
|
455
|
+
if not is_connection_error(e) or not self._running:
|
|
456
|
+
raise
|
|
457
|
+
logger.warning(
|
|
458
|
+
f"Connection lost while sending {op} result "
|
|
459
|
+
f"(attempt {attempt}/{max_attempts}), reconnecting..."
|
|
460
|
+
)
|
|
461
|
+
if self._result_client is not None:
|
|
462
|
+
await self._result_client.reconnect()
|
|
463
|
+
raise RuntimeError(f"Failed to send {op} result after {max_attempts} attempts")
|
|
309
464
|
|
|
310
465
|
async def _execute_task(self, task: TaskAssignment) -> None:
|
|
311
466
|
"""Execute a task with error handling."""
|
|
312
467
|
assert self._client is not None
|
|
313
468
|
assert self._semaphore is not None
|
|
469
|
+
rc = self._result_client or self._client
|
|
314
470
|
try:
|
|
315
471
|
logger.info(
|
|
316
472
|
f"Executing action: {task.task_type} (task={task.task_id}, attempt={task.attempt})"
|
|
@@ -320,10 +476,14 @@ class Worker:
|
|
|
320
476
|
handler = self._handlers.get(task.task_type)
|
|
321
477
|
if handler is None:
|
|
322
478
|
logger.error(f"No handler registered for action: {task.task_type}")
|
|
323
|
-
await self.
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
479
|
+
await self._send_with_retry(
|
|
480
|
+
"fail",
|
|
481
|
+
lambda: rc.worker.fail(
|
|
482
|
+
self.config.worker_id,
|
|
483
|
+
task.task_type,
|
|
484
|
+
task.task_id,
|
|
485
|
+
f"No handler for: {task.task_type}",
|
|
486
|
+
),
|
|
327
487
|
)
|
|
328
488
|
return
|
|
329
489
|
|
|
@@ -345,36 +505,85 @@ class Worker:
|
|
|
345
505
|
timeout=self.config.action_timeout,
|
|
346
506
|
)
|
|
347
507
|
|
|
348
|
-
#
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
508
|
+
# 3-way dispatch based on result type
|
|
509
|
+
from .types import WorkerCompleteOptions
|
|
510
|
+
|
|
511
|
+
if isinstance(result, ActionResult):
|
|
512
|
+
# Named outcome
|
|
513
|
+
await self._send_with_retry(
|
|
514
|
+
"complete",
|
|
515
|
+
lambda: rc.worker.complete(
|
|
516
|
+
self.config.worker_id,
|
|
517
|
+
task.task_type,
|
|
518
|
+
task.task_id,
|
|
519
|
+
result.data,
|
|
520
|
+
WorkerCompleteOptions(outcome=result.outcome),
|
|
521
|
+
),
|
|
522
|
+
)
|
|
523
|
+
elif isinstance(result, dict):
|
|
524
|
+
# Plain dict → JSON serialize
|
|
525
|
+
result_bytes = json.dumps(result).encode("utf-8")
|
|
526
|
+
await self._send_with_retry(
|
|
527
|
+
"complete",
|
|
528
|
+
lambda: rc.worker.complete(
|
|
529
|
+
self.config.worker_id,
|
|
530
|
+
task.task_type,
|
|
531
|
+
task.task_id,
|
|
532
|
+
result_bytes,
|
|
533
|
+
),
|
|
534
|
+
)
|
|
535
|
+
else:
|
|
536
|
+
# bytes or other → pass through
|
|
537
|
+
await self._send_with_retry(
|
|
538
|
+
"complete",
|
|
539
|
+
lambda: rc.worker.complete(
|
|
540
|
+
self.config.worker_id,
|
|
541
|
+
task.task_type,
|
|
542
|
+
task.task_id,
|
|
543
|
+
result if isinstance(result, bytes) else b"",
|
|
544
|
+
),
|
|
545
|
+
)
|
|
354
546
|
logger.info(f"Action completed: {task.task_type}")
|
|
355
547
|
|
|
356
548
|
except asyncio.TimeoutError:
|
|
357
549
|
logger.error(f"Action timed out: {task.task_type}")
|
|
358
|
-
await self.
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
550
|
+
await self._send_with_retry(
|
|
551
|
+
"fail",
|
|
552
|
+
lambda: rc.worker.fail(
|
|
553
|
+
self.config.worker_id,
|
|
554
|
+
task.task_type,
|
|
555
|
+
task.task_id,
|
|
556
|
+
"Action timed out",
|
|
557
|
+
),
|
|
362
558
|
)
|
|
363
559
|
|
|
364
560
|
except asyncio.CancelledError:
|
|
365
561
|
logger.warning(f"Action cancelled: {task.task_type}")
|
|
366
|
-
await self.
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
562
|
+
await self._send_with_retry(
|
|
563
|
+
"fail",
|
|
564
|
+
lambda: rc.worker.fail(
|
|
565
|
+
self.config.worker_id,
|
|
566
|
+
task.task_type,
|
|
567
|
+
task.task_id,
|
|
568
|
+
"Action cancelled",
|
|
569
|
+
),
|
|
370
570
|
)
|
|
371
571
|
|
|
372
|
-
except Exception as
|
|
373
|
-
logger.error(f"Action failed: {task.task_type} - {
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
572
|
+
except Exception as exc:
|
|
573
|
+
logger.error(f"Action failed: {task.task_type} - {exc}")
|
|
574
|
+
from .types import WorkerFailOptions
|
|
575
|
+
|
|
576
|
+
retry = not isinstance(exc, NonRetryableError)
|
|
577
|
+
err_msg = str(exc)
|
|
578
|
+
await self._send_with_retry(
|
|
579
|
+
"fail",
|
|
580
|
+
lambda: rc.worker.fail(
|
|
581
|
+
self.config.worker_id,
|
|
582
|
+
task.task_type,
|
|
583
|
+
task.task_id,
|
|
584
|
+
err_msg,
|
|
585
|
+
WorkerFailOptions(retry=retry),
|
|
586
|
+
),
|
|
378
587
|
)
|
|
379
588
|
|
|
380
589
|
except Exception as e:
|
|
@@ -384,12 +593,13 @@ class Worker:
|
|
|
384
593
|
if self._semaphore is not None:
|
|
385
594
|
self._semaphore.release()
|
|
386
595
|
|
|
387
|
-
async def _touch_task(self, task_id: str, extend_ms: int) -> None:
|
|
596
|
+
async def _touch_task(self, action_name: str, task_id: str, extend_ms: int) -> None:
|
|
388
597
|
"""Extend lease on a task (internal method)."""
|
|
389
598
|
if self._client is None:
|
|
390
599
|
raise RuntimeError("Worker not connected")
|
|
391
600
|
await self._client.worker.touch(
|
|
392
601
|
self.config.worker_id,
|
|
602
|
+
action_name,
|
|
393
603
|
task_id,
|
|
394
604
|
WorkerTouchOptions(extend_ms=extend_ms),
|
|
395
605
|
)
|
|
@@ -398,14 +608,400 @@ class Worker:
|
|
|
398
608
|
"""Signal the worker to stop.
|
|
399
609
|
|
|
400
610
|
This sets a flag that will cause the polling loop to exit
|
|
401
|
-
after the current iteration completes.
|
|
611
|
+
after the current iteration completes. Also interrupts in-flight
|
|
612
|
+
connections to unblock any blocking Await call immediately.
|
|
402
613
|
"""
|
|
403
614
|
logger.info("Stopping worker...")
|
|
404
615
|
self._running = False
|
|
405
616
|
self._stop_event.set()
|
|
617
|
+
# Interrupt connections to unblock any blocking Await
|
|
618
|
+
if self._client:
|
|
619
|
+
self._client.interrupt()
|
|
620
|
+
if self._result_client:
|
|
621
|
+
self._result_client.interrupt()
|
|
406
622
|
|
|
407
623
|
async def close(self) -> None:
|
|
408
624
|
"""Stop and close the worker."""
|
|
409
625
|
self.stop()
|
|
626
|
+
if self._result_client:
|
|
627
|
+
await self._result_client.close()
|
|
628
|
+
if self._client:
|
|
629
|
+
await self._client.close()
|
|
630
|
+
|
|
631
|
+
async def __aenter__(self) -> "ActionWorker":
|
|
632
|
+
"""Async context manager entry."""
|
|
633
|
+
return self
|
|
634
|
+
|
|
635
|
+
async def __aexit__(
|
|
636
|
+
self,
|
|
637
|
+
exc_type: type[BaseException] | None,
|
|
638
|
+
exc_val: BaseException | None,
|
|
639
|
+
exc_tb: object,
|
|
640
|
+
) -> None:
|
|
641
|
+
"""Async context manager exit — closes the worker."""
|
|
642
|
+
await self.close()
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
# =============================================================================
|
|
646
|
+
# Stream Worker
|
|
647
|
+
# =============================================================================
|
|
648
|
+
|
|
649
|
+
# Type alias for stream record handlers
|
|
650
|
+
# Return normally to auto-ack, raise to auto-nack.
|
|
651
|
+
StreamRecordHandler = Callable[["StreamContext"], Awaitable[None]]
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
@dataclass
|
|
655
|
+
class StreamWorkerOptions:
|
|
656
|
+
"""Configuration for a Flo stream worker.
|
|
657
|
+
|
|
658
|
+
Endpoint and namespace are inherited from the parent FloClient.
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
stream: str
|
|
662
|
+
group: str = ""
|
|
663
|
+
consumer: str = ""
|
|
664
|
+
worker_id: str = ""
|
|
665
|
+
concurrency: int = 10
|
|
666
|
+
batch_size: int = 10
|
|
667
|
+
block_ms: int = 30000
|
|
668
|
+
message_timeout: float = 300.0 # 5 minutes
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@dataclass
|
|
672
|
+
class StreamContext:
|
|
673
|
+
"""Context passed to stream record handlers.
|
|
674
|
+
|
|
675
|
+
Provides access to record data and helper methods.
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
record: StreamRecord
|
|
679
|
+
namespace: str
|
|
680
|
+
stream: str
|
|
681
|
+
group: str
|
|
682
|
+
consumer: str
|
|
683
|
+
|
|
684
|
+
@property
|
|
685
|
+
def stream_id(self) -> StreamID:
|
|
686
|
+
"""Get the record's StreamID."""
|
|
687
|
+
return self.record.id
|
|
688
|
+
|
|
689
|
+
@property
|
|
690
|
+
def payload(self) -> bytes:
|
|
691
|
+
"""Get the raw record payload."""
|
|
692
|
+
return self.record.payload
|
|
693
|
+
|
|
694
|
+
def json(self) -> Any:
|
|
695
|
+
"""Parse record payload as JSON."""
|
|
696
|
+
if not self.record.payload:
|
|
697
|
+
raise ValueError("No payload data")
|
|
698
|
+
return json.loads(self.record.payload.decode("utf-8"))
|
|
699
|
+
|
|
700
|
+
def into(self, cls: type) -> Any:
|
|
701
|
+
"""Parse payload as JSON and instantiate the given class."""
|
|
702
|
+
data = self.json()
|
|
703
|
+
if isinstance(data, dict):
|
|
704
|
+
return cls(**data)
|
|
705
|
+
return cls(data)
|
|
706
|
+
|
|
707
|
+
@property
|
|
708
|
+
def headers(self) -> dict[str, str]:
|
|
709
|
+
"""Get record headers."""
|
|
710
|
+
return self.record.headers if self.record.headers else {}
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
class StreamWorker:
|
|
714
|
+
"""High-level Flo worker for processing stream records via consumer groups.
|
|
715
|
+
|
|
716
|
+
Polls a consumer group with ``group_read()``, dispatches records to the
|
|
717
|
+
handler, and auto-acks on success or auto-nacks on error.
|
|
718
|
+
|
|
719
|
+
Created from a FloClient via ``client.new_stream_worker()``.
|
|
720
|
+
|
|
721
|
+
Example:
|
|
722
|
+
async with FloClient("localhost:3000", namespace="myapp") as client:
|
|
723
|
+
async def process_event(ctx: StreamContext) -> None:
|
|
724
|
+
event = ctx.json()
|
|
725
|
+
print(f"Got event: {event}")
|
|
726
|
+
# Return normally → auto-ack
|
|
727
|
+
# Raise → auto-nack
|
|
728
|
+
|
|
729
|
+
worker = client.new_stream_worker(
|
|
730
|
+
stream="events",
|
|
731
|
+
group="processors",
|
|
732
|
+
consumer="worker-1",
|
|
733
|
+
handler=process_event,
|
|
734
|
+
concurrency=5,
|
|
735
|
+
)
|
|
736
|
+
await worker.start()
|
|
737
|
+
"""
|
|
738
|
+
|
|
739
|
+
def __init__(
|
|
740
|
+
self,
|
|
741
|
+
parent_client: "FloClient",
|
|
742
|
+
handler: StreamRecordHandler,
|
|
743
|
+
*,
|
|
744
|
+
stream: str,
|
|
745
|
+
group: str,
|
|
746
|
+
consumer: str | None = None,
|
|
747
|
+
worker_id: str | None = None,
|
|
748
|
+
concurrency: int = 10,
|
|
749
|
+
batch_size: int = 10,
|
|
750
|
+
block_ms: int = 30000,
|
|
751
|
+
message_timeout: float = 300.0,
|
|
752
|
+
):
|
|
753
|
+
self._parent_client = parent_client
|
|
754
|
+
self._handler = handler
|
|
755
|
+
self.config = StreamWorkerOptions(
|
|
756
|
+
stream=stream,
|
|
757
|
+
group=group,
|
|
758
|
+
consumer=consumer or self._generate_consumer_id(),
|
|
759
|
+
worker_id=worker_id or self._generate_consumer_id(),
|
|
760
|
+
concurrency=concurrency,
|
|
761
|
+
batch_size=batch_size,
|
|
762
|
+
block_ms=block_ms,
|
|
763
|
+
message_timeout=message_timeout,
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
self._client: FloClient | None = None
|
|
767
|
+
self._running = False
|
|
768
|
+
self._stop_event = asyncio.Event()
|
|
769
|
+
self._tasks: set[asyncio.Task[None]] = set()
|
|
770
|
+
self._semaphore: asyncio.Semaphore | None = None
|
|
771
|
+
|
|
772
|
+
@staticmethod
|
|
773
|
+
def _generate_consumer_id() -> str:
|
|
774
|
+
"""Generate a unique consumer ID."""
|
|
775
|
+
try:
|
|
776
|
+
hostname = socket.gethostname()
|
|
777
|
+
except Exception:
|
|
778
|
+
hostname = "unknown"
|
|
779
|
+
return f"{hostname}-{secrets.token_hex(4)}"
|
|
780
|
+
|
|
781
|
+
async def start(self) -> None:
|
|
782
|
+
"""Start the stream worker and begin processing records.
|
|
783
|
+
|
|
784
|
+
Joins the consumer group, then polls for records. Blocks until
|
|
785
|
+
``stop()`` is called.
|
|
786
|
+
"""
|
|
787
|
+
logger.info(
|
|
788
|
+
f"Starting stream worker (stream={self.config.stream}, "
|
|
789
|
+
f"group={self.config.group}, consumer={self.config.consumer}, "
|
|
790
|
+
f"concurrency={self.config.concurrency})"
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Timeout must accommodate block_ms + message_timeout so blocking reads
|
|
794
|
+
# (group_read with block_ms) don't get killed by socket-level timeout.
|
|
795
|
+
worker_timeout_ms = max(
|
|
796
|
+
self.config.block_ms + 5000,
|
|
797
|
+
int(self.config.message_timeout * 1000),
|
|
798
|
+
)
|
|
799
|
+
self._client = FloClient(
|
|
800
|
+
self._parent_client._endpoint,
|
|
801
|
+
namespace=self._parent_client.namespace,
|
|
802
|
+
debug=self._parent_client._debug,
|
|
803
|
+
timeout_ms=worker_timeout_ms,
|
|
804
|
+
)
|
|
805
|
+
await self._client.connect()
|
|
806
|
+
|
|
807
|
+
try:
|
|
808
|
+
# Join consumer group
|
|
809
|
+
await self._client.stream.group_join(
|
|
810
|
+
self.config.stream,
|
|
811
|
+
self.config.group,
|
|
812
|
+
self.config.consumer,
|
|
813
|
+
)
|
|
814
|
+
logger.info(f"Joined consumer group {self.config.group} on stream {self.config.stream}")
|
|
815
|
+
|
|
816
|
+
self._semaphore = asyncio.Semaphore(self.config.concurrency)
|
|
817
|
+
self._running = True
|
|
818
|
+
self._stop_event.clear()
|
|
819
|
+
|
|
820
|
+
await self._poll_loop()
|
|
821
|
+
|
|
822
|
+
finally:
|
|
823
|
+
# Wait for in-flight tasks
|
|
824
|
+
if self._tasks:
|
|
825
|
+
logger.info(f"Waiting for {len(self._tasks)} tasks to complete...")
|
|
826
|
+
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
827
|
+
|
|
828
|
+
# Leave consumer group
|
|
829
|
+
if self._client:
|
|
830
|
+
try:
|
|
831
|
+
await self._client.stream.group_leave(
|
|
832
|
+
self.config.stream,
|
|
833
|
+
self.config.group,
|
|
834
|
+
self.config.consumer,
|
|
835
|
+
)
|
|
836
|
+
logger.info(f"Left consumer group {self.config.group}")
|
|
837
|
+
except Exception as e:
|
|
838
|
+
logger.warning(f"Failed to leave consumer group: {e}")
|
|
839
|
+
|
|
840
|
+
await self._client.close()
|
|
841
|
+
self._client = None
|
|
842
|
+
|
|
843
|
+
self._running = False
|
|
844
|
+
logger.info("Stream worker stopped")
|
|
845
|
+
|
|
846
|
+
async def _poll_loop(self) -> None:
|
|
847
|
+
"""Main polling loop for stream records."""
|
|
848
|
+
assert self._client is not None
|
|
849
|
+
assert self._semaphore is not None
|
|
850
|
+
|
|
851
|
+
while self._running and not self._stop_event.is_set():
|
|
852
|
+
try:
|
|
853
|
+
# Wait for a concurrency slot
|
|
854
|
+
await self._semaphore.acquire()
|
|
855
|
+
|
|
856
|
+
if self._stop_event.is_set():
|
|
857
|
+
self._semaphore.release()
|
|
858
|
+
break
|
|
859
|
+
|
|
860
|
+
result = await self._client.stream.group_read(
|
|
861
|
+
self.config.stream,
|
|
862
|
+
self.config.group,
|
|
863
|
+
self.config.consumer,
|
|
864
|
+
StreamGroupReadOptions(
|
|
865
|
+
count=self.config.batch_size,
|
|
866
|
+
block_ms=self.config.block_ms,
|
|
867
|
+
),
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if not result.records:
|
|
871
|
+
self._semaphore.release()
|
|
872
|
+
continue
|
|
873
|
+
|
|
874
|
+
# Release semaphore before dispatching — each task will
|
|
875
|
+
# acquire its own slot via _process_record.
|
|
876
|
+
self._semaphore.release()
|
|
877
|
+
|
|
878
|
+
for record in result.records:
|
|
879
|
+
await self._semaphore.acquire()
|
|
880
|
+
if self._stop_event.is_set():
|
|
881
|
+
self._semaphore.release()
|
|
882
|
+
return
|
|
883
|
+
task = asyncio.create_task(self._process_record(record))
|
|
884
|
+
self._tasks.add(task)
|
|
885
|
+
task.add_done_callback(self._tasks.discard)
|
|
886
|
+
|
|
887
|
+
except asyncio.CancelledError:
|
|
888
|
+
break
|
|
889
|
+
except Exception as e:
|
|
890
|
+
self._semaphore.release()
|
|
891
|
+
if is_connection_error(e):
|
|
892
|
+
logger.warning("Stream worker lost connection, reconnecting...")
|
|
893
|
+
try:
|
|
894
|
+
await self._handle_reconnect()
|
|
895
|
+
logger.info("Stream worker reconnected, resuming")
|
|
896
|
+
except Exception as recon_err:
|
|
897
|
+
logger.error(f"Stream worker reconnect failed: {recon_err}, retrying...")
|
|
898
|
+
await asyncio.sleep(1)
|
|
899
|
+
else:
|
|
900
|
+
logger.error(f"Stream read error: {e}, retrying...")
|
|
901
|
+
await asyncio.sleep(1)
|
|
902
|
+
|
|
903
|
+
async def _handle_reconnect(self) -> None:
|
|
904
|
+
"""Reconnect and re-join the consumer group."""
|
|
905
|
+
assert self._client is not None
|
|
906
|
+
await self._client.reconnect()
|
|
907
|
+
await self._client.stream.group_join(
|
|
908
|
+
self.config.stream,
|
|
909
|
+
self.config.group,
|
|
910
|
+
self.config.consumer,
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
async def _ack_with_retry(
|
|
914
|
+
self, record_ids: list[StreamID], options: StreamGroupAckOptions
|
|
915
|
+
) -> None:
|
|
916
|
+
"""Ack with retry on connection error."""
|
|
917
|
+
max_attempts = 3
|
|
918
|
+
for attempt in range(1, max_attempts + 1):
|
|
919
|
+
try:
|
|
920
|
+
assert self._client is not None
|
|
921
|
+
await self._client.stream.group_ack(
|
|
922
|
+
self.config.stream,
|
|
923
|
+
self.config.group,
|
|
924
|
+
record_ids,
|
|
925
|
+
options,
|
|
926
|
+
)
|
|
927
|
+
return
|
|
928
|
+
except Exception as e:
|
|
929
|
+
if not is_connection_error(e) or not self._running:
|
|
930
|
+
raise
|
|
931
|
+
logger.warning(
|
|
932
|
+
"Connection lost while acking "
|
|
933
|
+
f"(attempt {attempt}/{max_attempts}), reconnecting..."
|
|
934
|
+
)
|
|
935
|
+
await self._handle_reconnect()
|
|
936
|
+
|
|
937
|
+
async def _process_record(self, record: StreamRecord) -> None:
|
|
938
|
+
"""Process a single record: call handler, then ack or nack."""
|
|
939
|
+
assert self._client is not None
|
|
940
|
+
assert self._semaphore is not None
|
|
941
|
+
try:
|
|
942
|
+
ctx = StreamContext(
|
|
943
|
+
record=record,
|
|
944
|
+
namespace=self._parent_client.namespace,
|
|
945
|
+
stream=self.config.stream,
|
|
946
|
+
group=self.config.group,
|
|
947
|
+
consumer=self.config.consumer,
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
try:
|
|
951
|
+
await asyncio.wait_for(
|
|
952
|
+
self._handler(ctx),
|
|
953
|
+
timeout=self.config.message_timeout,
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
# Success — ack with retry
|
|
957
|
+
await self._ack_with_retry(
|
|
958
|
+
[record.id],
|
|
959
|
+
StreamGroupAckOptions(consumer=self.config.consumer),
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
except Exception as e:
|
|
963
|
+
logger.error(
|
|
964
|
+
f"Record processing failed (stream={self.config.stream}, id={record.id}): {e}"
|
|
965
|
+
)
|
|
966
|
+
try:
|
|
967
|
+
await self._client.stream.group_nack(
|
|
968
|
+
self.config.stream,
|
|
969
|
+
self.config.group,
|
|
970
|
+
[record.id],
|
|
971
|
+
StreamGroupNackOptions(consumer=self.config.consumer),
|
|
972
|
+
)
|
|
973
|
+
except Exception as nack_err:
|
|
974
|
+
logger.error(f"Failed to nack record: {nack_err}")
|
|
975
|
+
|
|
976
|
+
except Exception as e:
|
|
977
|
+
logger.error(f"Failed to process record: {e}")
|
|
978
|
+
|
|
979
|
+
finally:
|
|
980
|
+
self._semaphore.release()
|
|
981
|
+
|
|
982
|
+
def stop(self) -> None:
|
|
983
|
+
"""Signal the stream worker to stop."""
|
|
984
|
+
logger.info("Stopping stream worker...")
|
|
985
|
+
self._running = False
|
|
986
|
+
self._stop_event.set()
|
|
987
|
+
if self._client:
|
|
988
|
+
self._client.interrupt()
|
|
989
|
+
|
|
990
|
+
async def close(self) -> None:
|
|
991
|
+
"""Stop and close the stream worker."""
|
|
992
|
+
self.stop()
|
|
410
993
|
if self._client:
|
|
411
994
|
await self._client.close()
|
|
995
|
+
|
|
996
|
+
async def __aenter__(self) -> "StreamWorker":
|
|
997
|
+
"""Async context manager entry."""
|
|
998
|
+
return self
|
|
999
|
+
|
|
1000
|
+
async def __aexit__(
|
|
1001
|
+
self,
|
|
1002
|
+
exc_type: type[BaseException] | None,
|
|
1003
|
+
exc_val: BaseException | None,
|
|
1004
|
+
exc_tb: object,
|
|
1005
|
+
) -> None:
|
|
1006
|
+
"""Async context manager exit — closes the stream worker."""
|
|
1007
|
+
await self.close()
|