codetether 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_server/__init__.py +29 -0
- a2a_server/a2a_agent_card.py +365 -0
- a2a_server/a2a_errors.py +1133 -0
- a2a_server/a2a_executor.py +926 -0
- a2a_server/a2a_router.py +1033 -0
- a2a_server/a2a_types.py +344 -0
- a2a_server/agent_card.py +408 -0
- a2a_server/agents_server.py +271 -0
- a2a_server/auth_api.py +349 -0
- a2a_server/billing_api.py +638 -0
- a2a_server/billing_service.py +712 -0
- a2a_server/billing_webhooks.py +501 -0
- a2a_server/config.py +96 -0
- a2a_server/database.py +2165 -0
- a2a_server/email_inbound.py +398 -0
- a2a_server/email_notifications.py +486 -0
- a2a_server/enhanced_agents.py +919 -0
- a2a_server/enhanced_server.py +160 -0
- a2a_server/hosted_worker.py +1049 -0
- a2a_server/integrated_agents_server.py +347 -0
- a2a_server/keycloak_auth.py +750 -0
- a2a_server/livekit_bridge.py +439 -0
- a2a_server/marketing_tools.py +1364 -0
- a2a_server/mcp_client.py +196 -0
- a2a_server/mcp_http_server.py +2256 -0
- a2a_server/mcp_server.py +191 -0
- a2a_server/message_broker.py +725 -0
- a2a_server/mock_mcp.py +273 -0
- a2a_server/models.py +494 -0
- a2a_server/monitor_api.py +5904 -0
- a2a_server/opencode_bridge.py +1594 -0
- a2a_server/redis_task_manager.py +518 -0
- a2a_server/server.py +726 -0
- a2a_server/task_manager.py +668 -0
- a2a_server/task_queue.py +742 -0
- a2a_server/tenant_api.py +333 -0
- a2a_server/tenant_middleware.py +219 -0
- a2a_server/tenant_service.py +760 -0
- a2a_server/user_auth.py +721 -0
- a2a_server/vault_client.py +576 -0
- a2a_server/worker_sse.py +873 -0
- agent_worker/__init__.py +8 -0
- agent_worker/worker.py +4877 -0
- codetether/__init__.py +10 -0
- codetether/__main__.py +4 -0
- codetether/cli.py +112 -0
- codetether/worker_cli.py +57 -0
- codetether-1.2.2.dist-info/METADATA +570 -0
- codetether-1.2.2.dist-info/RECORD +66 -0
- codetether-1.2.2.dist-info/WHEEL +5 -0
- codetether-1.2.2.dist-info/entry_points.txt +4 -0
- codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
- codetether-1.2.2.dist-info/top_level.txt +5 -0
- codetether_voice_agent/__init__.py +6 -0
- codetether_voice_agent/agent.py +445 -0
- codetether_voice_agent/codetether_mcp.py +345 -0
- codetether_voice_agent/config.py +16 -0
- codetether_voice_agent/functiongemma_caller.py +380 -0
- codetether_voice_agent/session_playback.py +247 -0
- codetether_voice_agent/tools/__init__.py +21 -0
- codetether_voice_agent/tools/definitions.py +135 -0
- codetether_voice_agent/tools/handlers.py +380 -0
- run_server.py +314 -0
- ui/monitor-tailwind.html +1790 -0
- ui/monitor.html +1775 -0
- ui/monitor.js +2662 -0
|
@@ -0,0 +1,518 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Redis-backed Task Manager for A2A Server.
|
|
3
|
+
|
|
4
|
+
Provides persistent task storage using Redis, ensuring tasks survive server restarts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
import logging
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Dict, Optional, List, Callable
|
|
12
|
+
from asyncio import Lock
|
|
13
|
+
import asyncio
|
|
14
|
+
|
|
15
|
+
from .models import Task, TaskStatus, TaskStatusUpdateEvent, Message
|
|
16
|
+
from .task_manager import TaskManager
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import redis.asyncio as aioredis
|
|
22
|
+
|
|
23
|
+
REDIS_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
REDIS_AVAILABLE = False
|
|
26
|
+
logger.warning(
|
|
27
|
+
'redis package not installed. Install with: pip install redis'
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RedisTaskManager(TaskManager):
|
|
32
|
+
"""
|
|
33
|
+
Redis-backed task manager with persistent storage.
|
|
34
|
+
|
|
35
|
+
Tasks are stored as Redis hashes with the key pattern: task:{task_id}
|
|
36
|
+
Task IDs by status are indexed in Redis sets: tasks:status:{status}
|
|
37
|
+
All task IDs are tracked in a set: tasks:all
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, redis_url: str = 'redis://localhost:6379'):
|
|
41
|
+
"""
|
|
42
|
+
Initialize Redis task manager.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
|
|
46
|
+
"""
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
if not REDIS_AVAILABLE:
|
|
50
|
+
raise ImportError(
|
|
51
|
+
'redis package is required for RedisTaskManager. '
|
|
52
|
+
'Install with: pip install redis'
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
self.redis_url = redis_url
|
|
56
|
+
self.redis: Optional[aioredis.Redis] = None
|
|
57
|
+
self._connected = False
|
|
58
|
+
|
|
59
|
+
# Key prefixes
|
|
60
|
+
self.TASK_PREFIX = 'task:'
|
|
61
|
+
self.STATUS_SET_PREFIX = 'tasks:status:'
|
|
62
|
+
self.ALL_TASKS_SET = 'tasks:all'
|
|
63
|
+
|
|
64
|
+
# Lua scripts for atomic operations
|
|
65
|
+
self._claim_script = None
|
|
66
|
+
self._release_script = None
|
|
67
|
+
|
|
68
|
+
async def connect(self):
|
|
69
|
+
"""Establish connection to Redis."""
|
|
70
|
+
if self._connected and self.redis:
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
self.redis = await aioredis.from_url(
|
|
75
|
+
self.redis_url, encoding='utf-8', decode_responses=True
|
|
76
|
+
)
|
|
77
|
+
# Test connection
|
|
78
|
+
await self.redis.ping()
|
|
79
|
+
self._connected = True
|
|
80
|
+
logger.info(f'Connected to Redis at {self.redis_url}')
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f'Failed to connect to Redis: {e}')
|
|
83
|
+
raise
|
|
84
|
+
|
|
85
|
+
async def disconnect(self):
|
|
86
|
+
"""Close Redis connection."""
|
|
87
|
+
if self.redis:
|
|
88
|
+
await self.redis.close()
|
|
89
|
+
self._connected = False
|
|
90
|
+
logger.info('Disconnected from Redis')
|
|
91
|
+
|
|
92
|
+
def _task_key(self, task_id: str) -> str:
|
|
93
|
+
"""Generate Redis key for a task."""
|
|
94
|
+
return f'{self.TASK_PREFIX}{task_id}'
|
|
95
|
+
|
|
96
|
+
def _status_set_key(self, status: TaskStatus) -> str:
|
|
97
|
+
"""Generate Redis set key for tasks with a specific status."""
|
|
98
|
+
return f'{self.STATUS_SET_PREFIX}{status.value}'
|
|
99
|
+
|
|
100
|
+
def _serialize_task(self, task: Task) -> Dict[str, str]:
|
|
101
|
+
"""Serialize task to Redis hash format."""
|
|
102
|
+
return {
|
|
103
|
+
'id': task.id,
|
|
104
|
+
'status': task.status.value,
|
|
105
|
+
'title': task.title or '',
|
|
106
|
+
'description': task.description or '',
|
|
107
|
+
'created_at': task.created_at.isoformat(),
|
|
108
|
+
'updated_at': task.updated_at.isoformat(),
|
|
109
|
+
'progress': str(task.progress or 0.0),
|
|
110
|
+
# Store messages as JSON if present
|
|
111
|
+
'messages': json.dumps(
|
|
112
|
+
[msg.model_dump(mode='json') for msg in (task.messages or [])]
|
|
113
|
+
),
|
|
114
|
+
'worker_id': task.worker_id or '',
|
|
115
|
+
'claimed_at': task.claimed_at.isoformat()
|
|
116
|
+
if task.claimed_at
|
|
117
|
+
else '',
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
def _deserialize_task(self, data: Dict[str, str]) -> Task:
|
|
121
|
+
"""Deserialize task from Redis hash format."""
|
|
122
|
+
messages_json = data.get('messages', '[]')
|
|
123
|
+
messages = []
|
|
124
|
+
try:
|
|
125
|
+
messages_data = json.loads(messages_json)
|
|
126
|
+
messages = [Message.model_validate(msg) for msg in messages_data]
|
|
127
|
+
except (json.JSONDecodeError, Exception) as e:
|
|
128
|
+
logger.warning(f'Failed to deserialize messages: {e}')
|
|
129
|
+
|
|
130
|
+
# Get fields, preserving empty strings as valid values
|
|
131
|
+
title = data.get('title')
|
|
132
|
+
description = data.get('description')
|
|
133
|
+
worker_id = data.get('worker_id')
|
|
134
|
+
claimed_at_str = data.get('claimed_at')
|
|
135
|
+
|
|
136
|
+
return Task(
|
|
137
|
+
id=data['id'],
|
|
138
|
+
status=TaskStatus(data['status']),
|
|
139
|
+
title=title if title else None,
|
|
140
|
+
description=description if description else None,
|
|
141
|
+
created_at=datetime.fromisoformat(data['created_at']),
|
|
142
|
+
updated_at=datetime.fromisoformat(data['updated_at']),
|
|
143
|
+
progress=float(data.get('progress', 0.0)),
|
|
144
|
+
messages=messages if messages else None,
|
|
145
|
+
worker_id=worker_id if worker_id else None,
|
|
146
|
+
claimed_at=datetime.fromisoformat(claimed_at_str)
|
|
147
|
+
if claimed_at_str
|
|
148
|
+
else None,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def create_task(
|
|
152
|
+
self,
|
|
153
|
+
title: Optional[str] = None,
|
|
154
|
+
description: Optional[str] = None,
|
|
155
|
+
task_id: Optional[str] = None,
|
|
156
|
+
) -> Task:
|
|
157
|
+
"""Create a new task and store it in Redis."""
|
|
158
|
+
if not self._connected:
|
|
159
|
+
await self.connect()
|
|
160
|
+
|
|
161
|
+
if task_id is None:
|
|
162
|
+
task_id = str(uuid.uuid4())
|
|
163
|
+
|
|
164
|
+
now = datetime.utcnow()
|
|
165
|
+
task = Task(
|
|
166
|
+
id=task_id,
|
|
167
|
+
status=TaskStatus.PENDING,
|
|
168
|
+
created_at=now,
|
|
169
|
+
updated_at=now,
|
|
170
|
+
title=title,
|
|
171
|
+
description=description,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
async with self._task_lock:
|
|
175
|
+
# Store task in Redis
|
|
176
|
+
task_data = self._serialize_task(task)
|
|
177
|
+
await self.redis.hset(self._task_key(task_id), mapping=task_data)
|
|
178
|
+
|
|
179
|
+
# Add to status index
|
|
180
|
+
await self.redis.sadd(
|
|
181
|
+
self._status_set_key(TaskStatus.PENDING), task_id
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Add to all tasks index
|
|
185
|
+
await self.redis.sadd(self.ALL_TASKS_SET, task_id)
|
|
186
|
+
|
|
187
|
+
logger.info(f'Created task {task_id}: {title}')
|
|
188
|
+
return task
|
|
189
|
+
|
|
190
|
+
async def get_task(self, task_id: str) -> Optional[Task]:
|
|
191
|
+
"""Retrieve a task from Redis by ID."""
|
|
192
|
+
if not self._connected:
|
|
193
|
+
await self.connect()
|
|
194
|
+
|
|
195
|
+
async with self._task_lock:
|
|
196
|
+
task_data = await self.redis.hgetall(self._task_key(task_id))
|
|
197
|
+
|
|
198
|
+
if not task_data:
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
return self._deserialize_task(task_data)
|
|
202
|
+
|
|
203
|
+
async def update_task_status(
|
|
204
|
+
self,
|
|
205
|
+
task_id: str,
|
|
206
|
+
status: TaskStatus,
|
|
207
|
+
message: Optional[Message] = None,
|
|
208
|
+
progress: Optional[float] = None,
|
|
209
|
+
final: bool = False,
|
|
210
|
+
) -> Optional[Task]:
|
|
211
|
+
"""Update a task's status in Redis and notify handlers."""
|
|
212
|
+
if not self._connected:
|
|
213
|
+
await self.connect()
|
|
214
|
+
|
|
215
|
+
async with self._task_lock:
|
|
216
|
+
# Get existing task
|
|
217
|
+
task_data = await self.redis.hgetall(self._task_key(task_id))
|
|
218
|
+
if not task_data:
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
task = self._deserialize_task(task_data)
|
|
222
|
+
old_status = task.status
|
|
223
|
+
|
|
224
|
+
# Update task
|
|
225
|
+
task.status = status
|
|
226
|
+
task.updated_at = datetime.utcnow()
|
|
227
|
+
if progress is not None:
|
|
228
|
+
task.progress = progress
|
|
229
|
+
|
|
230
|
+
if message:
|
|
231
|
+
if task.messages is None:
|
|
232
|
+
task.messages = []
|
|
233
|
+
task.messages.append(message)
|
|
234
|
+
|
|
235
|
+
# Store updated task
|
|
236
|
+
updated_data = self._serialize_task(task)
|
|
237
|
+
await self.redis.hset(self._task_key(task_id), mapping=updated_data)
|
|
238
|
+
|
|
239
|
+
# Update status indices if status changed
|
|
240
|
+
if old_status != status:
|
|
241
|
+
await self.redis.srem(self._status_set_key(old_status), task_id)
|
|
242
|
+
await self.redis.sadd(self._status_set_key(status), task_id)
|
|
243
|
+
|
|
244
|
+
# Create update event
|
|
245
|
+
event = TaskStatusUpdateEvent(
|
|
246
|
+
task=task, message=message, final=final
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Notify handlers
|
|
250
|
+
await self._notify_handlers(task_id, event)
|
|
251
|
+
|
|
252
|
+
logger.info(
|
|
253
|
+
f'Updated task {task_id} status: {old_status.value} -> {status.value}'
|
|
254
|
+
)
|
|
255
|
+
return task
|
|
256
|
+
|
|
257
|
+
async def cancel_task(self, task_id: str) -> Optional[Task]:
|
|
258
|
+
"""Cancel a task."""
|
|
259
|
+
return await self.update_task_status(
|
|
260
|
+
task_id, TaskStatus.CANCELLED, final=True
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
async def delete_task(self, task_id: str) -> bool:
|
|
264
|
+
"""Delete a task from Redis storage."""
|
|
265
|
+
if not self._connected:
|
|
266
|
+
await self.connect()
|
|
267
|
+
|
|
268
|
+
async with self._task_lock:
|
|
269
|
+
# Get task to find its status
|
|
270
|
+
task_data = await self.redis.hgetall(self._task_key(task_id))
|
|
271
|
+
if not task_data:
|
|
272
|
+
return False
|
|
273
|
+
|
|
274
|
+
status = TaskStatus(task_data['status'])
|
|
275
|
+
|
|
276
|
+
# Remove from all indices
|
|
277
|
+
await self.redis.srem(self._status_set_key(status), task_id)
|
|
278
|
+
await self.redis.srem(self.ALL_TASKS_SET, task_id)
|
|
279
|
+
|
|
280
|
+
# Delete the task hash
|
|
281
|
+
await self.redis.delete(self._task_key(task_id))
|
|
282
|
+
|
|
283
|
+
logger.info(f'Deleted task {task_id}')
|
|
284
|
+
return True
|
|
285
|
+
|
|
286
|
+
async def list_tasks(
|
|
287
|
+
self, status: Optional[TaskStatus] = None
|
|
288
|
+
) -> List[Task]:
|
|
289
|
+
"""List all tasks, optionally filtered by status."""
|
|
290
|
+
if not self._connected:
|
|
291
|
+
await self.connect()
|
|
292
|
+
|
|
293
|
+
async with self._task_lock:
|
|
294
|
+
# Get task IDs
|
|
295
|
+
if status is not None:
|
|
296
|
+
task_ids = await self.redis.smembers(
|
|
297
|
+
self._status_set_key(status)
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
task_ids = await self.redis.smembers(self.ALL_TASKS_SET)
|
|
301
|
+
|
|
302
|
+
# Fetch all tasks
|
|
303
|
+
tasks = []
|
|
304
|
+
for task_id in task_ids:
|
|
305
|
+
task_data = await self.redis.hgetall(self._task_key(task_id))
|
|
306
|
+
if task_data:
|
|
307
|
+
tasks.append(self._deserialize_task(task_data))
|
|
308
|
+
|
|
309
|
+
return tasks
|
|
310
|
+
|
|
311
|
+
async def _get_claim_script(self):
|
|
312
|
+
"""Get or create the Lua script for atomic task claiming."""
|
|
313
|
+
if self._claim_script is None:
|
|
314
|
+
# Lua script for atomic task claiming
|
|
315
|
+
# KEYS[1] = task key (task:{task_id})
|
|
316
|
+
# KEYS[2] = pending status set key
|
|
317
|
+
# KEYS[3] = working status set key
|
|
318
|
+
# ARGV[1] = worker_id
|
|
319
|
+
# ARGV[2] = updated_at timestamp (ISO format)
|
|
320
|
+
# ARGV[3] = pending status value
|
|
321
|
+
# ARGV[4] = working status value
|
|
322
|
+
script = """
|
|
323
|
+
-- Get current task data
|
|
324
|
+
local task_data = redis.call('HGETALL', KEYS[1])
|
|
325
|
+
if #task_data == 0 then
|
|
326
|
+
return nil
|
|
327
|
+
end
|
|
328
|
+
|
|
329
|
+
-- Parse task data into a table
|
|
330
|
+
local task = {}
|
|
331
|
+
for i = 1, #task_data, 2 do
|
|
332
|
+
task[task_data[i]] = task_data[i + 1]
|
|
333
|
+
end
|
|
334
|
+
|
|
335
|
+
-- Check if task is in pending status
|
|
336
|
+
if task['status'] ~= ARGV[3] then
|
|
337
|
+
return nil
|
|
338
|
+
end
|
|
339
|
+
|
|
340
|
+
-- Update task fields atomically
|
|
341
|
+
redis.call('HSET', KEYS[1],
|
|
342
|
+
'status', ARGV[4],
|
|
343
|
+
'worker_id', ARGV[1],
|
|
344
|
+
'claimed_at', ARGV[2],
|
|
345
|
+
'updated_at', ARGV[2]
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
-- Update status indices
|
|
349
|
+
redis.call('SREM', KEYS[2], task['id'])
|
|
350
|
+
redis.call('SADD', KEYS[3], task['id'])
|
|
351
|
+
|
|
352
|
+
-- Return success indicator
|
|
353
|
+
return 1
|
|
354
|
+
"""
|
|
355
|
+
self._claim_script = self.redis.register_script(script)
|
|
356
|
+
return self._claim_script
|
|
357
|
+
|
|
358
|
+
async def _get_release_script(self):
|
|
359
|
+
"""Get or create the Lua script for atomic task release."""
|
|
360
|
+
if self._release_script is None:
|
|
361
|
+
# Lua script for atomic task release
|
|
362
|
+
# KEYS[1] = task key (task:{task_id})
|
|
363
|
+
# KEYS[2] = working status set key
|
|
364
|
+
# KEYS[3] = pending status set key
|
|
365
|
+
# ARGV[1] = worker_id (must match current owner)
|
|
366
|
+
# ARGV[2] = updated_at timestamp (ISO format)
|
|
367
|
+
# ARGV[3] = working status value
|
|
368
|
+
# ARGV[4] = pending status value
|
|
369
|
+
script = """
|
|
370
|
+
-- Get current task data
|
|
371
|
+
local task_data = redis.call('HGETALL', KEYS[1])
|
|
372
|
+
if #task_data == 0 then
|
|
373
|
+
return nil
|
|
374
|
+
end
|
|
375
|
+
|
|
376
|
+
-- Parse task data into a table
|
|
377
|
+
local task = {}
|
|
378
|
+
for i = 1, #task_data, 2 do
|
|
379
|
+
task[task_data[i]] = task_data[i + 1]
|
|
380
|
+
end
|
|
381
|
+
|
|
382
|
+
-- Check if worker owns this task
|
|
383
|
+
if task['worker_id'] ~= ARGV[1] then
|
|
384
|
+
return nil
|
|
385
|
+
end
|
|
386
|
+
|
|
387
|
+
-- Check if task is in working status
|
|
388
|
+
if task['status'] ~= ARGV[3] then
|
|
389
|
+
return nil
|
|
390
|
+
end
|
|
391
|
+
|
|
392
|
+
-- Update task fields atomically
|
|
393
|
+
redis.call('HSET', KEYS[1],
|
|
394
|
+
'status', ARGV[4],
|
|
395
|
+
'worker_id', '',
|
|
396
|
+
'claimed_at', '',
|
|
397
|
+
'updated_at', ARGV[2]
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
-- Update status indices
|
|
401
|
+
redis.call('SREM', KEYS[2], task['id'])
|
|
402
|
+
redis.call('SADD', KEYS[3], task['id'])
|
|
403
|
+
|
|
404
|
+
-- Return success indicator
|
|
405
|
+
return 1
|
|
406
|
+
"""
|
|
407
|
+
self._release_script = self.redis.register_script(script)
|
|
408
|
+
return self._release_script
|
|
409
|
+
|
|
410
|
+
async def claim_task(self, task_id: str, worker_id: str) -> Optional[Task]:
|
|
411
|
+
"""
|
|
412
|
+
Atomically claim a task for a worker using a Lua script.
|
|
413
|
+
|
|
414
|
+
This method uses a Lua script to ensure atomicity of the check-and-update
|
|
415
|
+
operation in Redis, preventing race conditions between multiple workers.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
task_id: The ID of the task to claim
|
|
419
|
+
worker_id: The ID of the worker claiming the task
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
The claimed Task if successful, None if the task doesn't exist,
|
|
423
|
+
is not in pending status, or was already claimed by another worker.
|
|
424
|
+
"""
|
|
425
|
+
if not self._connected:
|
|
426
|
+
await self.connect()
|
|
427
|
+
|
|
428
|
+
now = datetime.utcnow()
|
|
429
|
+
claim_script = await self._get_claim_script()
|
|
430
|
+
|
|
431
|
+
# Execute the Lua script
|
|
432
|
+
result = await claim_script(
|
|
433
|
+
keys=[
|
|
434
|
+
self._task_key(task_id),
|
|
435
|
+
self._status_set_key(TaskStatus.PENDING),
|
|
436
|
+
self._status_set_key(TaskStatus.WORKING),
|
|
437
|
+
],
|
|
438
|
+
args=[
|
|
439
|
+
worker_id,
|
|
440
|
+
now.isoformat(),
|
|
441
|
+
TaskStatus.PENDING.value,
|
|
442
|
+
TaskStatus.WORKING.value,
|
|
443
|
+
],
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if result is None:
|
|
447
|
+
logger.debug(
|
|
448
|
+
f'Task {task_id} could not be claimed by worker {worker_id}'
|
|
449
|
+
)
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
# Fetch and return the updated task
|
|
453
|
+
task = await self.get_task(task_id)
|
|
454
|
+
if task:
|
|
455
|
+
# Notify handlers
|
|
456
|
+
event = TaskStatusUpdateEvent(task=task, message=None, final=False)
|
|
457
|
+
await self._notify_handlers(task_id, event)
|
|
458
|
+
logger.info(f'Task {task_id} claimed by worker {worker_id}')
|
|
459
|
+
|
|
460
|
+
return task
|
|
461
|
+
|
|
462
|
+
async def release_task(
|
|
463
|
+
self, task_id: str, worker_id: str
|
|
464
|
+
) -> Optional[Task]:
|
|
465
|
+
"""
|
|
466
|
+
Release a claimed task back to pending status using a Lua script.
|
|
467
|
+
|
|
468
|
+
This method uses a Lua script to ensure atomicity of the check-and-update
|
|
469
|
+
operation in Redis, preventing race conditions.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
task_id: The ID of the task to release
|
|
473
|
+
worker_id: The ID of the worker releasing the task
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
The released Task if successful, None if the task doesn't exist
|
|
477
|
+
or the worker_id doesn't match the claiming worker.
|
|
478
|
+
"""
|
|
479
|
+
if not self._connected:
|
|
480
|
+
await self.connect()
|
|
481
|
+
|
|
482
|
+
now = datetime.utcnow()
|
|
483
|
+
release_script = await self._get_release_script()
|
|
484
|
+
|
|
485
|
+
# Execute the Lua script
|
|
486
|
+
result = await release_script(
|
|
487
|
+
keys=[
|
|
488
|
+
self._task_key(task_id),
|
|
489
|
+
self._status_set_key(TaskStatus.WORKING),
|
|
490
|
+
self._status_set_key(TaskStatus.PENDING),
|
|
491
|
+
],
|
|
492
|
+
args=[
|
|
493
|
+
worker_id,
|
|
494
|
+
now.isoformat(),
|
|
495
|
+
TaskStatus.WORKING.value,
|
|
496
|
+
TaskStatus.PENDING.value,
|
|
497
|
+
],
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
if result is None:
|
|
501
|
+
logger.debug(
|
|
502
|
+
f'Task {task_id} could not be released by worker {worker_id}'
|
|
503
|
+
)
|
|
504
|
+
return None
|
|
505
|
+
|
|
506
|
+
# Fetch and return the updated task
|
|
507
|
+
task = await self.get_task(task_id)
|
|
508
|
+
if task:
|
|
509
|
+
# Notify handlers
|
|
510
|
+
event = TaskStatusUpdateEvent(task=task, message=None, final=False)
|
|
511
|
+
await self._notify_handlers(task_id, event)
|
|
512
|
+
logger.info(f'Task {task_id} released by worker {worker_id}')
|
|
513
|
+
|
|
514
|
+
return task
|
|
515
|
+
|
|
516
|
+
async def cleanup(self):
|
|
517
|
+
"""Clean up Redis connections."""
|
|
518
|
+
await self.disconnect()
|