keble-task 2.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keble_task/__init__.py +143 -0
- keble_task/actions.py +304 -0
- keble_task/agent/__init__.py +27 -0
- keble_task/agent/chat_provider.py +117 -0
- keble_task/agent/deps.py +38 -0
- keble_task/agent/tools/__init__.py +14 -0
- keble_task/agent/tools/mutation.py +100 -0
- keble_task/agent/tools/query.py +160 -0
- keble_task/crud.py +347 -0
- keble_task/exceptions.py +62 -0
- keble_task/main.py +2708 -0
- keble_task/schemas/__init__.py +1295 -0
- keble_task/schemas/for_agent.py +177 -0
- keble_task/task_tree.py +106 -0
- keble_task/utils.py +58 -0
- keble_task-2.22.0.dist-info/METADATA +1083 -0
- keble_task-2.22.0.dist-info/RECORD +18 -0
- keble_task-2.22.0.dist-info/WHEEL +4 -0
keble_task/main.py
ADDED
|
@@ -0,0 +1,2708 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import secrets
|
|
4
|
+
import string
|
|
5
|
+
import traceback
|
|
6
|
+
import re
|
|
7
|
+
from datetime import datetime, timedelta
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import (
|
|
10
|
+
Any,
|
|
11
|
+
Awaitable,
|
|
12
|
+
Callable,
|
|
13
|
+
List,
|
|
14
|
+
Literal,
|
|
15
|
+
Mapping,
|
|
16
|
+
Optional,
|
|
17
|
+
Union,
|
|
18
|
+
cast,
|
|
19
|
+
overload,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
import keble_exceptions
|
|
23
|
+
import tenacity
|
|
24
|
+
from keble_db import AgentDbDeps, ExtendedAsyncRedis, QueryBase
|
|
25
|
+
from keble_helpers import (
|
|
26
|
+
AgenticActionEventStatus,
|
|
27
|
+
AgenticEventEmitter,
|
|
28
|
+
Currency,
|
|
29
|
+
ExchangeRateInUsd,
|
|
30
|
+
Language,
|
|
31
|
+
Money,
|
|
32
|
+
ObjectId,
|
|
33
|
+
PydanticModelConfig,
|
|
34
|
+
SharingScope,
|
|
35
|
+
UsageAccountingRecorderProtocol,
|
|
36
|
+
UsageAccountingSource,
|
|
37
|
+
utc_now,
|
|
38
|
+
)
|
|
39
|
+
from motor.motor_asyncio import AsyncIOMotorClient
|
|
40
|
+
from neo4j import AsyncDriver as Neo4jAsyncDriver
|
|
41
|
+
from pydantic import BaseModel, ConfigDict, SkipValidation
|
|
42
|
+
from pydantic_ai.usage import RunUsage
|
|
43
|
+
from pymongo import ASCENDING, DESCENDING
|
|
44
|
+
from pymongo.errors import DuplicateKeyError
|
|
45
|
+
from qdrant_client import AsyncQdrantClient
|
|
46
|
+
|
|
47
|
+
from keble_task.exceptions import (
|
|
48
|
+
TaskException,
|
|
49
|
+
TaskExceptionType,
|
|
50
|
+
TaskFailedToStartException,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
from .actions import (
|
|
54
|
+
CreateRelatedTaskAction,
|
|
55
|
+
CreateRelatedTaskActionedResult,
|
|
56
|
+
TaskAction,
|
|
57
|
+
TaskActionCreatedRelation,
|
|
58
|
+
TaskActionCreatedTask,
|
|
59
|
+
TaskActionedResults,
|
|
60
|
+
TaskActionEvent,
|
|
61
|
+
TaskActions,
|
|
62
|
+
TaskActionType,
|
|
63
|
+
TaskEventType,
|
|
64
|
+
TaskLifecycleEvent,
|
|
65
|
+
TaskLifecycleEventPayload,
|
|
66
|
+
TaskUxContext,
|
|
67
|
+
)
|
|
68
|
+
from .crud import CRUDTask, CRUDTaskCost, CRUDTaskRelation
|
|
69
|
+
from .schemas import (
|
|
70
|
+
TaskBase,
|
|
71
|
+
TaskCostAggregateRequest,
|
|
72
|
+
TaskCostAggregateResponse,
|
|
73
|
+
TaskCostCreate,
|
|
74
|
+
TaskCostListRequest,
|
|
75
|
+
TaskCostListResponse,
|
|
76
|
+
TaskCostMetadata,
|
|
77
|
+
TaskCostMongoObject,
|
|
78
|
+
TaskCostTokenRates,
|
|
79
|
+
TaskMetadata,
|
|
80
|
+
TaskMongoObject,
|
|
81
|
+
TaskMongoObjectExtended,
|
|
82
|
+
TaskPublicRef,
|
|
83
|
+
TaskRoomGraphContext,
|
|
84
|
+
TaskRoomResolution,
|
|
85
|
+
TaskRelationCreate,
|
|
86
|
+
TaskRelationMongoObject,
|
|
87
|
+
TaskStage,
|
|
88
|
+
TaskUpdate,
|
|
89
|
+
)
|
|
90
|
+
from .task_tree import build_task_tree, find_task_in_tree
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TokenConsumptionType(str, Enum):
|
|
94
|
+
CONSUME = "CONSUME"
|
|
95
|
+
RECOVER = "RECOVER"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class TokenConsumptionPayload(BaseModel):
|
|
99
|
+
"""Host-owned token consumption/recovery request for one task.
|
|
100
|
+
|
|
101
|
+
Step by step:
|
|
102
|
+
1. carry the consumption type, owner, token amount, and task id;
|
|
103
|
+
2. carry the minimal DB clients the host token handler needs to price and
|
|
104
|
+
persist the consumption (replacing the retired `TaskResources` bag);
|
|
105
|
+
3. carry optional handler metadata for pricing context.
|
|
106
|
+
|
|
107
|
+
Side effects:
|
|
108
|
+
- consumed by the host-provided `_token_consumption_handler` in keble.backend;
|
|
109
|
+
that handler now reads `payload.amongo`/`extended_aredis`/`aneo4j` directly
|
|
110
|
+
instead of `payload.resources.*`.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
model_config = ConfigDict(
|
|
114
|
+
**PydanticModelConfig.default_dict(),
|
|
115
|
+
arbitrary_types_allowed=True, # for database clients
|
|
116
|
+
)
|
|
117
|
+
consumption_type: TokenConsumptionType
|
|
118
|
+
owner: str
|
|
119
|
+
token: int
|
|
120
|
+
task_id: Optional[ObjectId] = None
|
|
121
|
+
amongo: AsyncIOMotorClient
|
|
122
|
+
extended_aredis: ExtendedAsyncRedis
|
|
123
|
+
aneo4j: Optional[Neo4jAsyncDriver] = None
|
|
124
|
+
metadata: Optional[TaskMetadata] = None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class TaskHandlerResponse(BaseModel):
|
|
128
|
+
model_config = PydanticModelConfig.default()
|
|
129
|
+
task: TaskMongoObject
|
|
130
|
+
success: bool
|
|
131
|
+
consuming_token: int
|
|
132
|
+
exception_type: Optional[TaskExceptionType] = None
|
|
133
|
+
error: Optional[str] = None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class TaskHandlerRequest(AgentDbDeps):
|
|
137
|
+
"""One task execution request: the DB-rooted agent deps plus the task itself.
|
|
138
|
+
|
|
139
|
+
Step by step:
|
|
140
|
+
1. inherits ALL DB clients + cross-cutting context (owner, owner_type,
|
|
141
|
+
owner_scope, event_emitter, usage_recorder, marketplace, language) from
|
|
142
|
+
`AgentDbDeps` — so a handler uses `request` directly as its deps object;
|
|
143
|
+
2. carries the persisted task row and optional handler metadata;
|
|
144
|
+
3. carries NO completion callback and NO resources bag — handlers RETURN a
|
|
145
|
+
`TaskHandlerResponse` (runtime finalizes) or `None` (the handler delegated
|
|
146
|
+
completion to another process; the task stays PROCESSING).
|
|
147
|
+
|
|
148
|
+
Side effects:
|
|
149
|
+
- every `atask_handler` (positioning, amz-product-report, backend handlers)
|
|
150
|
+
now reads clients off `request` directly and returns a response or None.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
model_config = PydanticModelConfig.default(arbitrary_types_allowed=True)
|
|
154
|
+
task: TaskMongoObject
|
|
155
|
+
metadata: Optional[TaskMetadata] = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
logger = logging.getLogger(__name__)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _merge_metadata(
|
|
162
|
+
base: Optional[Mapping[str, Any]], override: Optional[Mapping[str, Any]]
|
|
163
|
+
) -> Optional[dict[str, Any]]:
|
|
164
|
+
if base is None and override is None:
|
|
165
|
+
return None
|
|
166
|
+
if base is None:
|
|
167
|
+
assert override is not None
|
|
168
|
+
return dict(override)
|
|
169
|
+
if override is None:
|
|
170
|
+
return dict(base)
|
|
171
|
+
return {**base, **override}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _task_lifecycle_event_status(*, stage: TaskStage) -> AgenticActionEventStatus:
|
|
175
|
+
"""Map persisted task stage to the shared agentic event status.
|
|
176
|
+
|
|
177
|
+
Step by step:
|
|
178
|
+
1. processing is a lifecycle start/progress signal for room UIs;
|
|
179
|
+
2. success and failure map to terminal event statuses;
|
|
180
|
+
3. pending is retained as a non-terminal progress state for direct calls.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
if stage is TaskStage.PROCESSING:
|
|
184
|
+
return AgenticActionEventStatus.STARTED
|
|
185
|
+
if stage is TaskStage.SUCCESS:
|
|
186
|
+
return AgenticActionEventStatus.SUCCEEDED
|
|
187
|
+
if stage is TaskStage.FAILURE:
|
|
188
|
+
return AgenticActionEventStatus.FAILED
|
|
189
|
+
return AgenticActionEventStatus.PROGRESSED
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class TaskClient:
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
*,
|
|
196
|
+
# Token consumption handler can be synchronous or asynchronous
|
|
197
|
+
# If sync, it should return a bool
|
|
198
|
+
# If async, it should return a coroutine that resolves to a bool
|
|
199
|
+
token_consumption_handler: Callable[
|
|
200
|
+
[TokenConsumptionPayload], Union[bool, Awaitable[bool]]
|
|
201
|
+
],
|
|
202
|
+
# Run one task. The handler RETURNS a `TaskHandlerResponse` (the runtime
|
|
203
|
+
# finalizes success/failure and emits the terminal lifecycle event) OR
|
|
204
|
+
# `None` when it deliberately delegated completion to another process
|
|
205
|
+
# (e.g. fire-and-forget bootstrap, multi-invocation pipeline) — the task
|
|
206
|
+
# then stays PROCESSING until that process finalizes it.
|
|
207
|
+
task_handler: Callable[
|
|
208
|
+
[TaskHandlerRequest], Awaitable[Optional[TaskHandlerResponse]]
|
|
209
|
+
],
|
|
210
|
+
# Database config
|
|
211
|
+
mongo_database: str = "__keble_task__",
|
|
212
|
+
task_collection: str = "__keble_task__task__",
|
|
213
|
+
task_relation_collection: str = "__keble_task__task_relation__",
|
|
214
|
+
task_cost_collection: str = "__keble_task__task_cost__",
|
|
215
|
+
# Public short id config (for shared URLs)
|
|
216
|
+
public_id_prefix: str = "t",
|
|
217
|
+
public_id_length: int = 10,
|
|
218
|
+
):
|
|
219
|
+
"""Initialize the TaskClient.
|
|
220
|
+
|
|
221
|
+
Step by step:
|
|
222
|
+
1. store runtime handlers without opening database connections;
|
|
223
|
+
2. configure the task, relation, and cost collections used by later calls;
|
|
224
|
+
3. validate public-id settings before any public task is created.
|
|
225
|
+
|
|
226
|
+
Note: Resources are not initialized here but passed to each method when needed.
|
|
227
|
+
"""
|
|
228
|
+
self._token_consumption_handler = token_consumption_handler
|
|
229
|
+
self._mongo_database = mongo_database
|
|
230
|
+
self._task_collection = task_collection
|
|
231
|
+
self._task_relation_collection = task_relation_collection
|
|
232
|
+
self._task_cost_collection = task_cost_collection
|
|
233
|
+
self._task_handler = task_handler
|
|
234
|
+
self._public_id_prefix = public_id_prefix
|
|
235
|
+
self._public_id_length = public_id_length
|
|
236
|
+
self._public_id_alphabet = string.ascii_letters + string.digits
|
|
237
|
+
|
|
238
|
+
self._validate_public_id_config(
|
|
239
|
+
prefix=self._public_id_prefix, length=self._public_id_length
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
self.crud_task = CRUDTask(
|
|
243
|
+
model=TaskMongoObject,
|
|
244
|
+
collection=self._task_collection,
|
|
245
|
+
database=mongo_database,
|
|
246
|
+
)
|
|
247
|
+
self.crud_task_relation = CRUDTaskRelation(
|
|
248
|
+
model=TaskRelationMongoObject,
|
|
249
|
+
collection=self._task_relation_collection,
|
|
250
|
+
database=mongo_database,
|
|
251
|
+
)
|
|
252
|
+
self.crud_task_cost = CRUDTaskCost(
|
|
253
|
+
model=TaskCostMongoObject,
|
|
254
|
+
collection=self._task_cost_collection,
|
|
255
|
+
database=mongo_database,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def _validate_public_id_config(self, *, prefix: str, length: int) -> None:
|
|
259
|
+
if not isinstance(prefix, str):
|
|
260
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
261
|
+
admin_note={"prefix": prefix},
|
|
262
|
+
alert_admin=True,
|
|
263
|
+
but_got="prefix is not str",
|
|
264
|
+
expected="prefix is str",
|
|
265
|
+
invalid_params="prefix",
|
|
266
|
+
)
|
|
267
|
+
if not isinstance(length, int):
|
|
268
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
269
|
+
admin_note={"length": length},
|
|
270
|
+
alert_admin=True,
|
|
271
|
+
but_got="length is not int",
|
|
272
|
+
expected="length is int",
|
|
273
|
+
invalid_params="length",
|
|
274
|
+
)
|
|
275
|
+
if length <= len(prefix):
|
|
276
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
277
|
+
admin_note={"prefix": prefix, "length": length},
|
|
278
|
+
alert_admin=True,
|
|
279
|
+
but_got="length <= len(prefix)",
|
|
280
|
+
expected="length > len(prefix)",
|
|
281
|
+
invalid_params="length,prefix",
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
async def _ensure_public_id_index(self, *, amongo: AsyncIOMotorClient) -> None:
|
|
285
|
+
await self.crud_task.aensure_public_id_index(amongo)
|
|
286
|
+
|
|
287
|
+
async def _ensure_task_indexes(self, *, amongo: AsyncIOMotorClient) -> None:
|
|
288
|
+
"""Ensure the task collection has the tree/stage read indexes.
|
|
289
|
+
|
|
290
|
+
Step by step:
|
|
291
|
+
1. delegate the index definitions to the task CRUD helper;
|
|
292
|
+
2. keep index creation out of request handlers and worker loops;
|
|
293
|
+
3. share one idempotent setup between startup and package methods.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
await self.crud_task.aensure_task_indexes(amongo)
|
|
297
|
+
|
|
298
|
+
def _generate_public_id_candidate(self, *, prefix: str, length: int) -> str:
|
|
299
|
+
suffix_len = length - len(prefix)
|
|
300
|
+
suffix = "".join(
|
|
301
|
+
secrets.choice(self._public_id_alphabet) for _ in range(suffix_len)
|
|
302
|
+
)
|
|
303
|
+
return f"{prefix}{suffix}"
|
|
304
|
+
|
|
305
|
+
async def _ensure_task_public_id(
|
|
306
|
+
self,
|
|
307
|
+
*,
|
|
308
|
+
amongo: AsyncIOMotorClient,
|
|
309
|
+
task_id: ObjectId,
|
|
310
|
+
public_id_prefix: Optional[str],
|
|
311
|
+
public_id_length: Optional[int],
|
|
312
|
+
) -> str:
|
|
313
|
+
prefix = (
|
|
314
|
+
public_id_prefix if public_id_prefix is not None else self._public_id_prefix
|
|
315
|
+
)
|
|
316
|
+
length = (
|
|
317
|
+
public_id_length if public_id_length is not None else self._public_id_length
|
|
318
|
+
)
|
|
319
|
+
self._validate_public_id_config(prefix=prefix, length=length)
|
|
320
|
+
await self._ensure_public_id_index(amongo=amongo)
|
|
321
|
+
|
|
322
|
+
collection = amongo[self._mongo_database][self._task_collection]
|
|
323
|
+
existing = await collection.find_one({"_id": task_id}, {"public_id": 1})
|
|
324
|
+
if existing is None:
|
|
325
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
326
|
+
admin_note={"task_id": task_id},
|
|
327
|
+
alert_admin=True,
|
|
328
|
+
but_got="task not found",
|
|
329
|
+
expected="task exists",
|
|
330
|
+
invalid_params="task_id",
|
|
331
|
+
)
|
|
332
|
+
if isinstance(existing.get("public_id"), str) and existing.get("public_id"):
|
|
333
|
+
return str(existing["public_id"])
|
|
334
|
+
|
|
335
|
+
for _ in range(20):
|
|
336
|
+
candidate = self._generate_public_id_candidate(prefix=prefix, length=length)
|
|
337
|
+
try:
|
|
338
|
+
result = await collection.update_one(
|
|
339
|
+
{"_id": task_id, "public_id": None},
|
|
340
|
+
{"$set": {"public_id": candidate}},
|
|
341
|
+
)
|
|
342
|
+
except DuplicateKeyError:
|
|
343
|
+
continue
|
|
344
|
+
|
|
345
|
+
if result.matched_count == 0:
|
|
346
|
+
refreshed = await collection.find_one(
|
|
347
|
+
{"_id": task_id}, {"public_id": 1}
|
|
348
|
+
)
|
|
349
|
+
if refreshed and refreshed.get("public_id"):
|
|
350
|
+
return str(refreshed["public_id"])
|
|
351
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
352
|
+
admin_note={"task_id": task_id, "candidate": candidate},
|
|
353
|
+
alert_admin=True,
|
|
354
|
+
but_got="task public_id not updated",
|
|
355
|
+
expected="task public_id updated",
|
|
356
|
+
invalid_params="task_id,public_id",
|
|
357
|
+
)
|
|
358
|
+
return candidate
|
|
359
|
+
|
|
360
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
361
|
+
admin_note={"task_id": task_id, "prefix": prefix, "length": length},
|
|
362
|
+
alert_admin=True,
|
|
363
|
+
but_got="exceeded retries to generate unique public_id",
|
|
364
|
+
expected="unique public_id generated",
|
|
365
|
+
invalid_params="public_id",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
async def _ensure_task_relation_indexes(
|
|
369
|
+
self, *, amongo: AsyncIOMotorClient
|
|
370
|
+
) -> None:
|
|
371
|
+
"""Ensure the relation collection has the required query indexes.
|
|
372
|
+
|
|
373
|
+
Step by step:
|
|
374
|
+
1. delegate index ownership to the relation CRUD helper;
|
|
375
|
+
2. keep this method private so callers use relation APIs, not storage details.
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
await self.crud_task_relation.aensure_relation_indexes(amongo)
|
|
379
|
+
|
|
380
|
+
async def _ensure_task_cost_indexes(self, *, amongo: AsyncIOMotorClient) -> None:
|
|
381
|
+
"""Ensure the cost collection has the required query indexes.
|
|
382
|
+
|
|
383
|
+
Step by step:
|
|
384
|
+
1. delegate index definitions to the task-cost CRUD helper;
|
|
385
|
+
2. keep index creation outside request handlers and worker loops;
|
|
386
|
+
3. let startup and package methods share the same idempotent setup.
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
await self.crud_task_cost.aensure_cost_indexes(amongo)
|
|
390
|
+
|
|
391
|
+
async def aensure_indexes(self, *, amongo: AsyncIOMotorClient) -> None:
|
|
392
|
+
"""Create all Mongo indexes owned by the task package.
|
|
393
|
+
|
|
394
|
+
Step by step:
|
|
395
|
+
1. create the public task id index used by shared task URLs;
|
|
396
|
+
2. create task tree/stage indexes for root, child, and retry reads;
|
|
397
|
+
3. create relation indexes for root-scoped graph lookups;
|
|
398
|
+
4. create cost indexes for admin list and aggregate reports.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
await self._ensure_public_id_index(amongo=amongo)
|
|
402
|
+
await self._ensure_task_indexes(amongo=amongo)
|
|
403
|
+
await self._ensure_task_relation_indexes(amongo=amongo)
|
|
404
|
+
await self._ensure_task_cost_indexes(amongo=amongo)
|
|
405
|
+
|
|
406
|
+
async def ainit(self, *, amongo: AsyncIOMotorClient) -> None:
|
|
407
|
+
"""Initialize task Mongo indexes for backend startup.
|
|
408
|
+
|
|
409
|
+
Step by step:
|
|
410
|
+
1. expose one public startup hook for backend lifespan wiring;
|
|
411
|
+
2. delegate to the existing package-owned index initializer.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
await self.aensure_indexes(amongo=amongo)
|
|
415
|
+
|
|
416
|
+
def _task_workspace_root_id(self, task: TaskMongoObject) -> ObjectId:
|
|
417
|
+
"""Resolve the root-task id used to scope workspace-local relation rows.
|
|
418
|
+
|
|
419
|
+
Step by step:
|
|
420
|
+
1. prefer the persisted `root_task` field for normal modern rows;
|
|
421
|
+
2. fall back to the task id for old root rows that predate root backfill.
|
|
422
|
+
"""
|
|
423
|
+
|
|
424
|
+
return task.root_task or task.id
|
|
425
|
+
|
|
426
|
+
async def _aget_relation_task(
|
|
427
|
+
self,
|
|
428
|
+
*,
|
|
429
|
+
amongo: AsyncIOMotorClient,
|
|
430
|
+
task_id: ObjectId,
|
|
431
|
+
field_name: str,
|
|
432
|
+
) -> TaskMongoObject:
|
|
433
|
+
"""Load one task referenced by a relation row and raise a typed error.
|
|
434
|
+
|
|
435
|
+
Step by step:
|
|
436
|
+
1. load the task without a task-type filter;
|
|
437
|
+
2. return the typed task row when it exists;
|
|
438
|
+
3. raise a package-consistent invalid-params error when the reference is stale.
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
task = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
|
|
442
|
+
if task is None:
|
|
443
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
444
|
+
admin_note={field_name: task_id},
|
|
445
|
+
alert_admin=True,
|
|
446
|
+
but_got=f"{field_name} task not found",
|
|
447
|
+
expected=f"{field_name} task exists",
|
|
448
|
+
invalid_params=field_name,
|
|
449
|
+
)
|
|
450
|
+
return task
|
|
451
|
+
|
|
452
|
+
def _validate_relation_task_root(
|
|
453
|
+
self,
|
|
454
|
+
*,
|
|
455
|
+
relation: TaskRelationCreate,
|
|
456
|
+
task: TaskMongoObject,
|
|
457
|
+
field_name: str,
|
|
458
|
+
) -> None:
|
|
459
|
+
"""Validate that one relation endpoint belongs to the relation root.
|
|
460
|
+
|
|
461
|
+
Step by step:
|
|
462
|
+
1. resolve the endpoint task's canonical workspace root;
|
|
463
|
+
2. compare it with the relation root;
|
|
464
|
+
3. raise a typed error when a caller attempts cross-workspace lineage.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
task_root = self._task_workspace_root_id(task)
|
|
468
|
+
if task_root != relation.root_task:
|
|
469
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
470
|
+
admin_note={
|
|
471
|
+
"root_task": relation.root_task,
|
|
472
|
+
field_name: task.id,
|
|
473
|
+
f"{field_name}_root_task": task_root,
|
|
474
|
+
},
|
|
475
|
+
alert_admin=True,
|
|
476
|
+
but_got=f"{field_name} belongs to a different root task",
|
|
477
|
+
expected=f"{field_name} belongs to relation.root_task",
|
|
478
|
+
invalid_params=f"root_task,{field_name}",
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
@staticmethod
|
|
482
|
+
def _validate_unique_task_relation_edges(
|
|
483
|
+
*,
|
|
484
|
+
relations: list[TaskRelationCreate],
|
|
485
|
+
invalid_params: str,
|
|
486
|
+
) -> None:
|
|
487
|
+
"""Reject duplicate relation edge payloads before any write happens.
|
|
488
|
+
|
|
489
|
+
Step by step:
|
|
490
|
+
1. derive the same edge identity protected by the Mongo unique index;
|
|
491
|
+
2. collect duplicates without mutating task or relation collections;
|
|
492
|
+
3. raise a caller-facing validation error instead of relying on a partial
|
|
493
|
+
insert followed by a database duplicate-key failure.
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
seen_edge_keys: set[tuple[ObjectId, ObjectId, ObjectId, str]] = set()
|
|
497
|
+
duplicate_edges: list[str] = []
|
|
498
|
+
for relation in relations:
|
|
499
|
+
edge_key = (
|
|
500
|
+
relation.root_task,
|
|
501
|
+
relation.from_task_id,
|
|
502
|
+
relation.to_task_id,
|
|
503
|
+
relation.relation_type.value,
|
|
504
|
+
)
|
|
505
|
+
if edge_key in seen_edge_keys:
|
|
506
|
+
duplicate_edges.append(
|
|
507
|
+
"root_task="
|
|
508
|
+
f"{relation.root_task};from_task_id={relation.from_task_id};"
|
|
509
|
+
f"to_task_id={relation.to_task_id};"
|
|
510
|
+
f"relation_type={relation.relation_type.value}"
|
|
511
|
+
)
|
|
512
|
+
continue
|
|
513
|
+
seen_edge_keys.add(edge_key)
|
|
514
|
+
if duplicate_edges:
|
|
515
|
+
raise keble_exceptions.ClientSideInvalidParams(
|
|
516
|
+
invalid_params=invalid_params,
|
|
517
|
+
expected="unique relation edge keys",
|
|
518
|
+
but_got="; ".join(duplicate_edges),
|
|
519
|
+
alert_admin=False,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
@staticmethod
|
|
523
|
+
def _validate_unique_create_related_source_ids(
|
|
524
|
+
*,
|
|
525
|
+
from_task_ids: list[ObjectId],
|
|
526
|
+
) -> None:
|
|
527
|
+
"""Reject duplicate action source ids before creating the child task.
|
|
528
|
+
|
|
529
|
+
Step by step:
|
|
530
|
+
1. check the source id list that will become relation edges;
|
|
531
|
+
2. stop duplicate source ids before the child task is inserted;
|
|
532
|
+
3. return normally when each source would create one unique edge.
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
seen_task_ids: set[ObjectId] = set()
|
|
536
|
+
duplicate_task_ids: list[str] = []
|
|
537
|
+
for from_task_id in from_task_ids:
|
|
538
|
+
if from_task_id in seen_task_ids:
|
|
539
|
+
duplicate_task_ids.append(str(from_task_id))
|
|
540
|
+
continue
|
|
541
|
+
seen_task_ids.add(from_task_id)
|
|
542
|
+
if duplicate_task_ids:
|
|
543
|
+
raise keble_exceptions.ClientSideInvalidParams(
|
|
544
|
+
invalid_params="actions[].from_task_ids",
|
|
545
|
+
expected="unique source task ids per created relation edge",
|
|
546
|
+
but_got=",".join(duplicate_task_ids),
|
|
547
|
+
alert_admin=False,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
async def _avalidate_task_relation_create(
|
|
551
|
+
self,
|
|
552
|
+
*,
|
|
553
|
+
amongo: AsyncIOMotorClient,
|
|
554
|
+
relation: TaskRelationCreate,
|
|
555
|
+
) -> None:
|
|
556
|
+
"""Validate one task-relation create payload before insertion.
|
|
557
|
+
|
|
558
|
+
Step by step:
|
|
559
|
+
1. load the declared root, source task, and derived task;
|
|
560
|
+
2. prove the declared root is the canonical workspace root;
|
|
561
|
+
3. prove both endpoints belong to the same root workspace.
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
root_task = await self._aget_relation_task(
|
|
565
|
+
amongo=amongo,
|
|
566
|
+
task_id=relation.root_task,
|
|
567
|
+
field_name="root_task",
|
|
568
|
+
)
|
|
569
|
+
self._validate_relation_task_root(
|
|
570
|
+
relation=relation,
|
|
571
|
+
task=root_task,
|
|
572
|
+
field_name="root_task",
|
|
573
|
+
)
|
|
574
|
+
from_task = await self._aget_relation_task(
|
|
575
|
+
amongo=amongo,
|
|
576
|
+
task_id=relation.from_task_id,
|
|
577
|
+
field_name="from_task_id",
|
|
578
|
+
)
|
|
579
|
+
self._validate_relation_task_root(
|
|
580
|
+
relation=relation,
|
|
581
|
+
task=from_task,
|
|
582
|
+
field_name="from_task_id",
|
|
583
|
+
)
|
|
584
|
+
to_task = await self._aget_relation_task(
|
|
585
|
+
amongo=amongo,
|
|
586
|
+
task_id=relation.to_task_id,
|
|
587
|
+
field_name="to_task_id",
|
|
588
|
+
)
|
|
589
|
+
self._validate_relation_task_root(
|
|
590
|
+
relation=relation,
|
|
591
|
+
task=to_task,
|
|
592
|
+
field_name="to_task_id",
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
async def _execute_token_consumption_handler(
|
|
596
|
+
self, payload: TokenConsumptionPayload
|
|
597
|
+
) -> bool:
|
|
598
|
+
"""Execute token consumption handler supporting both sync and async implementations.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
payload: The token consumption payload
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
bool: True if successful, False otherwise
|
|
605
|
+
"""
|
|
606
|
+
result = self._token_consumption_handler(payload)
|
|
607
|
+
if asyncio.iscoroutine(result):
|
|
608
|
+
# If it's a coroutine, await it
|
|
609
|
+
return await cast(Awaitable[bool], result)
|
|
610
|
+
# If it's a boolean, return it directly
|
|
611
|
+
return cast(bool, result)
|
|
612
|
+
|
|
613
|
+
#
|
|
614
|
+
#
|
|
615
|
+
#
|
|
616
|
+
#
|
|
617
|
+
# Core running Task API
|
|
618
|
+
#
|
|
619
|
+
#
|
|
620
|
+
#
|
|
621
|
+
#
|
|
622
|
+
# Create
|
|
623
|
+
async def acreate(
|
|
624
|
+
self,
|
|
625
|
+
*,
|
|
626
|
+
amongo: AsyncIOMotorClient,
|
|
627
|
+
expected_token: int,
|
|
628
|
+
owner: str,
|
|
629
|
+
progress_key: Optional[str] = None,
|
|
630
|
+
image: Optional[str] = None,
|
|
631
|
+
title: Optional[str] = None,
|
|
632
|
+
subtitle: Optional[str] = None,
|
|
633
|
+
metadata: Optional[TaskMetadata] = None,
|
|
634
|
+
language: Optional[Language] = None,
|
|
635
|
+
task_type: str,
|
|
636
|
+
timeout_mins: int = 120,
|
|
637
|
+
sharing_scope: SharingScope = SharingScope.PRIVATE,
|
|
638
|
+
stage: TaskStage = TaskStage.PENDING,
|
|
639
|
+
attempts: int = 0,
|
|
640
|
+
consumed_token: int = 0,
|
|
641
|
+
# task tree fields
|
|
642
|
+
root_task: Optional[ObjectId] = None,
|
|
643
|
+
parent_task: Optional[ObjectId] = None,
|
|
644
|
+
public_id_prefix: Optional[str] = None,
|
|
645
|
+
public_id_length: Optional[int] = None,
|
|
646
|
+
) -> TaskMongoObject:
|
|
647
|
+
"""Create a new task asynchronously."""
|
|
648
|
+
if parent_task is None and root_task is not None:
|
|
649
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
650
|
+
admin_note={"root_task": root_task, "parent_task": parent_task},
|
|
651
|
+
alert_admin=True,
|
|
652
|
+
but_got="root_task provided for root task creation",
|
|
653
|
+
expected="root_task is None when parent_task is None",
|
|
654
|
+
invalid_params="root_task,parent_task",
|
|
655
|
+
)
|
|
656
|
+
if parent_task is not None and root_task is None:
|
|
657
|
+
parent = await self.aget(amongo=amongo, task_id=parent_task, task_type=None)
|
|
658
|
+
if parent is None:
|
|
659
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
660
|
+
admin_note={"parent_task": parent_task},
|
|
661
|
+
alert_admin=True,
|
|
662
|
+
but_got="parent task not found",
|
|
663
|
+
expected="parent task exists",
|
|
664
|
+
invalid_params="parent_task",
|
|
665
|
+
)
|
|
666
|
+
root_task = parent.root_task or parent.id
|
|
667
|
+
|
|
668
|
+
if sharing_scope == SharingScope.PUBLIC:
|
|
669
|
+
prefix = (
|
|
670
|
+
public_id_prefix
|
|
671
|
+
if public_id_prefix is not None
|
|
672
|
+
else self._public_id_prefix
|
|
673
|
+
)
|
|
674
|
+
length = (
|
|
675
|
+
public_id_length
|
|
676
|
+
if public_id_length is not None
|
|
677
|
+
else self._public_id_length
|
|
678
|
+
)
|
|
679
|
+
self._validate_public_id_config(prefix=prefix, length=length)
|
|
680
|
+
|
|
681
|
+
result = await self.crud_task.acreate(
|
|
682
|
+
amongo,
|
|
683
|
+
obj_in=cast(
|
|
684
|
+
TaskMongoObject,
|
|
685
|
+
TaskBase(
|
|
686
|
+
progress_key=progress_key,
|
|
687
|
+
image=image,
|
|
688
|
+
title=title,
|
|
689
|
+
subtitle=subtitle,
|
|
690
|
+
metadata=metadata,
|
|
691
|
+
language=language,
|
|
692
|
+
owner=str(owner),
|
|
693
|
+
consumed_token=consumed_token,
|
|
694
|
+
expected_token=expected_token,
|
|
695
|
+
attempts=attempts,
|
|
696
|
+
sharing_scope=sharing_scope,
|
|
697
|
+
stage=stage,
|
|
698
|
+
task_type=task_type,
|
|
699
|
+
timeout_mins=timeout_mins,
|
|
700
|
+
root_task=root_task,
|
|
701
|
+
parent_task=parent_task,
|
|
702
|
+
),
|
|
703
|
+
),
|
|
704
|
+
)
|
|
705
|
+
_id = result.inserted_id
|
|
706
|
+
if parent_task is None and root_task is None:
|
|
707
|
+
await self.crud_task.aupdate(amongo, _id=_id, obj_in={"root_task": _id})
|
|
708
|
+
if sharing_scope == SharingScope.PUBLIC:
|
|
709
|
+
await self._ensure_task_public_id(
|
|
710
|
+
amongo=amongo,
|
|
711
|
+
task_id=_id,
|
|
712
|
+
public_id_prefix=public_id_prefix,
|
|
713
|
+
public_id_length=public_id_length,
|
|
714
|
+
)
|
|
715
|
+
task = await self.crud_task.afirst_by_id(amongo, _id=_id)
|
|
716
|
+
assert task is not None
|
|
717
|
+
return task
|
|
718
|
+
|
|
719
|
+
async def acreate_task_relation(
|
|
720
|
+
self,
|
|
721
|
+
*,
|
|
722
|
+
amongo: AsyncIOMotorClient,
|
|
723
|
+
obj_in: TaskRelationCreate,
|
|
724
|
+
) -> TaskRelationMongoObject:
|
|
725
|
+
"""Create one extra lineage relation without changing task parentage.
|
|
726
|
+
|
|
727
|
+
Step by step:
|
|
728
|
+
1. validate the source and derived tasks belong to the declared root;
|
|
729
|
+
2. insert a relation row into the separate relation collection;
|
|
730
|
+
3. return the persisted relation object for downstream API responses.
|
|
731
|
+
"""
|
|
732
|
+
|
|
733
|
+
results = await self.acreate_task_relations(amongo=amongo, objs_in=[obj_in])
|
|
734
|
+
return results[0]
|
|
735
|
+
|
|
736
|
+
async def acreate_task_relations(
|
|
737
|
+
self,
|
|
738
|
+
*,
|
|
739
|
+
amongo: AsyncIOMotorClient,
|
|
740
|
+
objs_in: list[TaskRelationCreate],
|
|
741
|
+
) -> list[TaskRelationMongoObject]:
|
|
742
|
+
"""Create many extra lineage relations as independent durable rows.
|
|
743
|
+
|
|
744
|
+
Step by step:
|
|
745
|
+
1. ensure lookup and duplicate-protection indexes exist;
|
|
746
|
+
2. validate every relation before inserting any new rows;
|
|
747
|
+
3. insert each relation without touching `root_task` or `parent_task`.
|
|
748
|
+
"""
|
|
749
|
+
|
|
750
|
+
if not objs_in:
|
|
751
|
+
return []
|
|
752
|
+
|
|
753
|
+
self._validate_unique_task_relation_edges(
|
|
754
|
+
relations=objs_in,
|
|
755
|
+
invalid_params="objs_in",
|
|
756
|
+
)
|
|
757
|
+
await self._ensure_task_relation_indexes(amongo=amongo)
|
|
758
|
+
for relation in objs_in:
|
|
759
|
+
await self._avalidate_task_relation_create(
|
|
760
|
+
amongo=amongo,
|
|
761
|
+
relation=relation,
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
created: list[TaskRelationMongoObject] = []
|
|
765
|
+
for relation in objs_in:
|
|
766
|
+
result = await self.crud_task_relation.acreate_relation(
|
|
767
|
+
amongo,
|
|
768
|
+
obj_in=relation,
|
|
769
|
+
)
|
|
770
|
+
db_obj = await self.crud_task_relation.afirst_by_id(
|
|
771
|
+
amongo,
|
|
772
|
+
_id=result.inserted_id,
|
|
773
|
+
)
|
|
774
|
+
assert db_obj is not None
|
|
775
|
+
created.append(db_obj)
|
|
776
|
+
return created
|
|
777
|
+
|
|
778
|
+
async def acreate_task_cost(
|
|
779
|
+
self,
|
|
780
|
+
*,
|
|
781
|
+
amongo: AsyncIOMotorClient,
|
|
782
|
+
task_id: ObjectId,
|
|
783
|
+
tags: list[str] | None = None,
|
|
784
|
+
run_usage: RunUsage,
|
|
785
|
+
source: UsageAccountingSource | None = None,
|
|
786
|
+
token_rates_per_million: TaskCostTokenRates | None = None,
|
|
787
|
+
additional_cost: Money | None = None,
|
|
788
|
+
seconds: float = 0,
|
|
789
|
+
retry: int = 0,
|
|
790
|
+
occurred_at: datetime | None = None,
|
|
791
|
+
metadata: TaskCostMetadata | None = None,
|
|
792
|
+
) -> TaskCostMongoObject:
|
|
793
|
+
"""Create one task-cost row denormalized from the current task row.
|
|
794
|
+
|
|
795
|
+
Step by step:
|
|
796
|
+
1. load the task so owner, type, and root scope cannot drift from storage;
|
|
797
|
+
2. build a cost create schema with raw `RunUsage`, rates, timing, and metadata;
|
|
798
|
+
3. persist the row in the separate cost collection and reload it as a typed object.
|
|
799
|
+
"""
|
|
800
|
+
|
|
801
|
+
task = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
|
|
802
|
+
if task is None:
|
|
803
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
804
|
+
admin_note={"task_id": task_id},
|
|
805
|
+
alert_admin=True,
|
|
806
|
+
but_got="task not found",
|
|
807
|
+
expected="task exists",
|
|
808
|
+
invalid_params="task_id",
|
|
809
|
+
)
|
|
810
|
+
result = await self.crud_task_cost.acreate_cost(
|
|
811
|
+
amongo,
|
|
812
|
+
obj_in=TaskCostCreate.from_task_usage(
|
|
813
|
+
task=task,
|
|
814
|
+
root_task=self._task_workspace_root_id(task),
|
|
815
|
+
tags=tags,
|
|
816
|
+
run_usage=run_usage,
|
|
817
|
+
source=source,
|
|
818
|
+
token_rates_per_million=token_rates_per_million,
|
|
819
|
+
additional_cost=additional_cost,
|
|
820
|
+
seconds=seconds,
|
|
821
|
+
retry=retry,
|
|
822
|
+
occurred_at=occurred_at,
|
|
823
|
+
metadata=metadata,
|
|
824
|
+
),
|
|
825
|
+
)
|
|
826
|
+
db_obj = await self.crud_task_cost.afirst_by_id(
|
|
827
|
+
amongo,
|
|
828
|
+
_id=result.inserted_id,
|
|
829
|
+
)
|
|
830
|
+
assert db_obj is not None
|
|
831
|
+
return db_obj
|
|
832
|
+
|
|
833
|
+
async def alist_task_costs(
|
|
834
|
+
self,
|
|
835
|
+
*,
|
|
836
|
+
amongo: AsyncIOMotorClient,
|
|
837
|
+
query: TaskCostListRequest,
|
|
838
|
+
) -> TaskCostListResponse:
|
|
839
|
+
"""List task-cost rows through the package-owned indexed query shape.
|
|
840
|
+
|
|
841
|
+
Step by step:
|
|
842
|
+
1. list the requested page using `CRUDTaskCost`;
|
|
843
|
+
2. count the same filter for a compact paginated response;
|
|
844
|
+
3. rely on startup/`ainit(...)` for index creation.
|
|
845
|
+
"""
|
|
846
|
+
|
|
847
|
+
costs = await self.crud_task_cost.alist_costs(amongo, query=query)
|
|
848
|
+
total = await self.crud_task_cost.acount_costs(amongo, query=query)
|
|
849
|
+
return TaskCostListResponse(
|
|
850
|
+
costs=costs,
|
|
851
|
+
total=total,
|
|
852
|
+
skip=query.skip,
|
|
853
|
+
limit=query.limit,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
async def aaggregate_task_costs(
|
|
857
|
+
self,
|
|
858
|
+
*,
|
|
859
|
+
amongo: AsyncIOMotorClient,
|
|
860
|
+
query: TaskCostAggregateRequest,
|
|
861
|
+
exchange_rates: list[ExchangeRateInUsd],
|
|
862
|
+
) -> TaskCostAggregateResponse:
|
|
863
|
+
"""Aggregate task costs with explicit currency conversion.
|
|
864
|
+
|
|
865
|
+
Step by step:
|
|
866
|
+
1. load rows through the same indexed filters used by list reads;
|
|
867
|
+
2. calculate each row's precise Decimal cost in the requested currency;
|
|
868
|
+
3. group by requested time bucket and optionally fan rows into tag buckets.
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
costs = await self.crud_task_cost.alist_costs_for_aggregate(
|
|
872
|
+
amongo,
|
|
873
|
+
query=query,
|
|
874
|
+
)
|
|
875
|
+
return TaskCostAggregateResponse.build(
|
|
876
|
+
costs=costs,
|
|
877
|
+
query=query,
|
|
878
|
+
exchange_rates=exchange_rates,
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
def _validate_task_action_root_context(
|
|
882
|
+
self,
|
|
883
|
+
*,
|
|
884
|
+
current_task: TaskMongoObject,
|
|
885
|
+
ux_context: TaskUxContext | None,
|
|
886
|
+
) -> None:
|
|
887
|
+
"""Validate optional UI root context before running task actions.
|
|
888
|
+
|
|
889
|
+
Step by step:
|
|
890
|
+
1. skip validation when no UI root was provided;
|
|
891
|
+
2. resolve the current task's canonical root workspace id;
|
|
892
|
+
3. reject mismatched UI roots before creating any child task or relation.
|
|
893
|
+
"""
|
|
894
|
+
|
|
895
|
+
if ux_context is None or ux_context.selected_root_task_id is None:
|
|
896
|
+
return
|
|
897
|
+
current_root_id = self._task_workspace_root_id(current_task)
|
|
898
|
+
if ux_context.selected_root_task_id != current_root_id:
|
|
899
|
+
raise keble_exceptions.ClientSideInvalidParams(
|
|
900
|
+
invalid_params="ux_context.selected_root_task_id",
|
|
901
|
+
expected=str(current_root_id),
|
|
902
|
+
but_got=str(ux_context.selected_root_task_id),
|
|
903
|
+
alert_admin=False,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
async def _aprepare_create_related_task_action(
|
|
907
|
+
self,
|
|
908
|
+
*,
|
|
909
|
+
amongo: AsyncIOMotorClient,
|
|
910
|
+
current_task: TaskMongoObject,
|
|
911
|
+
action: CreateRelatedTaskAction,
|
|
912
|
+
) -> TaskMongoObject:
|
|
913
|
+
"""Load the parent task for a generic create-related-task action.
|
|
914
|
+
|
|
915
|
+
Step by step:
|
|
916
|
+
1. resolve `parent_task_id` to the current task when omitted;
|
|
917
|
+
2. load the parent task so `acreate(...)` can preserve root linkage;
|
|
918
|
+
3. return the parent task without any feature-specific side effects.
|
|
919
|
+
"""
|
|
920
|
+
|
|
921
|
+
parent_task_id = action.parent_task_id or current_task.id
|
|
922
|
+
return await self._aget_relation_task(
|
|
923
|
+
amongo=amongo,
|
|
924
|
+
task_id=parent_task_id,
|
|
925
|
+
field_name="parent_task_id",
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
async def _acomplete_created_task_if_needed(
|
|
929
|
+
self,
|
|
930
|
+
*,
|
|
931
|
+
amongo: AsyncIOMotorClient,
|
|
932
|
+
extended_aredis: ExtendedAsyncRedis | None,
|
|
933
|
+
created_task: TaskMongoObject,
|
|
934
|
+
complete_created_task: bool,
|
|
935
|
+
consuming_token: int,
|
|
936
|
+
) -> TaskMongoObject:
|
|
937
|
+
"""Mark a lightweight created task successful when the action requests it.
|
|
938
|
+
|
|
939
|
+
Step by step:
|
|
940
|
+
1. return the task unchanged for normal queued child work;
|
|
941
|
+
2. require Redis resources for task finalization;
|
|
942
|
+
3. delegate token/status accounting to the existing task success API.
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
if not complete_created_task:
|
|
946
|
+
return created_task
|
|
947
|
+
if extended_aredis is None:
|
|
948
|
+
raise keble_exceptions.ServerSideMissingParams(
|
|
949
|
+
missing_params="extended_aredis",
|
|
950
|
+
alert_admin=True,
|
|
951
|
+
)
|
|
952
|
+
return await self.aon_task_success(
|
|
953
|
+
amongo=amongo,
|
|
954
|
+
extended_aredis=extended_aredis,
|
|
955
|
+
task=created_task,
|
|
956
|
+
consuming_token=consuming_token,
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
async def _arollback_created_related_task(
|
|
960
|
+
self,
|
|
961
|
+
*,
|
|
962
|
+
amongo: AsyncIOMotorClient,
|
|
963
|
+
created_task_id: ObjectId,
|
|
964
|
+
) -> None:
|
|
965
|
+
"""Compensate a failed create-related-task action.
|
|
966
|
+
|
|
967
|
+
Step by step:
|
|
968
|
+
1. delete any relation rows already inserted that point at the orphaned child
|
|
969
|
+
(`to_task_id == created_task_id`);
|
|
970
|
+
2. delete the orphaned child task itself;
|
|
971
|
+
3. leave the original failure to propagate from the caller.
|
|
972
|
+
|
|
973
|
+
Used only on the relation-write failure path because standalone Mongo cannot run
|
|
974
|
+
the child + relation inserts in one atomic transaction.
|
|
975
|
+
"""
|
|
976
|
+
|
|
977
|
+
await self.crud_task_relation.adelete_multi(
|
|
978
|
+
amongo,
|
|
979
|
+
query=QueryBase(filters={"to_task_id": created_task_id}),
|
|
980
|
+
)
|
|
981
|
+
await self.crud_task.adelete(amongo, _id=created_task_id)
|
|
982
|
+
|
|
983
|
+
async def _aapply_create_related_task_action(
|
|
984
|
+
self,
|
|
985
|
+
*,
|
|
986
|
+
amongo: AsyncIOMotorClient,
|
|
987
|
+
current_task: TaskMongoObject,
|
|
988
|
+
action: CreateRelatedTaskAction,
|
|
989
|
+
extended_aredis: ExtendedAsyncRedis | None,
|
|
990
|
+
) -> CreateRelatedTaskActionedResult:
|
|
991
|
+
"""Apply one generic create-related-task action.
|
|
992
|
+
|
|
993
|
+
Step by step:
|
|
994
|
+
1. resolve parent/root context;
|
|
995
|
+
2. validate every source task belongs to the same root room;
|
|
996
|
+
3. create the child task from generic task fields only;
|
|
997
|
+
4. create relation rows from source tasks to the created task;
|
|
998
|
+
5. optionally mark the child task successful for lightweight generic work.
|
|
999
|
+
"""
|
|
1000
|
+
|
|
1001
|
+
parent_task = await self._aprepare_create_related_task_action(
|
|
1002
|
+
amongo=amongo,
|
|
1003
|
+
current_task=current_task,
|
|
1004
|
+
action=action,
|
|
1005
|
+
)
|
|
1006
|
+
from_task_ids = action.from_task_ids or [current_task.id]
|
|
1007
|
+
self._validate_unique_create_related_source_ids(from_task_ids=from_task_ids)
|
|
1008
|
+
parent_root_id = self._task_workspace_root_id(parent_task)
|
|
1009
|
+
for from_task_id in from_task_ids:
|
|
1010
|
+
from_task = await self._aget_relation_task(
|
|
1011
|
+
amongo=amongo,
|
|
1012
|
+
task_id=from_task_id,
|
|
1013
|
+
field_name="from_task_ids",
|
|
1014
|
+
)
|
|
1015
|
+
from_root_id = self._task_workspace_root_id(from_task)
|
|
1016
|
+
if from_root_id != parent_root_id:
|
|
1017
|
+
raise keble_exceptions.ClientSideInvalidParams(
|
|
1018
|
+
invalid_params="from_task_ids",
|
|
1019
|
+
expected=str(parent_root_id),
|
|
1020
|
+
but_got=f"from_task_id={from_task_id}; from_task_root_id={from_root_id}",
|
|
1021
|
+
alert_admin=False,
|
|
1022
|
+
)
|
|
1023
|
+
if action.task_type is None:
|
|
1024
|
+
raise keble_exceptions.ClientSideMissingParams(
|
|
1025
|
+
missing_params="actions[].task_type",
|
|
1026
|
+
alert_admin=False,
|
|
1027
|
+
)
|
|
1028
|
+
created_task = await self.acreate(
|
|
1029
|
+
amongo=amongo,
|
|
1030
|
+
expected_token=action.expected_token,
|
|
1031
|
+
owner=current_task.owner,
|
|
1032
|
+
progress_key=action.progress_key,
|
|
1033
|
+
image=action.image,
|
|
1034
|
+
title=action.title,
|
|
1035
|
+
subtitle=action.subtitle,
|
|
1036
|
+
metadata=action.metadata,
|
|
1037
|
+
task_type=action.task_type,
|
|
1038
|
+
timeout_mins=action.timeout_mins,
|
|
1039
|
+
sharing_scope=action.sharing_scope or current_task.sharing_scope,
|
|
1040
|
+
parent_task=parent_task.id,
|
|
1041
|
+
)
|
|
1042
|
+
# Saga compensation: standalone Mongo has no multi-document transactions, so the
|
|
1043
|
+
# child task and its relation rows are written as separate operations. If any
|
|
1044
|
+
# relation write fails (a concurrent duplicate hitting the unique edge index, or
|
|
1045
|
+
# an infra error mid-loop), roll back the just-created child and any partial
|
|
1046
|
+
# relation rows pointing at it, then re-raise unchanged so the workspace never
|
|
1047
|
+
# keeps an orphaned child without its intended relations.
|
|
1048
|
+
try:
|
|
1049
|
+
relations = await self.acreate_task_relations(
|
|
1050
|
+
amongo=amongo,
|
|
1051
|
+
objs_in=[
|
|
1052
|
+
TaskRelationCreate(
|
|
1053
|
+
root_task=parent_root_id,
|
|
1054
|
+
from_task_id=from_task_id,
|
|
1055
|
+
to_task_id=created_task.id,
|
|
1056
|
+
relation_type=action.relation_type,
|
|
1057
|
+
metadata=action.metadata,
|
|
1058
|
+
)
|
|
1059
|
+
for from_task_id in from_task_ids
|
|
1060
|
+
],
|
|
1061
|
+
)
|
|
1062
|
+
except Exception:
|
|
1063
|
+
await self._arollback_created_related_task(
|
|
1064
|
+
amongo=amongo, created_task_id=created_task.id
|
|
1065
|
+
)
|
|
1066
|
+
raise
|
|
1067
|
+
finalized_task = await self._acomplete_created_task_if_needed(
|
|
1068
|
+
amongo=amongo,
|
|
1069
|
+
extended_aredis=extended_aredis,
|
|
1070
|
+
created_task=created_task,
|
|
1071
|
+
complete_created_task=action.complete_created_task,
|
|
1072
|
+
consuming_token=action.consuming_token,
|
|
1073
|
+
)
|
|
1074
|
+
return CreateRelatedTaskActionedResult(
|
|
1075
|
+
created_task=TaskActionCreatedTask.build(task=finalized_task),
|
|
1076
|
+
created_relations=[
|
|
1077
|
+
TaskActionCreatedRelation.build(relation=relation)
|
|
1078
|
+
for relation in relations
|
|
1079
|
+
],
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
async def _aapply_one_task_action(
|
|
1083
|
+
self,
|
|
1084
|
+
*,
|
|
1085
|
+
amongo: AsyncIOMotorClient,
|
|
1086
|
+
current_task: TaskMongoObject,
|
|
1087
|
+
action: TaskAction,
|
|
1088
|
+
extended_aredis: ExtendedAsyncRedis | None,
|
|
1089
|
+
) -> tuple[CreateRelatedTaskActionedResult, TaskActionEvent]:
|
|
1090
|
+
"""Execute exactly one canonical task action and build its event.
|
|
1091
|
+
|
|
1092
|
+
The single canonical executor shared by the one-action agent surface
|
|
1093
|
+
(`aapply_action`) and the REST/programmatic batch (`aapply_actions`), so
|
|
1094
|
+
action semantics and event shape never fork between the two callers.
|
|
1095
|
+
"""
|
|
1096
|
+
|
|
1097
|
+
# CREATE_RELATED_TASK is the only action type today; keep the explicit
|
|
1098
|
+
# branch so a future action type cannot silently no-op.
|
|
1099
|
+
if action.action_type is not TaskActionType.CREATE_RELATED_TASK:
|
|
1100
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
1101
|
+
admin_note={"action_type": str(action.action_type)},
|
|
1102
|
+
alert_admin=True,
|
|
1103
|
+
but_got=f"unsupported task action type {action.action_type}",
|
|
1104
|
+
expected="a supported TaskActionType (CREATE_RELATED_TASK)",
|
|
1105
|
+
invalid_params="action_type",
|
|
1106
|
+
)
|
|
1107
|
+
result = await self._aapply_create_related_task_action(
|
|
1108
|
+
amongo=amongo,
|
|
1109
|
+
current_task=current_task,
|
|
1110
|
+
action=action,
|
|
1111
|
+
extended_aredis=extended_aredis,
|
|
1112
|
+
)
|
|
1113
|
+
event = TaskActionEvent(
|
|
1114
|
+
action_type=result.action_type.value,
|
|
1115
|
+
payload=result,
|
|
1116
|
+
root_id=str(
|
|
1117
|
+
result.created_task.root_task_id or result.created_task.task_id
|
|
1118
|
+
),
|
|
1119
|
+
object_id=str(result.created_task.task_id),
|
|
1120
|
+
correlation_id=result.created_task.progress_key,
|
|
1121
|
+
)
|
|
1122
|
+
return result, event
|
|
1123
|
+
|
|
1124
|
+
async def aapply_action(
|
|
1125
|
+
self,
|
|
1126
|
+
*,
|
|
1127
|
+
amongo: AsyncIOMotorClient,
|
|
1128
|
+
current_task: TaskMongoObject,
|
|
1129
|
+
action: TaskAction,
|
|
1130
|
+
ux_context: TaskUxContext | None = None,
|
|
1131
|
+
event_emitter: AgenticEventEmitter | None = None,
|
|
1132
|
+
extended_aredis: ExtendedAsyncRedis | None = None,
|
|
1133
|
+
) -> CreateRelatedTaskActionedResult:
|
|
1134
|
+
"""Apply ONE canonical task action (the agent one-action-per-call surface).
|
|
1135
|
+
|
|
1136
|
+
Multiplicity is expressed by the agent emitting several
|
|
1137
|
+
`mutate_task_workspace` tool calls; each becomes its own deferred
|
|
1138
|
+
approval card. This method intentionally takes a single
|
|
1139
|
+
`AgenticActionPayload` action — never a list wrapper.
|
|
1140
|
+
|
|
1141
|
+
Step by step:
|
|
1142
|
+
1. validate optional UI root context against the current task workspace;
|
|
1143
|
+
2. execute the single action through the shared canonical executor;
|
|
1144
|
+
3. emit the action event when an emitter is provided;
|
|
1145
|
+
4. return the single action result.
|
|
1146
|
+
"""
|
|
1147
|
+
|
|
1148
|
+
self._validate_task_action_root_context(
|
|
1149
|
+
current_task=current_task,
|
|
1150
|
+
ux_context=ux_context,
|
|
1151
|
+
)
|
|
1152
|
+
result, event = await self._aapply_one_task_action(
|
|
1153
|
+
amongo=amongo,
|
|
1154
|
+
current_task=current_task,
|
|
1155
|
+
action=action,
|
|
1156
|
+
extended_aredis=extended_aredis,
|
|
1157
|
+
)
|
|
1158
|
+
if event_emitter is not None:
|
|
1159
|
+
await event_emitter.aemit(event)
|
|
1160
|
+
# Drain detached event callbacks before returning (same boundary
|
|
1161
|
+
# discipline as the REST batch path), so the agent tool observes a
|
|
1162
|
+
# flushed, error-surfaced event stream.
|
|
1163
|
+
await event_emitter.adrain()
|
|
1164
|
+
return result
|
|
1165
|
+
|
|
1166
|
+
async def aapply_actions(
|
|
1167
|
+
self,
|
|
1168
|
+
*,
|
|
1169
|
+
amongo: AsyncIOMotorClient,
|
|
1170
|
+
current_task: TaskMongoObject,
|
|
1171
|
+
payload: TaskActions,
|
|
1172
|
+
ux_context: TaskUxContext | None = None,
|
|
1173
|
+
event_emitter: AgenticEventEmitter | None = None,
|
|
1174
|
+
extended_aredis: ExtendedAsyncRedis | None = None,
|
|
1175
|
+
) -> TaskActionedResults:
|
|
1176
|
+
"""Apply a plain REST/programmatic batch of canonical task actions in order.
|
|
1177
|
+
|
|
1178
|
+
This is the non-agent seam (REST endpoint, internal callers). The agent
|
|
1179
|
+
chat surface uses `aapply_action` (one action per tool call). Both share
|
|
1180
|
+
`_aapply_one_task_action`, so behavior never diverges.
|
|
1181
|
+
|
|
1182
|
+
Step by step:
|
|
1183
|
+
1. validate optional UI root context against the current task workspace;
|
|
1184
|
+
2. execute each action in caller-provided order via the shared executor;
|
|
1185
|
+
3. emit each event when an emitter is provided;
|
|
1186
|
+
4. return ordered action results plus the emitted event payloads.
|
|
1187
|
+
"""
|
|
1188
|
+
|
|
1189
|
+
self._validate_task_action_root_context(
|
|
1190
|
+
current_task=current_task,
|
|
1191
|
+
ux_context=ux_context,
|
|
1192
|
+
)
|
|
1193
|
+
actioned_results: list[CreateRelatedTaskActionedResult] = []
|
|
1194
|
+
events: list[TaskActionEvent] = []
|
|
1195
|
+
for action in payload.actions:
|
|
1196
|
+
result, event = await self._aapply_one_task_action(
|
|
1197
|
+
amongo=amongo,
|
|
1198
|
+
current_task=current_task,
|
|
1199
|
+
action=action,
|
|
1200
|
+
extended_aredis=extended_aredis,
|
|
1201
|
+
)
|
|
1202
|
+
actioned_results.append(result)
|
|
1203
|
+
events.append(event)
|
|
1204
|
+
if event_emitter is not None:
|
|
1205
|
+
await event_emitter.aemit(event)
|
|
1206
|
+
# Drain at the action-batch boundary: each action event's callbacks
|
|
1207
|
+
# (publish/persist) are detached by the helpers emitter, so join them
|
|
1208
|
+
# before returning so the caller observes a flushed, error-surfaced event
|
|
1209
|
+
# stream rather than racing in-flight callbacks.
|
|
1210
|
+
if event_emitter is not None:
|
|
1211
|
+
await event_emitter.adrain()
|
|
1212
|
+
return TaskActionedResults(
|
|
1213
|
+
actioned_results=actioned_results,
|
|
1214
|
+
events=events,
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
#
|
|
1218
|
+
#
|
|
1219
|
+
#
|
|
1220
|
+
#
|
|
1221
|
+
# Read
|
|
1222
|
+
#
|
|
1223
|
+
#
|
|
1224
|
+
#
|
|
1225
|
+
#
|
|
1226
|
+
@overload
|
|
1227
|
+
async def aget(
|
|
1228
|
+
self,
|
|
1229
|
+
*,
|
|
1230
|
+
amongo: AsyncIOMotorClient,
|
|
1231
|
+
task_id: ObjectId,
|
|
1232
|
+
task_type: Optional[str],
|
|
1233
|
+
include_childs: Literal[False] = False,
|
|
1234
|
+
) -> TaskMongoObject | None: ...
|
|
1235
|
+
|
|
1236
|
+
@overload
|
|
1237
|
+
async def aget(
|
|
1238
|
+
self,
|
|
1239
|
+
*,
|
|
1240
|
+
amongo: AsyncIOMotorClient,
|
|
1241
|
+
task_id: ObjectId,
|
|
1242
|
+
task_type: Optional[str],
|
|
1243
|
+
include_childs: Literal[True],
|
|
1244
|
+
) -> TaskMongoObjectExtended | None: ...
|
|
1245
|
+
|
|
1246
|
+
async def aget(
|
|
1247
|
+
self,
|
|
1248
|
+
*,
|
|
1249
|
+
amongo: AsyncIOMotorClient,
|
|
1250
|
+
task_id: ObjectId,
|
|
1251
|
+
task_type: Optional[str],
|
|
1252
|
+
include_childs: bool = False,
|
|
1253
|
+
) -> TaskMongoObject | TaskMongoObjectExtended | None:
|
|
1254
|
+
"""Get a task by ID asynchronously."""
|
|
1255
|
+
_filter: dict = {
|
|
1256
|
+
"_id": task_id,
|
|
1257
|
+
}
|
|
1258
|
+
if task_type is not None:
|
|
1259
|
+
_filter["task_type"] = task_type
|
|
1260
|
+
task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
|
|
1261
|
+
if task is None or not include_childs:
|
|
1262
|
+
return task
|
|
1263
|
+
|
|
1264
|
+
base_filter: dict = {}
|
|
1265
|
+
if task_type is not None:
|
|
1266
|
+
base_filter["task_type"] = task_type
|
|
1267
|
+
return await self._aget_task_tree_node(
|
|
1268
|
+
amongo=amongo,
|
|
1269
|
+
task=task,
|
|
1270
|
+
target_task_id=task_id,
|
|
1271
|
+
base_filter=base_filter,
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
@overload
|
|
1275
|
+
async def aowner_get(
|
|
1276
|
+
self,
|
|
1277
|
+
*,
|
|
1278
|
+
amongo: AsyncIOMotorClient,
|
|
1279
|
+
owner: str,
|
|
1280
|
+
task_id: ObjectId,
|
|
1281
|
+
task_type: Optional[str],
|
|
1282
|
+
sharing_scope: Optional[SharingScope],
|
|
1283
|
+
include_childs: Literal[False] = False,
|
|
1284
|
+
) -> TaskMongoObject | None: ...
|
|
1285
|
+
|
|
1286
|
+
@overload
|
|
1287
|
+
async def aowner_get(
|
|
1288
|
+
self,
|
|
1289
|
+
*,
|
|
1290
|
+
amongo: AsyncIOMotorClient,
|
|
1291
|
+
owner: str,
|
|
1292
|
+
task_id: ObjectId,
|
|
1293
|
+
task_type: Optional[str],
|
|
1294
|
+
sharing_scope: Optional[SharingScope],
|
|
1295
|
+
include_childs: Literal[True],
|
|
1296
|
+
) -> TaskMongoObjectExtended | None: ...
|
|
1297
|
+
|
|
1298
|
+
async def aowner_get(
|
|
1299
|
+
self,
|
|
1300
|
+
*,
|
|
1301
|
+
amongo: AsyncIOMotorClient,
|
|
1302
|
+
owner: str,
|
|
1303
|
+
task_id: ObjectId,
|
|
1304
|
+
task_type: Optional[str],
|
|
1305
|
+
sharing_scope: Optional[SharingScope],
|
|
1306
|
+
include_childs: bool = False,
|
|
1307
|
+
) -> TaskMongoObject | TaskMongoObjectExtended | None:
|
|
1308
|
+
"""Get a task by owner and ID asynchronously."""
|
|
1309
|
+
_filter: dict = {"_id": task_id, "owner": owner}
|
|
1310
|
+
if task_type is not None:
|
|
1311
|
+
_filter["task_type"] = task_type
|
|
1312
|
+
if sharing_scope is not None:
|
|
1313
|
+
_filter["sharing_scope"] = sharing_scope
|
|
1314
|
+
task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
|
|
1315
|
+
if task is None or not include_childs:
|
|
1316
|
+
return task
|
|
1317
|
+
|
|
1318
|
+
base_filter: dict = {"owner": owner}
|
|
1319
|
+
if task_type is not None:
|
|
1320
|
+
base_filter["task_type"] = task_type
|
|
1321
|
+
if sharing_scope is not None:
|
|
1322
|
+
base_filter["sharing_scope"] = sharing_scope
|
|
1323
|
+
return await self._aget_task_tree_node(
|
|
1324
|
+
amongo=amongo,
|
|
1325
|
+
task=task,
|
|
1326
|
+
target_task_id=task_id,
|
|
1327
|
+
base_filter=base_filter,
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
async def aresolve_task_room(
|
|
1331
|
+
self,
|
|
1332
|
+
*,
|
|
1333
|
+
amongo: AsyncIOMotorClient,
|
|
1334
|
+
owner: str,
|
|
1335
|
+
task_id: ObjectId,
|
|
1336
|
+
) -> TaskRoomResolution:
|
|
1337
|
+
"""Resolve one owner-visible task into its canonical root-room identity.
|
|
1338
|
+
|
|
1339
|
+
Step by step:
|
|
1340
|
+
1. load the requested task under owner permissions;
|
|
1341
|
+
2. resolve the durable root-task id from the stored task row;
|
|
1342
|
+
3. return only the canonical room identity without prescribing room-context shape.
|
|
1343
|
+
"""
|
|
1344
|
+
|
|
1345
|
+
requested_task = await self.aowner_get(
|
|
1346
|
+
amongo=amongo,
|
|
1347
|
+
owner=owner,
|
|
1348
|
+
task_id=task_id,
|
|
1349
|
+
task_type=None,
|
|
1350
|
+
sharing_scope=None,
|
|
1351
|
+
)
|
|
1352
|
+
if requested_task is None:
|
|
1353
|
+
raise keble_exceptions.NoObjectPermission(
|
|
1354
|
+
object_id=str(task_id),
|
|
1355
|
+
object_type="task",
|
|
1356
|
+
)
|
|
1357
|
+
return TaskRoomResolution(
|
|
1358
|
+
root_task_id=self._task_workspace_root_id(requested_task),
|
|
1359
|
+
requested_task_id=requested_task.id,
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
async def aget_task_room_graph_context(
|
|
1363
|
+
self,
|
|
1364
|
+
*,
|
|
1365
|
+
amongo: AsyncIOMotorClient,
|
|
1366
|
+
owner: str,
|
|
1367
|
+
root_task_id: ObjectId,
|
|
1368
|
+
focused_task_id: ObjectId | None = None,
|
|
1369
|
+
) -> TaskRoomGraphContext:
|
|
1370
|
+
"""Load generic task graph context for one owner-visible root task room.
|
|
1371
|
+
|
|
1372
|
+
Step by step:
|
|
1373
|
+
1. load the full root task tree under owner permissions;
|
|
1374
|
+
2. validate the focused task belongs to the same root task room;
|
|
1375
|
+
3. attach root-scoped relation rows without adding feature-specific fields.
|
|
1376
|
+
"""
|
|
1377
|
+
|
|
1378
|
+
root_task = await self.aowner_get(
|
|
1379
|
+
amongo=amongo,
|
|
1380
|
+
owner=owner,
|
|
1381
|
+
task_id=root_task_id,
|
|
1382
|
+
task_type=None,
|
|
1383
|
+
sharing_scope=None,
|
|
1384
|
+
include_childs=True,
|
|
1385
|
+
)
|
|
1386
|
+
if root_task is None:
|
|
1387
|
+
raise keble_exceptions.NoObjectPermission(
|
|
1388
|
+
object_id=str(root_task_id),
|
|
1389
|
+
object_type="task",
|
|
1390
|
+
)
|
|
1391
|
+
if self._task_workspace_root_id(root_task) != root_task.id:
|
|
1392
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
1393
|
+
admin_note={
|
|
1394
|
+
"root_task_id": root_task_id,
|
|
1395
|
+
"resolved_root_task_id": self._task_workspace_root_id(root_task),
|
|
1396
|
+
},
|
|
1397
|
+
alert_admin=True,
|
|
1398
|
+
but_got="requested task is not a root task",
|
|
1399
|
+
expected="root_task_id points to the root task room",
|
|
1400
|
+
invalid_params="root_task_id",
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
resolved_focused_task_id = focused_task_id or root_task.id
|
|
1404
|
+
if resolved_focused_task_id != root_task.id:
|
|
1405
|
+
focused_task = await self.aowner_get(
|
|
1406
|
+
amongo=amongo,
|
|
1407
|
+
owner=owner,
|
|
1408
|
+
task_id=resolved_focused_task_id,
|
|
1409
|
+
task_type=None,
|
|
1410
|
+
sharing_scope=None,
|
|
1411
|
+
)
|
|
1412
|
+
if focused_task is None:
|
|
1413
|
+
raise keble_exceptions.NoObjectPermission(
|
|
1414
|
+
object_id=str(resolved_focused_task_id),
|
|
1415
|
+
object_type="task",
|
|
1416
|
+
)
|
|
1417
|
+
if self._task_workspace_root_id(focused_task) != root_task.id:
|
|
1418
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
1419
|
+
admin_note={
|
|
1420
|
+
"root_task_id": root_task.id,
|
|
1421
|
+
"focused_task_id": resolved_focused_task_id,
|
|
1422
|
+
"focused_root_task_id": self._task_workspace_root_id(
|
|
1423
|
+
focused_task
|
|
1424
|
+
),
|
|
1425
|
+
},
|
|
1426
|
+
alert_admin=True,
|
|
1427
|
+
but_got="focused task belongs to another root task room",
|
|
1428
|
+
expected="focused task belongs to root_task_id",
|
|
1429
|
+
invalid_params="focused_task_id",
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
relations = await self.alist_task_relations_by_root(
|
|
1433
|
+
amongo=amongo,
|
|
1434
|
+
root_task=root_task.id,
|
|
1435
|
+
)
|
|
1436
|
+
return TaskRoomGraphContext.build(
|
|
1437
|
+
root_task=root_task,
|
|
1438
|
+
focused_task_id=resolved_focused_task_id,
|
|
1439
|
+
relations=relations,
|
|
1440
|
+
)
|
|
1441
|
+
|
|
1442
|
+
@overload
|
|
1443
|
+
async def apublic_get(
|
|
1444
|
+
self,
|
|
1445
|
+
*,
|
|
1446
|
+
amongo: AsyncIOMotorClient,
|
|
1447
|
+
task_id: ObjectId,
|
|
1448
|
+
task_type: Optional[str],
|
|
1449
|
+
include_childs: Literal[False] = False,
|
|
1450
|
+
) -> TaskMongoObject | None: ...
|
|
1451
|
+
|
|
1452
|
+
@overload
|
|
1453
|
+
async def apublic_get(
|
|
1454
|
+
self,
|
|
1455
|
+
*,
|
|
1456
|
+
amongo: AsyncIOMotorClient,
|
|
1457
|
+
task_id: ObjectId,
|
|
1458
|
+
task_type: Optional[str],
|
|
1459
|
+
include_childs: Literal[True],
|
|
1460
|
+
) -> TaskMongoObjectExtended | None: ...
|
|
1461
|
+
|
|
1462
|
+
async def apublic_get(
|
|
1463
|
+
self,
|
|
1464
|
+
*,
|
|
1465
|
+
amongo: AsyncIOMotorClient,
|
|
1466
|
+
task_id: ObjectId,
|
|
1467
|
+
task_type: Optional[str],
|
|
1468
|
+
include_childs: bool = False,
|
|
1469
|
+
) -> TaskMongoObject | TaskMongoObjectExtended | None:
|
|
1470
|
+
"""Get a public task by ID asynchronously."""
|
|
1471
|
+
_filter = {"_id": task_id, "sharing_scope": SharingScope.PUBLIC}
|
|
1472
|
+
if task_type is not None:
|
|
1473
|
+
_filter["task_type"] = task_type
|
|
1474
|
+
task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
|
|
1475
|
+
if task is None or not include_childs:
|
|
1476
|
+
return task
|
|
1477
|
+
|
|
1478
|
+
base_filter: dict = {"sharing_scope": SharingScope.PUBLIC}
|
|
1479
|
+
if task_type is not None:
|
|
1480
|
+
base_filter["task_type"] = task_type
|
|
1481
|
+
return await self._aget_task_tree_node(
|
|
1482
|
+
amongo=amongo,
|
|
1483
|
+
task=task,
|
|
1484
|
+
target_task_id=task_id,
|
|
1485
|
+
base_filter=base_filter,
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
@overload
|
|
1489
|
+
async def apublic_get_by_public_id(
|
|
1490
|
+
self,
|
|
1491
|
+
*,
|
|
1492
|
+
amongo: AsyncIOMotorClient,
|
|
1493
|
+
public_id: str,
|
|
1494
|
+
task_type: Optional[str],
|
|
1495
|
+
include_childs: Literal[False] = False,
|
|
1496
|
+
) -> TaskMongoObject | None: ...
|
|
1497
|
+
|
|
1498
|
+
@overload
|
|
1499
|
+
async def apublic_get_by_public_id(
|
|
1500
|
+
self,
|
|
1501
|
+
*,
|
|
1502
|
+
amongo: AsyncIOMotorClient,
|
|
1503
|
+
public_id: str,
|
|
1504
|
+
task_type: Optional[str],
|
|
1505
|
+
include_childs: Literal[True],
|
|
1506
|
+
) -> TaskMongoObjectExtended | None: ...
|
|
1507
|
+
|
|
1508
|
+
async def apublic_get_by_public_id(
|
|
1509
|
+
self,
|
|
1510
|
+
*,
|
|
1511
|
+
amongo: AsyncIOMotorClient,
|
|
1512
|
+
public_id: str,
|
|
1513
|
+
task_type: Optional[str],
|
|
1514
|
+
include_childs: bool = False,
|
|
1515
|
+
) -> TaskMongoObject | TaskMongoObjectExtended | None:
|
|
1516
|
+
"""Get a public task by public_id asynchronously."""
|
|
1517
|
+
_filter: dict = {"public_id": public_id, "sharing_scope": SharingScope.PUBLIC}
|
|
1518
|
+
if task_type is not None:
|
|
1519
|
+
_filter["task_type"] = task_type
|
|
1520
|
+
task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
|
|
1521
|
+
if task is None or not include_childs:
|
|
1522
|
+
return task
|
|
1523
|
+
|
|
1524
|
+
base_filter: dict = {"sharing_scope": SharingScope.PUBLIC}
|
|
1525
|
+
if task_type is not None:
|
|
1526
|
+
base_filter["task_type"] = task_type
|
|
1527
|
+
return await self._aget_task_tree_node(
|
|
1528
|
+
amongo=amongo,
|
|
1529
|
+
task=task,
|
|
1530
|
+
target_task_id=task.id,
|
|
1531
|
+
base_filter=base_filter,
|
|
1532
|
+
)
|
|
1533
|
+
|
|
1534
|
+
async def _aget_multi(
|
|
1535
|
+
self,
|
|
1536
|
+
*,
|
|
1537
|
+
amongo: AsyncIOMotorClient,
|
|
1538
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1539
|
+
skip: Optional[int],
|
|
1540
|
+
limit: Optional[int],
|
|
1541
|
+
filter_: dict,
|
|
1542
|
+
) -> List[TaskMongoObject]:
|
|
1543
|
+
"""Get multiple tasks by filter asynchronously."""
|
|
1544
|
+
tasks = await self.crud_task.aget_multi(
|
|
1545
|
+
amongo,
|
|
1546
|
+
query=QueryBase(
|
|
1547
|
+
filters=filter_,
|
|
1548
|
+
order_by=[("created", DESCENDING)], # by created time
|
|
1549
|
+
offset=skip,
|
|
1550
|
+
limit=limit,
|
|
1551
|
+
),
|
|
1552
|
+
)
|
|
1553
|
+
for i, task in enumerate(tasks):
|
|
1554
|
+
if task.unfinshed_timeout:
|
|
1555
|
+
# mark it as failure
|
|
1556
|
+
logger.warning(
|
|
1557
|
+
f"[Task] Found an unfinished timeout task: {task.id}. Task stage = {task.stage}, created = {task.created}, task type = {task.task_type}, attempts = {task.attempts}"
|
|
1558
|
+
)
|
|
1559
|
+
if task.allow_to_retry:
|
|
1560
|
+
# try to restart the task
|
|
1561
|
+
logger.warning(
|
|
1562
|
+
f"[Task] Found an unfinished timeout task: {task.id}. Start to retry."
|
|
1563
|
+
)
|
|
1564
|
+
await self.astart(
|
|
1565
|
+
amongo=amongo,
|
|
1566
|
+
extended_aredis=extended_aredis,
|
|
1567
|
+
task=task,
|
|
1568
|
+
)
|
|
1569
|
+
else:
|
|
1570
|
+
# mark it as timeout failure
|
|
1571
|
+
logger.warning(
|
|
1572
|
+
f"[Task] Found an unfinished timeout task: {task.id}. Make it as Timeout."
|
|
1573
|
+
)
|
|
1574
|
+
await self.aon_task_timeout(
|
|
1575
|
+
amongo=amongo,
|
|
1576
|
+
task=task,
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
updated_task = await self.crud_task.afirst_by_id(amongo, _id=task.id)
|
|
1580
|
+
assert updated_task is not None, (
|
|
1581
|
+
f"[Task] Internal error, expected not none object for id {task.id}, but got none."
|
|
1582
|
+
)
|
|
1583
|
+
tasks[i] = updated_task
|
|
1584
|
+
|
|
1585
|
+
return tasks
|
|
1586
|
+
|
|
1587
|
+
async def _aget_root_tasks(
|
|
1588
|
+
self,
|
|
1589
|
+
*,
|
|
1590
|
+
amongo: AsyncIOMotorClient,
|
|
1591
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1592
|
+
skip: int,
|
|
1593
|
+
limit: int,
|
|
1594
|
+
base_filter: dict,
|
|
1595
|
+
) -> List[TaskMongoObject]:
|
|
1596
|
+
filter_ = dict(base_filter)
|
|
1597
|
+
filter_["parent_task"] = None
|
|
1598
|
+
return await self._aget_multi(
|
|
1599
|
+
amongo=amongo,
|
|
1600
|
+
extended_aredis=extended_aredis,
|
|
1601
|
+
skip=skip,
|
|
1602
|
+
limit=limit,
|
|
1603
|
+
filter_=filter_,
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
async def _aget_descendant_tasks(
|
|
1607
|
+
self,
|
|
1608
|
+
*,
|
|
1609
|
+
amongo: AsyncIOMotorClient,
|
|
1610
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1611
|
+
root_ids: List[ObjectId],
|
|
1612
|
+
base_filter: dict,
|
|
1613
|
+
) -> List[TaskMongoObject]:
|
|
1614
|
+
if len(root_ids) == 0:
|
|
1615
|
+
return []
|
|
1616
|
+
filter_ = dict(base_filter)
|
|
1617
|
+
filter_["root_task"] = {"$in": root_ids}
|
|
1618
|
+
filter_["_id"] = {"$nin": root_ids}
|
|
1619
|
+
return await self._aget_multi(
|
|
1620
|
+
amongo=amongo,
|
|
1621
|
+
extended_aredis=extended_aredis,
|
|
1622
|
+
skip=None,
|
|
1623
|
+
limit=None,
|
|
1624
|
+
filter_=filter_,
|
|
1625
|
+
)
|
|
1626
|
+
|
|
1627
|
+
async def _aget_tasks_in_root_tree(
|
|
1628
|
+
self,
|
|
1629
|
+
*,
|
|
1630
|
+
amongo: AsyncIOMotorClient,
|
|
1631
|
+
root_id: ObjectId,
|
|
1632
|
+
base_filter: dict,
|
|
1633
|
+
) -> List[TaskMongoObject]:
|
|
1634
|
+
return await self.crud_task.aget_tasks_in_root_tree(
|
|
1635
|
+
amongo, root_id=root_id, base_filter=base_filter
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
async def alist_task_relations_by_root(
|
|
1639
|
+
self,
|
|
1640
|
+
*,
|
|
1641
|
+
amongo: AsyncIOMotorClient,
|
|
1642
|
+
root_task: ObjectId,
|
|
1643
|
+
) -> list[TaskRelationMongoObject]:
|
|
1644
|
+
"""List extra lineage relations scoped to one root-task workspace.
|
|
1645
|
+
|
|
1646
|
+
Step by step:
|
|
1647
|
+
1. query relation rows by the durable `root_task` workspace id;
|
|
1648
|
+
2. order rows by creation time for stable canvas rendering;
|
|
1649
|
+
3. return only relation rows, leaving tree composition to callers.
|
|
1650
|
+
"""
|
|
1651
|
+
|
|
1652
|
+
await self._ensure_task_relation_indexes(amongo=amongo)
|
|
1653
|
+
return await self.crud_task_relation.aget_multi(
|
|
1654
|
+
amongo,
|
|
1655
|
+
query=QueryBase(
|
|
1656
|
+
filters={"root_task": root_task},
|
|
1657
|
+
order_by=[("created", ASCENDING)],
|
|
1658
|
+
),
|
|
1659
|
+
)
|
|
1660
|
+
|
|
1661
|
+
async def _aget_task_tree_node(
|
|
1662
|
+
self,
|
|
1663
|
+
*,
|
|
1664
|
+
amongo: AsyncIOMotorClient,
|
|
1665
|
+
task: TaskMongoObject,
|
|
1666
|
+
target_task_id: ObjectId,
|
|
1667
|
+
base_filter: dict,
|
|
1668
|
+
) -> TaskMongoObjectExtended:
|
|
1669
|
+
root_id = task.root_task or task.id
|
|
1670
|
+
tasks = await self._aget_tasks_in_root_tree(
|
|
1671
|
+
amongo=amongo, root_id=root_id, base_filter=base_filter
|
|
1672
|
+
)
|
|
1673
|
+
subtree_roots = self._build_task_tree(roots=[task], tasks=tasks)
|
|
1674
|
+
found = self._find_task_in_tree(nodes=subtree_roots, task_id=target_task_id)
|
|
1675
|
+
if found is not None:
|
|
1676
|
+
return found
|
|
1677
|
+
return TaskMongoObjectExtended(**task.model_dump(), childs=[])
|
|
1678
|
+
|
|
1679
|
+
def _build_task_tree(
|
|
1680
|
+
self, *, roots: List[TaskMongoObject], tasks: List[TaskMongoObject]
|
|
1681
|
+
) -> List[TaskMongoObjectExtended]:
|
|
1682
|
+
return build_task_tree(roots=roots, tasks=tasks)
|
|
1683
|
+
|
|
1684
|
+
def _find_task_in_tree(
|
|
1685
|
+
self, *, nodes: List[TaskMongoObjectExtended], task_id: ObjectId
|
|
1686
|
+
) -> Optional[TaskMongoObjectExtended]:
|
|
1687
|
+
return find_task_in_tree(nodes=nodes, task_id=task_id)
|
|
1688
|
+
|
|
1689
|
+
async def _aget_multi_with_optional_childs(
|
|
1690
|
+
self,
|
|
1691
|
+
*,
|
|
1692
|
+
amongo: AsyncIOMotorClient,
|
|
1693
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1694
|
+
skip: int,
|
|
1695
|
+
limit: int,
|
|
1696
|
+
base_filter: dict,
|
|
1697
|
+
include_childs: bool,
|
|
1698
|
+
root_task: Optional[ObjectId],
|
|
1699
|
+
parent_task: Optional[ObjectId],
|
|
1700
|
+
) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
|
|
1701
|
+
"""List either flat root tasks or task trees from one shared filter.
|
|
1702
|
+
|
|
1703
|
+
Step by step:
|
|
1704
|
+
1. flat calls list root tasks by default so normal task lists stay workspace-level;
|
|
1705
|
+
2. explicit `parent_task` flat calls still list that parent's direct children;
|
|
1706
|
+
3. tree calls preserve the existing paginated-root plus descendant assembly.
|
|
1707
|
+
"""
|
|
1708
|
+
|
|
1709
|
+
if not include_childs:
|
|
1710
|
+
flat_filter = dict(base_filter)
|
|
1711
|
+
if parent_task is None:
|
|
1712
|
+
flat_filter["parent_task"] = None
|
|
1713
|
+
return await self._aget_multi(
|
|
1714
|
+
amongo=amongo,
|
|
1715
|
+
extended_aredis=extended_aredis,
|
|
1716
|
+
skip=skip,
|
|
1717
|
+
limit=limit,
|
|
1718
|
+
filter_=flat_filter,
|
|
1719
|
+
)
|
|
1720
|
+
|
|
1721
|
+
if parent_task is not None:
|
|
1722
|
+
page_roots = await self._aget_multi(
|
|
1723
|
+
amongo=amongo,
|
|
1724
|
+
extended_aredis=extended_aredis,
|
|
1725
|
+
skip=skip,
|
|
1726
|
+
limit=limit,
|
|
1727
|
+
filter_=base_filter,
|
|
1728
|
+
)
|
|
1729
|
+
if len(page_roots) == 0:
|
|
1730
|
+
return []
|
|
1731
|
+
top_root_ids = list(
|
|
1732
|
+
{
|
|
1733
|
+
task.root_task if task.root_task is not None else task.id
|
|
1734
|
+
for task in page_roots
|
|
1735
|
+
}
|
|
1736
|
+
)
|
|
1737
|
+
tree_filter = {
|
|
1738
|
+
k: v
|
|
1739
|
+
for k, v in base_filter.items()
|
|
1740
|
+
if k not in {"parent_task", "root_task"}
|
|
1741
|
+
}
|
|
1742
|
+
if root_task is not None:
|
|
1743
|
+
tree_filter["root_task"] = root_task
|
|
1744
|
+
else:
|
|
1745
|
+
tree_filter["root_task"] = {"$in": top_root_ids}
|
|
1746
|
+
tasks_in_trees = await self._aget_multi(
|
|
1747
|
+
amongo=amongo,
|
|
1748
|
+
extended_aredis=extended_aredis,
|
|
1749
|
+
skip=None,
|
|
1750
|
+
limit=None,
|
|
1751
|
+
filter_=tree_filter,
|
|
1752
|
+
)
|
|
1753
|
+
return self._build_task_tree(roots=page_roots, tasks=tasks_in_trees)
|
|
1754
|
+
|
|
1755
|
+
roots = await self._aget_root_tasks(
|
|
1756
|
+
amongo=amongo,
|
|
1757
|
+
extended_aredis=extended_aredis,
|
|
1758
|
+
skip=skip,
|
|
1759
|
+
limit=limit,
|
|
1760
|
+
base_filter=base_filter,
|
|
1761
|
+
)
|
|
1762
|
+
children = await self._aget_descendant_tasks(
|
|
1763
|
+
amongo=amongo,
|
|
1764
|
+
extended_aredis=extended_aredis,
|
|
1765
|
+
root_ids=[t.id for t in roots],
|
|
1766
|
+
base_filter=base_filter,
|
|
1767
|
+
)
|
|
1768
|
+
return self._build_task_tree(roots=roots, tasks=roots + children)
|
|
1769
|
+
|
|
1770
|
+
@overload
|
|
1771
|
+
async def aget_multi(
|
|
1772
|
+
self,
|
|
1773
|
+
*,
|
|
1774
|
+
amongo: AsyncIOMotorClient,
|
|
1775
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1776
|
+
skip: int,
|
|
1777
|
+
limit: int,
|
|
1778
|
+
task_types: Optional[List[str]],
|
|
1779
|
+
include_childs: Literal[False] = False,
|
|
1780
|
+
root_task: Optional[ObjectId] = None,
|
|
1781
|
+
stages: Optional[List[str]] = None,
|
|
1782
|
+
title_contains: Optional[str] = None,
|
|
1783
|
+
parent_task: Optional[ObjectId] = None,
|
|
1784
|
+
) -> List[TaskMongoObject]: ...
|
|
1785
|
+
|
|
1786
|
+
@overload
|
|
1787
|
+
async def aget_multi(
|
|
1788
|
+
self,
|
|
1789
|
+
*,
|
|
1790
|
+
amongo: AsyncIOMotorClient,
|
|
1791
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1792
|
+
skip: int,
|
|
1793
|
+
limit: int,
|
|
1794
|
+
task_types: Optional[List[str]],
|
|
1795
|
+
include_childs: Literal[True],
|
|
1796
|
+
root_task: Optional[ObjectId] = None,
|
|
1797
|
+
stages: Optional[List[str]] = None,
|
|
1798
|
+
title_contains: Optional[str] = None,
|
|
1799
|
+
parent_task: Optional[ObjectId] = None,
|
|
1800
|
+
) -> List[TaskMongoObjectExtended]: ...
|
|
1801
|
+
|
|
1802
|
+
async def aget_multi(
|
|
1803
|
+
self,
|
|
1804
|
+
*,
|
|
1805
|
+
amongo: AsyncIOMotorClient,
|
|
1806
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1807
|
+
skip: int,
|
|
1808
|
+
limit: int,
|
|
1809
|
+
task_types: Optional[List[str]],
|
|
1810
|
+
include_childs: bool = False,
|
|
1811
|
+
root_task: Optional[ObjectId] = None,
|
|
1812
|
+
stages: Optional[List[str]] = None,
|
|
1813
|
+
title_contains: Optional[str] = None,
|
|
1814
|
+
parent_task: Optional[ObjectId] = None,
|
|
1815
|
+
) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
|
|
1816
|
+
"""Get multiple tasks asynchronously.
|
|
1817
|
+
|
|
1818
|
+
Side effect if changes:
|
|
1819
|
+
- `aowner_get_multi` and agent query registry tests mirror this generic
|
|
1820
|
+
read shape; keep stage/title filters aligned across owner/public
|
|
1821
|
+
variants.
|
|
1822
|
+
- Tree/list endpoints rely on `_aget_multi_with_optional_childs` for
|
|
1823
|
+
the `include_childs` return-type split.
|
|
1824
|
+
"""
|
|
1825
|
+
base_filter: dict = {}
|
|
1826
|
+
if task_types is not None:
|
|
1827
|
+
base_filter["task_type"] = {"$in": task_types}
|
|
1828
|
+
if root_task is not None:
|
|
1829
|
+
base_filter["root_task"] = root_task
|
|
1830
|
+
if parent_task is not None:
|
|
1831
|
+
base_filter["parent_task"] = parent_task
|
|
1832
|
+
if stages is not None:
|
|
1833
|
+
base_filter["stage"] = {"$in": stages}
|
|
1834
|
+
if title_contains is not None and title_contains.strip():
|
|
1835
|
+
base_filter["title"] = {
|
|
1836
|
+
"$regex": re.escape(title_contains.strip()),
|
|
1837
|
+
"$options": "i",
|
|
1838
|
+
}
|
|
1839
|
+
return await self._aget_multi_with_optional_childs(
|
|
1840
|
+
amongo=amongo,
|
|
1841
|
+
extended_aredis=extended_aredis,
|
|
1842
|
+
skip=skip,
|
|
1843
|
+
limit=limit,
|
|
1844
|
+
base_filter=base_filter,
|
|
1845
|
+
include_childs=include_childs,
|
|
1846
|
+
root_task=root_task,
|
|
1847
|
+
parent_task=parent_task,
|
|
1848
|
+
)
|
|
1849
|
+
|
|
1850
|
+
@overload
|
|
1851
|
+
async def aowner_get_multi(
|
|
1852
|
+
self,
|
|
1853
|
+
*,
|
|
1854
|
+
amongo: AsyncIOMotorClient,
|
|
1855
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1856
|
+
owner: str,
|
|
1857
|
+
skip: int,
|
|
1858
|
+
limit: int,
|
|
1859
|
+
task_types: Optional[List[str]],
|
|
1860
|
+
sharing_scopes: Optional[List[SharingScope]],
|
|
1861
|
+
include_childs: Literal[False] = False,
|
|
1862
|
+
root_task: Optional[ObjectId] = None,
|
|
1863
|
+
stages: Optional[List[str]] = None,
|
|
1864
|
+
title_contains: Optional[str] = None,
|
|
1865
|
+
parent_task: Optional[ObjectId] = None,
|
|
1866
|
+
) -> List[TaskMongoObject]: ...
|
|
1867
|
+
|
|
1868
|
+
@overload
|
|
1869
|
+
async def aowner_get_multi(
|
|
1870
|
+
self,
|
|
1871
|
+
*,
|
|
1872
|
+
amongo: AsyncIOMotorClient,
|
|
1873
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1874
|
+
owner: str,
|
|
1875
|
+
skip: int,
|
|
1876
|
+
limit: int,
|
|
1877
|
+
task_types: Optional[List[str]],
|
|
1878
|
+
sharing_scopes: Optional[List[SharingScope]],
|
|
1879
|
+
include_childs: Literal[True],
|
|
1880
|
+
root_task: Optional[ObjectId] = None,
|
|
1881
|
+
stages: Optional[List[str]] = None,
|
|
1882
|
+
title_contains: Optional[str] = None,
|
|
1883
|
+
parent_task: Optional[ObjectId] = None,
|
|
1884
|
+
) -> List[TaskMongoObjectExtended]: ...
|
|
1885
|
+
|
|
1886
|
+
async def aowner_get_multi(
|
|
1887
|
+
self,
|
|
1888
|
+
*,
|
|
1889
|
+
amongo: AsyncIOMotorClient,
|
|
1890
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1891
|
+
owner: str,
|
|
1892
|
+
skip: int,
|
|
1893
|
+
limit: int,
|
|
1894
|
+
task_types: Optional[List[str]],
|
|
1895
|
+
sharing_scopes: Optional[List[SharingScope]],
|
|
1896
|
+
include_childs: bool = False,
|
|
1897
|
+
root_task: Optional[ObjectId] = None,
|
|
1898
|
+
stages: Optional[List[str]] = None,
|
|
1899
|
+
title_contains: Optional[str] = None,
|
|
1900
|
+
parent_task: Optional[ObjectId] = None,
|
|
1901
|
+
) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
|
|
1902
|
+
"""Get multiple tasks by owner asynchronously.
|
|
1903
|
+
|
|
1904
|
+
Collection: the task collection. Filter shape: `owner` equality rides
|
|
1905
|
+
the `owner + parent_task + created` index prefix; `task_type`,
|
|
1906
|
+
`sharing_scope`, `root_task`, `parent_task` are indexed-or-residual
|
|
1907
|
+
equality filters as before. Additive (2.10.0): `stages` (`$in`
|
|
1908
|
+
equality) and `title_contains` (escaped case-insensitive regex) are
|
|
1909
|
+
RESIDUAL filters inside the owner-bounded shape — intended for
|
|
1910
|
+
bounded agent/list reads (limit <= 50), never unbounded sweeps.
|
|
1911
|
+
"""
|
|
1912
|
+
base_filter: dict = {"owner": owner}
|
|
1913
|
+
if task_types is not None:
|
|
1914
|
+
base_filter["task_type"] = {"$in": task_types}
|
|
1915
|
+
if sharing_scopes is not None:
|
|
1916
|
+
base_filter["sharing_scope"] = {"$in": sharing_scopes}
|
|
1917
|
+
if root_task is not None:
|
|
1918
|
+
base_filter["root_task"] = root_task
|
|
1919
|
+
if parent_task is not None:
|
|
1920
|
+
base_filter["parent_task"] = parent_task
|
|
1921
|
+
if stages is not None:
|
|
1922
|
+
base_filter["stage"] = {"$in": stages}
|
|
1923
|
+
if title_contains is not None and title_contains.strip():
|
|
1924
|
+
base_filter["title"] = {
|
|
1925
|
+
"$regex": re.escape(title_contains.strip()),
|
|
1926
|
+
"$options": "i",
|
|
1927
|
+
}
|
|
1928
|
+
return await self._aget_multi_with_optional_childs(
|
|
1929
|
+
amongo=amongo,
|
|
1930
|
+
extended_aredis=extended_aredis,
|
|
1931
|
+
skip=skip,
|
|
1932
|
+
limit=limit,
|
|
1933
|
+
base_filter=base_filter,
|
|
1934
|
+
include_childs=include_childs,
|
|
1935
|
+
root_task=root_task,
|
|
1936
|
+
parent_task=parent_task,
|
|
1937
|
+
)
|
|
1938
|
+
|
|
1939
|
+
@overload
|
|
1940
|
+
async def apublic_get_multi(
|
|
1941
|
+
self,
|
|
1942
|
+
*,
|
|
1943
|
+
amongo: AsyncIOMotorClient,
|
|
1944
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1945
|
+
skip: int,
|
|
1946
|
+
limit: int,
|
|
1947
|
+
task_types: Optional[List[str]],
|
|
1948
|
+
include_childs: Literal[False] = False,
|
|
1949
|
+
root_task: Optional[ObjectId] = None,
|
|
1950
|
+
stages: Optional[List[str]] = None,
|
|
1951
|
+
title_contains: Optional[str] = None,
|
|
1952
|
+
parent_task: Optional[ObjectId] = None,
|
|
1953
|
+
) -> List[TaskMongoObject]: ...
|
|
1954
|
+
|
|
1955
|
+
@overload
|
|
1956
|
+
async def apublic_get_multi(
|
|
1957
|
+
self,
|
|
1958
|
+
*,
|
|
1959
|
+
amongo: AsyncIOMotorClient,
|
|
1960
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1961
|
+
skip: int,
|
|
1962
|
+
limit: int,
|
|
1963
|
+
task_types: Optional[List[str]],
|
|
1964
|
+
include_childs: Literal[True],
|
|
1965
|
+
root_task: Optional[ObjectId] = None,
|
|
1966
|
+
stages: Optional[List[str]] = None,
|
|
1967
|
+
title_contains: Optional[str] = None,
|
|
1968
|
+
parent_task: Optional[ObjectId] = None,
|
|
1969
|
+
) -> List[TaskMongoObjectExtended]: ...
|
|
1970
|
+
|
|
1971
|
+
async def apublic_get_multi(
|
|
1972
|
+
self,
|
|
1973
|
+
*,
|
|
1974
|
+
amongo: AsyncIOMotorClient,
|
|
1975
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
1976
|
+
skip: int,
|
|
1977
|
+
limit: int,
|
|
1978
|
+
task_types: Optional[List[str]],
|
|
1979
|
+
include_childs: bool = False,
|
|
1980
|
+
root_task: Optional[ObjectId] = None,
|
|
1981
|
+
stages: Optional[List[str]] = None,
|
|
1982
|
+
title_contains: Optional[str] = None,
|
|
1983
|
+
parent_task: Optional[ObjectId] = None,
|
|
1984
|
+
) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
|
|
1985
|
+
"""Get multiple public tasks asynchronously.
|
|
1986
|
+
|
|
1987
|
+
`stages` (stage equality) and `title_contains` (escaped case-insensitive
|
|
1988
|
+
regex) are RESIDUAL filters inside the PUBLIC-bounded shape — mirroring
|
|
1989
|
+
`aowner_get_multi`. The overloads always promised them; the implementation
|
|
1990
|
+
now honors them instead of silently dropping them.
|
|
1991
|
+
"""
|
|
1992
|
+
base_filter: dict = {"sharing_scope": SharingScope.PUBLIC}
|
|
1993
|
+
if task_types is not None:
|
|
1994
|
+
base_filter["task_type"] = {"$in": task_types}
|
|
1995
|
+
if root_task is not None:
|
|
1996
|
+
base_filter["root_task"] = root_task
|
|
1997
|
+
if parent_task is not None:
|
|
1998
|
+
base_filter["parent_task"] = parent_task
|
|
1999
|
+
if stages is not None:
|
|
2000
|
+
base_filter["stage"] = {"$in": stages}
|
|
2001
|
+
if title_contains is not None and title_contains.strip():
|
|
2002
|
+
base_filter["title"] = {
|
|
2003
|
+
"$regex": re.escape(title_contains.strip()),
|
|
2004
|
+
"$options": "i",
|
|
2005
|
+
}
|
|
2006
|
+
return await self._aget_multi_with_optional_childs(
|
|
2007
|
+
amongo=amongo,
|
|
2008
|
+
extended_aredis=extended_aredis,
|
|
2009
|
+
skip=skip,
|
|
2010
|
+
limit=limit,
|
|
2011
|
+
base_filter=base_filter,
|
|
2012
|
+
include_childs=include_childs,
|
|
2013
|
+
root_task=root_task,
|
|
2014
|
+
parent_task=parent_task,
|
|
2015
|
+
)
|
|
2016
|
+
|
|
2017
|
+
async def apublic_list_indexable(
|
|
2018
|
+
self,
|
|
2019
|
+
*,
|
|
2020
|
+
amongo: AsyncIOMotorClient,
|
|
2021
|
+
task_types: List[str],
|
|
2022
|
+
stages: List[str],
|
|
2023
|
+
limit: int,
|
|
2024
|
+
) -> List[TaskPublicRef]:
|
|
2025
|
+
"""List PUBLIC tasks of given types+stages, newest-first, id+updated only.
|
|
2026
|
+
|
|
2027
|
+
Objective: the lean read backing the dynamic-page sitemap. Unlike
|
|
2028
|
+
`apublic_get_multi` this is PROJECTED (only `_id` + `updated`), needs no
|
|
2029
|
+
`extended_aredis`, never loads childs, and returns the tiny `TaskPublicRef`
|
|
2030
|
+
instead of the full `TaskMongoObject` — so a crawl-time listing of many
|
|
2031
|
+
public reports stays cheap.
|
|
2032
|
+
|
|
2033
|
+
Filter shape (served by `sharing_scope + task_type + stage + updated DESC`):
|
|
2034
|
+
equality on `sharing_scope=PUBLIC`, `task_type $in`, `stage $in`, sorted by
|
|
2035
|
+
`updated` DESC and capped by `limit`. Converts raw projection docs ->
|
|
2036
|
+
`TaskPublicRef`.
|
|
2037
|
+
"""
|
|
2038
|
+
docs = await self.crud_task.aget_multi_docs(
|
|
2039
|
+
amongo,
|
|
2040
|
+
query=QueryBase(
|
|
2041
|
+
filters={
|
|
2042
|
+
"sharing_scope": SharingScope.PUBLIC,
|
|
2043
|
+
"task_type": {"$in": task_types},
|
|
2044
|
+
"stage": {"$in": stages},
|
|
2045
|
+
},
|
|
2046
|
+
order_by=[("updated", DESCENDING)],
|
|
2047
|
+
limit=limit,
|
|
2048
|
+
),
|
|
2049
|
+
project={"_id": 1, "updated": 1},
|
|
2050
|
+
)
|
|
2051
|
+
return [TaskPublicRef(**doc) for doc in docs]
|
|
2052
|
+
|
|
2053
|
+
#
|
|
2054
|
+
#
|
|
2055
|
+
#
|
|
2056
|
+
#
|
|
2057
|
+
# Update
|
|
2058
|
+
#
|
|
2059
|
+
#
|
|
2060
|
+
#
|
|
2061
|
+
#
|
|
2062
|
+
async def aupdate(
|
|
2063
|
+
self,
|
|
2064
|
+
*,
|
|
2065
|
+
amongo: AsyncIOMotorClient,
|
|
2066
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
2067
|
+
task_id: ObjectId,
|
|
2068
|
+
obj_in: TaskUpdate,
|
|
2069
|
+
public_id_prefix: Optional[str] = None,
|
|
2070
|
+
public_id_length: Optional[int] = None,
|
|
2071
|
+
):
|
|
2072
|
+
"""Update a task asynchronously."""
|
|
2073
|
+
task = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
|
|
2074
|
+
if task is None:
|
|
2075
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
2076
|
+
admin_note={"task_id": task_id},
|
|
2077
|
+
alert_admin=True,
|
|
2078
|
+
but_got="task is None",
|
|
2079
|
+
expected="task exists",
|
|
2080
|
+
invalid_params=f"task_id={task_id}",
|
|
2081
|
+
)
|
|
2082
|
+
if obj_in.sharing_scope == SharingScope.PUBLIC:
|
|
2083
|
+
await self._ensure_task_public_id(
|
|
2084
|
+
amongo=amongo,
|
|
2085
|
+
task_id=task_id,
|
|
2086
|
+
public_id_prefix=public_id_prefix,
|
|
2087
|
+
public_id_length=public_id_length,
|
|
2088
|
+
)
|
|
2089
|
+
await self.crud_task.aupdate(
|
|
2090
|
+
amongo,
|
|
2091
|
+
_id=task_id,
|
|
2092
|
+
obj_in=obj_in.model_dump(exclude_unset=True),
|
|
2093
|
+
)
|
|
2094
|
+
|
|
2095
|
+
async def aowner_update(
|
|
2096
|
+
self,
|
|
2097
|
+
*,
|
|
2098
|
+
amongo: AsyncIOMotorClient,
|
|
2099
|
+
owner: str,
|
|
2100
|
+
task_id: ObjectId,
|
|
2101
|
+
obj_in: TaskUpdate,
|
|
2102
|
+
public_id_prefix: Optional[str] = None,
|
|
2103
|
+
public_id_length: Optional[int] = None,
|
|
2104
|
+
):
|
|
2105
|
+
"""Update a task by owner asynchronously."""
|
|
2106
|
+
r = await self.aowner_get(
|
|
2107
|
+
amongo=amongo,
|
|
2108
|
+
owner=owner,
|
|
2109
|
+
sharing_scope=None,
|
|
2110
|
+
task_id=task_id,
|
|
2111
|
+
task_type=None,
|
|
2112
|
+
)
|
|
2113
|
+
if r is None:
|
|
2114
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
2115
|
+
admin_note={"task_id": task_id, "owner": owner},
|
|
2116
|
+
alert_admin=True,
|
|
2117
|
+
but_got="task is None",
|
|
2118
|
+
expected="task exists",
|
|
2119
|
+
invalid_params=f"task_id={task_id}",
|
|
2120
|
+
)
|
|
2121
|
+
if obj_in.sharing_scope == SharingScope.PUBLIC:
|
|
2122
|
+
await self._ensure_task_public_id(
|
|
2123
|
+
amongo=amongo,
|
|
2124
|
+
task_id=task_id,
|
|
2125
|
+
public_id_prefix=public_id_prefix,
|
|
2126
|
+
public_id_length=public_id_length,
|
|
2127
|
+
)
|
|
2128
|
+
await self.crud_task.aupdate(
|
|
2129
|
+
amongo,
|
|
2130
|
+
_id=task_id,
|
|
2131
|
+
obj_in=obj_in.model_dump(exclude_unset=True),
|
|
2132
|
+
)
|
|
2133
|
+
|
|
2134
|
+
#
|
|
2135
|
+
#
|
|
2136
|
+
#
|
|
2137
|
+
#
|
|
2138
|
+
# Event Handler
|
|
2139
|
+
#
|
|
2140
|
+
#
|
|
2141
|
+
#
|
|
2142
|
+
#
|
|
2143
|
+
def _get_task_start_lock_key(self, task: TaskMongoObject) -> str:
|
|
2144
|
+
"""Get the Redis key for task start lock."""
|
|
2145
|
+
return f"just_started:{task.id}"
|
|
2146
|
+
|
|
2147
|
+
async def _acheck_start_lock(
|
|
2148
|
+
self, extended_aredis: ExtendedAsyncRedis, task: TaskMongoObject
|
|
2149
|
+
) -> bool:
|
|
2150
|
+
"""Check if task is locked. Returns True if not locked."""
|
|
2151
|
+
return await extended_aredis.get(self._get_task_start_lock_key(task)) is None
|
|
2152
|
+
|
|
2153
|
+
async def _alock_start(
|
|
2154
|
+
self, extended_aredis: ExtendedAsyncRedis, task: TaskMongoObject
|
|
2155
|
+
):
|
|
2156
|
+
"""Lock a task for starting."""
|
|
2157
|
+
await extended_aredis.set(
|
|
2158
|
+
self._get_task_start_lock_key(task),
|
|
2159
|
+
"1",
|
|
2160
|
+
ex=60, # expire after 1 min
|
|
2161
|
+
)
|
|
2162
|
+
|
|
2163
|
+
async def _aclear_lock(
|
|
2164
|
+
self, extended_aredis: ExtendedAsyncRedis, task: TaskMongoObject
|
|
2165
|
+
):
|
|
2166
|
+
"""Clear the task start lock."""
|
|
2167
|
+
await extended_aredis.delete(self._get_task_start_lock_key(task))
|
|
2168
|
+
|
|
2169
|
+
# `before_sleep` logs EVERY failed attempt with its full traceback.
|
|
2170
|
+
# Without it, `reraise=True` surfaces only the LAST attempt's exception —
|
|
2171
|
+
# in a live incident the real attempt-1 failure (a crashed spawn tool) was
|
|
2172
|
+
# silently discarded and only a misleading downstream error reached the
|
|
2173
|
+
# task row, costing the diagnosis.
|
|
2174
|
+
@tenacity.retry(
|
|
2175
|
+
stop=tenacity.stop_after_attempt(3),
|
|
2176
|
+
reraise=True,
|
|
2177
|
+
before_sleep=tenacity.before_sleep_log(logger, logging.ERROR, exc_info=True),
|
|
2178
|
+
)
|
|
2179
|
+
async def _ahandle_task(
|
|
2180
|
+
self,
|
|
2181
|
+
*,
|
|
2182
|
+
amongo: AsyncIOMotorClient,
|
|
2183
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
2184
|
+
aneo4j: Optional[Neo4jAsyncDriver] = None,
|
|
2185
|
+
qdrant_client: Optional[AsyncQdrantClient] = None,
|
|
2186
|
+
task: TaskMongoObject,
|
|
2187
|
+
task_handler_metadata: Optional[TaskMetadata] = None,
|
|
2188
|
+
event_emitter: AgenticEventEmitter,
|
|
2189
|
+
):
|
|
2190
|
+
"""Run one task handler and route its returned response.
|
|
2191
|
+
|
|
2192
|
+
Step by step:
|
|
2193
|
+
1. increment attempts and merge metadata;
|
|
2194
|
+
2. build the request AS an `AgentDbDeps` (DB clients + owner + the shared
|
|
2195
|
+
room `event_emitter` + `usage_recorder`) — the handler uses it directly;
|
|
2196
|
+
3. await the handler's `TaskHandlerResponse | None`;
|
|
2197
|
+
4. `None` => the handler delegated completion to another process; leave the
|
|
2198
|
+
task PROCESSING (a separate worker finalizes it later);
|
|
2199
|
+
5. otherwise route success/failure through the runtime finalizers, which
|
|
2200
|
+
emit the terminal `KEBLE_TASK` lifecycle event via `event_emitter`.
|
|
2201
|
+
|
|
2202
|
+
Side effects:
|
|
2203
|
+
- replaces the old `aon_task_completed` callback inversion-of-control: the
|
|
2204
|
+
runtime (not the handler) owns persistence, token accounting, and the
|
|
2205
|
+
terminal event.
|
|
2206
|
+
"""
|
|
2207
|
+
# increment attempts for each retry
|
|
2208
|
+
task = await self.crud_task.aincrement_attempts(amongo, _id=task.id)
|
|
2209
|
+
merged_metadata = _merge_metadata(task.metadata, task_handler_metadata)
|
|
2210
|
+
|
|
2211
|
+
# `ExtendedAsyncRedis` IS an `AsyncRedis` subclass, so it satisfies both the
|
|
2212
|
+
# base `aredis` field and the extended `extended_aredis` field.
|
|
2213
|
+
request = TaskHandlerRequest(
|
|
2214
|
+
amongo=amongo,
|
|
2215
|
+
aredis=extended_aredis,
|
|
2216
|
+
extended_aredis=extended_aredis,
|
|
2217
|
+
aneo4j=aneo4j,
|
|
2218
|
+
aqdrant=qdrant_client,
|
|
2219
|
+
owner=task.owner,
|
|
2220
|
+
language=task.language or Language.ENGLISH,
|
|
2221
|
+
event_emitter=event_emitter,
|
|
2222
|
+
task=task,
|
|
2223
|
+
metadata=merged_metadata,
|
|
2224
|
+
)
|
|
2225
|
+
response = await self._task_handler(request)
|
|
2226
|
+
|
|
2227
|
+
# None => the handler delegated completion elsewhere; do NOT finalize.
|
|
2228
|
+
if response is None:
|
|
2229
|
+
return
|
|
2230
|
+
|
|
2231
|
+
if response.success:
|
|
2232
|
+
await self.aon_task_success(
|
|
2233
|
+
amongo=amongo,
|
|
2234
|
+
task=response.task,
|
|
2235
|
+
consuming_token=response.consuming_token,
|
|
2236
|
+
extended_aredis=extended_aredis,
|
|
2237
|
+
aneo4j=aneo4j,
|
|
2238
|
+
metadata=merged_metadata,
|
|
2239
|
+
event_emitter=event_emitter,
|
|
2240
|
+
)
|
|
2241
|
+
else:
|
|
2242
|
+
await self.aon_task_failure(
|
|
2243
|
+
amongo=amongo,
|
|
2244
|
+
extended_aredis=extended_aredis,
|
|
2245
|
+
error=response.error
|
|
2246
|
+
or f"Task handled response has a status = {response.success}",
|
|
2247
|
+
exception_type=response.exception_type or TaskExceptionType.UNKNOWN,
|
|
2248
|
+
task=response.task,
|
|
2249
|
+
aneo4j=aneo4j,
|
|
2250
|
+
metadata=merged_metadata,
|
|
2251
|
+
event_emitter=event_emitter,
|
|
2252
|
+
)
|
|
2253
|
+
|
|
2254
|
+
async def astart(
|
|
2255
|
+
self,
|
|
2256
|
+
*,
|
|
2257
|
+
get_amongo: Optional[Callable[[], AsyncIOMotorClient]] = None,
|
|
2258
|
+
get_extended_aredis: Optional[Callable[[], ExtendedAsyncRedis]] = None,
|
|
2259
|
+
get_aneo4j: Optional[Callable[[], Optional[Neo4jAsyncDriver]]] = None,
|
|
2260
|
+
get_qdrant_client: Optional[Callable[[], Optional[AsyncQdrantClient]]] = None,
|
|
2261
|
+
amongo: Optional[AsyncIOMotorClient] = None,
|
|
2262
|
+
extended_aredis: Optional[ExtendedAsyncRedis] = None,
|
|
2263
|
+
aneo4j: Optional[Neo4jAsyncDriver] = None,
|
|
2264
|
+
qdrant_client: Optional[AsyncQdrantClient] = None,
|
|
2265
|
+
task_id: Optional[ObjectId] = None,
|
|
2266
|
+
task: Optional[TaskMongoObject] = None,
|
|
2267
|
+
# metadata to pass into handler request
|
|
2268
|
+
task_handler_metadata: Optional[TaskMetadata] = None,
|
|
2269
|
+
task_lifecycle_event_emitter: AgenticEventEmitter | None = None,
|
|
2270
|
+
):
|
|
2271
|
+
"""Start one task and always finalize uncaught handler failures into task state.
|
|
2272
|
+
|
|
2273
|
+
Step by step this entrypoint:
|
|
2274
|
+
1. resolves the required runtime resources;
|
|
2275
|
+
2. checks the single-start Redis lock;
|
|
2276
|
+
3. moves the task into `PROCESSING`;
|
|
2277
|
+
4. runs the configured task handler with retry support;
|
|
2278
|
+
5. if the handler still raises, marks the task as `FAILURE` unless it was
|
|
2279
|
+
already finalized by the handler itself;
|
|
2280
|
+
6. emits lifecycle events after persisted stage transitions when provided;
|
|
2281
|
+
7. always clears the single-start Redis lock before the method exits, even
|
|
2282
|
+
when the processing transition itself fails.
|
|
2283
|
+
"""
|
|
2284
|
+
if amongo is None:
|
|
2285
|
+
if get_amongo is None:
|
|
2286
|
+
raise keble_exceptions.ServerSideMissingParams(
|
|
2287
|
+
admin_note={"get_amongo": get_amongo},
|
|
2288
|
+
alert_admin=True,
|
|
2289
|
+
missing_params="get_amongo",
|
|
2290
|
+
)
|
|
2291
|
+
amongo = get_amongo()
|
|
2292
|
+
if extended_aredis is None:
|
|
2293
|
+
if get_extended_aredis is None:
|
|
2294
|
+
raise keble_exceptions.ServerSideMissingParams(
|
|
2295
|
+
admin_note={"get_extended_aredis": get_extended_aredis},
|
|
2296
|
+
alert_admin=True,
|
|
2297
|
+
missing_params="get_extended_aredis",
|
|
2298
|
+
)
|
|
2299
|
+
extended_aredis = get_extended_aredis()
|
|
2300
|
+
# neo4j + getter can be none
|
|
2301
|
+
if aneo4j is None and get_aneo4j is not None:
|
|
2302
|
+
aneo4j = get_aneo4j()
|
|
2303
|
+
# qdrant + getter can be none
|
|
2304
|
+
if qdrant_client is None and get_qdrant_client is not None:
|
|
2305
|
+
qdrant_client = get_qdrant_client()
|
|
2306
|
+
|
|
2307
|
+
db_obj = await self._aget_task_for_event(
|
|
2308
|
+
amongo=amongo, task_id=task_id, task=task
|
|
2309
|
+
)
|
|
2310
|
+
# Try to start the task
|
|
2311
|
+
allow_proceed = await self._acheck_start_lock(
|
|
2312
|
+
extended_aredis=extended_aredis, task=db_obj
|
|
2313
|
+
)
|
|
2314
|
+
if not allow_proceed:
|
|
2315
|
+
logger.warning(f"[Task] Task {db_obj.id} already started, skip to start.")
|
|
2316
|
+
return db_obj
|
|
2317
|
+
|
|
2318
|
+
if not db_obj.stage.allow_to_start:
|
|
2319
|
+
logger.warning(
|
|
2320
|
+
f"[Task] Can not start task: {db_obj.id}, current stage = {db_obj.stage}"
|
|
2321
|
+
)
|
|
2322
|
+
raise TaskFailedToStartException(
|
|
2323
|
+
f"Can not start task: {db_obj.id}, current stage = {db_obj.stage}"
|
|
2324
|
+
)
|
|
2325
|
+
|
|
2326
|
+
try:
|
|
2327
|
+
await self._alock_start(extended_aredis=extended_aredis, task=db_obj)
|
|
2328
|
+
await self.aon_task_processing(
|
|
2329
|
+
amongo=amongo,
|
|
2330
|
+
task=db_obj,
|
|
2331
|
+
event_emitter=task_lifecycle_event_emitter,
|
|
2332
|
+
)
|
|
2333
|
+
try:
|
|
2334
|
+
await self._ahandle_task(
|
|
2335
|
+
amongo=amongo,
|
|
2336
|
+
extended_aredis=extended_aredis,
|
|
2337
|
+
task=db_obj,
|
|
2338
|
+
aneo4j=aneo4j,
|
|
2339
|
+
qdrant_client=qdrant_client,
|
|
2340
|
+
task_handler_metadata=task_handler_metadata,
|
|
2341
|
+
event_emitter=task_lifecycle_event_emitter
|
|
2342
|
+
or AgenticEventEmitter(),
|
|
2343
|
+
)
|
|
2344
|
+
except Exception as exc:
|
|
2345
|
+
logger.critical(
|
|
2346
|
+
f"[Task] Task failed during execution: {db_obj.id}, with error: {str(exc)}\n{traceback.format_exc()}"
|
|
2347
|
+
)
|
|
2348
|
+
latest_task = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
|
|
2349
|
+
assert latest_task is not None, (
|
|
2350
|
+
f"Task with ID {db_obj.id} not found after handler failure"
|
|
2351
|
+
)
|
|
2352
|
+
if latest_task.stage not in {TaskStage.SUCCESS, TaskStage.FAILURE}:
|
|
2353
|
+
failure_error = str(exc)
|
|
2354
|
+
failure_exception_type = TaskExceptionType.UNKNOWN
|
|
2355
|
+
if isinstance(exc, TaskException):
|
|
2356
|
+
failure_error = exc.error or failure_error
|
|
2357
|
+
failure_exception_type = exc.exception_type
|
|
2358
|
+
await self.aon_task_failure(
|
|
2359
|
+
amongo=amongo,
|
|
2360
|
+
extended_aredis=extended_aredis,
|
|
2361
|
+
aneo4j=aneo4j,
|
|
2362
|
+
task=latest_task,
|
|
2363
|
+
error=failure_error,
|
|
2364
|
+
exception_type=failure_exception_type,
|
|
2365
|
+
metadata=task_handler_metadata,
|
|
2366
|
+
event_emitter=task_lifecycle_event_emitter,
|
|
2367
|
+
)
|
|
2368
|
+
finally:
|
|
2369
|
+
await self._aclear_lock(extended_aredis=extended_aredis, task=db_obj)
|
|
2370
|
+
# Backstop drain: every aon_task_* lifecycle method already self-drains,
|
|
2371
|
+
# but join any residual detached emit here so the worker envelope never
|
|
2372
|
+
# returns (and releases the Celery lease) with spawn-on-terminal or
|
|
2373
|
+
# terminal-subtree-settle callbacks still in flight.
|
|
2374
|
+
if task_lifecycle_event_emitter is not None:
|
|
2375
|
+
await task_lifecycle_event_emitter.adrain()
|
|
2376
|
+
|
|
2377
|
+
result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
|
|
2378
|
+
assert result is not None, f"Task with ID {db_obj.id} not found after start"
|
|
2379
|
+
return result
|
|
2380
|
+
|
|
2381
|
+
async def aretry_all_undone(
|
|
2382
|
+
self,
|
|
2383
|
+
*,
|
|
2384
|
+
amongo: AsyncIOMotorClient,
|
|
2385
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
2386
|
+
aneo4j: Optional[Neo4jAsyncDriver] = None,
|
|
2387
|
+
get_aneo4j: Optional[Callable[[], Optional[Neo4jAsyncDriver]]] = None,
|
|
2388
|
+
task_handler_metadata: Optional[TaskMetadata] = None,
|
|
2389
|
+
include_tasks_created_within_minutes: int = 30,
|
|
2390
|
+
):
|
|
2391
|
+
"""Retry all undone tasks asynchronously."""
|
|
2392
|
+
# fetch all undone
|
|
2393
|
+
after = utc_now() - timedelta(
|
|
2394
|
+
minutes=include_tasks_created_within_minutes
|
|
2395
|
+
)
|
|
2396
|
+
db_objs = await self.crud_task.aget_multi(
|
|
2397
|
+
amongo,
|
|
2398
|
+
query=QueryBase(
|
|
2399
|
+
filters={
|
|
2400
|
+
"stage": {"$in": [TaskStage.PENDING, TaskStage.PROCESSING]},
|
|
2401
|
+
"created": {"$gt": after},
|
|
2402
|
+
},
|
|
2403
|
+
order_by=[("created", DESCENDING)], # by created time
|
|
2404
|
+
),
|
|
2405
|
+
)
|
|
2406
|
+
for db_obj in db_objs:
|
|
2407
|
+
if db_obj.allow_to_retry:
|
|
2408
|
+
# restart it
|
|
2409
|
+
await self.astart(
|
|
2410
|
+
amongo=amongo,
|
|
2411
|
+
extended_aredis=extended_aredis,
|
|
2412
|
+
aneo4j=aneo4j,
|
|
2413
|
+
get_aneo4j=get_aneo4j,
|
|
2414
|
+
task_handler_metadata=task_handler_metadata,
|
|
2415
|
+
task=db_obj,
|
|
2416
|
+
)
|
|
2417
|
+
else:
|
|
2418
|
+
# set it as failure due to timeout
|
|
2419
|
+
await self.aon_task_timeout(amongo=amongo, task=db_obj)
|
|
2420
|
+
|
|
2421
|
+
async def aget_tasks_by_stage_since(
|
|
2422
|
+
self,
|
|
2423
|
+
*,
|
|
2424
|
+
amongo: AsyncIOMotorClient,
|
|
2425
|
+
stages: Optional[List[TaskStage]] = None,
|
|
2426
|
+
hours: int = 1,
|
|
2427
|
+
) -> List[TaskMongoObject]:
|
|
2428
|
+
"""Get tasks from the last N hours with specified statuses."""
|
|
2429
|
+
cutoff_time = utc_now() - timedelta(hours=hours)
|
|
2430
|
+
return await self.crud_task.aget_multi(
|
|
2431
|
+
amongo,
|
|
2432
|
+
query=QueryBase(
|
|
2433
|
+
filters={
|
|
2434
|
+
"stage": {
|
|
2435
|
+
"$in": stages or [TaskStage.PENDING, TaskStage.PROCESSING]
|
|
2436
|
+
},
|
|
2437
|
+
"created": {"$gt": cutoff_time},
|
|
2438
|
+
},
|
|
2439
|
+
order_by=[("created", DESCENDING)],
|
|
2440
|
+
),
|
|
2441
|
+
)
|
|
2442
|
+
|
|
2443
|
+
async def _aget_task_for_event(
|
|
2444
|
+
self,
|
|
2445
|
+
*,
|
|
2446
|
+
amongo: AsyncIOMotorClient,
|
|
2447
|
+
task_id: Optional[ObjectId],
|
|
2448
|
+
task: Optional[TaskMongoObject],
|
|
2449
|
+
):
|
|
2450
|
+
"""Get a task for an event asynchronously."""
|
|
2451
|
+
if task_id is not None:
|
|
2452
|
+
db_obj = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
|
|
2453
|
+
if db_obj is None:
|
|
2454
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
2455
|
+
admin_note={"task_id": task_id},
|
|
2456
|
+
alert_admin=True,
|
|
2457
|
+
but_got="report is None",
|
|
2458
|
+
expected="report exists",
|
|
2459
|
+
invalid_params=f"task_id={task_id}",
|
|
2460
|
+
)
|
|
2461
|
+
return db_obj
|
|
2462
|
+
elif task is not None:
|
|
2463
|
+
return task
|
|
2464
|
+
else:
|
|
2465
|
+
raise keble_exceptions.ServerSideInvalidParams(
|
|
2466
|
+
admin_note={},
|
|
2467
|
+
alert_admin=True,
|
|
2468
|
+
but_got="no params",
|
|
2469
|
+
expected="at least one of task_id or task",
|
|
2470
|
+
invalid_params="task_id or task",
|
|
2471
|
+
)
|
|
2472
|
+
|
|
2473
|
+
async def _aemit_task_lifecycle_event(
|
|
2474
|
+
self,
|
|
2475
|
+
*,
|
|
2476
|
+
task: TaskMongoObject,
|
|
2477
|
+
event_emitter: AgenticEventEmitter | None,
|
|
2478
|
+
) -> None:
|
|
2479
|
+
"""Emit one canonical task lifecycle event after a persisted mutation.
|
|
2480
|
+
|
|
2481
|
+
Step by step:
|
|
2482
|
+
1. do nothing when the caller did not provide an event emitter;
|
|
2483
|
+
2. build the lifecycle payload from the reloaded task row;
|
|
2484
|
+
3. scope the event to the root task room while identifying the changed task.
|
|
2485
|
+
"""
|
|
2486
|
+
|
|
2487
|
+
if event_emitter is None:
|
|
2488
|
+
return
|
|
2489
|
+
event = TaskLifecycleEvent(
|
|
2490
|
+
action_type=TaskEventType.TASK_STAGE_CHANGED.value,
|
|
2491
|
+
status=_task_lifecycle_event_status(stage=task.stage),
|
|
2492
|
+
payload=TaskLifecycleEventPayload(task=task),
|
|
2493
|
+
root_id=str(task.root_task or task.id),
|
|
2494
|
+
object_id=str(task.id),
|
|
2495
|
+
correlation_id=task.progress_key,
|
|
2496
|
+
)
|
|
2497
|
+
await event_emitter.aemit(event)
|
|
2498
|
+
# Drain at the lifecycle boundary: the helpers emitter detaches callbacks, but
|
|
2499
|
+
# lifecycle callbacks carry DOMAIN side effects that the workflow reads right
|
|
2500
|
+
# after (spawn-on-terminal, terminal-subtree settle, chat-memory writes). The
|
|
2501
|
+
# task-row mutation is already persisted before this emit, so joining the
|
|
2502
|
+
# detached chain here makes each aon_task_* atomic w.r.t. its own callbacks and
|
|
2503
|
+
# guarantees those side effects complete before the lifecycle method returns.
|
|
2504
|
+
# Side effect if changes: removing this re-opens the spawn-on-terminal /
|
|
2505
|
+
# subtree-settle race in keble.backend (subagent_workflow lifecycle callback).
|
|
2506
|
+
await event_emitter.adrain()
|
|
2507
|
+
|
|
2508
|
+
async def aon_task_success(
|
|
2509
|
+
self,
|
|
2510
|
+
*,
|
|
2511
|
+
amongo: AsyncIOMotorClient,
|
|
2512
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
2513
|
+
aneo4j: Optional[Neo4jAsyncDriver] = None,
|
|
2514
|
+
task_id: Optional[ObjectId] = None,
|
|
2515
|
+
task: Optional[TaskMongoObject] = None,
|
|
2516
|
+
consuming_token: int, # how many token should be consume (actually)
|
|
2517
|
+
metadata: Optional[TaskMetadata] = None,
|
|
2518
|
+
event_emitter: AgenticEventEmitter | None = None,
|
|
2519
|
+
):
|
|
2520
|
+
"""Mark a task successful and optionally emit a lifecycle event."""
|
|
2521
|
+
db_obj = await self._aget_task_for_event(
|
|
2522
|
+
amongo=amongo, task=task, task_id=task_id
|
|
2523
|
+
)
|
|
2524
|
+
update_consumed_token = db_obj.consumed_token
|
|
2525
|
+
if consuming_token < db_obj.consumed_token:
|
|
2526
|
+
# do a recovery
|
|
2527
|
+
recover_amount = db_obj.consumed_token - consuming_token
|
|
2528
|
+
recovered = await self._execute_token_consumption_handler(
|
|
2529
|
+
TokenConsumptionPayload(
|
|
2530
|
+
consumption_type=TokenConsumptionType.RECOVER,
|
|
2531
|
+
owner=db_obj.owner,
|
|
2532
|
+
task_id=db_obj.id,
|
|
2533
|
+
token=recover_amount,
|
|
2534
|
+
amongo=amongo,
|
|
2535
|
+
extended_aredis=extended_aredis,
|
|
2536
|
+
aneo4j=aneo4j,
|
|
2537
|
+
metadata=metadata,
|
|
2538
|
+
)
|
|
2539
|
+
)
|
|
2540
|
+
if recovered:
|
|
2541
|
+
update_consumed_token = consuming_token
|
|
2542
|
+
elif consuming_token > db_obj.consumed_token:
|
|
2543
|
+
# do a token consumption
|
|
2544
|
+
consume_amount = consuming_token - db_obj.consumed_token
|
|
2545
|
+
consumed = await self._execute_token_consumption_handler(
|
|
2546
|
+
TokenConsumptionPayload(
|
|
2547
|
+
token=consume_amount,
|
|
2548
|
+
owner=db_obj.owner,
|
|
2549
|
+
consumption_type=TokenConsumptionType.CONSUME,
|
|
2550
|
+
task_id=db_obj.id,
|
|
2551
|
+
amongo=amongo,
|
|
2552
|
+
extended_aredis=extended_aredis,
|
|
2553
|
+
aneo4j=aneo4j,
|
|
2554
|
+
metadata=metadata,
|
|
2555
|
+
)
|
|
2556
|
+
)
|
|
2557
|
+
if consumed:
|
|
2558
|
+
update_consumed_token = consuming_token
|
|
2559
|
+
|
|
2560
|
+
# Mark as done
|
|
2561
|
+
now = utc_now()
|
|
2562
|
+
await self.crud_task.aupdate(
|
|
2563
|
+
amongo,
|
|
2564
|
+
_id=db_obj.id,
|
|
2565
|
+
obj_in={
|
|
2566
|
+
"stage": TaskStage.SUCCESS,
|
|
2567
|
+
"success_ts": int(now.timestamp()),
|
|
2568
|
+
"consumed_token": update_consumed_token,
|
|
2569
|
+
},
|
|
2570
|
+
)
|
|
2571
|
+
|
|
2572
|
+
result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
|
|
2573
|
+
assert result is not None, f"Task with ID {db_obj.id} not found after update"
|
|
2574
|
+
await self._aemit_task_lifecycle_event(
|
|
2575
|
+
task=result,
|
|
2576
|
+
event_emitter=event_emitter,
|
|
2577
|
+
)
|
|
2578
|
+
return result
|
|
2579
|
+
|
|
2580
|
+
async def aon_task_processing(
|
|
2581
|
+
self,
|
|
2582
|
+
*,
|
|
2583
|
+
amongo: AsyncIOMotorClient,
|
|
2584
|
+
task_id: Optional[ObjectId] = None,
|
|
2585
|
+
task: Optional[TaskMongoObject] = None,
|
|
2586
|
+
event_emitter: AgenticEventEmitter | None = None,
|
|
2587
|
+
) -> TaskMongoObject:
|
|
2588
|
+
"""Mark a task processing and optionally emit a lifecycle event."""
|
|
2589
|
+
db_obj = await self._aget_task_for_event(
|
|
2590
|
+
amongo=amongo, task=task, task_id=task_id
|
|
2591
|
+
)
|
|
2592
|
+
if not db_obj.stage.allow_to_start:
|
|
2593
|
+
logger.warning(
|
|
2594
|
+
f"[Task] Can not switch to processing for task: {db_obj.id}, current stage = {db_obj.stage}"
|
|
2595
|
+
)
|
|
2596
|
+
return db_obj
|
|
2597
|
+
# Mark as processing
|
|
2598
|
+
now = utc_now()
|
|
2599
|
+
await self.crud_task.aupdate(
|
|
2600
|
+
amongo,
|
|
2601
|
+
_id=db_obj.id,
|
|
2602
|
+
obj_in={
|
|
2603
|
+
"stage": TaskStage.PROCESSING,
|
|
2604
|
+
"started_ts": int(now.timestamp()),
|
|
2605
|
+
},
|
|
2606
|
+
)
|
|
2607
|
+
result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
|
|
2608
|
+
assert result is not None, f"Task with ID {db_obj.id} not found after update"
|
|
2609
|
+
await self._aemit_task_lifecycle_event(
|
|
2610
|
+
task=result,
|
|
2611
|
+
event_emitter=event_emitter,
|
|
2612
|
+
)
|
|
2613
|
+
return result
|
|
2614
|
+
|
|
2615
|
+
async def aon_task_failure(
|
|
2616
|
+
self,
|
|
2617
|
+
*,
|
|
2618
|
+
amongo: AsyncIOMotorClient,
|
|
2619
|
+
extended_aredis: ExtendedAsyncRedis,
|
|
2620
|
+
aneo4j: Optional[Neo4jAsyncDriver] = None,
|
|
2621
|
+
task: Optional[TaskMongoObject] = None,
|
|
2622
|
+
task_id: Optional[ObjectId] = None,
|
|
2623
|
+
error: Optional[str] = None,
|
|
2624
|
+
exception_type: TaskExceptionType,
|
|
2625
|
+
metadata: Optional[TaskMetadata] = None,
|
|
2626
|
+
event_emitter: AgenticEventEmitter | None = None,
|
|
2627
|
+
) -> TaskMongoObject:
|
|
2628
|
+
"""Mark a task failed and optionally emit a lifecycle event."""
|
|
2629
|
+
db_obj = await self._aget_task_for_event(
|
|
2630
|
+
amongo=amongo, task=task, task_id=task_id
|
|
2631
|
+
)
|
|
2632
|
+
|
|
2633
|
+
# Handle token recovery if needed
|
|
2634
|
+
update_obj = {
|
|
2635
|
+
"stage": TaskStage.FAILURE,
|
|
2636
|
+
"failure_ts": int(utc_now().timestamp()),
|
|
2637
|
+
"error": error,
|
|
2638
|
+
"exception_type": exception_type,
|
|
2639
|
+
}
|
|
2640
|
+
|
|
2641
|
+
if db_obj.consumed_token > 0:
|
|
2642
|
+
# recover tokens
|
|
2643
|
+
recovered = await self._execute_token_consumption_handler(
|
|
2644
|
+
TokenConsumptionPayload(
|
|
2645
|
+
consumption_type=TokenConsumptionType.RECOVER,
|
|
2646
|
+
owner=db_obj.owner,
|
|
2647
|
+
token=db_obj.consumed_token,
|
|
2648
|
+
task_id=db_obj.id,
|
|
2649
|
+
amongo=amongo,
|
|
2650
|
+
extended_aredis=extended_aredis,
|
|
2651
|
+
aneo4j=aneo4j,
|
|
2652
|
+
metadata=metadata,
|
|
2653
|
+
)
|
|
2654
|
+
)
|
|
2655
|
+
if recovered:
|
|
2656
|
+
update_obj["consumed_token"] = 0
|
|
2657
|
+
|
|
2658
|
+
# Mark as failure
|
|
2659
|
+
await self.crud_task.aupdate(
|
|
2660
|
+
amongo,
|
|
2661
|
+
_id=db_obj.id,
|
|
2662
|
+
obj_in=update_obj,
|
|
2663
|
+
)
|
|
2664
|
+
|
|
2665
|
+
result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
|
|
2666
|
+
assert result is not None, f"Task with ID {db_obj.id} not found after update"
|
|
2667
|
+
await self._aemit_task_lifecycle_event(
|
|
2668
|
+
task=result,
|
|
2669
|
+
event_emitter=event_emitter,
|
|
2670
|
+
)
|
|
2671
|
+
return result
|
|
2672
|
+
|
|
2673
|
+
async def aon_task_timeout(
|
|
2674
|
+
self,
|
|
2675
|
+
*,
|
|
2676
|
+
amongo: AsyncIOMotorClient,
|
|
2677
|
+
task: Optional[TaskMongoObject] = None,
|
|
2678
|
+
task_id: Optional[ObjectId] = None,
|
|
2679
|
+
event_emitter: AgenticEventEmitter | None = None,
|
|
2680
|
+
):
|
|
2681
|
+
"""Mark a task timed out and optionally emit a lifecycle event."""
|
|
2682
|
+
db_obj = await self._aget_task_for_event(
|
|
2683
|
+
amongo=amongo, task=task, task_id=task_id
|
|
2684
|
+
)
|
|
2685
|
+
error_message = (
|
|
2686
|
+
f"Task timeout after {db_obj.timeout_mins} min. Created at {db_obj.created}"
|
|
2687
|
+
)
|
|
2688
|
+
|
|
2689
|
+
# Mark as failure
|
|
2690
|
+
now = utc_now()
|
|
2691
|
+
await self.crud_task.aupdate(
|
|
2692
|
+
amongo,
|
|
2693
|
+
_id=db_obj.id,
|
|
2694
|
+
obj_in={
|
|
2695
|
+
"stage": TaskStage.FAILURE,
|
|
2696
|
+
"failure_ts": int(now.timestamp()),
|
|
2697
|
+
"error": error_message,
|
|
2698
|
+
"exception_type": TaskExceptionType.TIMEOUT,
|
|
2699
|
+
},
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
|
|
2703
|
+
assert result is not None, f"Task with ID {db_obj.id} not found after update"
|
|
2704
|
+
await self._aemit_task_lifecycle_event(
|
|
2705
|
+
task=result,
|
|
2706
|
+
event_emitter=event_emitter,
|
|
2707
|
+
)
|
|
2708
|
+
return result
|