thirdmagic 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- thirdmagic/__init__.py +5 -0
- thirdmagic/chain/__init__.py +4 -0
- thirdmagic/chain/creator.py +34 -0
- thirdmagic/chain/model.py +91 -0
- thirdmagic/clients/__init__.py +3 -0
- thirdmagic/clients/base.py +150 -0
- thirdmagic/clients/lifecycle.py +24 -0
- thirdmagic/consts.py +4 -0
- thirdmagic/container.py +35 -0
- thirdmagic/errors.py +26 -0
- thirdmagic/message.py +26 -0
- thirdmagic/signature/__init__.py +4 -0
- thirdmagic/signature/model.py +191 -0
- thirdmagic/signature/status.py +32 -0
- thirdmagic/swarm/__init__.py +5 -0
- thirdmagic/swarm/consts.py +1 -0
- thirdmagic/swarm/creator.py +53 -0
- thirdmagic/swarm/model.py +222 -0
- thirdmagic/swarm/state.py +13 -0
- thirdmagic/task/__init__.py +22 -0
- thirdmagic/task/creator.py +87 -0
- thirdmagic/task/model.py +83 -0
- thirdmagic/task_def.py +12 -0
- thirdmagic/typing_support.py +15 -0
- thirdmagic/utils.py +57 -0
- thirdmagic-0.0.1.dist-info/METADATA +16 -0
- thirdmagic-0.0.1.dist-info/RECORD +28 -0
- thirdmagic-0.0.1.dist-info/WHEEL +4 -0
thirdmagic/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from thirdmagic.chain.model import ChainTaskSignature
|
|
2
|
+
from thirdmagic.signature.model import TaskInputType
|
|
3
|
+
from thirdmagic.task.creator import resolve_signatures, TaskSignatureConvertible
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def chain(
|
|
7
|
+
tasks: list[TaskSignatureConvertible],
|
|
8
|
+
name: str = None,
|
|
9
|
+
error: TaskInputType = None,
|
|
10
|
+
success: TaskInputType = None,
|
|
11
|
+
**kwargs,
|
|
12
|
+
) -> ChainTaskSignature:
|
|
13
|
+
if len(tasks) < 2:
|
|
14
|
+
raise ValueError(
|
|
15
|
+
"Chained tasks must contain at least two tasks. "
|
|
16
|
+
"If you want to run a single task, use `create_workflow` instead."
|
|
17
|
+
)
|
|
18
|
+
tasks = await resolve_signatures(tasks)
|
|
19
|
+
|
|
20
|
+
# Create a chain task that will be deleted only at the end of the chain
|
|
21
|
+
first_task = tasks[0]
|
|
22
|
+
chain_task_signature = ChainTaskSignature(
|
|
23
|
+
task_name=f"chain-task:{name or first_task.task_name}",
|
|
24
|
+
success_callbacks=[success] if success else [],
|
|
25
|
+
error_callbacks=[error] if error else [],
|
|
26
|
+
tasks=tasks,
|
|
27
|
+
kwargs=kwargs,
|
|
28
|
+
)
|
|
29
|
+
async with first_task.apipeline():
|
|
30
|
+
for task in tasks:
|
|
31
|
+
task.signature_container_id = chain_task_signature.key
|
|
32
|
+
await chain_task_signature.asave()
|
|
33
|
+
|
|
34
|
+
return chain_task_signature
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import cast, Any
|
|
3
|
+
|
|
4
|
+
import rapyer
|
|
5
|
+
from pydantic import field_validator, Field, BaseModel
|
|
6
|
+
from rapyer.fields import RapyerKey
|
|
7
|
+
|
|
8
|
+
from thirdmagic.container import ContainerTaskSignature
|
|
9
|
+
from thirdmagic.errors import MissingSignatureError
|
|
10
|
+
from thirdmagic.signature.status import SignatureStatus
|
|
11
|
+
from thirdmagic.task.model import TaskSignature
|
|
12
|
+
from thirdmagic.utils import HAS_HATCHET
|
|
13
|
+
|
|
14
|
+
if HAS_HATCHET:
|
|
15
|
+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ChainTaskSignature(ContainerTaskSignature):
|
|
19
|
+
tasks: list[RapyerKey] = Field(default_factory=list)
|
|
20
|
+
|
|
21
|
+
@field_validator("tasks", mode="before")
|
|
22
|
+
@classmethod
|
|
23
|
+
def validate_tasks(cls, v: list[TaskSignature]):
|
|
24
|
+
return [cls.validate_task_key(item) for item in v]
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def task_ids(self):
|
|
28
|
+
return self.tasks
|
|
29
|
+
|
|
30
|
+
async def on_sub_task_done(self, sub_task: TaskSignature, results: Any):
|
|
31
|
+
sub_task_idx = self.tasks.index(sub_task.key)
|
|
32
|
+
# If this is the last task, activate chain success callbacks
|
|
33
|
+
if sub_task_idx == len(self.tasks) - 1:
|
|
34
|
+
await self.ClientAdapter.acall_chain_done(results, self)
|
|
35
|
+
else:
|
|
36
|
+
next_task_key = self.tasks[sub_task_idx + 1]
|
|
37
|
+
next_task = await rapyer.aget(next_task_key)
|
|
38
|
+
next_task = cast(TaskSignature, next_task)
|
|
39
|
+
await next_task.acall(results, set_return_field=True, **self.kwargs)
|
|
40
|
+
|
|
41
|
+
async def on_sub_task_error(
|
|
42
|
+
self, sub_task: TaskSignature, error: BaseException, original_msg: dict
|
|
43
|
+
):
|
|
44
|
+
await self.ClientAdapter.acall_chain_error(original_msg, error, self, sub_task)
|
|
45
|
+
|
|
46
|
+
async def sub_tasks(self) -> list[TaskSignature]:
|
|
47
|
+
sub_tasks = await rapyer.afind(*self.tasks, skip_missing=True)
|
|
48
|
+
return cast(list[TaskSignature], sub_tasks)
|
|
49
|
+
|
|
50
|
+
async def acall(self, msg: Any, set_return_field: bool = True, **kwargs):
|
|
51
|
+
first_task = await rapyer.afind_one(self.tasks[0])
|
|
52
|
+
if first_task is None:
|
|
53
|
+
raise MissingSignatureError(f"First task from chain {self.key} not found")
|
|
54
|
+
|
|
55
|
+
full_kwargs = self.kwargs | kwargs
|
|
56
|
+
return await first_task.acall(msg, set_return_field, **full_kwargs)
|
|
57
|
+
|
|
58
|
+
if HAS_HATCHET:
|
|
59
|
+
|
|
60
|
+
async def aio_run_no_wait(
|
|
61
|
+
self, msg: BaseModel, options: TriggerWorkflowOptions = None
|
|
62
|
+
):
|
|
63
|
+
return await self.acall(msg, options=options, set_return_field=False)
|
|
64
|
+
|
|
65
|
+
async def change_status(self, status: SignatureStatus):
|
|
66
|
+
pause_chain_tasks = [
|
|
67
|
+
TaskSignature.safe_change_status(task, status) for task in self.tasks
|
|
68
|
+
]
|
|
69
|
+
pause_chain = super().change_status(status)
|
|
70
|
+
await asyncio.gather(pause_chain, *pause_chain_tasks, return_exceptions=True)
|
|
71
|
+
|
|
72
|
+
async def suspend(self):
|
|
73
|
+
await asyncio.gather(
|
|
74
|
+
*[TaskSignature.suspend_from_key(task_id) for task_id in self.tasks],
|
|
75
|
+
return_exceptions=True,
|
|
76
|
+
)
|
|
77
|
+
await super().change_status(SignatureStatus.SUSPENDED)
|
|
78
|
+
|
|
79
|
+
async def interrupt(self):
|
|
80
|
+
await asyncio.gather(
|
|
81
|
+
*[TaskSignature.interrupt_from_key(task_id) for task_id in self.tasks],
|
|
82
|
+
return_exceptions=True,
|
|
83
|
+
)
|
|
84
|
+
await super().change_status(SignatureStatus.INTERRUPTED)
|
|
85
|
+
|
|
86
|
+
async def resume(self):
|
|
87
|
+
await asyncio.gather(
|
|
88
|
+
*[TaskSignature.resume_from_key(task_key) for task_key in self.tasks],
|
|
89
|
+
return_exceptions=True,
|
|
90
|
+
)
|
|
91
|
+
await super().change_status(self.task_status.last_status)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import asyncio
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from thirdmagic.clients.lifecycle import BaseLifecycle
|
|
9
|
+
from thirdmagic.task_def import MageflowTaskDefinition
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from thirdmagic.task import TaskSignature
|
|
13
|
+
from thirdmagic.signature import Signature
|
|
14
|
+
from thirdmagic.chain.model import ChainTaskSignature
|
|
15
|
+
from thirdmagic.swarm.model import SwarmTaskSignature
|
|
16
|
+
from thirdmagic.utils import HatchetTaskType
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseClientAdapter(ABC):
|
|
20
|
+
@abc.abstractmethod
|
|
21
|
+
def extract_validator(self, client_task) -> type[BaseModel]:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abc.abstractmethod
|
|
25
|
+
def extract_retries(self, client_task) -> int:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abc.abstractmethod
|
|
29
|
+
async def acall_chain_done(self, results: Any, chain: "ChainTaskSignature"):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
async def acall_chain_error(
|
|
34
|
+
self,
|
|
35
|
+
original_msg: dict,
|
|
36
|
+
error: BaseException,
|
|
37
|
+
chain: "ChainTaskSignature",
|
|
38
|
+
failed_task: "Signature",
|
|
39
|
+
):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
async def afill_swarm(
|
|
44
|
+
self, swarm: "SwarmTaskSignature", max_tasks: int = None, **kwargs
|
|
45
|
+
):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abc.abstractmethod
|
|
49
|
+
async def acall_swarm_item_error(
|
|
50
|
+
self, error: BaseException, swarm: "SwarmTaskSignature", swarm_item: "Signature"
|
|
51
|
+
):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
async def acall_swarm_item_done(
|
|
56
|
+
self, results: Any, swarm: "SwarmTaskSignature", swarm_item: "Signature"
|
|
57
|
+
):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@abc.abstractmethod
|
|
61
|
+
async def acall_signature(
|
|
62
|
+
self, signature: "TaskSignature", msg: Any, set_return_field: bool, **kwargs
|
|
63
|
+
):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
async def acall_signatures(
|
|
67
|
+
self,
|
|
68
|
+
signatures: list["Signature"],
|
|
69
|
+
msg: Any,
|
|
70
|
+
set_return_field: bool,
|
|
71
|
+
**kwargs,
|
|
72
|
+
):
|
|
73
|
+
return await asyncio.gather(
|
|
74
|
+
*[
|
|
75
|
+
signature.acall(msg, set_return_field, **kwargs)
|
|
76
|
+
for signature in signatures
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@abc.abstractmethod
|
|
81
|
+
def should_task_retry(
|
|
82
|
+
self, task_definition: MageflowTaskDefinition, attempt_num: int, e: BaseException
|
|
83
|
+
) -> bool:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
@abc.abstractmethod
|
|
87
|
+
def task_name(self, task: "HatchetTaskType") -> str:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@abc.abstractmethod
|
|
91
|
+
async def create_lifecycle(self, *args) -> BaseLifecycle:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@abc.abstractmethod
|
|
95
|
+
async def lifecycle_from_signature(self, *args) -> BaseLifecycle:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class DefaultClientAdapter(BaseClientAdapter):
|
|
100
|
+
async def acall_chain_done(self, results: Any, chain: "ChainTaskSignature"):
|
|
101
|
+
raise NotImplementedError("Set a client before we start")
|
|
102
|
+
|
|
103
|
+
async def acall_chain_error(
|
|
104
|
+
self,
|
|
105
|
+
original_msg: dict,
|
|
106
|
+
error: BaseException,
|
|
107
|
+
chain: "ChainTaskSignature",
|
|
108
|
+
failed_task: "Signature",
|
|
109
|
+
):
|
|
110
|
+
raise NotImplementedError("Set a client before we start")
|
|
111
|
+
|
|
112
|
+
async def afill_swarm(
|
|
113
|
+
self, swarm: "SwarmTaskSignature", max_tasks: int = None, **kwargs
|
|
114
|
+
):
|
|
115
|
+
raise NotImplementedError("Set a client before we start")
|
|
116
|
+
|
|
117
|
+
async def acall_swarm_item_error(
|
|
118
|
+
self, error: BaseException, swarm: "SwarmTaskSignature", swarm_item: "Signature"
|
|
119
|
+
):
|
|
120
|
+
raise NotImplementedError("Set a client before we start")
|
|
121
|
+
|
|
122
|
+
async def acall_swarm_item_done(
|
|
123
|
+
self, results: Any, swarm: "SwarmTaskSignature", swarm_item: "Signature"
|
|
124
|
+
):
|
|
125
|
+
raise NotImplementedError("Set a client before we start")
|
|
126
|
+
|
|
127
|
+
def extract_validator(self, client_task) -> type[BaseModel]:
|
|
128
|
+
raise NotImplementedError("Set a client before we start")
|
|
129
|
+
|
|
130
|
+
def extract_retries(self, client_task) -> int:
|
|
131
|
+
raise NotImplementedError("Set a client before we start")
|
|
132
|
+
|
|
133
|
+
async def acall_signature(
|
|
134
|
+
self, signature: "Signature", set_return_field: bool, **kwargs
|
|
135
|
+
):
|
|
136
|
+
raise NotImplementedError("Set a client before we start")
|
|
137
|
+
|
|
138
|
+
def should_task_retry(
|
|
139
|
+
self, task_definition: MageflowTaskDefinition, attempt_num: int, e: BaseException
|
|
140
|
+
) -> bool:
|
|
141
|
+
raise NotImplementedError("Set a client before we start")
|
|
142
|
+
|
|
143
|
+
def task_name(self, task: "HatchetTaskType") -> str:
|
|
144
|
+
raise NotImplementedError("Set a client before we start")
|
|
145
|
+
|
|
146
|
+
def create_lifecycle(self, *args) -> BaseLifecycle:
|
|
147
|
+
raise NotImplementedError("Set a client before we start")
|
|
148
|
+
|
|
149
|
+
async def lifecycle_from_signature(self, *args) -> BaseLifecycle:
|
|
150
|
+
pass
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseLifecycle(ABC):
|
|
7
|
+
@abc.abstractmethod
|
|
8
|
+
async def start_task(self):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
@abc.abstractmethod
|
|
12
|
+
async def task_success(self, result: Any):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abc.abstractmethod
|
|
16
|
+
async def task_failed(self, message: dict, error: BaseException):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
async def should_run_task(self, message: dict) -> bool:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def is_vanilla_run(self):
|
|
24
|
+
return False
|
thirdmagic/consts.py
ADDED
thirdmagic/container.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import asyncio
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from thirdmagic.signature import Signature
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContainerTaskSignature(Signature, ABC):
|
|
12
|
+
@property
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def task_ids(self):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abc.abstractmethod
|
|
18
|
+
async def sub_tasks(self) -> list[Signature]:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
async def remove_references(self):
|
|
22
|
+
sub_tasks = await self.sub_tasks()
|
|
23
|
+
await asyncio.gather(
|
|
24
|
+
*[task.remove() for task in sub_tasks], return_exceptions=True
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
async def on_sub_task_error(
|
|
29
|
+
self, sub_task: Signature, error: BaseException, original_msg: dict
|
|
30
|
+
):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
async def on_sub_task_done(self, sub_task: Signature, results: Any):
|
|
35
|
+
pass
|
thirdmagic/errors.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
class MageflowError(Exception):
|
|
2
|
+
pass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MissingSignatureError(MageflowError):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MissingSwarmItemError(MissingSignatureError):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SwarmError(MageflowError):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TooManyTasksError(SwarmError, RuntimeError):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SwarmIsCanceledError(SwarmError, RuntimeError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class UnrecognizedTaskError(MageflowError):
|
|
26
|
+
pass
|
thirdmagic/message.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Generic, TypeVar, Any, Annotated, TYPE_CHECKING, TypeAlias
|
|
3
|
+
|
|
4
|
+
DEFAULT_RESULT_NAME = "mageflow_results"
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass(frozen=True)
|
|
9
|
+
class ReturnValueAnnotation:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _ReturnValue(Generic[T]):
|
|
14
|
+
def __new__(cls, typ: Any = None):
|
|
15
|
+
if typ is None:
|
|
16
|
+
return ReturnValueAnnotation()
|
|
17
|
+
return Annotated[typ, ReturnValueAnnotation()]
|
|
18
|
+
|
|
19
|
+
def __class_getitem__(cls, item):
|
|
20
|
+
return Annotated[item, ReturnValueAnnotation()]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
ReturnValue = _ReturnValue
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
ReturnValue: TypeAlias = Annotated[T, ReturnValueAnnotation()]
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import asyncio
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Optional, Self, Any, TypeAlias, ClassVar, cast
|
|
6
|
+
|
|
7
|
+
import rapyer
|
|
8
|
+
from pydantic import BaseModel, field_validator, Field
|
|
9
|
+
from rapyer import AtomicRedisModel
|
|
10
|
+
from rapyer.config import RedisConfig
|
|
11
|
+
from rapyer.errors.base import KeyNotFound
|
|
12
|
+
from rapyer.fields import RapyerKey
|
|
13
|
+
from rapyer.types import RedisDict, RedisList, RedisDatetime
|
|
14
|
+
|
|
15
|
+
from thirdmagic.clients import BaseClientAdapter, DefaultClientAdapter
|
|
16
|
+
from thirdmagic.consts import REMOVED_TASK_TTL
|
|
17
|
+
from thirdmagic.signature.status import TaskStatus, PauseActionTypes, SignatureStatus
|
|
18
|
+
from thirdmagic.utils import HAS_HATCHET
|
|
19
|
+
|
|
20
|
+
if HAS_HATCHET:
|
|
21
|
+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Signature(AtomicRedisModel, ABC):
|
|
25
|
+
task_name: str
|
|
26
|
+
kwargs: RedisDict[Any] = Field(default_factory=dict)
|
|
27
|
+
creation_time: RedisDatetime = Field(default_factory=datetime.now)
|
|
28
|
+
success_callbacks: RedisList[RapyerKey] = Field(default_factory=list)
|
|
29
|
+
error_callbacks: RedisList[RapyerKey] = Field(default_factory=list)
|
|
30
|
+
task_status: TaskStatus = Field(default_factory=TaskStatus)
|
|
31
|
+
signature_container_id: Optional[RapyerKey] = None
|
|
32
|
+
|
|
33
|
+
Meta: ClassVar[RedisConfig] = RedisConfig(ttl=24 * 60 * 60, refresh_ttl=False)
|
|
34
|
+
ClientAdapter: ClassVar[BaseClientAdapter] = DefaultClientAdapter()
|
|
35
|
+
|
|
36
|
+
@field_validator("success_callbacks", "error_callbacks", mode="before")
|
|
37
|
+
@classmethod
|
|
38
|
+
def validate_tasks_id(cls, v: list) -> list[str]:
|
|
39
|
+
return [cls.validate_task_key(item) for item in v]
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def validate_task_key(cls, v) -> str:
|
|
43
|
+
if isinstance(v, bytes):
|
|
44
|
+
return RapyerKey(v.decode())
|
|
45
|
+
if isinstance(v, str):
|
|
46
|
+
return v
|
|
47
|
+
elif isinstance(v, Signature):
|
|
48
|
+
return v.key
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f"Expected task ID or Signature, got {type(v).__name__}")
|
|
51
|
+
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
async def acall(self, msg: Any, set_return_field: bool = True, **kwargs):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
if HAS_HATCHET:
|
|
57
|
+
|
|
58
|
+
@abc.abstractmethod
|
|
59
|
+
async def aio_run_no_wait(
|
|
60
|
+
self, msg: BaseModel, options: TriggerWorkflowOptions = None
|
|
61
|
+
):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
async def on_pause_signature(self, msg: dict):
|
|
65
|
+
await self.kwargs.aupdate(**msg)
|
|
66
|
+
|
|
67
|
+
async def on_cancel_signature(self, msg: dict):
|
|
68
|
+
await self.remove()
|
|
69
|
+
|
|
70
|
+
async def activate_success(self, msg):
|
|
71
|
+
success_signatures = await rapyer.afind(*self.success_callbacks)
|
|
72
|
+
success_signatures = cast(list[Signature], success_signatures)
|
|
73
|
+
return await self.ClientAdapter.acall_signatures(success_signatures, msg, True)
|
|
74
|
+
|
|
75
|
+
async def activate_error(self, msg):
|
|
76
|
+
error_signatures = await rapyer.afind(*self.error_callbacks)
|
|
77
|
+
error_signatures = cast(list[Signature], error_signatures)
|
|
78
|
+
return await self.ClientAdapter.acall_signatures(error_signatures, msg, False)
|
|
79
|
+
|
|
80
|
+
async def remove_task(self):
|
|
81
|
+
await self.aset_ttl(REMOVED_TASK_TTL)
|
|
82
|
+
|
|
83
|
+
async def remove_branches(self, success: bool = True, errors: bool = True):
|
|
84
|
+
keys_to_remove = []
|
|
85
|
+
if errors:
|
|
86
|
+
keys_to_remove.extend([error_id for error_id in self.error_callbacks])
|
|
87
|
+
if success:
|
|
88
|
+
keys_to_remove.extend([success_id for success_id in self.success_callbacks])
|
|
89
|
+
|
|
90
|
+
signatures = cast(list[Signature], await rapyer.afind(*keys_to_remove))
|
|
91
|
+
await asyncio.gather(*[signature.remove() for signature in signatures])
|
|
92
|
+
|
|
93
|
+
async def remove_references(self):
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
async def remove(self, with_error: bool = True, with_success: bool = True):
|
|
97
|
+
return await self._remove(with_error, with_success)
|
|
98
|
+
|
|
99
|
+
async def _remove(self, with_error: bool = True, with_success: bool = True):
|
|
100
|
+
await self.remove_branches(with_success, with_error)
|
|
101
|
+
await self.remove_references()
|
|
102
|
+
await self.remove_task()
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
async def remove_from_key(cls, task_key: RapyerKey):
|
|
106
|
+
async with rapyer.alock_from_key(task_key) as task:
|
|
107
|
+
task = cast(Signature, task)
|
|
108
|
+
return await task.remove()
|
|
109
|
+
|
|
110
|
+
async def should_run(self):
|
|
111
|
+
return self.task_status.should_run()
|
|
112
|
+
|
|
113
|
+
async def change_status(self, status: SignatureStatus):
|
|
114
|
+
await self.task_status.aupdate(
|
|
115
|
+
last_status=self.task_status.status, status=status
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# When pausing task from outside the task
|
|
119
|
+
@classmethod
|
|
120
|
+
async def safe_change_status(cls, task_id: RapyerKey, status: SignatureStatus):
|
|
121
|
+
try:
|
|
122
|
+
async with rapyer.alock_from_key(task_id) as task:
|
|
123
|
+
task = cast(Signature, task)
|
|
124
|
+
return await task.change_status(status)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
async def resume_from_key(cls, task_key: RapyerKey):
|
|
130
|
+
async with rapyer.alock_from_key(task_key) as task:
|
|
131
|
+
task = cast(Signature, task)
|
|
132
|
+
await task.resume()
|
|
133
|
+
|
|
134
|
+
@abc.abstractmethod
|
|
135
|
+
async def resume(self):
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
async def suspend_from_key(cls, task_key: RapyerKey):
|
|
140
|
+
async with rapyer.alock_from_key(task_key) as task:
|
|
141
|
+
task = cast(Signature, task)
|
|
142
|
+
await task.suspend()
|
|
143
|
+
|
|
144
|
+
async def done(self):
|
|
145
|
+
await self.task_status.aupdate(
|
|
146
|
+
last_status=self.task_status.status, status=SignatureStatus.DONE
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
async def failed(self):
|
|
150
|
+
await self.task_status.aupdate(
|
|
151
|
+
last_status=self.task_status.status, status=SignatureStatus.FAILED
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
async def suspend(self):
|
|
155
|
+
"""
|
|
156
|
+
Task suspension will try and stop the task at before it starts
|
|
157
|
+
"""
|
|
158
|
+
await self.change_status(SignatureStatus.SUSPENDED)
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
async def interrupt_from_key(cls, task_key: RapyerKey):
|
|
162
|
+
async with rapyer.alock_from_key(task_key) as task:
|
|
163
|
+
task = cast(Signature, task)
|
|
164
|
+
return task.interrupt()
|
|
165
|
+
|
|
166
|
+
async def interrupt(self):
|
|
167
|
+
"""
|
|
168
|
+
Task interrupt will try to aggressively take hold of the async loop and stop the task
|
|
169
|
+
"""
|
|
170
|
+
# TODO - not implemented yet - implement
|
|
171
|
+
await self.suspend()
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
async def pause_from_key(
|
|
175
|
+
cls,
|
|
176
|
+
task_key: RapyerKey,
|
|
177
|
+
pause_type: PauseActionTypes = PauseActionTypes.SUSPEND,
|
|
178
|
+
):
|
|
179
|
+
async with rapyer.alock_from_key(task_key) as task:
|
|
180
|
+
task = cast(Signature, task)
|
|
181
|
+
await task.pause_task(pause_type)
|
|
182
|
+
|
|
183
|
+
async def pause_task(self, pause_type: PauseActionTypes = PauseActionTypes.SUSPEND):
|
|
184
|
+
if pause_type == PauseActionTypes.SUSPEND:
|
|
185
|
+
return await self.suspend()
|
|
186
|
+
elif pause_type == PauseActionTypes.INTERRUPT:
|
|
187
|
+
return await self.interrupt()
|
|
188
|
+
raise NotImplementedError(f"Pause type {pause_type} not supported")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
TaskInputType: TypeAlias = RapyerKey | Signature
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import ClassVar
|
|
3
|
+
|
|
4
|
+
from rapyer import AtomicRedisModel
|
|
5
|
+
from rapyer.config import RedisConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SignatureStatus(str, Enum):
|
|
9
|
+
PENDING = "pending"
|
|
10
|
+
ACTIVE = "active"
|
|
11
|
+
FAILED = "failed"
|
|
12
|
+
DONE = "done"
|
|
13
|
+
SUSPENDED = "suspended"
|
|
14
|
+
INTERRUPTED = "interrupted"
|
|
15
|
+
CANCELED = "canceled"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PauseActionTypes(str, Enum):
|
|
19
|
+
SUSPEND = "soft"
|
|
20
|
+
INTERRUPT = "hard"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TaskStatus(AtomicRedisModel):
|
|
24
|
+
status: SignatureStatus = SignatureStatus.PENDING
|
|
25
|
+
last_status: SignatureStatus = SignatureStatus.PENDING
|
|
26
|
+
Meta: ClassVar[RedisConfig] = RedisConfig(ttl=24 * 60 * 60, refresh_ttl=False)
|
|
27
|
+
|
|
28
|
+
def is_canceled(self):
|
|
29
|
+
return self.status in [SignatureStatus.CANCELED]
|
|
30
|
+
|
|
31
|
+
def should_run(self):
|
|
32
|
+
return self.status in [SignatureStatus.PENDING, SignatureStatus.ACTIVE]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
SWARM_MESSAGE_PARAM_NAME = "__mageflow_swarm_message__"
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from typing import overload, Any
|
|
3
|
+
|
|
4
|
+
import rapyer
|
|
5
|
+
|
|
6
|
+
from thirdmagic.task.creator import TaskSignatureOptions, TaskSignatureConvertible
|
|
7
|
+
from thirdmagic.swarm.model import SwarmConfig, SwarmTaskSignature
|
|
8
|
+
from thirdmagic.swarm.state import PublishState
|
|
9
|
+
from thirdmagic.typing_support import Unpack
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SignatureOptions(TaskSignatureOptions):
|
|
13
|
+
is_swarm_closed: bool
|
|
14
|
+
config: SwarmConfig
|
|
15
|
+
task_kwargs: dict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@overload
|
|
19
|
+
async def swarm(
|
|
20
|
+
tasks: list[TaskSignatureConvertible],
|
|
21
|
+
task_name: str = None,
|
|
22
|
+
**options: Unpack[TaskSignatureOptions],
|
|
23
|
+
) -> SwarmTaskSignature: ...
|
|
24
|
+
@overload
|
|
25
|
+
async def swarm(
|
|
26
|
+
tasks: list[TaskSignatureConvertible], task_name: str = None, **options: Any
|
|
27
|
+
) -> SwarmTaskSignature: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def swarm(
|
|
31
|
+
tasks: list[TaskSignatureConvertible] = None,
|
|
32
|
+
task_name: str = None,
|
|
33
|
+
**options: Unpack[SignatureOptions],
|
|
34
|
+
) -> SwarmTaskSignature:
|
|
35
|
+
tasks = tasks or []
|
|
36
|
+
task_name = task_name or f"swarm-task-{uuid.uuid4()}"
|
|
37
|
+
publish_state = PublishState()
|
|
38
|
+
model_fields = list(SwarmTaskSignature.model_fields.keys())
|
|
39
|
+
direct_kwargs_param = options.pop("kwargs", {})
|
|
40
|
+
kwargs = {
|
|
41
|
+
field_name: options.pop(field_name)
|
|
42
|
+
for field_name in model_fields
|
|
43
|
+
if field_name in options
|
|
44
|
+
}
|
|
45
|
+
swarm_signature = SwarmTaskSignature(
|
|
46
|
+
**kwargs,
|
|
47
|
+
task_name=task_name,
|
|
48
|
+
publishing_state_id=publish_state.key,
|
|
49
|
+
kwargs=direct_kwargs_param | options,
|
|
50
|
+
)
|
|
51
|
+
await rapyer.ainsert(publish_state, swarm_signature)
|
|
52
|
+
await swarm_signature.add_tasks(tasks)
|
|
53
|
+
return swarm_signature
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Self, Any, Optional, cast
|
|
3
|
+
|
|
4
|
+
import rapyer
|
|
5
|
+
from pydantic import Field, field_validator, BaseModel
|
|
6
|
+
from rapyer import AtomicRedisModel
|
|
7
|
+
from rapyer.fields import RapyerKey
|
|
8
|
+
from rapyer.types import RedisList, RedisInt
|
|
9
|
+
|
|
10
|
+
from thirdmagic.consts import REMOVED_TASK_TTL
|
|
11
|
+
from thirdmagic.container import ContainerTaskSignature
|
|
12
|
+
from thirdmagic.errors import TooManyTasksError, SwarmIsCanceledError
|
|
13
|
+
from thirdmagic.signature import Signature
|
|
14
|
+
from thirdmagic.signature.status import SignatureStatus
|
|
15
|
+
from thirdmagic.swarm.consts import SWARM_MESSAGE_PARAM_NAME
|
|
16
|
+
from thirdmagic.swarm.state import PublishState
|
|
17
|
+
from thirdmagic.task.creator import TaskSignatureConvertible, resolve_signatures
|
|
18
|
+
from thirdmagic.task.model import TaskSignature
|
|
19
|
+
from thirdmagic.utils import HAS_HATCHET
|
|
20
|
+
|
|
21
|
+
if HAS_HATCHET:
|
|
22
|
+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
|
|
23
|
+
from hatchet_sdk.runnables.types import EmptyModel
|
|
24
|
+
from hatchet_sdk.runnables.workflow import TaskRunRef
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SwarmConfig(AtomicRedisModel):
|
|
28
|
+
max_concurrency: int = 30
|
|
29
|
+
stop_after_n_failures: Optional[int] = Field(default=None, gt=0)
|
|
30
|
+
max_task_allowed: Optional[int] = Field(default=None, gt=0)
|
|
31
|
+
send_swarm_message_to_return_field: bool = False
|
|
32
|
+
|
|
33
|
+
def can_add_task(self, swarm: "SwarmTaskSignature") -> bool:
|
|
34
|
+
return self.can_add_n_tasks(swarm, 1)
|
|
35
|
+
|
|
36
|
+
def can_add_n_tasks(self, swarm: "SwarmTaskSignature", n: int) -> bool:
|
|
37
|
+
if self.max_task_allowed is None:
|
|
38
|
+
return True
|
|
39
|
+
return len(swarm.tasks) + n <= self.max_task_allowed
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SwarmTaskSignature(ContainerTaskSignature):
|
|
43
|
+
# TODO - TASKS list should be set once we enable this in rapyer
|
|
44
|
+
tasks: RedisList[RapyerKey] = Field(default_factory=list)
|
|
45
|
+
tasks_left_to_run: RedisList[RapyerKey] = Field(default_factory=list)
|
|
46
|
+
finished_tasks: RedisList[RapyerKey] = Field(default_factory=list)
|
|
47
|
+
failed_tasks: RedisList[RapyerKey] = Field(default_factory=list)
|
|
48
|
+
tasks_results: RedisList[Any] = Field(default_factory=list)
|
|
49
|
+
# This flag is raised when no more tasks can be added to the swarm
|
|
50
|
+
is_swarm_closed: bool = False
|
|
51
|
+
# How many tasks can be added to the swarm at a time
|
|
52
|
+
current_running_tasks: RedisInt = 0
|
|
53
|
+
publishing_state_id: str
|
|
54
|
+
config: SwarmConfig = Field(default_factory=SwarmConfig)
|
|
55
|
+
|
|
56
|
+
@field_validator(
|
|
57
|
+
"tasks", "tasks_left_to_run", "finished_tasks", "failed_tasks", mode="before"
|
|
58
|
+
)
|
|
59
|
+
@classmethod
|
|
60
|
+
def validate_tasks(cls, v):
|
|
61
|
+
return [cls.validate_task_key(item) for item in v]
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def task_ids(self):
|
|
65
|
+
return self.tasks
|
|
66
|
+
|
|
67
|
+
async def sub_tasks(self) -> list[TaskSignature]:
|
|
68
|
+
tasks = await rapyer.afind(*self.tasks)
|
|
69
|
+
return cast(list[TaskSignature], tasks)
|
|
70
|
+
|
|
71
|
+
async def on_sub_task_done(self, sub_task: TaskSignature, results: Any):
|
|
72
|
+
await self.ClientAdapter.acall_swarm_item_done(results, self, sub_task)
|
|
73
|
+
|
|
74
|
+
async def on_sub_task_error(
|
|
75
|
+
self, sub_task: TaskSignature, error: BaseException, original_msg: BaseModel
|
|
76
|
+
):
|
|
77
|
+
await self.ClientAdapter.acall_swarm_item_error(error, self, sub_task)
|
|
78
|
+
|
|
79
|
+
async def acall(self, msg: Any, set_return_field: bool = True, **kwargs):
|
|
80
|
+
# We update the kwargs that everyone are using, we also tell weather we should put this in the Return value or just in the message
|
|
81
|
+
async with self.apipeline():
|
|
82
|
+
self.kwargs.update(**{SWARM_MESSAGE_PARAM_NAME: msg})
|
|
83
|
+
self.config.send_swarm_message_to_return_field = set_return_field
|
|
84
|
+
return await self.ClientAdapter.afill_swarm(self, **kwargs)
|
|
85
|
+
|
|
86
|
+
if HAS_HATCHET:
|
|
87
|
+
|
|
88
|
+
async def aio_run_no_wait(
|
|
89
|
+
self, msg: BaseModel, options: "TriggerWorkflowOptions" = None
|
|
90
|
+
):
|
|
91
|
+
return await self.acall(
|
|
92
|
+
msg.model_dump(mode="json", exclude_unset=True),
|
|
93
|
+
set_return_field=False,
|
|
94
|
+
options=options,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
async def aio_run_in_swarm(
|
|
98
|
+
self,
|
|
99
|
+
task: TaskSignatureConvertible,
|
|
100
|
+
msg: BaseModel,
|
|
101
|
+
options: TriggerWorkflowOptions = None,
|
|
102
|
+
close_on_max_task: bool = True,
|
|
103
|
+
) -> Optional["TaskRunRef"]:
|
|
104
|
+
sub_task = await self.add_task(task, close_on_max_task)
|
|
105
|
+
await sub_task.kwargs.aupdate(**msg.model_dump(mode="json"))
|
|
106
|
+
return await self.ClientAdapter.afill_swarm(
|
|
107
|
+
self, max_tasks=1, options=options
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
async def change_status(self, status: SignatureStatus):
|
|
111
|
+
paused_chain_tasks = [
|
|
112
|
+
TaskSignature.safe_change_status(task, status) for task in self.tasks
|
|
113
|
+
]
|
|
114
|
+
pause_chain = super().change_status(status)
|
|
115
|
+
await asyncio.gather(pause_chain, *paused_chain_tasks, return_exceptions=True)
|
|
116
|
+
|
|
117
|
+
async def add_tasks(
|
|
118
|
+
self, tasks: list[TaskSignatureConvertible], close_on_max_task: bool = True
|
|
119
|
+
) -> list[Signature]:
|
|
120
|
+
"""
|
|
121
|
+
tasks - tasks signature to add to swarm
|
|
122
|
+
close_on_max_task - if true, and you set max task allowed on swarm, this swarm will close if the task reached maximum capacity
|
|
123
|
+
"""
|
|
124
|
+
if not self.config.can_add_n_tasks(self, len(tasks)):
|
|
125
|
+
raise TooManyTasksError(
|
|
126
|
+
f"Swarm {self.task_name} has reached max tasks limit"
|
|
127
|
+
)
|
|
128
|
+
if self.task_status.is_canceled():
|
|
129
|
+
raise SwarmIsCanceledError(
|
|
130
|
+
f"Swarm {self.task_name} is {self.task_status} - can't add task"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
tasks = await resolve_signatures(tasks)
|
|
134
|
+
task_keys = [task.key for task in tasks]
|
|
135
|
+
|
|
136
|
+
async with self.apipeline():
|
|
137
|
+
for task in tasks:
|
|
138
|
+
task.signature_container_id = self.key
|
|
139
|
+
self.tasks.extend(task_keys)
|
|
140
|
+
self.tasks_left_to_run.extend(task_keys)
|
|
141
|
+
|
|
142
|
+
if close_on_max_task and not self.config.can_add_task(self):
|
|
143
|
+
await self.close_swarm()
|
|
144
|
+
|
|
145
|
+
return tasks
|
|
146
|
+
|
|
147
|
+
async def add_task(
|
|
148
|
+
self, task: TaskSignatureConvertible, close_on_max_task: bool = True
|
|
149
|
+
) -> Signature:
|
|
150
|
+
"""
|
|
151
|
+
task - task signature to add to swarm
|
|
152
|
+
close_on_max_task - if true, and you set max task allowed on swarm, this swarm will close if the task reached maximum capacity
|
|
153
|
+
"""
|
|
154
|
+
added_tasks = await self.add_tasks([task], close_on_max_task)
|
|
155
|
+
return added_tasks[0]
|
|
156
|
+
|
|
157
|
+
async def is_swarm_done(self):
|
|
158
|
+
done_tasks = self.finished_tasks + self.failed_tasks
|
|
159
|
+
finished_all_tasks = set(done_tasks) == set(self.tasks)
|
|
160
|
+
return self.is_swarm_closed and finished_all_tasks
|
|
161
|
+
|
|
162
|
+
def has_published_callback(self):
|
|
163
|
+
return self.task_status.status == SignatureStatus.DONE
|
|
164
|
+
|
|
165
|
+
def has_published_errors(self):
|
|
166
|
+
return self.task_status.status == SignatureStatus.FAILED
|
|
167
|
+
|
|
168
|
+
async def activate_success(self, msg):
|
|
169
|
+
results = await self.tasks_results.aload()
|
|
170
|
+
tasks_results = [res for res in results]
|
|
171
|
+
|
|
172
|
+
await super().activate_success(tasks_results)
|
|
173
|
+
await self.remove_branches(success=False)
|
|
174
|
+
await self.remove_task()
|
|
175
|
+
|
|
176
|
+
async def suspend(self):
|
|
177
|
+
await asyncio.gather(
|
|
178
|
+
*[TaskSignature.suspend_from_key(swarm_id) for swarm_id in self.tasks],
|
|
179
|
+
return_exceptions=True,
|
|
180
|
+
)
|
|
181
|
+
await super().change_status(SignatureStatus.SUSPENDED)
|
|
182
|
+
|
|
183
|
+
async def resume(self):
|
|
184
|
+
await asyncio.gather(
|
|
185
|
+
*[TaskSignature.resume_from_key(task_id) for task_id in self.tasks],
|
|
186
|
+
return_exceptions=True,
|
|
187
|
+
)
|
|
188
|
+
await super().change_status(self.task_status.last_status)
|
|
189
|
+
|
|
190
|
+
async def close_swarm(self) -> Self:
|
|
191
|
+
await self.aupdate(is_swarm_closed=True)
|
|
192
|
+
await self.ClientAdapter.afill_swarm(self, max_tasks=0)
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
def has_swarm_failed(self):
|
|
196
|
+
should_stop_after_failures = self.config.stop_after_n_failures is not None
|
|
197
|
+
stop_after_n_failures = self.config.stop_after_n_failures or 0
|
|
198
|
+
too_many_errors = len(self.failed_tasks) >= stop_after_n_failures
|
|
199
|
+
return should_stop_after_failures and too_many_errors
|
|
200
|
+
|
|
201
|
+
async def finish_task(self, task_key: str, results: Any):
|
|
202
|
+
async with self.apipeline() as swarm_task:
|
|
203
|
+
# In case this was already updated
|
|
204
|
+
if task_key in swarm_task.finished_tasks:
|
|
205
|
+
return
|
|
206
|
+
swarm_task.finished_tasks.append(task_key)
|
|
207
|
+
swarm_task.tasks_results.append(results)
|
|
208
|
+
swarm_task.current_running_tasks -= 1
|
|
209
|
+
|
|
210
|
+
async def task_failed(self, task_key: str):
|
|
211
|
+
async with self.apipeline() as swarm_task:
|
|
212
|
+
if task_key in swarm_task.failed_tasks:
|
|
213
|
+
return
|
|
214
|
+
swarm_task.failed_tasks.append(task_key)
|
|
215
|
+
swarm_task.current_running_tasks -= 1
|
|
216
|
+
|
|
217
|
+
async def remove_task(self):
|
|
218
|
+
publish_state = await PublishState.aget(self.publishing_state_id)
|
|
219
|
+
async with self.apipeline():
|
|
220
|
+
# TODO - this should be removed once we use foreign key
|
|
221
|
+
await publish_state.aset_ttl(REMOVED_TASK_TTL)
|
|
222
|
+
return await super().remove_task()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import ClassVar
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
from rapyer import AtomicRedisModel
|
|
5
|
+
from rapyer.config import RedisConfig
|
|
6
|
+
from rapyer.fields import RapyerKey
|
|
7
|
+
from rapyer.types import RedisList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PublishState(AtomicRedisModel):
|
|
11
|
+
task_ids: RedisList[RapyerKey] = Field(default_factory=list)
|
|
12
|
+
|
|
13
|
+
Meta: ClassVar[RedisConfig] = RedisConfig(ttl=24 * 60 * 60, refresh_ttl=False)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from thirdmagic.signature.model import TaskInputType
|
|
2
|
+
from thirdmagic.signature.status import PauseActionTypes, TaskStatus, SignatureStatus
|
|
3
|
+
from thirdmagic.task.creator import (
|
|
4
|
+
sign,
|
|
5
|
+
resolve_signatures,
|
|
6
|
+
resolve_signature,
|
|
7
|
+
TaskSignatureConvertible,
|
|
8
|
+
)
|
|
9
|
+
from thirdmagic.task.model import TaskSignature
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"sign",
|
|
14
|
+
"resolve_signatures",
|
|
15
|
+
"resolve_signature",
|
|
16
|
+
"TaskSignatureConvertible",
|
|
17
|
+
"TaskSignature",
|
|
18
|
+
"SignatureStatus",
|
|
19
|
+
"TaskStatus",
|
|
20
|
+
"PauseActionTypes",
|
|
21
|
+
"TaskInputType",
|
|
22
|
+
]
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import TypedDict, Any, overload, TypeAlias, Optional
|
|
3
|
+
|
|
4
|
+
import rapyer
|
|
5
|
+
from rapyer.fields import RapyerKey
|
|
6
|
+
|
|
7
|
+
from thirdmagic.signature import Signature
|
|
8
|
+
from thirdmagic.signature.status import TaskStatus
|
|
9
|
+
from thirdmagic.task.model import TaskSignature
|
|
10
|
+
from thirdmagic.typing_support import Unpack
|
|
11
|
+
from thirdmagic.utils import HatchetTaskType
|
|
12
|
+
|
|
13
|
+
TaskSignatureConvertible: TypeAlias = RapyerKey | Signature | HatchetTaskType | str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def resolve_signatures(
|
|
17
|
+
tasks: list[TaskSignatureConvertible],
|
|
18
|
+
) -> list[Signature]:
|
|
19
|
+
result: list[Optional[Signature]] = [None] * len(tasks)
|
|
20
|
+
identifier_entries: list[tuple[int, RapyerKey]] = []
|
|
21
|
+
hatchet_entries: list[tuple[int, HatchetTaskType]] = []
|
|
22
|
+
task_names: list[tuple[int, str]] = []
|
|
23
|
+
|
|
24
|
+
for i, task in enumerate(tasks):
|
|
25
|
+
if isinstance(task, Signature):
|
|
26
|
+
result[i] = task
|
|
27
|
+
elif isinstance(task, RapyerKey):
|
|
28
|
+
identifier_entries.append((i, task))
|
|
29
|
+
elif isinstance(task, str):
|
|
30
|
+
task_names.append((i, task))
|
|
31
|
+
else:
|
|
32
|
+
hatchet_entries.append((i, task))
|
|
33
|
+
|
|
34
|
+
if identifier_entries:
|
|
35
|
+
keys = [key for _, key in identifier_entries]
|
|
36
|
+
found = await rapyer.afind(*keys, skip_missing=True)
|
|
37
|
+
found_by_key = {sig.key: sig for sig in found}
|
|
38
|
+
for i, key in identifier_entries:
|
|
39
|
+
result[i] = found_by_key.get(key)
|
|
40
|
+
|
|
41
|
+
if hatchet_entries:
|
|
42
|
+
async with rapyer.apipeline():
|
|
43
|
+
for i, task in hatchet_entries:
|
|
44
|
+
result[i] = await TaskSignature.from_task(task)
|
|
45
|
+
|
|
46
|
+
if task_names:
|
|
47
|
+
async with rapyer.apipeline():
|
|
48
|
+
for i, task_name in task_names:
|
|
49
|
+
result[i] = await TaskSignature.from_task_name(task_name)
|
|
50
|
+
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def resolve_signature(task: TaskSignatureConvertible) -> Signature:
|
|
55
|
+
signatures = await resolve_signatures([task])
|
|
56
|
+
return signatures[0]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TaskSignatureOptions(TypedDict, total=False):
|
|
60
|
+
kwargs: dict
|
|
61
|
+
creation_time: datetime
|
|
62
|
+
model_validators: Any
|
|
63
|
+
success_callbacks: list[RapyerKey]
|
|
64
|
+
error_callbacks: list[RapyerKey]
|
|
65
|
+
task_status: TaskStatus
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@overload
|
|
69
|
+
async def sign(
|
|
70
|
+
task: str | HatchetTaskType, **options: Unpack[TaskSignatureOptions]
|
|
71
|
+
) -> TaskSignature: ...
|
|
72
|
+
@overload
|
|
73
|
+
async def sign(task: str | HatchetTaskType, **options: Any) -> TaskSignature: ...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def sign(task: str | HatchetTaskType, **options: Any) -> TaskSignature:
|
|
77
|
+
model_fields = list(TaskSignature.model_fields.keys())
|
|
78
|
+
kwargs = {
|
|
79
|
+
field_name: options.pop(field_name)
|
|
80
|
+
for field_name in model_fields
|
|
81
|
+
if field_name in options
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
if isinstance(task, str):
|
|
85
|
+
return await TaskSignature.from_task_name(task, kwargs=options, **kwargs)
|
|
86
|
+
else:
|
|
87
|
+
return await TaskSignature.from_task(task, kwargs=options, **kwargs)
|
thirdmagic/task/model.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Optional, Self, Any
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from rapyer.fields import SafeLoad, RapyerKey
|
|
5
|
+
|
|
6
|
+
from thirdmagic.errors import UnrecognizedTaskError
|
|
7
|
+
from thirdmagic.message import DEFAULT_RESULT_NAME
|
|
8
|
+
from thirdmagic.signature import Signature
|
|
9
|
+
from thirdmagic.signature.status import SignatureStatus
|
|
10
|
+
from thirdmagic.task_def import MageflowTaskDefinition
|
|
11
|
+
from thirdmagic.utils import return_value_field, HAS_HATCHET, HatchetTaskType
|
|
12
|
+
|
|
13
|
+
if HAS_HATCHET:
|
|
14
|
+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TaskSignature(Signature):
|
|
18
|
+
model_validators: SafeLoad[Optional[type[BaseModel]]] = None
|
|
19
|
+
return_field_name: str = DEFAULT_RESULT_NAME
|
|
20
|
+
worker_task_id: str = ""
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
async def from_task(
|
|
24
|
+
cls,
|
|
25
|
+
task: HatchetTaskType,
|
|
26
|
+
success_callbacks: list[RapyerKey | Self] = None,
|
|
27
|
+
error_callbacks: list[RapyerKey | Self] = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
) -> Self:
|
|
30
|
+
validator = cls.ClientAdapter.extract_validator(task)
|
|
31
|
+
return_field_name = return_value_field(validator)
|
|
32
|
+
signature = cls(
|
|
33
|
+
task_name=cls.ClientAdapter.task_name(task),
|
|
34
|
+
model_validators=validator,
|
|
35
|
+
return_field_name=return_field_name,
|
|
36
|
+
success_callbacks=success_callbacks or [],
|
|
37
|
+
error_callbacks=error_callbacks or [],
|
|
38
|
+
**kwargs,
|
|
39
|
+
)
|
|
40
|
+
await signature.asave()
|
|
41
|
+
return signature
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
async def from_task_name(
|
|
45
|
+
cls, task_name: str, model_validators: type[BaseModel] = None, **kwargs
|
|
46
|
+
) -> Self:
|
|
47
|
+
if not model_validators:
|
|
48
|
+
task_def = await MageflowTaskDefinition.afind_one(task_name)
|
|
49
|
+
if not task_def:
|
|
50
|
+
raise UnrecognizedTaskError(f"Task {task_name} was not initialized")
|
|
51
|
+
model_validators = task_def.input_validator
|
|
52
|
+
task_name = task_def.mageflow_task_name if task_def else task_name
|
|
53
|
+
return_field_name = return_value_field(model_validators)
|
|
54
|
+
|
|
55
|
+
signature = cls(
|
|
56
|
+
task_name=task_name,
|
|
57
|
+
return_field_name=return_field_name,
|
|
58
|
+
model_validators=model_validators,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
61
|
+
await signature.asave()
|
|
62
|
+
return signature
|
|
63
|
+
|
|
64
|
+
async def acall(self, msg: Any, set_return_field: bool = True, **kwargs):
|
|
65
|
+
return await self.ClientAdapter.acall_signature(
|
|
66
|
+
self, msg, set_return_field, **kwargs
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if HAS_HATCHET:
|
|
70
|
+
|
|
71
|
+
async def aio_run_no_wait(
|
|
72
|
+
self, msg: BaseModel, options: TriggerWorkflowOptions = None
|
|
73
|
+
):
|
|
74
|
+
params = dict(options=options) if options else {}
|
|
75
|
+
return await self.acall(msg, set_return_field=False, **params)
|
|
76
|
+
|
|
77
|
+
async def resume(self):
|
|
78
|
+
last_status = self.task_status.last_status
|
|
79
|
+
if last_status == SignatureStatus.ACTIVE:
|
|
80
|
+
await self.change_status(SignatureStatus.PENDING)
|
|
81
|
+
await self.ClientAdapter.acall_signature(self, None, set_return_field=False)
|
|
82
|
+
else:
|
|
83
|
+
await self.change_status(last_status)
|
thirdmagic/task_def.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from rapyer import AtomicRedisModel
|
|
5
|
+
from rapyer.fields import Key
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MageflowTaskDefinition(AtomicRedisModel):
|
|
9
|
+
mageflow_task_name: Key[str]
|
|
10
|
+
task_name: str
|
|
11
|
+
input_validator: Optional[type[BaseModel]] = None
|
|
12
|
+
retries: Optional[int] = None
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# For python 310
|
|
2
|
+
try:
|
|
3
|
+
from typing import Self
|
|
4
|
+
except ImportError:
|
|
5
|
+
from typing_extensions import Self
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
# Python 3.12+
|
|
10
|
+
from typing import Unpack
|
|
11
|
+
except ImportError:
|
|
12
|
+
# Older Python versions
|
|
13
|
+
from typing_extensions import Unpack
|
|
14
|
+
|
|
15
|
+
__all__ = ["Self", "Unpack"]
|
thirdmagic/utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import TypeVar, get_type_hints, Optional, Callable, Any
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from thirdmagic.message import ReturnValueAnnotation, DEFAULT_RESULT_NAME
|
|
7
|
+
|
|
8
|
+
PropType = TypeVar("PropType", bound=dataclasses.dataclass)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_marked_fields(
|
|
12
|
+
model: type[BaseModel], mark_type: type[PropType]
|
|
13
|
+
) -> list[tuple[PropType, str]]:
|
|
14
|
+
hints = get_type_hints(model, include_extras=True)
|
|
15
|
+
marked = []
|
|
16
|
+
for field_name, annotated_type in hints.items():
|
|
17
|
+
if hasattr(annotated_type, "__metadata__"): # Annotated stores extras here
|
|
18
|
+
for meta in annotated_type.__metadata__:
|
|
19
|
+
if isinstance(meta, mark_type):
|
|
20
|
+
marked.append((meta, field_name))
|
|
21
|
+
return marked
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def return_value_field(model_validators: type[BaseModel]) -> Optional[str]:
|
|
25
|
+
try:
|
|
26
|
+
marked_field = get_marked_fields(model_validators, ReturnValueAnnotation)
|
|
27
|
+
return_field_name = marked_field[0][1]
|
|
28
|
+
except (IndexError, TypeError):
|
|
29
|
+
return_field_name = None
|
|
30
|
+
return return_field_name or DEFAULT_RESULT_NAME
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def deep_merge(base: dict, updates: dict) -> dict:
|
|
34
|
+
results = base.copy()
|
|
35
|
+
for key, value in updates.items():
|
|
36
|
+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
|
37
|
+
results[key] = deep_merge(base[key], value)
|
|
38
|
+
else:
|
|
39
|
+
results[key] = value
|
|
40
|
+
return results
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Which client is installed
|
|
44
|
+
HatchetTaskType = Callable
|
|
45
|
+
try:
|
|
46
|
+
HAS_HATCHET = True
|
|
47
|
+
from hatchet_sdk.runnables.workflow import BaseWorkflow
|
|
48
|
+
|
|
49
|
+
HatchetTaskType = HatchetTaskType | Callable
|
|
50
|
+
except ImportError:
|
|
51
|
+
HAS_HATCHET = False
|
|
52
|
+
|
|
53
|
+
# try:
|
|
54
|
+
# HAS_TEMPORAL = True
|
|
55
|
+
# HatchetTaskType = None
|
|
56
|
+
# except ImportError:
|
|
57
|
+
# HAS_TEMPORAL = False
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: thirdmagic
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Core models and signatures for mageflow task orchestration
|
|
5
|
+
Author-email: imaginary-cherry <yedidyakfir@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: <3.14,>=3.10
|
|
8
|
+
Requires-Dist: pydantic<3.0.0,>=2.0.0
|
|
9
|
+
Requires-Dist: rapyer<1.3.0,>=1.2.3
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: black>=26.1.0; extra == 'dev'
|
|
12
|
+
Requires-Dist: coverage[toml]<8.0.0,>=7.0.0; extra == 'dev'
|
|
13
|
+
Requires-Dist: fakeredis[json,lua]<3.0.0,>=2.32.1; extra == 'dev'
|
|
14
|
+
Requires-Dist: hatchet-sdk>=1.23.0; extra == 'dev'
|
|
15
|
+
Requires-Dist: pytest-asyncio<2.0.0,>=1.2.0; extra == 'dev'
|
|
16
|
+
Requires-Dist: pytest<10.0.0,>=9.0.2; extra == 'dev'
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
thirdmagic/__init__.py,sha256=XcEm5efUwT6XeX85pKBo6mEia0Gc4nw5i7ft7Py8TAM,165
|
|
2
|
+
thirdmagic/consts.py,sha256=DvFDNOmi0LpHpi2IArv19NQvLcf7mTgnpqigLqw2P4c,95
|
|
3
|
+
thirdmagic/container.py,sha256=738Yx5aG_1Z7puDRMKkOg3AYFStE7KagjQXvIdz1Pbc,819
|
|
4
|
+
thirdmagic/errors.py,sha256=DfSAovxzKWiZ88eFga9CVD1zhkTOrLpIbeLIbDHFWBA,385
|
|
5
|
+
thirdmagic/message.py,sha256=PS-kbiWeX-BtGxDyokGD-e5pDoAhw-NrdTBtgOLLrsU,637
|
|
6
|
+
thirdmagic/task_def.py,sha256=Exm85y-FdIcKwehrZvHII3VkX-s-DgJLM4MLr1KitqI,316
|
|
7
|
+
thirdmagic/typing_support.py,sha256=sPcXb_BPNOhYwDclrJMv4avgHIKR-F8_LDKSMJwx2VI,284
|
|
8
|
+
thirdmagic/utils.py,sha256=HZaPmu6VWDUG6zhRKgCfNXIF9MpIFysr3777dKVlse0,1760
|
|
9
|
+
thirdmagic/chain/__init__.py,sha256=mLiGXAJtW8-XD_lPmvgCIonr5EAe_Lops0c-HwqH1LM,140
|
|
10
|
+
thirdmagic/chain/creator.py,sha256=8DobAR5qjyGrO3-8Vo-E3w7rFhQZ-euqsdWq7amvfg8,1202
|
|
11
|
+
thirdmagic/chain/model.py,sha256=7RYTN-jJyutZG4CiyryuLd6O7b5VFcA2oWX_8YuMdQc,3486
|
|
12
|
+
thirdmagic/clients/__init__.py,sha256=CQpt8cqfKp_pLzjpWSzwL5vWZy61cB0fJLaQxFtsrkY,133
|
|
13
|
+
thirdmagic/clients/base.py,sha256=1IctGFIkhWeDkdQlGchT3Lw3jTTQe_FKn4vMsz7OppM,4500
|
|
14
|
+
thirdmagic/clients/lifecycle.py,sha256=Kl-sHbXeg5b9jXRzRyqhhX7etZvihAa1kRzWGmUPmtk,494
|
|
15
|
+
thirdmagic/signature/__init__.py,sha256=klbJzw6XbxHb4DWxwcLW01MvdDex4eo5yYxbNbeduu8,213
|
|
16
|
+
thirdmagic/signature/model.py,sha256=K-Wi3bM_YOkTeVEhippcVGsUykybH-V2spREWvomjCw,6895
|
|
17
|
+
thirdmagic/signature/status.py,sha256=rK4ipp9bBtfWUFBCHhkvISZpBT28e502y8NqudACWz8,840
|
|
18
|
+
thirdmagic/swarm/__init__.py,sha256=G8KA4c0ciedzvV4JlRKQ5oA21oKp4J60Cywfw1AMLUU,232
|
|
19
|
+
thirdmagic/swarm/consts.py,sha256=O-bKrVnzCxw4tETvArmSxevLAfZjs3Mv-h4lnvoq2OY,56
|
|
20
|
+
thirdmagic/swarm/creator.py,sha256=QpggNG1R7UCXLq8SosVqU4FUD_z-qcxEfvemP40SOIo,1581
|
|
21
|
+
thirdmagic/swarm/model.py,sha256=4NncuSTN9BBozEUOtEgViZSNg6-1EsTK1NrOnNVkxkU,9042
|
|
22
|
+
thirdmagic/swarm/state.py,sha256=p2QjltycxQHBgaX7jDlG_10cJjZWlhKpKwgISKWvBY0,390
|
|
23
|
+
thirdmagic/task/__init__.py,sha256=JX74AHbxN3J-IK1Gv1KqV6Z-65OR6LQ98hFzp3xPcKU,532
|
|
24
|
+
thirdmagic/task/creator.py,sha256=d6pr0MJC3dUu3d7xv6pxEObO8QFYylSA8-iWW_XnYiQ,2843
|
|
25
|
+
thirdmagic/task/model.py,sha256=yGehcmFhe8pNa0t6bXd3qcCdPiKlV23XL6Wk_Yond58,3058
|
|
26
|
+
thirdmagic-0.0.1.dist-info/METADATA,sha256=bcT2ya9ug_8wjLUQHIeZk4qxMuCNzpKQTzW8mQXik70,649
|
|
27
|
+
thirdmagic-0.0.1.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
28
|
+
thirdmagic-0.0.1.dist-info/RECORD,,
|