keble-task 2.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
keble_task/main.py ADDED
@@ -0,0 +1,2708 @@
1
+ import asyncio
2
+ import logging
3
+ import secrets
4
+ import string
5
+ import traceback
6
+ import re
7
+ from datetime import datetime, timedelta
8
+ from enum import Enum
9
+ from typing import (
10
+ Any,
11
+ Awaitable,
12
+ Callable,
13
+ List,
14
+ Literal,
15
+ Mapping,
16
+ Optional,
17
+ Union,
18
+ cast,
19
+ overload,
20
+ )
21
+
22
+ import keble_exceptions
23
+ import tenacity
24
+ from keble_db import AgentDbDeps, ExtendedAsyncRedis, QueryBase
25
+ from keble_helpers import (
26
+ AgenticActionEventStatus,
27
+ AgenticEventEmitter,
28
+ Currency,
29
+ ExchangeRateInUsd,
30
+ Language,
31
+ Money,
32
+ ObjectId,
33
+ PydanticModelConfig,
34
+ SharingScope,
35
+ UsageAccountingRecorderProtocol,
36
+ UsageAccountingSource,
37
+ utc_now,
38
+ )
39
+ from motor.motor_asyncio import AsyncIOMotorClient
40
+ from neo4j import AsyncDriver as Neo4jAsyncDriver
41
+ from pydantic import BaseModel, ConfigDict, SkipValidation
42
+ from pydantic_ai.usage import RunUsage
43
+ from pymongo import ASCENDING, DESCENDING
44
+ from pymongo.errors import DuplicateKeyError
45
+ from qdrant_client import AsyncQdrantClient
46
+
47
+ from keble_task.exceptions import (
48
+ TaskException,
49
+ TaskExceptionType,
50
+ TaskFailedToStartException,
51
+ )
52
+
53
+ from .actions import (
54
+ CreateRelatedTaskAction,
55
+ CreateRelatedTaskActionedResult,
56
+ TaskAction,
57
+ TaskActionCreatedRelation,
58
+ TaskActionCreatedTask,
59
+ TaskActionedResults,
60
+ TaskActionEvent,
61
+ TaskActions,
62
+ TaskActionType,
63
+ TaskEventType,
64
+ TaskLifecycleEvent,
65
+ TaskLifecycleEventPayload,
66
+ TaskUxContext,
67
+ )
68
+ from .crud import CRUDTask, CRUDTaskCost, CRUDTaskRelation
69
+ from .schemas import (
70
+ TaskBase,
71
+ TaskCostAggregateRequest,
72
+ TaskCostAggregateResponse,
73
+ TaskCostCreate,
74
+ TaskCostListRequest,
75
+ TaskCostListResponse,
76
+ TaskCostMetadata,
77
+ TaskCostMongoObject,
78
+ TaskCostTokenRates,
79
+ TaskMetadata,
80
+ TaskMongoObject,
81
+ TaskMongoObjectExtended,
82
+ TaskPublicRef,
83
+ TaskRoomGraphContext,
84
+ TaskRoomResolution,
85
+ TaskRelationCreate,
86
+ TaskRelationMongoObject,
87
+ TaskStage,
88
+ TaskUpdate,
89
+ )
90
+ from .task_tree import build_task_tree, find_task_in_tree
91
+
92
+
93
+ class TokenConsumptionType(str, Enum):
94
+ CONSUME = "CONSUME"
95
+ RECOVER = "RECOVER"
96
+
97
+
98
+ class TokenConsumptionPayload(BaseModel):
99
+ """Host-owned token consumption/recovery request for one task.
100
+
101
+ Step by step:
102
+ 1. carry the consumption type, owner, token amount, and task id;
103
+ 2. carry the minimal DB clients the host token handler needs to price and
104
+ persist the consumption (replacing the retired `TaskResources` bag);
105
+ 3. carry optional handler metadata for pricing context.
106
+
107
+ Side effects:
108
+ - consumed by the host-provided `_token_consumption_handler` in keble.backend;
109
+ that handler now reads `payload.amongo`/`extended_aredis`/`aneo4j` directly
110
+ instead of `payload.resources.*`.
111
+ """
112
+
113
+ model_config = ConfigDict(
114
+ **PydanticModelConfig.default_dict(),
115
+ arbitrary_types_allowed=True, # for database clients
116
+ )
117
+ consumption_type: TokenConsumptionType
118
+ owner: str
119
+ token: int
120
+ task_id: Optional[ObjectId] = None
121
+ amongo: AsyncIOMotorClient
122
+ extended_aredis: ExtendedAsyncRedis
123
+ aneo4j: Optional[Neo4jAsyncDriver] = None
124
+ metadata: Optional[TaskMetadata] = None
125
+
126
+
127
+ class TaskHandlerResponse(BaseModel):
128
+ model_config = PydanticModelConfig.default()
129
+ task: TaskMongoObject
130
+ success: bool
131
+ consuming_token: int
132
+ exception_type: Optional[TaskExceptionType] = None
133
+ error: Optional[str] = None
134
+
135
+
136
+ class TaskHandlerRequest(AgentDbDeps):
137
+ """One task execution request: the DB-rooted agent deps plus the task itself.
138
+
139
+ Step by step:
140
+ 1. inherits ALL DB clients + cross-cutting context (owner, owner_type,
141
+ owner_scope, event_emitter, usage_recorder, marketplace, language) from
142
+ `AgentDbDeps` — so a handler uses `request` directly as its deps object;
143
+ 2. carries the persisted task row and optional handler metadata;
144
+ 3. carries NO completion callback and NO resources bag — handlers RETURN a
145
+ `TaskHandlerResponse` (runtime finalizes) or `None` (the handler delegated
146
+ completion to another process; the task stays PROCESSING).
147
+
148
+ Side effects:
149
+ - every `atask_handler` (positioning, amz-product-report, backend handlers)
150
+ now reads clients off `request` directly and returns a response or None.
151
+ """
152
+
153
+ model_config = PydanticModelConfig.default(arbitrary_types_allowed=True)
154
+ task: TaskMongoObject
155
+ metadata: Optional[TaskMetadata] = None
156
+
157
+
158
+ logger = logging.getLogger(__name__)
159
+
160
+
161
+ def _merge_metadata(
162
+ base: Optional[Mapping[str, Any]], override: Optional[Mapping[str, Any]]
163
+ ) -> Optional[dict[str, Any]]:
164
+ if base is None and override is None:
165
+ return None
166
+ if base is None:
167
+ assert override is not None
168
+ return dict(override)
169
+ if override is None:
170
+ return dict(base)
171
+ return {**base, **override}
172
+
173
+
174
+ def _task_lifecycle_event_status(*, stage: TaskStage) -> AgenticActionEventStatus:
175
+ """Map persisted task stage to the shared agentic event status.
176
+
177
+ Step by step:
178
+ 1. processing is a lifecycle start/progress signal for room UIs;
179
+ 2. success and failure map to terminal event statuses;
180
+ 3. pending is retained as a non-terminal progress state for direct calls.
181
+ """
182
+
183
+ if stage is TaskStage.PROCESSING:
184
+ return AgenticActionEventStatus.STARTED
185
+ if stage is TaskStage.SUCCESS:
186
+ return AgenticActionEventStatus.SUCCEEDED
187
+ if stage is TaskStage.FAILURE:
188
+ return AgenticActionEventStatus.FAILED
189
+ return AgenticActionEventStatus.PROGRESSED
190
+
191
+
192
+ class TaskClient:
193
+ def __init__(
194
+ self,
195
+ *,
196
+ # Token consumption handler can be synchronous or asynchronous
197
+ # If sync, it should return a bool
198
+ # If async, it should return a coroutine that resolves to a bool
199
+ token_consumption_handler: Callable[
200
+ [TokenConsumptionPayload], Union[bool, Awaitable[bool]]
201
+ ],
202
+ # Run one task. The handler RETURNS a `TaskHandlerResponse` (the runtime
203
+ # finalizes success/failure and emits the terminal lifecycle event) OR
204
+ # `None` when it deliberately delegated completion to another process
205
+ # (e.g. fire-and-forget bootstrap, multi-invocation pipeline) — the task
206
+ # then stays PROCESSING until that process finalizes it.
207
+ task_handler: Callable[
208
+ [TaskHandlerRequest], Awaitable[Optional[TaskHandlerResponse]]
209
+ ],
210
+ # Database config
211
+ mongo_database: str = "__keble_task__",
212
+ task_collection: str = "__keble_task__task__",
213
+ task_relation_collection: str = "__keble_task__task_relation__",
214
+ task_cost_collection: str = "__keble_task__task_cost__",
215
+ # Public short id config (for shared URLs)
216
+ public_id_prefix: str = "t",
217
+ public_id_length: int = 10,
218
+ ):
219
+ """Initialize the TaskClient.
220
+
221
+ Step by step:
222
+ 1. store runtime handlers without opening database connections;
223
+ 2. configure the task, relation, and cost collections used by later calls;
224
+ 3. validate public-id settings before any public task is created.
225
+
226
+ Note: Resources are not initialized here but passed to each method when needed.
227
+ """
228
+ self._token_consumption_handler = token_consumption_handler
229
+ self._mongo_database = mongo_database
230
+ self._task_collection = task_collection
231
+ self._task_relation_collection = task_relation_collection
232
+ self._task_cost_collection = task_cost_collection
233
+ self._task_handler = task_handler
234
+ self._public_id_prefix = public_id_prefix
235
+ self._public_id_length = public_id_length
236
+ self._public_id_alphabet = string.ascii_letters + string.digits
237
+
238
+ self._validate_public_id_config(
239
+ prefix=self._public_id_prefix, length=self._public_id_length
240
+ )
241
+
242
+ self.crud_task = CRUDTask(
243
+ model=TaskMongoObject,
244
+ collection=self._task_collection,
245
+ database=mongo_database,
246
+ )
247
+ self.crud_task_relation = CRUDTaskRelation(
248
+ model=TaskRelationMongoObject,
249
+ collection=self._task_relation_collection,
250
+ database=mongo_database,
251
+ )
252
+ self.crud_task_cost = CRUDTaskCost(
253
+ model=TaskCostMongoObject,
254
+ collection=self._task_cost_collection,
255
+ database=mongo_database,
256
+ )
257
+
258
+ def _validate_public_id_config(self, *, prefix: str, length: int) -> None:
259
+ if not isinstance(prefix, str):
260
+ raise keble_exceptions.ServerSideInvalidParams(
261
+ admin_note={"prefix": prefix},
262
+ alert_admin=True,
263
+ but_got="prefix is not str",
264
+ expected="prefix is str",
265
+ invalid_params="prefix",
266
+ )
267
+ if not isinstance(length, int):
268
+ raise keble_exceptions.ServerSideInvalidParams(
269
+ admin_note={"length": length},
270
+ alert_admin=True,
271
+ but_got="length is not int",
272
+ expected="length is int",
273
+ invalid_params="length",
274
+ )
275
+ if length <= len(prefix):
276
+ raise keble_exceptions.ServerSideInvalidParams(
277
+ admin_note={"prefix": prefix, "length": length},
278
+ alert_admin=True,
279
+ but_got="length <= len(prefix)",
280
+ expected="length > len(prefix)",
281
+ invalid_params="length,prefix",
282
+ )
283
+
284
+ async def _ensure_public_id_index(self, *, amongo: AsyncIOMotorClient) -> None:
285
+ await self.crud_task.aensure_public_id_index(amongo)
286
+
287
+ async def _ensure_task_indexes(self, *, amongo: AsyncIOMotorClient) -> None:
288
+ """Ensure the task collection has the tree/stage read indexes.
289
+
290
+ Step by step:
291
+ 1. delegate the index definitions to the task CRUD helper;
292
+ 2. keep index creation out of request handlers and worker loops;
293
+ 3. share one idempotent setup between startup and package methods.
294
+ """
295
+
296
+ await self.crud_task.aensure_task_indexes(amongo)
297
+
298
+ def _generate_public_id_candidate(self, *, prefix: str, length: int) -> str:
299
+ suffix_len = length - len(prefix)
300
+ suffix = "".join(
301
+ secrets.choice(self._public_id_alphabet) for _ in range(suffix_len)
302
+ )
303
+ return f"{prefix}{suffix}"
304
+
305
+ async def _ensure_task_public_id(
306
+ self,
307
+ *,
308
+ amongo: AsyncIOMotorClient,
309
+ task_id: ObjectId,
310
+ public_id_prefix: Optional[str],
311
+ public_id_length: Optional[int],
312
+ ) -> str:
313
+ prefix = (
314
+ public_id_prefix if public_id_prefix is not None else self._public_id_prefix
315
+ )
316
+ length = (
317
+ public_id_length if public_id_length is not None else self._public_id_length
318
+ )
319
+ self._validate_public_id_config(prefix=prefix, length=length)
320
+ await self._ensure_public_id_index(amongo=amongo)
321
+
322
+ collection = amongo[self._mongo_database][self._task_collection]
323
+ existing = await collection.find_one({"_id": task_id}, {"public_id": 1})
324
+ if existing is None:
325
+ raise keble_exceptions.ServerSideInvalidParams(
326
+ admin_note={"task_id": task_id},
327
+ alert_admin=True,
328
+ but_got="task not found",
329
+ expected="task exists",
330
+ invalid_params="task_id",
331
+ )
332
+ if isinstance(existing.get("public_id"), str) and existing.get("public_id"):
333
+ return str(existing["public_id"])
334
+
335
+ for _ in range(20):
336
+ candidate = self._generate_public_id_candidate(prefix=prefix, length=length)
337
+ try:
338
+ result = await collection.update_one(
339
+ {"_id": task_id, "public_id": None},
340
+ {"$set": {"public_id": candidate}},
341
+ )
342
+ except DuplicateKeyError:
343
+ continue
344
+
345
+ if result.matched_count == 0:
346
+ refreshed = await collection.find_one(
347
+ {"_id": task_id}, {"public_id": 1}
348
+ )
349
+ if refreshed and refreshed.get("public_id"):
350
+ return str(refreshed["public_id"])
351
+ raise keble_exceptions.ServerSideInvalidParams(
352
+ admin_note={"task_id": task_id, "candidate": candidate},
353
+ alert_admin=True,
354
+ but_got="task public_id not updated",
355
+ expected="task public_id updated",
356
+ invalid_params="task_id,public_id",
357
+ )
358
+ return candidate
359
+
360
+ raise keble_exceptions.ServerSideInvalidParams(
361
+ admin_note={"task_id": task_id, "prefix": prefix, "length": length},
362
+ alert_admin=True,
363
+ but_got="exceeded retries to generate unique public_id",
364
+ expected="unique public_id generated",
365
+ invalid_params="public_id",
366
+ )
367
+
368
+ async def _ensure_task_relation_indexes(
369
+ self, *, amongo: AsyncIOMotorClient
370
+ ) -> None:
371
+ """Ensure the relation collection has the required query indexes.
372
+
373
+ Step by step:
374
+ 1. delegate index ownership to the relation CRUD helper;
375
+ 2. keep this method private so callers use relation APIs, not storage details.
376
+ """
377
+
378
+ await self.crud_task_relation.aensure_relation_indexes(amongo)
379
+
380
+ async def _ensure_task_cost_indexes(self, *, amongo: AsyncIOMotorClient) -> None:
381
+ """Ensure the cost collection has the required query indexes.
382
+
383
+ Step by step:
384
+ 1. delegate index definitions to the task-cost CRUD helper;
385
+ 2. keep index creation outside request handlers and worker loops;
386
+ 3. let startup and package methods share the same idempotent setup.
387
+ """
388
+
389
+ await self.crud_task_cost.aensure_cost_indexes(amongo)
390
+
391
+ async def aensure_indexes(self, *, amongo: AsyncIOMotorClient) -> None:
392
+ """Create all Mongo indexes owned by the task package.
393
+
394
+ Step by step:
395
+ 1. create the public task id index used by shared task URLs;
396
+ 2. create task tree/stage indexes for root, child, and retry reads;
397
+ 3. create relation indexes for root-scoped graph lookups;
398
+ 4. create cost indexes for admin list and aggregate reports.
399
+ """
400
+
401
+ await self._ensure_public_id_index(amongo=amongo)
402
+ await self._ensure_task_indexes(amongo=amongo)
403
+ await self._ensure_task_relation_indexes(amongo=amongo)
404
+ await self._ensure_task_cost_indexes(amongo=amongo)
405
+
406
+ async def ainit(self, *, amongo: AsyncIOMotorClient) -> None:
407
+ """Initialize task Mongo indexes for backend startup.
408
+
409
+ Step by step:
410
+ 1. expose one public startup hook for backend lifespan wiring;
411
+ 2. delegate to the existing package-owned index initializer.
412
+ """
413
+
414
+ await self.aensure_indexes(amongo=amongo)
415
+
416
+ def _task_workspace_root_id(self, task: TaskMongoObject) -> ObjectId:
417
+ """Resolve the root-task id used to scope workspace-local relation rows.
418
+
419
+ Step by step:
420
+ 1. prefer the persisted `root_task` field for normal modern rows;
421
+ 2. fall back to the task id for old root rows that predate root backfill.
422
+ """
423
+
424
+ return task.root_task or task.id
425
+
426
+ async def _aget_relation_task(
427
+ self,
428
+ *,
429
+ amongo: AsyncIOMotorClient,
430
+ task_id: ObjectId,
431
+ field_name: str,
432
+ ) -> TaskMongoObject:
433
+ """Load one task referenced by a relation row and raise a typed error.
434
+
435
+ Step by step:
436
+ 1. load the task without a task-type filter;
437
+ 2. return the typed task row when it exists;
438
+ 3. raise a package-consistent invalid-params error when the reference is stale.
439
+ """
440
+
441
+ task = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
442
+ if task is None:
443
+ raise keble_exceptions.ServerSideInvalidParams(
444
+ admin_note={field_name: task_id},
445
+ alert_admin=True,
446
+ but_got=f"{field_name} task not found",
447
+ expected=f"{field_name} task exists",
448
+ invalid_params=field_name,
449
+ )
450
+ return task
451
+
452
+ def _validate_relation_task_root(
453
+ self,
454
+ *,
455
+ relation: TaskRelationCreate,
456
+ task: TaskMongoObject,
457
+ field_name: str,
458
+ ) -> None:
459
+ """Validate that one relation endpoint belongs to the relation root.
460
+
461
+ Step by step:
462
+ 1. resolve the endpoint task's canonical workspace root;
463
+ 2. compare it with the relation root;
464
+ 3. raise a typed error when a caller attempts cross-workspace lineage.
465
+ """
466
+
467
+ task_root = self._task_workspace_root_id(task)
468
+ if task_root != relation.root_task:
469
+ raise keble_exceptions.ServerSideInvalidParams(
470
+ admin_note={
471
+ "root_task": relation.root_task,
472
+ field_name: task.id,
473
+ f"{field_name}_root_task": task_root,
474
+ },
475
+ alert_admin=True,
476
+ but_got=f"{field_name} belongs to a different root task",
477
+ expected=f"{field_name} belongs to relation.root_task",
478
+ invalid_params=f"root_task,{field_name}",
479
+ )
480
+
481
+ @staticmethod
482
+ def _validate_unique_task_relation_edges(
483
+ *,
484
+ relations: list[TaskRelationCreate],
485
+ invalid_params: str,
486
+ ) -> None:
487
+ """Reject duplicate relation edge payloads before any write happens.
488
+
489
+ Step by step:
490
+ 1. derive the same edge identity protected by the Mongo unique index;
491
+ 2. collect duplicates without mutating task or relation collections;
492
+ 3. raise a caller-facing validation error instead of relying on a partial
493
+ insert followed by a database duplicate-key failure.
494
+ """
495
+
496
+ seen_edge_keys: set[tuple[ObjectId, ObjectId, ObjectId, str]] = set()
497
+ duplicate_edges: list[str] = []
498
+ for relation in relations:
499
+ edge_key = (
500
+ relation.root_task,
501
+ relation.from_task_id,
502
+ relation.to_task_id,
503
+ relation.relation_type.value,
504
+ )
505
+ if edge_key in seen_edge_keys:
506
+ duplicate_edges.append(
507
+ "root_task="
508
+ f"{relation.root_task};from_task_id={relation.from_task_id};"
509
+ f"to_task_id={relation.to_task_id};"
510
+ f"relation_type={relation.relation_type.value}"
511
+ )
512
+ continue
513
+ seen_edge_keys.add(edge_key)
514
+ if duplicate_edges:
515
+ raise keble_exceptions.ClientSideInvalidParams(
516
+ invalid_params=invalid_params,
517
+ expected="unique relation edge keys",
518
+ but_got="; ".join(duplicate_edges),
519
+ alert_admin=False,
520
+ )
521
+
522
+ @staticmethod
523
+ def _validate_unique_create_related_source_ids(
524
+ *,
525
+ from_task_ids: list[ObjectId],
526
+ ) -> None:
527
+ """Reject duplicate action source ids before creating the child task.
528
+
529
+ Step by step:
530
+ 1. check the source id list that will become relation edges;
531
+ 2. stop duplicate source ids before the child task is inserted;
532
+ 3. return normally when each source would create one unique edge.
533
+ """
534
+
535
+ seen_task_ids: set[ObjectId] = set()
536
+ duplicate_task_ids: list[str] = []
537
+ for from_task_id in from_task_ids:
538
+ if from_task_id in seen_task_ids:
539
+ duplicate_task_ids.append(str(from_task_id))
540
+ continue
541
+ seen_task_ids.add(from_task_id)
542
+ if duplicate_task_ids:
543
+ raise keble_exceptions.ClientSideInvalidParams(
544
+ invalid_params="actions[].from_task_ids",
545
+ expected="unique source task ids per created relation edge",
546
+ but_got=",".join(duplicate_task_ids),
547
+ alert_admin=False,
548
+ )
549
+
550
+ async def _avalidate_task_relation_create(
551
+ self,
552
+ *,
553
+ amongo: AsyncIOMotorClient,
554
+ relation: TaskRelationCreate,
555
+ ) -> None:
556
+ """Validate one task-relation create payload before insertion.
557
+
558
+ Step by step:
559
+ 1. load the declared root, source task, and derived task;
560
+ 2. prove the declared root is the canonical workspace root;
561
+ 3. prove both endpoints belong to the same root workspace.
562
+ """
563
+
564
+ root_task = await self._aget_relation_task(
565
+ amongo=amongo,
566
+ task_id=relation.root_task,
567
+ field_name="root_task",
568
+ )
569
+ self._validate_relation_task_root(
570
+ relation=relation,
571
+ task=root_task,
572
+ field_name="root_task",
573
+ )
574
+ from_task = await self._aget_relation_task(
575
+ amongo=amongo,
576
+ task_id=relation.from_task_id,
577
+ field_name="from_task_id",
578
+ )
579
+ self._validate_relation_task_root(
580
+ relation=relation,
581
+ task=from_task,
582
+ field_name="from_task_id",
583
+ )
584
+ to_task = await self._aget_relation_task(
585
+ amongo=amongo,
586
+ task_id=relation.to_task_id,
587
+ field_name="to_task_id",
588
+ )
589
+ self._validate_relation_task_root(
590
+ relation=relation,
591
+ task=to_task,
592
+ field_name="to_task_id",
593
+ )
594
+
595
+ async def _execute_token_consumption_handler(
596
+ self, payload: TokenConsumptionPayload
597
+ ) -> bool:
598
+ """Execute token consumption handler supporting both sync and async implementations.
599
+
600
+ Args:
601
+ payload: The token consumption payload
602
+
603
+ Returns:
604
+ bool: True if successful, False otherwise
605
+ """
606
+ result = self._token_consumption_handler(payload)
607
+ if asyncio.iscoroutine(result):
608
+ # If it's a coroutine, await it
609
+ return await cast(Awaitable[bool], result)
610
+ # If it's a boolean, return it directly
611
+ return cast(bool, result)
612
+
613
+ #
614
+ #
615
+ #
616
+ #
617
+ # Core running Task API
618
+ #
619
+ #
620
+ #
621
+ #
622
+ # Create
623
+ async def acreate(
624
+ self,
625
+ *,
626
+ amongo: AsyncIOMotorClient,
627
+ expected_token: int,
628
+ owner: str,
629
+ progress_key: Optional[str] = None,
630
+ image: Optional[str] = None,
631
+ title: Optional[str] = None,
632
+ subtitle: Optional[str] = None,
633
+ metadata: Optional[TaskMetadata] = None,
634
+ language: Optional[Language] = None,
635
+ task_type: str,
636
+ timeout_mins: int = 120,
637
+ sharing_scope: SharingScope = SharingScope.PRIVATE,
638
+ stage: TaskStage = TaskStage.PENDING,
639
+ attempts: int = 0,
640
+ consumed_token: int = 0,
641
+ # task tree fields
642
+ root_task: Optional[ObjectId] = None,
643
+ parent_task: Optional[ObjectId] = None,
644
+ public_id_prefix: Optional[str] = None,
645
+ public_id_length: Optional[int] = None,
646
+ ) -> TaskMongoObject:
647
+ """Create a new task asynchronously."""
648
+ if parent_task is None and root_task is not None:
649
+ raise keble_exceptions.ServerSideInvalidParams(
650
+ admin_note={"root_task": root_task, "parent_task": parent_task},
651
+ alert_admin=True,
652
+ but_got="root_task provided for root task creation",
653
+ expected="root_task is None when parent_task is None",
654
+ invalid_params="root_task,parent_task",
655
+ )
656
+ if parent_task is not None and root_task is None:
657
+ parent = await self.aget(amongo=amongo, task_id=parent_task, task_type=None)
658
+ if parent is None:
659
+ raise keble_exceptions.ServerSideInvalidParams(
660
+ admin_note={"parent_task": parent_task},
661
+ alert_admin=True,
662
+ but_got="parent task not found",
663
+ expected="parent task exists",
664
+ invalid_params="parent_task",
665
+ )
666
+ root_task = parent.root_task or parent.id
667
+
668
+ if sharing_scope == SharingScope.PUBLIC:
669
+ prefix = (
670
+ public_id_prefix
671
+ if public_id_prefix is not None
672
+ else self._public_id_prefix
673
+ )
674
+ length = (
675
+ public_id_length
676
+ if public_id_length is not None
677
+ else self._public_id_length
678
+ )
679
+ self._validate_public_id_config(prefix=prefix, length=length)
680
+
681
+ result = await self.crud_task.acreate(
682
+ amongo,
683
+ obj_in=cast(
684
+ TaskMongoObject,
685
+ TaskBase(
686
+ progress_key=progress_key,
687
+ image=image,
688
+ title=title,
689
+ subtitle=subtitle,
690
+ metadata=metadata,
691
+ language=language,
692
+ owner=str(owner),
693
+ consumed_token=consumed_token,
694
+ expected_token=expected_token,
695
+ attempts=attempts,
696
+ sharing_scope=sharing_scope,
697
+ stage=stage,
698
+ task_type=task_type,
699
+ timeout_mins=timeout_mins,
700
+ root_task=root_task,
701
+ parent_task=parent_task,
702
+ ),
703
+ ),
704
+ )
705
+ _id = result.inserted_id
706
+ if parent_task is None and root_task is None:
707
+ await self.crud_task.aupdate(amongo, _id=_id, obj_in={"root_task": _id})
708
+ if sharing_scope == SharingScope.PUBLIC:
709
+ await self._ensure_task_public_id(
710
+ amongo=amongo,
711
+ task_id=_id,
712
+ public_id_prefix=public_id_prefix,
713
+ public_id_length=public_id_length,
714
+ )
715
+ task = await self.crud_task.afirst_by_id(amongo, _id=_id)
716
+ assert task is not None
717
+ return task
718
+
719
+ async def acreate_task_relation(
720
+ self,
721
+ *,
722
+ amongo: AsyncIOMotorClient,
723
+ obj_in: TaskRelationCreate,
724
+ ) -> TaskRelationMongoObject:
725
+ """Create one extra lineage relation without changing task parentage.
726
+
727
+ Step by step:
728
+ 1. validate the source and derived tasks belong to the declared root;
729
+ 2. insert a relation row into the separate relation collection;
730
+ 3. return the persisted relation object for downstream API responses.
731
+ """
732
+
733
+ results = await self.acreate_task_relations(amongo=amongo, objs_in=[obj_in])
734
+ return results[0]
735
+
736
+ async def acreate_task_relations(
737
+ self,
738
+ *,
739
+ amongo: AsyncIOMotorClient,
740
+ objs_in: list[TaskRelationCreate],
741
+ ) -> list[TaskRelationMongoObject]:
742
+ """Create many extra lineage relations as independent durable rows.
743
+
744
+ Step by step:
745
+ 1. ensure lookup and duplicate-protection indexes exist;
746
+ 2. validate every relation before inserting any new rows;
747
+ 3. insert each relation without touching `root_task` or `parent_task`.
748
+ """
749
+
750
+ if not objs_in:
751
+ return []
752
+
753
+ self._validate_unique_task_relation_edges(
754
+ relations=objs_in,
755
+ invalid_params="objs_in",
756
+ )
757
+ await self._ensure_task_relation_indexes(amongo=amongo)
758
+ for relation in objs_in:
759
+ await self._avalidate_task_relation_create(
760
+ amongo=amongo,
761
+ relation=relation,
762
+ )
763
+
764
+ created: list[TaskRelationMongoObject] = []
765
+ for relation in objs_in:
766
+ result = await self.crud_task_relation.acreate_relation(
767
+ amongo,
768
+ obj_in=relation,
769
+ )
770
+ db_obj = await self.crud_task_relation.afirst_by_id(
771
+ amongo,
772
+ _id=result.inserted_id,
773
+ )
774
+ assert db_obj is not None
775
+ created.append(db_obj)
776
+ return created
777
+
778
+ async def acreate_task_cost(
779
+ self,
780
+ *,
781
+ amongo: AsyncIOMotorClient,
782
+ task_id: ObjectId,
783
+ tags: list[str] | None = None,
784
+ run_usage: RunUsage,
785
+ source: UsageAccountingSource | None = None,
786
+ token_rates_per_million: TaskCostTokenRates | None = None,
787
+ additional_cost: Money | None = None,
788
+ seconds: float = 0,
789
+ retry: int = 0,
790
+ occurred_at: datetime | None = None,
791
+ metadata: TaskCostMetadata | None = None,
792
+ ) -> TaskCostMongoObject:
793
+ """Create one task-cost row denormalized from the current task row.
794
+
795
+ Step by step:
796
+ 1. load the task so owner, type, and root scope cannot drift from storage;
797
+ 2. build a cost create schema with raw `RunUsage`, rates, timing, and metadata;
798
+ 3. persist the row in the separate cost collection and reload it as a typed object.
799
+ """
800
+
801
+ task = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
802
+ if task is None:
803
+ raise keble_exceptions.ServerSideInvalidParams(
804
+ admin_note={"task_id": task_id},
805
+ alert_admin=True,
806
+ but_got="task not found",
807
+ expected="task exists",
808
+ invalid_params="task_id",
809
+ )
810
+ result = await self.crud_task_cost.acreate_cost(
811
+ amongo,
812
+ obj_in=TaskCostCreate.from_task_usage(
813
+ task=task,
814
+ root_task=self._task_workspace_root_id(task),
815
+ tags=tags,
816
+ run_usage=run_usage,
817
+ source=source,
818
+ token_rates_per_million=token_rates_per_million,
819
+ additional_cost=additional_cost,
820
+ seconds=seconds,
821
+ retry=retry,
822
+ occurred_at=occurred_at,
823
+ metadata=metadata,
824
+ ),
825
+ )
826
+ db_obj = await self.crud_task_cost.afirst_by_id(
827
+ amongo,
828
+ _id=result.inserted_id,
829
+ )
830
+ assert db_obj is not None
831
+ return db_obj
832
+
833
+ async def alist_task_costs(
834
+ self,
835
+ *,
836
+ amongo: AsyncIOMotorClient,
837
+ query: TaskCostListRequest,
838
+ ) -> TaskCostListResponse:
839
+ """List task-cost rows through the package-owned indexed query shape.
840
+
841
+ Step by step:
842
+ 1. list the requested page using `CRUDTaskCost`;
843
+ 2. count the same filter for a compact paginated response;
844
+ 3. rely on startup/`ainit(...)` for index creation.
845
+ """
846
+
847
+ costs = await self.crud_task_cost.alist_costs(amongo, query=query)
848
+ total = await self.crud_task_cost.acount_costs(amongo, query=query)
849
+ return TaskCostListResponse(
850
+ costs=costs,
851
+ total=total,
852
+ skip=query.skip,
853
+ limit=query.limit,
854
+ )
855
+
856
+ async def aaggregate_task_costs(
857
+ self,
858
+ *,
859
+ amongo: AsyncIOMotorClient,
860
+ query: TaskCostAggregateRequest,
861
+ exchange_rates: list[ExchangeRateInUsd],
862
+ ) -> TaskCostAggregateResponse:
863
+ """Aggregate task costs with explicit currency conversion.
864
+
865
+ Step by step:
866
+ 1. load rows through the same indexed filters used by list reads;
867
+ 2. calculate each row's precise Decimal cost in the requested currency;
868
+ 3. group by requested time bucket and optionally fan rows into tag buckets.
869
+ """
870
+
871
+ costs = await self.crud_task_cost.alist_costs_for_aggregate(
872
+ amongo,
873
+ query=query,
874
+ )
875
+ return TaskCostAggregateResponse.build(
876
+ costs=costs,
877
+ query=query,
878
+ exchange_rates=exchange_rates,
879
+ )
880
+
881
+ def _validate_task_action_root_context(
882
+ self,
883
+ *,
884
+ current_task: TaskMongoObject,
885
+ ux_context: TaskUxContext | None,
886
+ ) -> None:
887
+ """Validate optional UI root context before running task actions.
888
+
889
+ Step by step:
890
+ 1. skip validation when no UI root was provided;
891
+ 2. resolve the current task's canonical root workspace id;
892
+ 3. reject mismatched UI roots before creating any child task or relation.
893
+ """
894
+
895
+ if ux_context is None or ux_context.selected_root_task_id is None:
896
+ return
897
+ current_root_id = self._task_workspace_root_id(current_task)
898
+ if ux_context.selected_root_task_id != current_root_id:
899
+ raise keble_exceptions.ClientSideInvalidParams(
900
+ invalid_params="ux_context.selected_root_task_id",
901
+ expected=str(current_root_id),
902
+ but_got=str(ux_context.selected_root_task_id),
903
+ alert_admin=False,
904
+ )
905
+
906
+ async def _aprepare_create_related_task_action(
907
+ self,
908
+ *,
909
+ amongo: AsyncIOMotorClient,
910
+ current_task: TaskMongoObject,
911
+ action: CreateRelatedTaskAction,
912
+ ) -> TaskMongoObject:
913
+ """Load the parent task for a generic create-related-task action.
914
+
915
+ Step by step:
916
+ 1. resolve `parent_task_id` to the current task when omitted;
917
+ 2. load the parent task so `acreate(...)` can preserve root linkage;
918
+ 3. return the parent task without any feature-specific side effects.
919
+ """
920
+
921
+ parent_task_id = action.parent_task_id or current_task.id
922
+ return await self._aget_relation_task(
923
+ amongo=amongo,
924
+ task_id=parent_task_id,
925
+ field_name="parent_task_id",
926
+ )
927
+
928
+ async def _acomplete_created_task_if_needed(
929
+ self,
930
+ *,
931
+ amongo: AsyncIOMotorClient,
932
+ extended_aredis: ExtendedAsyncRedis | None,
933
+ created_task: TaskMongoObject,
934
+ complete_created_task: bool,
935
+ consuming_token: int,
936
+ ) -> TaskMongoObject:
937
+ """Mark a lightweight created task successful when the action requests it.
938
+
939
+ Step by step:
940
+ 1. return the task unchanged for normal queued child work;
941
+ 2. require Redis resources for task finalization;
942
+ 3. delegate token/status accounting to the existing task success API.
943
+ """
944
+
945
+ if not complete_created_task:
946
+ return created_task
947
+ if extended_aredis is None:
948
+ raise keble_exceptions.ServerSideMissingParams(
949
+ missing_params="extended_aredis",
950
+ alert_admin=True,
951
+ )
952
+ return await self.aon_task_success(
953
+ amongo=amongo,
954
+ extended_aredis=extended_aredis,
955
+ task=created_task,
956
+ consuming_token=consuming_token,
957
+ )
958
+
959
+ async def _arollback_created_related_task(
960
+ self,
961
+ *,
962
+ amongo: AsyncIOMotorClient,
963
+ created_task_id: ObjectId,
964
+ ) -> None:
965
+ """Compensate a failed create-related-task action.
966
+
967
+ Step by step:
968
+ 1. delete any relation rows already inserted that point at the orphaned child
969
+ (`to_task_id == created_task_id`);
970
+ 2. delete the orphaned child task itself;
971
+ 3. leave the original failure to propagate from the caller.
972
+
973
+ Used only on the relation-write failure path because standalone Mongo cannot run
974
+ the child + relation inserts in one atomic transaction.
975
+ """
976
+
977
+ await self.crud_task_relation.adelete_multi(
978
+ amongo,
979
+ query=QueryBase(filters={"to_task_id": created_task_id}),
980
+ )
981
+ await self.crud_task.adelete(amongo, _id=created_task_id)
982
+
983
+ async def _aapply_create_related_task_action(
984
+ self,
985
+ *,
986
+ amongo: AsyncIOMotorClient,
987
+ current_task: TaskMongoObject,
988
+ action: CreateRelatedTaskAction,
989
+ extended_aredis: ExtendedAsyncRedis | None,
990
+ ) -> CreateRelatedTaskActionedResult:
991
+ """Apply one generic create-related-task action.
992
+
993
+ Step by step:
994
+ 1. resolve parent/root context;
995
+ 2. validate every source task belongs to the same root room;
996
+ 3. create the child task from generic task fields only;
997
+ 4. create relation rows from source tasks to the created task;
998
+ 5. optionally mark the child task successful for lightweight generic work.
999
+ """
1000
+
1001
+ parent_task = await self._aprepare_create_related_task_action(
1002
+ amongo=amongo,
1003
+ current_task=current_task,
1004
+ action=action,
1005
+ )
1006
+ from_task_ids = action.from_task_ids or [current_task.id]
1007
+ self._validate_unique_create_related_source_ids(from_task_ids=from_task_ids)
1008
+ parent_root_id = self._task_workspace_root_id(parent_task)
1009
+ for from_task_id in from_task_ids:
1010
+ from_task = await self._aget_relation_task(
1011
+ amongo=amongo,
1012
+ task_id=from_task_id,
1013
+ field_name="from_task_ids",
1014
+ )
1015
+ from_root_id = self._task_workspace_root_id(from_task)
1016
+ if from_root_id != parent_root_id:
1017
+ raise keble_exceptions.ClientSideInvalidParams(
1018
+ invalid_params="from_task_ids",
1019
+ expected=str(parent_root_id),
1020
+ but_got=f"from_task_id={from_task_id}; from_task_root_id={from_root_id}",
1021
+ alert_admin=False,
1022
+ )
1023
+ if action.task_type is None:
1024
+ raise keble_exceptions.ClientSideMissingParams(
1025
+ missing_params="actions[].task_type",
1026
+ alert_admin=False,
1027
+ )
1028
+ created_task = await self.acreate(
1029
+ amongo=amongo,
1030
+ expected_token=action.expected_token,
1031
+ owner=current_task.owner,
1032
+ progress_key=action.progress_key,
1033
+ image=action.image,
1034
+ title=action.title,
1035
+ subtitle=action.subtitle,
1036
+ metadata=action.metadata,
1037
+ task_type=action.task_type,
1038
+ timeout_mins=action.timeout_mins,
1039
+ sharing_scope=action.sharing_scope or current_task.sharing_scope,
1040
+ parent_task=parent_task.id,
1041
+ )
1042
+ # Saga compensation: standalone Mongo has no multi-document transactions, so the
1043
+ # child task and its relation rows are written as separate operations. If any
1044
+ # relation write fails (a concurrent duplicate hitting the unique edge index, or
1045
+ # an infra error mid-loop), roll back the just-created child and any partial
1046
+ # relation rows pointing at it, then re-raise unchanged so the workspace never
1047
+ # keeps an orphaned child without its intended relations.
1048
+ try:
1049
+ relations = await self.acreate_task_relations(
1050
+ amongo=amongo,
1051
+ objs_in=[
1052
+ TaskRelationCreate(
1053
+ root_task=parent_root_id,
1054
+ from_task_id=from_task_id,
1055
+ to_task_id=created_task.id,
1056
+ relation_type=action.relation_type,
1057
+ metadata=action.metadata,
1058
+ )
1059
+ for from_task_id in from_task_ids
1060
+ ],
1061
+ )
1062
+ except Exception:
1063
+ await self._arollback_created_related_task(
1064
+ amongo=amongo, created_task_id=created_task.id
1065
+ )
1066
+ raise
1067
+ finalized_task = await self._acomplete_created_task_if_needed(
1068
+ amongo=amongo,
1069
+ extended_aredis=extended_aredis,
1070
+ created_task=created_task,
1071
+ complete_created_task=action.complete_created_task,
1072
+ consuming_token=action.consuming_token,
1073
+ )
1074
+ return CreateRelatedTaskActionedResult(
1075
+ created_task=TaskActionCreatedTask.build(task=finalized_task),
1076
+ created_relations=[
1077
+ TaskActionCreatedRelation.build(relation=relation)
1078
+ for relation in relations
1079
+ ],
1080
+ )
1081
+
1082
+ async def _aapply_one_task_action(
1083
+ self,
1084
+ *,
1085
+ amongo: AsyncIOMotorClient,
1086
+ current_task: TaskMongoObject,
1087
+ action: TaskAction,
1088
+ extended_aredis: ExtendedAsyncRedis | None,
1089
+ ) -> tuple[CreateRelatedTaskActionedResult, TaskActionEvent]:
1090
+ """Execute exactly one canonical task action and build its event.
1091
+
1092
+ The single canonical executor shared by the one-action agent surface
1093
+ (`aapply_action`) and the REST/programmatic batch (`aapply_actions`), so
1094
+ action semantics and event shape never fork between the two callers.
1095
+ """
1096
+
1097
+ # CREATE_RELATED_TASK is the only action type today; keep the explicit
1098
+ # branch so a future action type cannot silently no-op.
1099
+ if action.action_type is not TaskActionType.CREATE_RELATED_TASK:
1100
+ raise keble_exceptions.ServerSideInvalidParams(
1101
+ admin_note={"action_type": str(action.action_type)},
1102
+ alert_admin=True,
1103
+ but_got=f"unsupported task action type {action.action_type}",
1104
+ expected="a supported TaskActionType (CREATE_RELATED_TASK)",
1105
+ invalid_params="action_type",
1106
+ )
1107
+ result = await self._aapply_create_related_task_action(
1108
+ amongo=amongo,
1109
+ current_task=current_task,
1110
+ action=action,
1111
+ extended_aredis=extended_aredis,
1112
+ )
1113
+ event = TaskActionEvent(
1114
+ action_type=result.action_type.value,
1115
+ payload=result,
1116
+ root_id=str(
1117
+ result.created_task.root_task_id or result.created_task.task_id
1118
+ ),
1119
+ object_id=str(result.created_task.task_id),
1120
+ correlation_id=result.created_task.progress_key,
1121
+ )
1122
+ return result, event
1123
+
1124
+ async def aapply_action(
1125
+ self,
1126
+ *,
1127
+ amongo: AsyncIOMotorClient,
1128
+ current_task: TaskMongoObject,
1129
+ action: TaskAction,
1130
+ ux_context: TaskUxContext | None = None,
1131
+ event_emitter: AgenticEventEmitter | None = None,
1132
+ extended_aredis: ExtendedAsyncRedis | None = None,
1133
+ ) -> CreateRelatedTaskActionedResult:
1134
+ """Apply ONE canonical task action (the agent one-action-per-call surface).
1135
+
1136
+ Multiplicity is expressed by the agent emitting several
1137
+ `mutate_task_workspace` tool calls; each becomes its own deferred
1138
+ approval card. This method intentionally takes a single
1139
+ `AgenticActionPayload` action — never a list wrapper.
1140
+
1141
+ Step by step:
1142
+ 1. validate optional UI root context against the current task workspace;
1143
+ 2. execute the single action through the shared canonical executor;
1144
+ 3. emit the action event when an emitter is provided;
1145
+ 4. return the single action result.
1146
+ """
1147
+
1148
+ self._validate_task_action_root_context(
1149
+ current_task=current_task,
1150
+ ux_context=ux_context,
1151
+ )
1152
+ result, event = await self._aapply_one_task_action(
1153
+ amongo=amongo,
1154
+ current_task=current_task,
1155
+ action=action,
1156
+ extended_aredis=extended_aredis,
1157
+ )
1158
+ if event_emitter is not None:
1159
+ await event_emitter.aemit(event)
1160
+ # Drain detached event callbacks before returning (same boundary
1161
+ # discipline as the REST batch path), so the agent tool observes a
1162
+ # flushed, error-surfaced event stream.
1163
+ await event_emitter.adrain()
1164
+ return result
1165
+
1166
+ async def aapply_actions(
1167
+ self,
1168
+ *,
1169
+ amongo: AsyncIOMotorClient,
1170
+ current_task: TaskMongoObject,
1171
+ payload: TaskActions,
1172
+ ux_context: TaskUxContext | None = None,
1173
+ event_emitter: AgenticEventEmitter | None = None,
1174
+ extended_aredis: ExtendedAsyncRedis | None = None,
1175
+ ) -> TaskActionedResults:
1176
+ """Apply a plain REST/programmatic batch of canonical task actions in order.
1177
+
1178
+ This is the non-agent seam (REST endpoint, internal callers). The agent
1179
+ chat surface uses `aapply_action` (one action per tool call). Both share
1180
+ `_aapply_one_task_action`, so behavior never diverges.
1181
+
1182
+ Step by step:
1183
+ 1. validate optional UI root context against the current task workspace;
1184
+ 2. execute each action in caller-provided order via the shared executor;
1185
+ 3. emit each event when an emitter is provided;
1186
+ 4. return ordered action results plus the emitted event payloads.
1187
+ """
1188
+
1189
+ self._validate_task_action_root_context(
1190
+ current_task=current_task,
1191
+ ux_context=ux_context,
1192
+ )
1193
+ actioned_results: list[CreateRelatedTaskActionedResult] = []
1194
+ events: list[TaskActionEvent] = []
1195
+ for action in payload.actions:
1196
+ result, event = await self._aapply_one_task_action(
1197
+ amongo=amongo,
1198
+ current_task=current_task,
1199
+ action=action,
1200
+ extended_aredis=extended_aredis,
1201
+ )
1202
+ actioned_results.append(result)
1203
+ events.append(event)
1204
+ if event_emitter is not None:
1205
+ await event_emitter.aemit(event)
1206
+ # Drain at the action-batch boundary: each action event's callbacks
1207
+ # (publish/persist) are detached by the helpers emitter, so join them
1208
+ # before returning so the caller observes a flushed, error-surfaced event
1209
+ # stream rather than racing in-flight callbacks.
1210
+ if event_emitter is not None:
1211
+ await event_emitter.adrain()
1212
+ return TaskActionedResults(
1213
+ actioned_results=actioned_results,
1214
+ events=events,
1215
+ )
1216
+
1217
+ #
1218
+ #
1219
+ #
1220
+ #
1221
+ # Read
1222
+ #
1223
+ #
1224
+ #
1225
+ #
1226
+ @overload
1227
+ async def aget(
1228
+ self,
1229
+ *,
1230
+ amongo: AsyncIOMotorClient,
1231
+ task_id: ObjectId,
1232
+ task_type: Optional[str],
1233
+ include_childs: Literal[False] = False,
1234
+ ) -> TaskMongoObject | None: ...
1235
+
1236
+ @overload
1237
+ async def aget(
1238
+ self,
1239
+ *,
1240
+ amongo: AsyncIOMotorClient,
1241
+ task_id: ObjectId,
1242
+ task_type: Optional[str],
1243
+ include_childs: Literal[True],
1244
+ ) -> TaskMongoObjectExtended | None: ...
1245
+
1246
+ async def aget(
1247
+ self,
1248
+ *,
1249
+ amongo: AsyncIOMotorClient,
1250
+ task_id: ObjectId,
1251
+ task_type: Optional[str],
1252
+ include_childs: bool = False,
1253
+ ) -> TaskMongoObject | TaskMongoObjectExtended | None:
1254
+ """Get a task by ID asynchronously."""
1255
+ _filter: dict = {
1256
+ "_id": task_id,
1257
+ }
1258
+ if task_type is not None:
1259
+ _filter["task_type"] = task_type
1260
+ task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
1261
+ if task is None or not include_childs:
1262
+ return task
1263
+
1264
+ base_filter: dict = {}
1265
+ if task_type is not None:
1266
+ base_filter["task_type"] = task_type
1267
+ return await self._aget_task_tree_node(
1268
+ amongo=amongo,
1269
+ task=task,
1270
+ target_task_id=task_id,
1271
+ base_filter=base_filter,
1272
+ )
1273
+
1274
+ @overload
1275
+ async def aowner_get(
1276
+ self,
1277
+ *,
1278
+ amongo: AsyncIOMotorClient,
1279
+ owner: str,
1280
+ task_id: ObjectId,
1281
+ task_type: Optional[str],
1282
+ sharing_scope: Optional[SharingScope],
1283
+ include_childs: Literal[False] = False,
1284
+ ) -> TaskMongoObject | None: ...
1285
+
1286
+ @overload
1287
+ async def aowner_get(
1288
+ self,
1289
+ *,
1290
+ amongo: AsyncIOMotorClient,
1291
+ owner: str,
1292
+ task_id: ObjectId,
1293
+ task_type: Optional[str],
1294
+ sharing_scope: Optional[SharingScope],
1295
+ include_childs: Literal[True],
1296
+ ) -> TaskMongoObjectExtended | None: ...
1297
+
1298
+ async def aowner_get(
1299
+ self,
1300
+ *,
1301
+ amongo: AsyncIOMotorClient,
1302
+ owner: str,
1303
+ task_id: ObjectId,
1304
+ task_type: Optional[str],
1305
+ sharing_scope: Optional[SharingScope],
1306
+ include_childs: bool = False,
1307
+ ) -> TaskMongoObject | TaskMongoObjectExtended | None:
1308
+ """Get a task by owner and ID asynchronously."""
1309
+ _filter: dict = {"_id": task_id, "owner": owner}
1310
+ if task_type is not None:
1311
+ _filter["task_type"] = task_type
1312
+ if sharing_scope is not None:
1313
+ _filter["sharing_scope"] = sharing_scope
1314
+ task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
1315
+ if task is None or not include_childs:
1316
+ return task
1317
+
1318
+ base_filter: dict = {"owner": owner}
1319
+ if task_type is not None:
1320
+ base_filter["task_type"] = task_type
1321
+ if sharing_scope is not None:
1322
+ base_filter["sharing_scope"] = sharing_scope
1323
+ return await self._aget_task_tree_node(
1324
+ amongo=amongo,
1325
+ task=task,
1326
+ target_task_id=task_id,
1327
+ base_filter=base_filter,
1328
+ )
1329
+
1330
+ async def aresolve_task_room(
1331
+ self,
1332
+ *,
1333
+ amongo: AsyncIOMotorClient,
1334
+ owner: str,
1335
+ task_id: ObjectId,
1336
+ ) -> TaskRoomResolution:
1337
+ """Resolve one owner-visible task into its canonical root-room identity.
1338
+
1339
+ Step by step:
1340
+ 1. load the requested task under owner permissions;
1341
+ 2. resolve the durable root-task id from the stored task row;
1342
+ 3. return only the canonical room identity without prescribing room-context shape.
1343
+ """
1344
+
1345
+ requested_task = await self.aowner_get(
1346
+ amongo=amongo,
1347
+ owner=owner,
1348
+ task_id=task_id,
1349
+ task_type=None,
1350
+ sharing_scope=None,
1351
+ )
1352
+ if requested_task is None:
1353
+ raise keble_exceptions.NoObjectPermission(
1354
+ object_id=str(task_id),
1355
+ object_type="task",
1356
+ )
1357
+ return TaskRoomResolution(
1358
+ root_task_id=self._task_workspace_root_id(requested_task),
1359
+ requested_task_id=requested_task.id,
1360
+ )
1361
+
1362
+ async def aget_task_room_graph_context(
1363
+ self,
1364
+ *,
1365
+ amongo: AsyncIOMotorClient,
1366
+ owner: str,
1367
+ root_task_id: ObjectId,
1368
+ focused_task_id: ObjectId | None = None,
1369
+ ) -> TaskRoomGraphContext:
1370
+ """Load generic task graph context for one owner-visible root task room.
1371
+
1372
+ Step by step:
1373
+ 1. load the full root task tree under owner permissions;
1374
+ 2. validate the focused task belongs to the same root task room;
1375
+ 3. attach root-scoped relation rows without adding feature-specific fields.
1376
+ """
1377
+
1378
+ root_task = await self.aowner_get(
1379
+ amongo=amongo,
1380
+ owner=owner,
1381
+ task_id=root_task_id,
1382
+ task_type=None,
1383
+ sharing_scope=None,
1384
+ include_childs=True,
1385
+ )
1386
+ if root_task is None:
1387
+ raise keble_exceptions.NoObjectPermission(
1388
+ object_id=str(root_task_id),
1389
+ object_type="task",
1390
+ )
1391
+ if self._task_workspace_root_id(root_task) != root_task.id:
1392
+ raise keble_exceptions.ServerSideInvalidParams(
1393
+ admin_note={
1394
+ "root_task_id": root_task_id,
1395
+ "resolved_root_task_id": self._task_workspace_root_id(root_task),
1396
+ },
1397
+ alert_admin=True,
1398
+ but_got="requested task is not a root task",
1399
+ expected="root_task_id points to the root task room",
1400
+ invalid_params="root_task_id",
1401
+ )
1402
+
1403
+ resolved_focused_task_id = focused_task_id or root_task.id
1404
+ if resolved_focused_task_id != root_task.id:
1405
+ focused_task = await self.aowner_get(
1406
+ amongo=amongo,
1407
+ owner=owner,
1408
+ task_id=resolved_focused_task_id,
1409
+ task_type=None,
1410
+ sharing_scope=None,
1411
+ )
1412
+ if focused_task is None:
1413
+ raise keble_exceptions.NoObjectPermission(
1414
+ object_id=str(resolved_focused_task_id),
1415
+ object_type="task",
1416
+ )
1417
+ if self._task_workspace_root_id(focused_task) != root_task.id:
1418
+ raise keble_exceptions.ServerSideInvalidParams(
1419
+ admin_note={
1420
+ "root_task_id": root_task.id,
1421
+ "focused_task_id": resolved_focused_task_id,
1422
+ "focused_root_task_id": self._task_workspace_root_id(
1423
+ focused_task
1424
+ ),
1425
+ },
1426
+ alert_admin=True,
1427
+ but_got="focused task belongs to another root task room",
1428
+ expected="focused task belongs to root_task_id",
1429
+ invalid_params="focused_task_id",
1430
+ )
1431
+
1432
+ relations = await self.alist_task_relations_by_root(
1433
+ amongo=amongo,
1434
+ root_task=root_task.id,
1435
+ )
1436
+ return TaskRoomGraphContext.build(
1437
+ root_task=root_task,
1438
+ focused_task_id=resolved_focused_task_id,
1439
+ relations=relations,
1440
+ )
1441
+
1442
+ @overload
1443
+ async def apublic_get(
1444
+ self,
1445
+ *,
1446
+ amongo: AsyncIOMotorClient,
1447
+ task_id: ObjectId,
1448
+ task_type: Optional[str],
1449
+ include_childs: Literal[False] = False,
1450
+ ) -> TaskMongoObject | None: ...
1451
+
1452
+ @overload
1453
+ async def apublic_get(
1454
+ self,
1455
+ *,
1456
+ amongo: AsyncIOMotorClient,
1457
+ task_id: ObjectId,
1458
+ task_type: Optional[str],
1459
+ include_childs: Literal[True],
1460
+ ) -> TaskMongoObjectExtended | None: ...
1461
+
1462
+ async def apublic_get(
1463
+ self,
1464
+ *,
1465
+ amongo: AsyncIOMotorClient,
1466
+ task_id: ObjectId,
1467
+ task_type: Optional[str],
1468
+ include_childs: bool = False,
1469
+ ) -> TaskMongoObject | TaskMongoObjectExtended | None:
1470
+ """Get a public task by ID asynchronously."""
1471
+ _filter = {"_id": task_id, "sharing_scope": SharingScope.PUBLIC}
1472
+ if task_type is not None:
1473
+ _filter["task_type"] = task_type
1474
+ task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
1475
+ if task is None or not include_childs:
1476
+ return task
1477
+
1478
+ base_filter: dict = {"sharing_scope": SharingScope.PUBLIC}
1479
+ if task_type is not None:
1480
+ base_filter["task_type"] = task_type
1481
+ return await self._aget_task_tree_node(
1482
+ amongo=amongo,
1483
+ task=task,
1484
+ target_task_id=task_id,
1485
+ base_filter=base_filter,
1486
+ )
1487
+
1488
+ @overload
1489
+ async def apublic_get_by_public_id(
1490
+ self,
1491
+ *,
1492
+ amongo: AsyncIOMotorClient,
1493
+ public_id: str,
1494
+ task_type: Optional[str],
1495
+ include_childs: Literal[False] = False,
1496
+ ) -> TaskMongoObject | None: ...
1497
+
1498
+ @overload
1499
+ async def apublic_get_by_public_id(
1500
+ self,
1501
+ *,
1502
+ amongo: AsyncIOMotorClient,
1503
+ public_id: str,
1504
+ task_type: Optional[str],
1505
+ include_childs: Literal[True],
1506
+ ) -> TaskMongoObjectExtended | None: ...
1507
+
1508
+ async def apublic_get_by_public_id(
1509
+ self,
1510
+ *,
1511
+ amongo: AsyncIOMotorClient,
1512
+ public_id: str,
1513
+ task_type: Optional[str],
1514
+ include_childs: bool = False,
1515
+ ) -> TaskMongoObject | TaskMongoObjectExtended | None:
1516
+ """Get a public task by public_id asynchronously."""
1517
+ _filter: dict = {"public_id": public_id, "sharing_scope": SharingScope.PUBLIC}
1518
+ if task_type is not None:
1519
+ _filter["task_type"] = task_type
1520
+ task = await self.crud_task.afirst(amongo, query=QueryBase(filters=_filter))
1521
+ if task is None or not include_childs:
1522
+ return task
1523
+
1524
+ base_filter: dict = {"sharing_scope": SharingScope.PUBLIC}
1525
+ if task_type is not None:
1526
+ base_filter["task_type"] = task_type
1527
+ return await self._aget_task_tree_node(
1528
+ amongo=amongo,
1529
+ task=task,
1530
+ target_task_id=task.id,
1531
+ base_filter=base_filter,
1532
+ )
1533
+
1534
+ async def _aget_multi(
1535
+ self,
1536
+ *,
1537
+ amongo: AsyncIOMotorClient,
1538
+ extended_aredis: ExtendedAsyncRedis,
1539
+ skip: Optional[int],
1540
+ limit: Optional[int],
1541
+ filter_: dict,
1542
+ ) -> List[TaskMongoObject]:
1543
+ """Get multiple tasks by filter asynchronously."""
1544
+ tasks = await self.crud_task.aget_multi(
1545
+ amongo,
1546
+ query=QueryBase(
1547
+ filters=filter_,
1548
+ order_by=[("created", DESCENDING)], # by created time
1549
+ offset=skip,
1550
+ limit=limit,
1551
+ ),
1552
+ )
1553
+ for i, task in enumerate(tasks):
1554
+ if task.unfinshed_timeout:
1555
+ # mark it as failure
1556
+ logger.warning(
1557
+ f"[Task] Found an unfinished timeout task: {task.id}. Task stage = {task.stage}, created = {task.created}, task type = {task.task_type}, attempts = {task.attempts}"
1558
+ )
1559
+ if task.allow_to_retry:
1560
+ # try to restart the task
1561
+ logger.warning(
1562
+ f"[Task] Found an unfinished timeout task: {task.id}. Start to retry."
1563
+ )
1564
+ await self.astart(
1565
+ amongo=amongo,
1566
+ extended_aredis=extended_aredis,
1567
+ task=task,
1568
+ )
1569
+ else:
1570
+ # mark it as timeout failure
1571
+ logger.warning(
1572
+ f"[Task] Found an unfinished timeout task: {task.id}. Make it as Timeout."
1573
+ )
1574
+ await self.aon_task_timeout(
1575
+ amongo=amongo,
1576
+ task=task,
1577
+ )
1578
+
1579
+ updated_task = await self.crud_task.afirst_by_id(amongo, _id=task.id)
1580
+ assert updated_task is not None, (
1581
+ f"[Task] Internal error, expected not none object for id {task.id}, but got none."
1582
+ )
1583
+ tasks[i] = updated_task
1584
+
1585
+ return tasks
1586
+
1587
+ async def _aget_root_tasks(
1588
+ self,
1589
+ *,
1590
+ amongo: AsyncIOMotorClient,
1591
+ extended_aredis: ExtendedAsyncRedis,
1592
+ skip: int,
1593
+ limit: int,
1594
+ base_filter: dict,
1595
+ ) -> List[TaskMongoObject]:
1596
+ filter_ = dict(base_filter)
1597
+ filter_["parent_task"] = None
1598
+ return await self._aget_multi(
1599
+ amongo=amongo,
1600
+ extended_aredis=extended_aredis,
1601
+ skip=skip,
1602
+ limit=limit,
1603
+ filter_=filter_,
1604
+ )
1605
+
1606
+ async def _aget_descendant_tasks(
1607
+ self,
1608
+ *,
1609
+ amongo: AsyncIOMotorClient,
1610
+ extended_aredis: ExtendedAsyncRedis,
1611
+ root_ids: List[ObjectId],
1612
+ base_filter: dict,
1613
+ ) -> List[TaskMongoObject]:
1614
+ if len(root_ids) == 0:
1615
+ return []
1616
+ filter_ = dict(base_filter)
1617
+ filter_["root_task"] = {"$in": root_ids}
1618
+ filter_["_id"] = {"$nin": root_ids}
1619
+ return await self._aget_multi(
1620
+ amongo=amongo,
1621
+ extended_aredis=extended_aredis,
1622
+ skip=None,
1623
+ limit=None,
1624
+ filter_=filter_,
1625
+ )
1626
+
1627
+ async def _aget_tasks_in_root_tree(
1628
+ self,
1629
+ *,
1630
+ amongo: AsyncIOMotorClient,
1631
+ root_id: ObjectId,
1632
+ base_filter: dict,
1633
+ ) -> List[TaskMongoObject]:
1634
+ return await self.crud_task.aget_tasks_in_root_tree(
1635
+ amongo, root_id=root_id, base_filter=base_filter
1636
+ )
1637
+
1638
+ async def alist_task_relations_by_root(
1639
+ self,
1640
+ *,
1641
+ amongo: AsyncIOMotorClient,
1642
+ root_task: ObjectId,
1643
+ ) -> list[TaskRelationMongoObject]:
1644
+ """List extra lineage relations scoped to one root-task workspace.
1645
+
1646
+ Step by step:
1647
+ 1. query relation rows by the durable `root_task` workspace id;
1648
+ 2. order rows by creation time for stable canvas rendering;
1649
+ 3. return only relation rows, leaving tree composition to callers.
1650
+ """
1651
+
1652
+ await self._ensure_task_relation_indexes(amongo=amongo)
1653
+ return await self.crud_task_relation.aget_multi(
1654
+ amongo,
1655
+ query=QueryBase(
1656
+ filters={"root_task": root_task},
1657
+ order_by=[("created", ASCENDING)],
1658
+ ),
1659
+ )
1660
+
1661
+ async def _aget_task_tree_node(
1662
+ self,
1663
+ *,
1664
+ amongo: AsyncIOMotorClient,
1665
+ task: TaskMongoObject,
1666
+ target_task_id: ObjectId,
1667
+ base_filter: dict,
1668
+ ) -> TaskMongoObjectExtended:
1669
+ root_id = task.root_task or task.id
1670
+ tasks = await self._aget_tasks_in_root_tree(
1671
+ amongo=amongo, root_id=root_id, base_filter=base_filter
1672
+ )
1673
+ subtree_roots = self._build_task_tree(roots=[task], tasks=tasks)
1674
+ found = self._find_task_in_tree(nodes=subtree_roots, task_id=target_task_id)
1675
+ if found is not None:
1676
+ return found
1677
+ return TaskMongoObjectExtended(**task.model_dump(), childs=[])
1678
+
1679
+ def _build_task_tree(
1680
+ self, *, roots: List[TaskMongoObject], tasks: List[TaskMongoObject]
1681
+ ) -> List[TaskMongoObjectExtended]:
1682
+ return build_task_tree(roots=roots, tasks=tasks)
1683
+
1684
+ def _find_task_in_tree(
1685
+ self, *, nodes: List[TaskMongoObjectExtended], task_id: ObjectId
1686
+ ) -> Optional[TaskMongoObjectExtended]:
1687
+ return find_task_in_tree(nodes=nodes, task_id=task_id)
1688
+
1689
+ async def _aget_multi_with_optional_childs(
1690
+ self,
1691
+ *,
1692
+ amongo: AsyncIOMotorClient,
1693
+ extended_aredis: ExtendedAsyncRedis,
1694
+ skip: int,
1695
+ limit: int,
1696
+ base_filter: dict,
1697
+ include_childs: bool,
1698
+ root_task: Optional[ObjectId],
1699
+ parent_task: Optional[ObjectId],
1700
+ ) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
1701
+ """List either flat root tasks or task trees from one shared filter.
1702
+
1703
+ Step by step:
1704
+ 1. flat calls list root tasks by default so normal task lists stay workspace-level;
1705
+ 2. explicit `parent_task` flat calls still list that parent's direct children;
1706
+ 3. tree calls preserve the existing paginated-root plus descendant assembly.
1707
+ """
1708
+
1709
+ if not include_childs:
1710
+ flat_filter = dict(base_filter)
1711
+ if parent_task is None:
1712
+ flat_filter["parent_task"] = None
1713
+ return await self._aget_multi(
1714
+ amongo=amongo,
1715
+ extended_aredis=extended_aredis,
1716
+ skip=skip,
1717
+ limit=limit,
1718
+ filter_=flat_filter,
1719
+ )
1720
+
1721
+ if parent_task is not None:
1722
+ page_roots = await self._aget_multi(
1723
+ amongo=amongo,
1724
+ extended_aredis=extended_aredis,
1725
+ skip=skip,
1726
+ limit=limit,
1727
+ filter_=base_filter,
1728
+ )
1729
+ if len(page_roots) == 0:
1730
+ return []
1731
+ top_root_ids = list(
1732
+ {
1733
+ task.root_task if task.root_task is not None else task.id
1734
+ for task in page_roots
1735
+ }
1736
+ )
1737
+ tree_filter = {
1738
+ k: v
1739
+ for k, v in base_filter.items()
1740
+ if k not in {"parent_task", "root_task"}
1741
+ }
1742
+ if root_task is not None:
1743
+ tree_filter["root_task"] = root_task
1744
+ else:
1745
+ tree_filter["root_task"] = {"$in": top_root_ids}
1746
+ tasks_in_trees = await self._aget_multi(
1747
+ amongo=amongo,
1748
+ extended_aredis=extended_aredis,
1749
+ skip=None,
1750
+ limit=None,
1751
+ filter_=tree_filter,
1752
+ )
1753
+ return self._build_task_tree(roots=page_roots, tasks=tasks_in_trees)
1754
+
1755
+ roots = await self._aget_root_tasks(
1756
+ amongo=amongo,
1757
+ extended_aredis=extended_aredis,
1758
+ skip=skip,
1759
+ limit=limit,
1760
+ base_filter=base_filter,
1761
+ )
1762
+ children = await self._aget_descendant_tasks(
1763
+ amongo=amongo,
1764
+ extended_aredis=extended_aredis,
1765
+ root_ids=[t.id for t in roots],
1766
+ base_filter=base_filter,
1767
+ )
1768
+ return self._build_task_tree(roots=roots, tasks=roots + children)
1769
+
1770
+ @overload
1771
+ async def aget_multi(
1772
+ self,
1773
+ *,
1774
+ amongo: AsyncIOMotorClient,
1775
+ extended_aredis: ExtendedAsyncRedis,
1776
+ skip: int,
1777
+ limit: int,
1778
+ task_types: Optional[List[str]],
1779
+ include_childs: Literal[False] = False,
1780
+ root_task: Optional[ObjectId] = None,
1781
+ stages: Optional[List[str]] = None,
1782
+ title_contains: Optional[str] = None,
1783
+ parent_task: Optional[ObjectId] = None,
1784
+ ) -> List[TaskMongoObject]: ...
1785
+
1786
+ @overload
1787
+ async def aget_multi(
1788
+ self,
1789
+ *,
1790
+ amongo: AsyncIOMotorClient,
1791
+ extended_aredis: ExtendedAsyncRedis,
1792
+ skip: int,
1793
+ limit: int,
1794
+ task_types: Optional[List[str]],
1795
+ include_childs: Literal[True],
1796
+ root_task: Optional[ObjectId] = None,
1797
+ stages: Optional[List[str]] = None,
1798
+ title_contains: Optional[str] = None,
1799
+ parent_task: Optional[ObjectId] = None,
1800
+ ) -> List[TaskMongoObjectExtended]: ...
1801
+
1802
+ async def aget_multi(
1803
+ self,
1804
+ *,
1805
+ amongo: AsyncIOMotorClient,
1806
+ extended_aredis: ExtendedAsyncRedis,
1807
+ skip: int,
1808
+ limit: int,
1809
+ task_types: Optional[List[str]],
1810
+ include_childs: bool = False,
1811
+ root_task: Optional[ObjectId] = None,
1812
+ stages: Optional[List[str]] = None,
1813
+ title_contains: Optional[str] = None,
1814
+ parent_task: Optional[ObjectId] = None,
1815
+ ) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
1816
+ """Get multiple tasks asynchronously.
1817
+
1818
+ Side effect if changes:
1819
+ - `aowner_get_multi` and agent query registry tests mirror this generic
1820
+ read shape; keep stage/title filters aligned across owner/public
1821
+ variants.
1822
+ - Tree/list endpoints rely on `_aget_multi_with_optional_childs` for
1823
+ the `include_childs` return-type split.
1824
+ """
1825
+ base_filter: dict = {}
1826
+ if task_types is not None:
1827
+ base_filter["task_type"] = {"$in": task_types}
1828
+ if root_task is not None:
1829
+ base_filter["root_task"] = root_task
1830
+ if parent_task is not None:
1831
+ base_filter["parent_task"] = parent_task
1832
+ if stages is not None:
1833
+ base_filter["stage"] = {"$in": stages}
1834
+ if title_contains is not None and title_contains.strip():
1835
+ base_filter["title"] = {
1836
+ "$regex": re.escape(title_contains.strip()),
1837
+ "$options": "i",
1838
+ }
1839
+ return await self._aget_multi_with_optional_childs(
1840
+ amongo=amongo,
1841
+ extended_aredis=extended_aredis,
1842
+ skip=skip,
1843
+ limit=limit,
1844
+ base_filter=base_filter,
1845
+ include_childs=include_childs,
1846
+ root_task=root_task,
1847
+ parent_task=parent_task,
1848
+ )
1849
+
1850
+ @overload
1851
+ async def aowner_get_multi(
1852
+ self,
1853
+ *,
1854
+ amongo: AsyncIOMotorClient,
1855
+ extended_aredis: ExtendedAsyncRedis,
1856
+ owner: str,
1857
+ skip: int,
1858
+ limit: int,
1859
+ task_types: Optional[List[str]],
1860
+ sharing_scopes: Optional[List[SharingScope]],
1861
+ include_childs: Literal[False] = False,
1862
+ root_task: Optional[ObjectId] = None,
1863
+ stages: Optional[List[str]] = None,
1864
+ title_contains: Optional[str] = None,
1865
+ parent_task: Optional[ObjectId] = None,
1866
+ ) -> List[TaskMongoObject]: ...
1867
+
1868
+ @overload
1869
+ async def aowner_get_multi(
1870
+ self,
1871
+ *,
1872
+ amongo: AsyncIOMotorClient,
1873
+ extended_aredis: ExtendedAsyncRedis,
1874
+ owner: str,
1875
+ skip: int,
1876
+ limit: int,
1877
+ task_types: Optional[List[str]],
1878
+ sharing_scopes: Optional[List[SharingScope]],
1879
+ include_childs: Literal[True],
1880
+ root_task: Optional[ObjectId] = None,
1881
+ stages: Optional[List[str]] = None,
1882
+ title_contains: Optional[str] = None,
1883
+ parent_task: Optional[ObjectId] = None,
1884
+ ) -> List[TaskMongoObjectExtended]: ...
1885
+
1886
+ async def aowner_get_multi(
1887
+ self,
1888
+ *,
1889
+ amongo: AsyncIOMotorClient,
1890
+ extended_aredis: ExtendedAsyncRedis,
1891
+ owner: str,
1892
+ skip: int,
1893
+ limit: int,
1894
+ task_types: Optional[List[str]],
1895
+ sharing_scopes: Optional[List[SharingScope]],
1896
+ include_childs: bool = False,
1897
+ root_task: Optional[ObjectId] = None,
1898
+ stages: Optional[List[str]] = None,
1899
+ title_contains: Optional[str] = None,
1900
+ parent_task: Optional[ObjectId] = None,
1901
+ ) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
1902
+ """Get multiple tasks by owner asynchronously.
1903
+
1904
+ Collection: the task collection. Filter shape: `owner` equality rides
1905
+ the `owner + parent_task + created` index prefix; `task_type`,
1906
+ `sharing_scope`, `root_task`, `parent_task` are indexed-or-residual
1907
+ equality filters as before. Additive (2.10.0): `stages` (`$in`
1908
+ equality) and `title_contains` (escaped case-insensitive regex) are
1909
+ RESIDUAL filters inside the owner-bounded shape — intended for
1910
+ bounded agent/list reads (limit <= 50), never unbounded sweeps.
1911
+ """
1912
+ base_filter: dict = {"owner": owner}
1913
+ if task_types is not None:
1914
+ base_filter["task_type"] = {"$in": task_types}
1915
+ if sharing_scopes is not None:
1916
+ base_filter["sharing_scope"] = {"$in": sharing_scopes}
1917
+ if root_task is not None:
1918
+ base_filter["root_task"] = root_task
1919
+ if parent_task is not None:
1920
+ base_filter["parent_task"] = parent_task
1921
+ if stages is not None:
1922
+ base_filter["stage"] = {"$in": stages}
1923
+ if title_contains is not None and title_contains.strip():
1924
+ base_filter["title"] = {
1925
+ "$regex": re.escape(title_contains.strip()),
1926
+ "$options": "i",
1927
+ }
1928
+ return await self._aget_multi_with_optional_childs(
1929
+ amongo=amongo,
1930
+ extended_aredis=extended_aredis,
1931
+ skip=skip,
1932
+ limit=limit,
1933
+ base_filter=base_filter,
1934
+ include_childs=include_childs,
1935
+ root_task=root_task,
1936
+ parent_task=parent_task,
1937
+ )
1938
+
1939
+ @overload
1940
+ async def apublic_get_multi(
1941
+ self,
1942
+ *,
1943
+ amongo: AsyncIOMotorClient,
1944
+ extended_aredis: ExtendedAsyncRedis,
1945
+ skip: int,
1946
+ limit: int,
1947
+ task_types: Optional[List[str]],
1948
+ include_childs: Literal[False] = False,
1949
+ root_task: Optional[ObjectId] = None,
1950
+ stages: Optional[List[str]] = None,
1951
+ title_contains: Optional[str] = None,
1952
+ parent_task: Optional[ObjectId] = None,
1953
+ ) -> List[TaskMongoObject]: ...
1954
+
1955
+ @overload
1956
+ async def apublic_get_multi(
1957
+ self,
1958
+ *,
1959
+ amongo: AsyncIOMotorClient,
1960
+ extended_aredis: ExtendedAsyncRedis,
1961
+ skip: int,
1962
+ limit: int,
1963
+ task_types: Optional[List[str]],
1964
+ include_childs: Literal[True],
1965
+ root_task: Optional[ObjectId] = None,
1966
+ stages: Optional[List[str]] = None,
1967
+ title_contains: Optional[str] = None,
1968
+ parent_task: Optional[ObjectId] = None,
1969
+ ) -> List[TaskMongoObjectExtended]: ...
1970
+
1971
+ async def apublic_get_multi(
1972
+ self,
1973
+ *,
1974
+ amongo: AsyncIOMotorClient,
1975
+ extended_aredis: ExtendedAsyncRedis,
1976
+ skip: int,
1977
+ limit: int,
1978
+ task_types: Optional[List[str]],
1979
+ include_childs: bool = False,
1980
+ root_task: Optional[ObjectId] = None,
1981
+ stages: Optional[List[str]] = None,
1982
+ title_contains: Optional[str] = None,
1983
+ parent_task: Optional[ObjectId] = None,
1984
+ ) -> List[TaskMongoObject] | List[TaskMongoObjectExtended]:
1985
+ """Get multiple public tasks asynchronously.
1986
+
1987
+ `stages` (stage equality) and `title_contains` (escaped case-insensitive
1988
+ regex) are RESIDUAL filters inside the PUBLIC-bounded shape — mirroring
1989
+ `aowner_get_multi`. The overloads always promised them; the implementation
1990
+ now honors them instead of silently dropping them.
1991
+ """
1992
+ base_filter: dict = {"sharing_scope": SharingScope.PUBLIC}
1993
+ if task_types is not None:
1994
+ base_filter["task_type"] = {"$in": task_types}
1995
+ if root_task is not None:
1996
+ base_filter["root_task"] = root_task
1997
+ if parent_task is not None:
1998
+ base_filter["parent_task"] = parent_task
1999
+ if stages is not None:
2000
+ base_filter["stage"] = {"$in": stages}
2001
+ if title_contains is not None and title_contains.strip():
2002
+ base_filter["title"] = {
2003
+ "$regex": re.escape(title_contains.strip()),
2004
+ "$options": "i",
2005
+ }
2006
+ return await self._aget_multi_with_optional_childs(
2007
+ amongo=amongo,
2008
+ extended_aredis=extended_aredis,
2009
+ skip=skip,
2010
+ limit=limit,
2011
+ base_filter=base_filter,
2012
+ include_childs=include_childs,
2013
+ root_task=root_task,
2014
+ parent_task=parent_task,
2015
+ )
2016
+
2017
+ async def apublic_list_indexable(
2018
+ self,
2019
+ *,
2020
+ amongo: AsyncIOMotorClient,
2021
+ task_types: List[str],
2022
+ stages: List[str],
2023
+ limit: int,
2024
+ ) -> List[TaskPublicRef]:
2025
+ """List PUBLIC tasks of given types+stages, newest-first, id+updated only.
2026
+
2027
+ Objective: the lean read backing the dynamic-page sitemap. Unlike
2028
+ `apublic_get_multi` this is PROJECTED (only `_id` + `updated`), needs no
2029
+ `extended_aredis`, never loads childs, and returns the tiny `TaskPublicRef`
2030
+ instead of the full `TaskMongoObject` — so a crawl-time listing of many
2031
+ public reports stays cheap.
2032
+
2033
+ Filter shape (served by `sharing_scope + task_type + stage + updated DESC`):
2034
+ equality on `sharing_scope=PUBLIC`, `task_type $in`, `stage $in`, sorted by
2035
+ `updated` DESC and capped by `limit`. Converts raw projection docs ->
2036
+ `TaskPublicRef`.
2037
+ """
2038
+ docs = await self.crud_task.aget_multi_docs(
2039
+ amongo,
2040
+ query=QueryBase(
2041
+ filters={
2042
+ "sharing_scope": SharingScope.PUBLIC,
2043
+ "task_type": {"$in": task_types},
2044
+ "stage": {"$in": stages},
2045
+ },
2046
+ order_by=[("updated", DESCENDING)],
2047
+ limit=limit,
2048
+ ),
2049
+ project={"_id": 1, "updated": 1},
2050
+ )
2051
+ return [TaskPublicRef(**doc) for doc in docs]
2052
+
2053
+ #
2054
+ #
2055
+ #
2056
+ #
2057
+ # Update
2058
+ #
2059
+ #
2060
+ #
2061
+ #
2062
+ async def aupdate(
2063
+ self,
2064
+ *,
2065
+ amongo: AsyncIOMotorClient,
2066
+ extended_aredis: ExtendedAsyncRedis,
2067
+ task_id: ObjectId,
2068
+ obj_in: TaskUpdate,
2069
+ public_id_prefix: Optional[str] = None,
2070
+ public_id_length: Optional[int] = None,
2071
+ ):
2072
+ """Update a task asynchronously."""
2073
+ task = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
2074
+ if task is None:
2075
+ raise keble_exceptions.ServerSideInvalidParams(
2076
+ admin_note={"task_id": task_id},
2077
+ alert_admin=True,
2078
+ but_got="task is None",
2079
+ expected="task exists",
2080
+ invalid_params=f"task_id={task_id}",
2081
+ )
2082
+ if obj_in.sharing_scope == SharingScope.PUBLIC:
2083
+ await self._ensure_task_public_id(
2084
+ amongo=amongo,
2085
+ task_id=task_id,
2086
+ public_id_prefix=public_id_prefix,
2087
+ public_id_length=public_id_length,
2088
+ )
2089
+ await self.crud_task.aupdate(
2090
+ amongo,
2091
+ _id=task_id,
2092
+ obj_in=obj_in.model_dump(exclude_unset=True),
2093
+ )
2094
+
2095
+ async def aowner_update(
2096
+ self,
2097
+ *,
2098
+ amongo: AsyncIOMotorClient,
2099
+ owner: str,
2100
+ task_id: ObjectId,
2101
+ obj_in: TaskUpdate,
2102
+ public_id_prefix: Optional[str] = None,
2103
+ public_id_length: Optional[int] = None,
2104
+ ):
2105
+ """Update a task by owner asynchronously."""
2106
+ r = await self.aowner_get(
2107
+ amongo=amongo,
2108
+ owner=owner,
2109
+ sharing_scope=None,
2110
+ task_id=task_id,
2111
+ task_type=None,
2112
+ )
2113
+ if r is None:
2114
+ raise keble_exceptions.ServerSideInvalidParams(
2115
+ admin_note={"task_id": task_id, "owner": owner},
2116
+ alert_admin=True,
2117
+ but_got="task is None",
2118
+ expected="task exists",
2119
+ invalid_params=f"task_id={task_id}",
2120
+ )
2121
+ if obj_in.sharing_scope == SharingScope.PUBLIC:
2122
+ await self._ensure_task_public_id(
2123
+ amongo=amongo,
2124
+ task_id=task_id,
2125
+ public_id_prefix=public_id_prefix,
2126
+ public_id_length=public_id_length,
2127
+ )
2128
+ await self.crud_task.aupdate(
2129
+ amongo,
2130
+ _id=task_id,
2131
+ obj_in=obj_in.model_dump(exclude_unset=True),
2132
+ )
2133
+
2134
+ #
2135
+ #
2136
+ #
2137
+ #
2138
+ # Event Handler
2139
+ #
2140
+ #
2141
+ #
2142
+ #
2143
+ def _get_task_start_lock_key(self, task: TaskMongoObject) -> str:
2144
+ """Get the Redis key for task start lock."""
2145
+ return f"just_started:{task.id}"
2146
+
2147
+ async def _acheck_start_lock(
2148
+ self, extended_aredis: ExtendedAsyncRedis, task: TaskMongoObject
2149
+ ) -> bool:
2150
+ """Check if task is locked. Returns True if not locked."""
2151
+ return await extended_aredis.get(self._get_task_start_lock_key(task)) is None
2152
+
2153
+ async def _alock_start(
2154
+ self, extended_aredis: ExtendedAsyncRedis, task: TaskMongoObject
2155
+ ):
2156
+ """Lock a task for starting."""
2157
+ await extended_aredis.set(
2158
+ self._get_task_start_lock_key(task),
2159
+ "1",
2160
+ ex=60, # expire after 1 min
2161
+ )
2162
+
2163
+ async def _aclear_lock(
2164
+ self, extended_aredis: ExtendedAsyncRedis, task: TaskMongoObject
2165
+ ):
2166
+ """Clear the task start lock."""
2167
+ await extended_aredis.delete(self._get_task_start_lock_key(task))
2168
+
2169
+ # `before_sleep` logs EVERY failed attempt with its full traceback.
2170
+ # Without it, `reraise=True` surfaces only the LAST attempt's exception —
2171
+ # in a live incident the real attempt-1 failure (a crashed spawn tool) was
2172
+ # silently discarded and only a misleading downstream error reached the
2173
+ # task row, costing the diagnosis.
2174
+ @tenacity.retry(
2175
+ stop=tenacity.stop_after_attempt(3),
2176
+ reraise=True,
2177
+ before_sleep=tenacity.before_sleep_log(logger, logging.ERROR, exc_info=True),
2178
+ )
2179
+ async def _ahandle_task(
2180
+ self,
2181
+ *,
2182
+ amongo: AsyncIOMotorClient,
2183
+ extended_aredis: ExtendedAsyncRedis,
2184
+ aneo4j: Optional[Neo4jAsyncDriver] = None,
2185
+ qdrant_client: Optional[AsyncQdrantClient] = None,
2186
+ task: TaskMongoObject,
2187
+ task_handler_metadata: Optional[TaskMetadata] = None,
2188
+ event_emitter: AgenticEventEmitter,
2189
+ ):
2190
+ """Run one task handler and route its returned response.
2191
+
2192
+ Step by step:
2193
+ 1. increment attempts and merge metadata;
2194
+ 2. build the request AS an `AgentDbDeps` (DB clients + owner + the shared
2195
+ room `event_emitter` + `usage_recorder`) — the handler uses it directly;
2196
+ 3. await the handler's `TaskHandlerResponse | None`;
2197
+ 4. `None` => the handler delegated completion to another process; leave the
2198
+ task PROCESSING (a separate worker finalizes it later);
2199
+ 5. otherwise route success/failure through the runtime finalizers, which
2200
+ emit the terminal `KEBLE_TASK` lifecycle event via `event_emitter`.
2201
+
2202
+ Side effects:
2203
+ - replaces the old `aon_task_completed` callback inversion-of-control: the
2204
+ runtime (not the handler) owns persistence, token accounting, and the
2205
+ terminal event.
2206
+ """
2207
+ # increment attempts for each retry
2208
+ task = await self.crud_task.aincrement_attempts(amongo, _id=task.id)
2209
+ merged_metadata = _merge_metadata(task.metadata, task_handler_metadata)
2210
+
2211
+ # `ExtendedAsyncRedis` IS an `AsyncRedis` subclass, so it satisfies both the
2212
+ # base `aredis` field and the extended `extended_aredis` field.
2213
+ request = TaskHandlerRequest(
2214
+ amongo=amongo,
2215
+ aredis=extended_aredis,
2216
+ extended_aredis=extended_aredis,
2217
+ aneo4j=aneo4j,
2218
+ aqdrant=qdrant_client,
2219
+ owner=task.owner,
2220
+ language=task.language or Language.ENGLISH,
2221
+ event_emitter=event_emitter,
2222
+ task=task,
2223
+ metadata=merged_metadata,
2224
+ )
2225
+ response = await self._task_handler(request)
2226
+
2227
+ # None => the handler delegated completion elsewhere; do NOT finalize.
2228
+ if response is None:
2229
+ return
2230
+
2231
+ if response.success:
2232
+ await self.aon_task_success(
2233
+ amongo=amongo,
2234
+ task=response.task,
2235
+ consuming_token=response.consuming_token,
2236
+ extended_aredis=extended_aredis,
2237
+ aneo4j=aneo4j,
2238
+ metadata=merged_metadata,
2239
+ event_emitter=event_emitter,
2240
+ )
2241
+ else:
2242
+ await self.aon_task_failure(
2243
+ amongo=amongo,
2244
+ extended_aredis=extended_aredis,
2245
+ error=response.error
2246
+ or f"Task handled response has a status = {response.success}",
2247
+ exception_type=response.exception_type or TaskExceptionType.UNKNOWN,
2248
+ task=response.task,
2249
+ aneo4j=aneo4j,
2250
+ metadata=merged_metadata,
2251
+ event_emitter=event_emitter,
2252
+ )
2253
+
2254
+ async def astart(
2255
+ self,
2256
+ *,
2257
+ get_amongo: Optional[Callable[[], AsyncIOMotorClient]] = None,
2258
+ get_extended_aredis: Optional[Callable[[], ExtendedAsyncRedis]] = None,
2259
+ get_aneo4j: Optional[Callable[[], Optional[Neo4jAsyncDriver]]] = None,
2260
+ get_qdrant_client: Optional[Callable[[], Optional[AsyncQdrantClient]]] = None,
2261
+ amongo: Optional[AsyncIOMotorClient] = None,
2262
+ extended_aredis: Optional[ExtendedAsyncRedis] = None,
2263
+ aneo4j: Optional[Neo4jAsyncDriver] = None,
2264
+ qdrant_client: Optional[AsyncQdrantClient] = None,
2265
+ task_id: Optional[ObjectId] = None,
2266
+ task: Optional[TaskMongoObject] = None,
2267
+ # metadata to pass into handler request
2268
+ task_handler_metadata: Optional[TaskMetadata] = None,
2269
+ task_lifecycle_event_emitter: AgenticEventEmitter | None = None,
2270
+ ):
2271
+ """Start one task and always finalize uncaught handler failures into task state.
2272
+
2273
+ Step by step this entrypoint:
2274
+ 1. resolves the required runtime resources;
2275
+ 2. checks the single-start Redis lock;
2276
+ 3. moves the task into `PROCESSING`;
2277
+ 4. runs the configured task handler with retry support;
2278
+ 5. if the handler still raises, marks the task as `FAILURE` unless it was
2279
+ already finalized by the handler itself;
2280
+ 6. emits lifecycle events after persisted stage transitions when provided;
2281
+ 7. always clears the single-start Redis lock before the method exits, even
2282
+ when the processing transition itself fails.
2283
+ """
2284
+ if amongo is None:
2285
+ if get_amongo is None:
2286
+ raise keble_exceptions.ServerSideMissingParams(
2287
+ admin_note={"get_amongo": get_amongo},
2288
+ alert_admin=True,
2289
+ missing_params="get_amongo",
2290
+ )
2291
+ amongo = get_amongo()
2292
+ if extended_aredis is None:
2293
+ if get_extended_aredis is None:
2294
+ raise keble_exceptions.ServerSideMissingParams(
2295
+ admin_note={"get_extended_aredis": get_extended_aredis},
2296
+ alert_admin=True,
2297
+ missing_params="get_extended_aredis",
2298
+ )
2299
+ extended_aredis = get_extended_aredis()
2300
+ # neo4j + getter can be none
2301
+ if aneo4j is None and get_aneo4j is not None:
2302
+ aneo4j = get_aneo4j()
2303
+ # qdrant + getter can be none
2304
+ if qdrant_client is None and get_qdrant_client is not None:
2305
+ qdrant_client = get_qdrant_client()
2306
+
2307
+ db_obj = await self._aget_task_for_event(
2308
+ amongo=amongo, task_id=task_id, task=task
2309
+ )
2310
+ # Try to start the task
2311
+ allow_proceed = await self._acheck_start_lock(
2312
+ extended_aredis=extended_aredis, task=db_obj
2313
+ )
2314
+ if not allow_proceed:
2315
+ logger.warning(f"[Task] Task {db_obj.id} already started, skip to start.")
2316
+ return db_obj
2317
+
2318
+ if not db_obj.stage.allow_to_start:
2319
+ logger.warning(
2320
+ f"[Task] Can not start task: {db_obj.id}, current stage = {db_obj.stage}"
2321
+ )
2322
+ raise TaskFailedToStartException(
2323
+ f"Can not start task: {db_obj.id}, current stage = {db_obj.stage}"
2324
+ )
2325
+
2326
+ try:
2327
+ await self._alock_start(extended_aredis=extended_aredis, task=db_obj)
2328
+ await self.aon_task_processing(
2329
+ amongo=amongo,
2330
+ task=db_obj,
2331
+ event_emitter=task_lifecycle_event_emitter,
2332
+ )
2333
+ try:
2334
+ await self._ahandle_task(
2335
+ amongo=amongo,
2336
+ extended_aredis=extended_aredis,
2337
+ task=db_obj,
2338
+ aneo4j=aneo4j,
2339
+ qdrant_client=qdrant_client,
2340
+ task_handler_metadata=task_handler_metadata,
2341
+ event_emitter=task_lifecycle_event_emitter
2342
+ or AgenticEventEmitter(),
2343
+ )
2344
+ except Exception as exc:
2345
+ logger.critical(
2346
+ f"[Task] Task failed during execution: {db_obj.id}, with error: {str(exc)}\n{traceback.format_exc()}"
2347
+ )
2348
+ latest_task = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
2349
+ assert latest_task is not None, (
2350
+ f"Task with ID {db_obj.id} not found after handler failure"
2351
+ )
2352
+ if latest_task.stage not in {TaskStage.SUCCESS, TaskStage.FAILURE}:
2353
+ failure_error = str(exc)
2354
+ failure_exception_type = TaskExceptionType.UNKNOWN
2355
+ if isinstance(exc, TaskException):
2356
+ failure_error = exc.error or failure_error
2357
+ failure_exception_type = exc.exception_type
2358
+ await self.aon_task_failure(
2359
+ amongo=amongo,
2360
+ extended_aredis=extended_aredis,
2361
+ aneo4j=aneo4j,
2362
+ task=latest_task,
2363
+ error=failure_error,
2364
+ exception_type=failure_exception_type,
2365
+ metadata=task_handler_metadata,
2366
+ event_emitter=task_lifecycle_event_emitter,
2367
+ )
2368
+ finally:
2369
+ await self._aclear_lock(extended_aredis=extended_aredis, task=db_obj)
2370
+ # Backstop drain: every aon_task_* lifecycle method already self-drains,
2371
+ # but join any residual detached emit here so the worker envelope never
2372
+ # returns (and releases the Celery lease) with spawn-on-terminal or
2373
+ # terminal-subtree-settle callbacks still in flight.
2374
+ if task_lifecycle_event_emitter is not None:
2375
+ await task_lifecycle_event_emitter.adrain()
2376
+
2377
+ result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
2378
+ assert result is not None, f"Task with ID {db_obj.id} not found after start"
2379
+ return result
2380
+
2381
+ async def aretry_all_undone(
2382
+ self,
2383
+ *,
2384
+ amongo: AsyncIOMotorClient,
2385
+ extended_aredis: ExtendedAsyncRedis,
2386
+ aneo4j: Optional[Neo4jAsyncDriver] = None,
2387
+ get_aneo4j: Optional[Callable[[], Optional[Neo4jAsyncDriver]]] = None,
2388
+ task_handler_metadata: Optional[TaskMetadata] = None,
2389
+ include_tasks_created_within_minutes: int = 30,
2390
+ ):
2391
+ """Retry all undone tasks asynchronously."""
2392
+ # fetch all undone
2393
+ after = utc_now() - timedelta(
2394
+ minutes=include_tasks_created_within_minutes
2395
+ )
2396
+ db_objs = await self.crud_task.aget_multi(
2397
+ amongo,
2398
+ query=QueryBase(
2399
+ filters={
2400
+ "stage": {"$in": [TaskStage.PENDING, TaskStage.PROCESSING]},
2401
+ "created": {"$gt": after},
2402
+ },
2403
+ order_by=[("created", DESCENDING)], # by created time
2404
+ ),
2405
+ )
2406
+ for db_obj in db_objs:
2407
+ if db_obj.allow_to_retry:
2408
+ # restart it
2409
+ await self.astart(
2410
+ amongo=amongo,
2411
+ extended_aredis=extended_aredis,
2412
+ aneo4j=aneo4j,
2413
+ get_aneo4j=get_aneo4j,
2414
+ task_handler_metadata=task_handler_metadata,
2415
+ task=db_obj,
2416
+ )
2417
+ else:
2418
+ # set it as failure due to timeout
2419
+ await self.aon_task_timeout(amongo=amongo, task=db_obj)
2420
+
2421
+ async def aget_tasks_by_stage_since(
2422
+ self,
2423
+ *,
2424
+ amongo: AsyncIOMotorClient,
2425
+ stages: Optional[List[TaskStage]] = None,
2426
+ hours: int = 1,
2427
+ ) -> List[TaskMongoObject]:
2428
+ """Get tasks from the last N hours with specified statuses."""
2429
+ cutoff_time = utc_now() - timedelta(hours=hours)
2430
+ return await self.crud_task.aget_multi(
2431
+ amongo,
2432
+ query=QueryBase(
2433
+ filters={
2434
+ "stage": {
2435
+ "$in": stages or [TaskStage.PENDING, TaskStage.PROCESSING]
2436
+ },
2437
+ "created": {"$gt": cutoff_time},
2438
+ },
2439
+ order_by=[("created", DESCENDING)],
2440
+ ),
2441
+ )
2442
+
2443
+ async def _aget_task_for_event(
2444
+ self,
2445
+ *,
2446
+ amongo: AsyncIOMotorClient,
2447
+ task_id: Optional[ObjectId],
2448
+ task: Optional[TaskMongoObject],
2449
+ ):
2450
+ """Get a task for an event asynchronously."""
2451
+ if task_id is not None:
2452
+ db_obj = await self.aget(amongo=amongo, task_id=task_id, task_type=None)
2453
+ if db_obj is None:
2454
+ raise keble_exceptions.ServerSideInvalidParams(
2455
+ admin_note={"task_id": task_id},
2456
+ alert_admin=True,
2457
+ but_got="report is None",
2458
+ expected="report exists",
2459
+ invalid_params=f"task_id={task_id}",
2460
+ )
2461
+ return db_obj
2462
+ elif task is not None:
2463
+ return task
2464
+ else:
2465
+ raise keble_exceptions.ServerSideInvalidParams(
2466
+ admin_note={},
2467
+ alert_admin=True,
2468
+ but_got="no params",
2469
+ expected="at least one of task_id or task",
2470
+ invalid_params="task_id or task",
2471
+ )
2472
+
2473
+ async def _aemit_task_lifecycle_event(
2474
+ self,
2475
+ *,
2476
+ task: TaskMongoObject,
2477
+ event_emitter: AgenticEventEmitter | None,
2478
+ ) -> None:
2479
+ """Emit one canonical task lifecycle event after a persisted mutation.
2480
+
2481
+ Step by step:
2482
+ 1. do nothing when the caller did not provide an event emitter;
2483
+ 2. build the lifecycle payload from the reloaded task row;
2484
+ 3. scope the event to the root task room while identifying the changed task.
2485
+ """
2486
+
2487
+ if event_emitter is None:
2488
+ return
2489
+ event = TaskLifecycleEvent(
2490
+ action_type=TaskEventType.TASK_STAGE_CHANGED.value,
2491
+ status=_task_lifecycle_event_status(stage=task.stage),
2492
+ payload=TaskLifecycleEventPayload(task=task),
2493
+ root_id=str(task.root_task or task.id),
2494
+ object_id=str(task.id),
2495
+ correlation_id=task.progress_key,
2496
+ )
2497
+ await event_emitter.aemit(event)
2498
+ # Drain at the lifecycle boundary: the helpers emitter detaches callbacks, but
2499
+ # lifecycle callbacks carry DOMAIN side effects that the workflow reads right
2500
+ # after (spawn-on-terminal, terminal-subtree settle, chat-memory writes). The
2501
+ # task-row mutation is already persisted before this emit, so joining the
2502
+ # detached chain here makes each aon_task_* atomic w.r.t. its own callbacks and
2503
+ # guarantees those side effects complete before the lifecycle method returns.
2504
+ # Side effect if changes: removing this re-opens the spawn-on-terminal /
2505
+ # subtree-settle race in keble.backend (subagent_workflow lifecycle callback).
2506
+ await event_emitter.adrain()
2507
+
2508
+ async def aon_task_success(
2509
+ self,
2510
+ *,
2511
+ amongo: AsyncIOMotorClient,
2512
+ extended_aredis: ExtendedAsyncRedis,
2513
+ aneo4j: Optional[Neo4jAsyncDriver] = None,
2514
+ task_id: Optional[ObjectId] = None,
2515
+ task: Optional[TaskMongoObject] = None,
2516
+ consuming_token: int, # how many token should be consume (actually)
2517
+ metadata: Optional[TaskMetadata] = None,
2518
+ event_emitter: AgenticEventEmitter | None = None,
2519
+ ):
2520
+ """Mark a task successful and optionally emit a lifecycle event."""
2521
+ db_obj = await self._aget_task_for_event(
2522
+ amongo=amongo, task=task, task_id=task_id
2523
+ )
2524
+ update_consumed_token = db_obj.consumed_token
2525
+ if consuming_token < db_obj.consumed_token:
2526
+ # do a recovery
2527
+ recover_amount = db_obj.consumed_token - consuming_token
2528
+ recovered = await self._execute_token_consumption_handler(
2529
+ TokenConsumptionPayload(
2530
+ consumption_type=TokenConsumptionType.RECOVER,
2531
+ owner=db_obj.owner,
2532
+ task_id=db_obj.id,
2533
+ token=recover_amount,
2534
+ amongo=amongo,
2535
+ extended_aredis=extended_aredis,
2536
+ aneo4j=aneo4j,
2537
+ metadata=metadata,
2538
+ )
2539
+ )
2540
+ if recovered:
2541
+ update_consumed_token = consuming_token
2542
+ elif consuming_token > db_obj.consumed_token:
2543
+ # do a token consumption
2544
+ consume_amount = consuming_token - db_obj.consumed_token
2545
+ consumed = await self._execute_token_consumption_handler(
2546
+ TokenConsumptionPayload(
2547
+ token=consume_amount,
2548
+ owner=db_obj.owner,
2549
+ consumption_type=TokenConsumptionType.CONSUME,
2550
+ task_id=db_obj.id,
2551
+ amongo=amongo,
2552
+ extended_aredis=extended_aredis,
2553
+ aneo4j=aneo4j,
2554
+ metadata=metadata,
2555
+ )
2556
+ )
2557
+ if consumed:
2558
+ update_consumed_token = consuming_token
2559
+
2560
+ # Mark as done
2561
+ now = utc_now()
2562
+ await self.crud_task.aupdate(
2563
+ amongo,
2564
+ _id=db_obj.id,
2565
+ obj_in={
2566
+ "stage": TaskStage.SUCCESS,
2567
+ "success_ts": int(now.timestamp()),
2568
+ "consumed_token": update_consumed_token,
2569
+ },
2570
+ )
2571
+
2572
+ result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
2573
+ assert result is not None, f"Task with ID {db_obj.id} not found after update"
2574
+ await self._aemit_task_lifecycle_event(
2575
+ task=result,
2576
+ event_emitter=event_emitter,
2577
+ )
2578
+ return result
2579
+
2580
+ async def aon_task_processing(
2581
+ self,
2582
+ *,
2583
+ amongo: AsyncIOMotorClient,
2584
+ task_id: Optional[ObjectId] = None,
2585
+ task: Optional[TaskMongoObject] = None,
2586
+ event_emitter: AgenticEventEmitter | None = None,
2587
+ ) -> TaskMongoObject:
2588
+ """Mark a task processing and optionally emit a lifecycle event."""
2589
+ db_obj = await self._aget_task_for_event(
2590
+ amongo=amongo, task=task, task_id=task_id
2591
+ )
2592
+ if not db_obj.stage.allow_to_start:
2593
+ logger.warning(
2594
+ f"[Task] Can not switch to processing for task: {db_obj.id}, current stage = {db_obj.stage}"
2595
+ )
2596
+ return db_obj
2597
+ # Mark as processing
2598
+ now = utc_now()
2599
+ await self.crud_task.aupdate(
2600
+ amongo,
2601
+ _id=db_obj.id,
2602
+ obj_in={
2603
+ "stage": TaskStage.PROCESSING,
2604
+ "started_ts": int(now.timestamp()),
2605
+ },
2606
+ )
2607
+ result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
2608
+ assert result is not None, f"Task with ID {db_obj.id} not found after update"
2609
+ await self._aemit_task_lifecycle_event(
2610
+ task=result,
2611
+ event_emitter=event_emitter,
2612
+ )
2613
+ return result
2614
+
2615
+ async def aon_task_failure(
2616
+ self,
2617
+ *,
2618
+ amongo: AsyncIOMotorClient,
2619
+ extended_aredis: ExtendedAsyncRedis,
2620
+ aneo4j: Optional[Neo4jAsyncDriver] = None,
2621
+ task: Optional[TaskMongoObject] = None,
2622
+ task_id: Optional[ObjectId] = None,
2623
+ error: Optional[str] = None,
2624
+ exception_type: TaskExceptionType,
2625
+ metadata: Optional[TaskMetadata] = None,
2626
+ event_emitter: AgenticEventEmitter | None = None,
2627
+ ) -> TaskMongoObject:
2628
+ """Mark a task failed and optionally emit a lifecycle event."""
2629
+ db_obj = await self._aget_task_for_event(
2630
+ amongo=amongo, task=task, task_id=task_id
2631
+ )
2632
+
2633
+ # Handle token recovery if needed
2634
+ update_obj = {
2635
+ "stage": TaskStage.FAILURE,
2636
+ "failure_ts": int(utc_now().timestamp()),
2637
+ "error": error,
2638
+ "exception_type": exception_type,
2639
+ }
2640
+
2641
+ if db_obj.consumed_token > 0:
2642
+ # recover tokens
2643
+ recovered = await self._execute_token_consumption_handler(
2644
+ TokenConsumptionPayload(
2645
+ consumption_type=TokenConsumptionType.RECOVER,
2646
+ owner=db_obj.owner,
2647
+ token=db_obj.consumed_token,
2648
+ task_id=db_obj.id,
2649
+ amongo=amongo,
2650
+ extended_aredis=extended_aredis,
2651
+ aneo4j=aneo4j,
2652
+ metadata=metadata,
2653
+ )
2654
+ )
2655
+ if recovered:
2656
+ update_obj["consumed_token"] = 0
2657
+
2658
+ # Mark as failure
2659
+ await self.crud_task.aupdate(
2660
+ amongo,
2661
+ _id=db_obj.id,
2662
+ obj_in=update_obj,
2663
+ )
2664
+
2665
+ result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
2666
+ assert result is not None, f"Task with ID {db_obj.id} not found after update"
2667
+ await self._aemit_task_lifecycle_event(
2668
+ task=result,
2669
+ event_emitter=event_emitter,
2670
+ )
2671
+ return result
2672
+
2673
+ async def aon_task_timeout(
2674
+ self,
2675
+ *,
2676
+ amongo: AsyncIOMotorClient,
2677
+ task: Optional[TaskMongoObject] = None,
2678
+ task_id: Optional[ObjectId] = None,
2679
+ event_emitter: AgenticEventEmitter | None = None,
2680
+ ):
2681
+ """Mark a task timed out and optionally emit a lifecycle event."""
2682
+ db_obj = await self._aget_task_for_event(
2683
+ amongo=amongo, task=task, task_id=task_id
2684
+ )
2685
+ error_message = (
2686
+ f"Task timeout after {db_obj.timeout_mins} min. Created at {db_obj.created}"
2687
+ )
2688
+
2689
+ # Mark as failure
2690
+ now = utc_now()
2691
+ await self.crud_task.aupdate(
2692
+ amongo,
2693
+ _id=db_obj.id,
2694
+ obj_in={
2695
+ "stage": TaskStage.FAILURE,
2696
+ "failure_ts": int(now.timestamp()),
2697
+ "error": error_message,
2698
+ "exception_type": TaskExceptionType.TIMEOUT,
2699
+ },
2700
+ )
2701
+
2702
+ result = await self.crud_task.afirst_by_id(amongo, _id=db_obj.id)
2703
+ assert result is not None, f"Task with ID {db_obj.id} not found after update"
2704
+ await self._aemit_task_lifecycle_event(
2705
+ task=result,
2706
+ event_emitter=event_emitter,
2707
+ )
2708
+ return result