flo-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flo/worker.py ADDED
@@ -0,0 +1,421 @@
1
+ """Flo High-Level Worker API
2
+
3
+ Provides an easy-to-use Worker class for executing actions.
4
+
5
+ Example:
6
+ from flo import Worker, ActionContext
7
+
8
+ async def process_order(ctx: ActionContext) -> bytes:
9
+ order = ctx.json()
10
+ # Process the order...
11
+ return ctx.to_bytes({"status": "completed"})
12
+
13
+ async def main():
14
+ worker = Worker(
15
+ endpoint="localhost:3000",
16
+ namespace="myapp",
17
+ )
18
+ worker.action("process-order")(process_order)
19
+
20
+ await worker.start()
21
+ """
22
+
23
+ import asyncio
24
+ import json
25
+ import logging
26
+ import secrets
27
+ import socket
28
+ from collections.abc import Awaitable, Callable
29
+ from dataclasses import dataclass, field
30
+ from typing import Any
31
+
32
+ from .client import FloClient
33
+ from .types import ActionType, TaskAssignment, WorkerAwaitOptions, WorkerTouchOptions
34
+
35
+ logger = logging.getLogger("flo.worker")
36
+
37
+
38
+ # Type alias for action handlers
39
+ ActionHandler = Callable[["ActionContext"], Awaitable[bytes]]
40
+
41
+
42
+ @dataclass
43
+ class WorkerConfig:
44
+ """Configuration for a Flo worker."""
45
+
46
+ endpoint: str
47
+ namespace: str = "default"
48
+ worker_id: str = ""
49
+ concurrency: int = 10
50
+ action_timeout: float = 300.0 # 5 minutes
51
+ block_ms: int = 30000
52
+ debug: bool = False
53
+
54
+
55
+ @dataclass
56
+ class ActionContext:
57
+ """Context passed to action handlers.
58
+
59
+ Provides access to task information and helper methods for
60
+ parsing input and formatting output.
61
+ """
62
+
63
+ task_id: str
64
+ action_name: str
65
+ payload: bytes
66
+ attempt: int
67
+ created_at: int
68
+ namespace: str
69
+ _worker: "Worker" = field(repr=False)
70
+ _cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
71
+
72
+ def input(self) -> bytes:
73
+ """Get the raw input bytes."""
74
+ return self.payload
75
+
76
+ def json(self) -> Any:
77
+ """Parse input as JSON and return the result."""
78
+ if not self.payload:
79
+ raise ValueError("No input data")
80
+ return json.loads(self.payload.decode("utf-8"))
81
+
82
+ def into(self, cls: type) -> Any:
83
+ """Parse input as JSON and instantiate the given class.
84
+
85
+ Args:
86
+ cls: A class that accepts keyword arguments matching the JSON structure.
87
+
88
+ Returns:
89
+ An instance of the class populated with the JSON data.
90
+ """
91
+ data = self.json()
92
+ if isinstance(data, dict):
93
+ return cls(**data)
94
+ return cls(data)
95
+
96
+ @staticmethod
97
+ def to_bytes(value: Any) -> bytes:
98
+ """Serialize a value to JSON bytes."""
99
+ return json.dumps(value).encode("utf-8")
100
+
101
+ async def touch(self, extend_ms: int = 30000) -> None:
102
+ """Extend the lease on this task.
103
+
104
+ Use this for long-running tasks to prevent timeout.
105
+
106
+ Args:
107
+ extend_ms: How long to extend the lease in milliseconds.
108
+ """
109
+ await self._worker._touch_task(self.task_id, extend_ms)
110
+
111
+ @property
112
+ def cancelled(self) -> bool:
113
+ """Check if the task has been cancelled."""
114
+ return self._cancel_event.is_set()
115
+
116
+ async def check_cancelled(self) -> None:
117
+ """Check if cancelled and raise asyncio.CancelledError if so."""
118
+ if self._cancel_event.is_set():
119
+ raise asyncio.CancelledError("Task was cancelled")
120
+
121
+
122
+ class Worker:
123
+ """High-level Flo worker for executing actions.
124
+
125
+ Example:
126
+ worker = Worker(endpoint="localhost:3000", namespace="myapp")
127
+
128
+ @worker.action("process-order")
129
+ async def process_order(ctx: ActionContext) -> bytes:
130
+ order = ctx.json()
131
+ # Process the order...
132
+ return ctx.to_bytes({"status": "completed"})
133
+
134
+ await worker.start()
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ endpoint: str,
140
+ *,
141
+ namespace: str = "default",
142
+ worker_id: str | None = None,
143
+ concurrency: int = 10,
144
+ action_timeout: float = 300.0,
145
+ block_ms: int = 30000,
146
+ debug: bool = False,
147
+ ):
148
+ """Initialize a Flo worker.
149
+
150
+ Args:
151
+ endpoint: Server endpoint in "host:port" format.
152
+ namespace: Namespace for operations.
153
+ worker_id: Unique worker identifier (auto-generated if not provided).
154
+ concurrency: Maximum number of concurrent actions.
155
+ action_timeout: Timeout for action handlers in seconds.
156
+ block_ms: Timeout for blocking dequeue in milliseconds.
157
+ debug: Enable debug logging.
158
+ """
159
+ self.config = WorkerConfig(
160
+ endpoint=endpoint,
161
+ namespace=namespace,
162
+ worker_id=worker_id or self._generate_worker_id(),
163
+ concurrency=concurrency,
164
+ action_timeout=action_timeout,
165
+ block_ms=block_ms,
166
+ debug=debug,
167
+ )
168
+
169
+ self._client: FloClient | None = None
170
+ self._handlers: dict[str, ActionHandler] = {}
171
+ self._running = False
172
+ self._stop_event = asyncio.Event()
173
+ self._tasks: set[asyncio.Task[None]] = set()
174
+ self._semaphore: asyncio.Semaphore | None = None
175
+
176
+ if debug:
177
+ logging.basicConfig(level=logging.DEBUG)
178
+
179
+ @staticmethod
180
+ def _generate_worker_id() -> str:
181
+ """Generate a unique worker ID."""
182
+ try:
183
+ hostname = socket.gethostname()
184
+ except Exception:
185
+ hostname = "unknown"
186
+ return f"{hostname}-{secrets.token_hex(4)}"
187
+
188
+ def action(self, name: str) -> Callable[[ActionHandler], ActionHandler]:
189
+ """Decorator to register an action handler.
190
+
191
+ Args:
192
+ name: The action name to register.
193
+
194
+ Returns:
195
+ Decorator function.
196
+
197
+ Example:
198
+ @worker.action("process-order")
199
+ async def process_order(ctx: ActionContext) -> bytes:
200
+ return ctx.to_bytes({"status": "ok"})
201
+ """
202
+
203
+ def decorator(handler: ActionHandler) -> ActionHandler:
204
+ self.register_action(name, handler)
205
+ return handler
206
+
207
+ return decorator
208
+
209
+ def register_action(self, name: str, handler: ActionHandler) -> None:
210
+ """Register an action handler.
211
+
212
+ Args:
213
+ name: The action name.
214
+ handler: Async function that handles the action.
215
+
216
+ Raises:
217
+ ValueError: If action is already registered.
218
+ """
219
+ if name in self._handlers:
220
+ raise ValueError(f"Action '{name}' is already registered")
221
+ self._handlers[name] = handler
222
+ logger.info(f"Registered action: {name}")
223
+
224
+ async def start(self) -> None:
225
+ """Start the worker and begin processing actions.
226
+
227
+ This method blocks until stop() is called or an error occurs.
228
+
229
+ Raises:
230
+ ValueError: If no handlers are registered.
231
+ ConnectionError: If connection to server fails.
232
+ """
233
+ if not self._handlers:
234
+ raise ValueError("No action handlers registered")
235
+
236
+ logger.info(
237
+ f"Starting Flo worker (id={self.config.worker_id}, "
238
+ f"namespace={self.config.namespace}, concurrency={self.config.concurrency})"
239
+ )
240
+
241
+ # Connect to server
242
+ self._client = FloClient(
243
+ self.config.endpoint,
244
+ namespace=self.config.namespace,
245
+ debug=self.config.debug,
246
+ )
247
+ await self._client.connect()
248
+
249
+ try:
250
+ # Register actions with the server
251
+ action_names = list(self._handlers.keys())
252
+ for action_name in action_names:
253
+ await self._client.action.register(action_name, ActionType.USER)
254
+ logger.debug(f"Registered action with server: {action_name}")
255
+
256
+ # Register worker
257
+ await self._client.worker.register(
258
+ self.config.worker_id,
259
+ action_names,
260
+ )
261
+ logger.info(f"Worker registered with {len(action_names)} actions")
262
+
263
+ # Initialize concurrency control
264
+ self._semaphore = asyncio.Semaphore(self.config.concurrency)
265
+ self._running = True
266
+ self._stop_event.clear()
267
+
268
+ # Main polling loop
269
+ await self._poll_loop(action_names)
270
+
271
+ finally:
272
+ # Wait for running tasks
273
+ if self._tasks:
274
+ logger.info(f"Waiting for {len(self._tasks)} tasks to complete...")
275
+ await asyncio.gather(*self._tasks, return_exceptions=True)
276
+
277
+ await self._client.close()
278
+ self._client = None
279
+ self._running = False
280
+ logger.info("Worker stopped")
281
+
282
+ async def _poll_loop(self, action_names: list[str]) -> None:
283
+ """Main polling loop for tasks."""
284
+ assert self._client is not None
285
+ assert self._semaphore is not None
286
+ while self._running and not self._stop_event.is_set():
287
+ try:
288
+ # Wait for semaphore slot
289
+ await self._semaphore.acquire()
290
+
291
+ # Check if we should stop
292
+ if self._stop_event.is_set():
293
+ self._semaphore.release()
294
+ break
295
+
296
+ # Await task from server
297
+ result = await self._client.worker.await_task(
298
+ self.config.worker_id,
299
+ action_names,
300
+ WorkerAwaitOptions(block_ms=self.config.block_ms),
301
+ )
302
+
303
+ if result.task is None:
304
+ # No task available, release semaphore and continue
305
+ self._semaphore.release()
306
+ continue
307
+
308
+ # Execute task in background
309
+ task = asyncio.create_task(self._execute_task(result.task))
310
+ self._tasks.add(task)
311
+ task.add_done_callback(self._tasks.discard)
312
+
313
+ except asyncio.CancelledError:
314
+ break
315
+ except Exception as e:
316
+ self._semaphore.release()
317
+ logger.error(f"Await error: {e}, retrying...")
318
+ await asyncio.sleep(1)
319
+
320
+ async def _execute_task(self, task: TaskAssignment) -> None:
321
+ """Execute a task with error handling."""
322
+ assert self._client is not None
323
+ assert self._semaphore is not None
324
+ try:
325
+ logger.info(
326
+ f"Executing action: {task.task_type} (task={task.task_id}, attempt={task.attempt})"
327
+ )
328
+
329
+ # Get handler
330
+ handler = self._handlers.get(task.task_type)
331
+ if handler is None:
332
+ logger.error(f"No handler registered for action: {task.task_type}")
333
+ await self._client.worker.fail(
334
+ self.config.worker_id,
335
+ task.task_id,
336
+ f"No handler for: {task.task_type}",
337
+ )
338
+ return
339
+
340
+ # Create action context
341
+ ctx = ActionContext(
342
+ task_id=task.task_id,
343
+ action_name=task.task_type,
344
+ payload=task.payload,
345
+ attempt=task.attempt,
346
+ created_at=task.created_at,
347
+ namespace=self.config.namespace,
348
+ _worker=self,
349
+ )
350
+
351
+ # Execute with timeout
352
+ try:
353
+ result = await asyncio.wait_for(
354
+ handler(ctx),
355
+ timeout=self.config.action_timeout,
356
+ )
357
+
358
+ # Success - complete the task
359
+ await self._client.worker.complete(
360
+ self.config.worker_id,
361
+ task.task_id,
362
+ result,
363
+ )
364
+ logger.info(f"Action completed: {task.task_type}")
365
+
366
+ except asyncio.TimeoutError:
367
+ logger.error(f"Action timed out: {task.task_type}")
368
+ await self._client.worker.fail(
369
+ self.config.worker_id,
370
+ task.task_id,
371
+ "Action timed out",
372
+ )
373
+
374
+ except asyncio.CancelledError:
375
+ logger.warning(f"Action cancelled: {task.task_type}")
376
+ await self._client.worker.fail(
377
+ self.config.worker_id,
378
+ task.task_id,
379
+ "Action cancelled",
380
+ )
381
+
382
+ except Exception as e:
383
+ logger.error(f"Action failed: {task.task_type} - {e}")
384
+ await self._client.worker.fail(
385
+ self.config.worker_id,
386
+ task.task_id,
387
+ str(e),
388
+ )
389
+
390
+ except Exception as e:
391
+ logger.error(f"Failed to report task result: {e}")
392
+
393
+ finally:
394
+ if self._semaphore is not None:
395
+ self._semaphore.release()
396
+
397
+ async def _touch_task(self, task_id: str, extend_ms: int) -> None:
398
+ """Extend lease on a task (internal method)."""
399
+ if self._client is None:
400
+ raise RuntimeError("Worker not connected")
401
+ await self._client.worker.touch(
402
+ self.config.worker_id,
403
+ task_id,
404
+ WorkerTouchOptions(extend_ms=extend_ms),
405
+ )
406
+
407
+ def stop(self) -> None:
408
+ """Signal the worker to stop.
409
+
410
+ This sets a flag that will cause the polling loop to exit
411
+ after the current iteration completes.
412
+ """
413
+ logger.info("Stopping worker...")
414
+ self._running = False
415
+ self._stop_event.set()
416
+
417
+ async def close(self) -> None:
418
+ """Stop and close the worker."""
419
+ self.stop()
420
+ if self._client:
421
+ await self._client.close()