mageflow 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mageflow/__init__.py +30 -0
- mageflow/callbacks.py +72 -0
- mageflow/chain/__init__.py +0 -0
- mageflow/chain/consts.py +8 -0
- mageflow/chain/creator.py +73 -0
- mageflow/chain/messages.py +9 -0
- mageflow/chain/model.py +61 -0
- mageflow/chain/workflows.py +65 -0
- mageflow/client.py +140 -0
- mageflow/errors.py +22 -0
- mageflow/init.py +53 -0
- mageflow/invokers/__init__.py +0 -0
- mageflow/invokers/base.py +34 -0
- mageflow/invokers/hatchet.py +82 -0
- mageflow/models/__init__.py +0 -0
- mageflow/models/message.py +6 -0
- mageflow/signature/__init__.py +0 -0
- mageflow/signature/consts.py +3 -0
- mageflow/signature/creator.py +69 -0
- mageflow/signature/model.py +319 -0
- mageflow/signature/status.py +25 -0
- mageflow/signature/types.py +6 -0
- mageflow/startup.py +65 -0
- mageflow/swarm/__init__.py +0 -0
- mageflow/swarm/consts.py +12 -0
- mageflow/swarm/creator.py +34 -0
- mageflow/swarm/messages.py +7 -0
- mageflow/swarm/model.py +260 -0
- mageflow/swarm/workflows.py +120 -0
- mageflow/task/__init__.py +0 -0
- mageflow/task/model.py +19 -0
- mageflow/typing_support.py +8 -0
- mageflow/utils/__init__.py +0 -0
- mageflow/utils/models.py +19 -0
- mageflow/utils/pythonic.py +21 -0
- mageflow/visualizer/__init__.py +0 -0
- mageflow/visualizer/app.py +221 -0
- mageflow/visualizer/assets/cytoscape_styles.py +63 -0
- mageflow/visualizer/assets/styles.css +143 -0
- mageflow/visualizer/builder.py +497 -0
- mageflow/visualizer/data.py +65 -0
- mageflow/visualizer/utils.py +72 -0
- mageflow/workflows.py +128 -0
- mageflow-0.0.1.dist-info/METADATA +164 -0
- mageflow-0.0.1.dist-info/RECORD +48 -0
- mageflow-0.0.1.dist-info/WHEEL +4 -0
- mageflow-0.0.1.dist-info/entry_points.txt +3 -0
- mageflow-0.0.1.dist-info/licenses/LICENSE +21 -0
mageflow/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from mageflow.callbacks import register_task, handle_task_callback
|
|
2
|
+
from mageflow.chain.creator import chain
|
|
3
|
+
from mageflow.client import Mageflow
|
|
4
|
+
from mageflow.init import init_mageflow_hatchet_tasks
|
|
5
|
+
from mageflow.signature.creator import (
|
|
6
|
+
sign,
|
|
7
|
+
load_signature,
|
|
8
|
+
resume_task,
|
|
9
|
+
lock_task,
|
|
10
|
+
resume,
|
|
11
|
+
pause,
|
|
12
|
+
)
|
|
13
|
+
from mageflow.signature.status import TaskStatus
|
|
14
|
+
from mageflow.swarm.creator import swarm
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"load_signature",
|
|
19
|
+
"resume_task",
|
|
20
|
+
"lock_task",
|
|
21
|
+
"resume",
|
|
22
|
+
"pause",
|
|
23
|
+
"sign",
|
|
24
|
+
"init_mageflow_hatchet_tasks",
|
|
25
|
+
"register_task",
|
|
26
|
+
"handle_task_callback",
|
|
27
|
+
"Mageflow",
|
|
28
|
+
"chain",
|
|
29
|
+
"swarm",
|
|
30
|
+
]
|
mageflow/callbacks.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from hatchet_sdk import Context
|
|
8
|
+
from hatchet_sdk.runnables.types import EmptyModel
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from mageflow.invokers.hatchet import HatchetInvoker
|
|
12
|
+
from mageflow.utils.pythonic import flexible_call
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AcceptParams(Enum):
|
|
16
|
+
JUST_MESSAGE = 1
|
|
17
|
+
NO_CTX = 2
|
|
18
|
+
ALL = 3
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HatchetResult(BaseModel):
|
|
22
|
+
hatchet_results: Any
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def handle_task_callback(
|
|
26
|
+
expected_params: AcceptParams = AcceptParams.NO_CTX, wrap_res: bool = True
|
|
27
|
+
):
|
|
28
|
+
def task_decorator(func):
|
|
29
|
+
@functools.wraps(func)
|
|
30
|
+
async def wrapper(message: EmptyModel, ctx: Context, *args, **kwargs):
|
|
31
|
+
invoker = HatchetInvoker(message, ctx)
|
|
32
|
+
if not await invoker.should_run_task():
|
|
33
|
+
await ctx.aio_cancel()
|
|
34
|
+
await asyncio.sleep(10)
|
|
35
|
+
# NOTE: This should not run, the task should cancel, but just in case
|
|
36
|
+
return {"Error": "Task should have been canceled"}
|
|
37
|
+
try:
|
|
38
|
+
await invoker.start_task()
|
|
39
|
+
if expected_params == AcceptParams.JUST_MESSAGE:
|
|
40
|
+
result = await flexible_call(func, message)
|
|
41
|
+
elif expected_params == AcceptParams.NO_CTX:
|
|
42
|
+
result = await flexible_call(func, message, *args, **kwargs)
|
|
43
|
+
else:
|
|
44
|
+
result = await flexible_call(func, message, ctx, *args, **kwargs)
|
|
45
|
+
except (Exception, asyncio.CancelledError) as e:
|
|
46
|
+
await invoker.run_error()
|
|
47
|
+
await invoker.remove_task(with_error=False)
|
|
48
|
+
raise
|
|
49
|
+
else:
|
|
50
|
+
task_results = HatchetResult(hatchet_results=result)
|
|
51
|
+
dumped_results = task_results.model_dump(mode="json")
|
|
52
|
+
await invoker.run_success(dumped_results["hatchet_results"])
|
|
53
|
+
await invoker.remove_task(with_success=False)
|
|
54
|
+
if wrap_res:
|
|
55
|
+
return task_results
|
|
56
|
+
else:
|
|
57
|
+
return result
|
|
58
|
+
|
|
59
|
+
wrapper.__signature__ = inspect.signature(func)
|
|
60
|
+
return wrapper
|
|
61
|
+
|
|
62
|
+
return task_decorator
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def register_task(register_name: str):
|
|
66
|
+
from mageflow.startup import REGISTERED_TASKS
|
|
67
|
+
|
|
68
|
+
def decorator(func):
|
|
69
|
+
REGISTERED_TASKS.append((func, register_name))
|
|
70
|
+
return func
|
|
71
|
+
|
|
72
|
+
return decorator
|
|
File without changes
|
mageflow/chain/consts.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from mageflow.chain.consts import ON_CHAIN_END, ON_CHAIN_ERROR
|
|
4
|
+
from mageflow.chain.messages import ChainSuccessTaskCommandMessage
|
|
5
|
+
from mageflow.chain.model import ChainTaskSignature
|
|
6
|
+
from mageflow.signature.creator import (
|
|
7
|
+
TaskSignatureConvertible,
|
|
8
|
+
resolve_signature_key,
|
|
9
|
+
)
|
|
10
|
+
from mageflow.signature.model import (
|
|
11
|
+
TaskIdentifierType,
|
|
12
|
+
TaskSignature,
|
|
13
|
+
TaskInputType,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def chain(
|
|
18
|
+
tasks: list[TaskSignatureConvertible],
|
|
19
|
+
name: str = None,
|
|
20
|
+
error: TaskInputType = None,
|
|
21
|
+
success: TaskInputType = None,
|
|
22
|
+
) -> ChainTaskSignature:
|
|
23
|
+
tasks = [await resolve_signature_key(task) for task in tasks]
|
|
24
|
+
|
|
25
|
+
# Create a chain task that will be deleted only at the end of the chain
|
|
26
|
+
first_task = tasks[0]
|
|
27
|
+
chain_task_signature = ChainTaskSignature(
|
|
28
|
+
task_name=f"chain-task:{name or first_task.task_name}",
|
|
29
|
+
success_callbacks=[success] if success else [],
|
|
30
|
+
error_callbacks=[error] if error else [],
|
|
31
|
+
tasks=tasks,
|
|
32
|
+
)
|
|
33
|
+
await chain_task_signature.save()
|
|
34
|
+
|
|
35
|
+
callback_kwargs = dict(chain_task_id=chain_task_signature.key)
|
|
36
|
+
on_chain_error = TaskSignature(
|
|
37
|
+
task_name=ON_CHAIN_ERROR,
|
|
38
|
+
task_identifiers=callback_kwargs,
|
|
39
|
+
model_validators=ChainSuccessTaskCommandMessage,
|
|
40
|
+
)
|
|
41
|
+
on_chain_success = TaskSignature(
|
|
42
|
+
task_name=ON_CHAIN_END,
|
|
43
|
+
task_identifiers=callback_kwargs,
|
|
44
|
+
model_validators=ChainSuccessTaskCommandMessage,
|
|
45
|
+
)
|
|
46
|
+
await _chain_task_to_previous_success(tasks, on_chain_error, on_chain_success)
|
|
47
|
+
return chain_task_signature
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def _chain_task_to_previous_success(
|
|
51
|
+
tasks: list[TaskSignature], error: TaskSignature, success: TaskSignature
|
|
52
|
+
) -> TaskIdentifierType:
|
|
53
|
+
"""
|
|
54
|
+
Take a list of tasks and connect each one to the previous one.
|
|
55
|
+
"""
|
|
56
|
+
if len(tasks) < 2:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"Chained tasks must contain at least two tasks. "
|
|
59
|
+
"If you want to run a single task, use `create_workflow` instead."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
total_tasks = tasks + [success]
|
|
63
|
+
error_tasks = await error.duplicate_many(len(tasks))
|
|
64
|
+
store_errors = [error.save() for error in error_tasks]
|
|
65
|
+
|
|
66
|
+
# Store tasks
|
|
67
|
+
await asyncio.gather(success.save(), *store_errors)
|
|
68
|
+
update_tasks = [
|
|
69
|
+
task.add_callbacks(success=[total_tasks[i + 1]], errors=[error_tasks[i]])
|
|
70
|
+
for i, task in enumerate(tasks)
|
|
71
|
+
]
|
|
72
|
+
chained_tasks = await asyncio.gather(*update_tasks)
|
|
73
|
+
return chained_tasks[0]
|
mageflow/chain/model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from pydantic import field_validator, Field
|
|
4
|
+
|
|
5
|
+
from mageflow.errors import MissingSignatureError
|
|
6
|
+
from mageflow.signature.model import TaskSignature, TaskIdentifierType
|
|
7
|
+
from mageflow.signature.status import SignatureStatus
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ChainTaskSignature(TaskSignature):
|
|
11
|
+
tasks: list[TaskIdentifierType] = Field(default_factory=list)
|
|
12
|
+
|
|
13
|
+
@field_validator("tasks", mode="before")
|
|
14
|
+
@classmethod
|
|
15
|
+
def validate_tasks(cls, v: list[TaskSignature]):
|
|
16
|
+
return [cls.validate_task_key(item) for item in v]
|
|
17
|
+
|
|
18
|
+
async def workflow(self, **task_additional_params):
|
|
19
|
+
first_task = await TaskSignature.get_safe(self.tasks[0])
|
|
20
|
+
if first_task is None:
|
|
21
|
+
raise MissingSignatureError(f"First task from chain {self.key} not found")
|
|
22
|
+
return await first_task.workflow(**task_additional_params)
|
|
23
|
+
|
|
24
|
+
async def delete_chain_tasks(self, with_errors=True, with_success=True):
|
|
25
|
+
signatures = await asyncio.gather(
|
|
26
|
+
*[TaskSignature.get_safe(signature_id) for signature_id in self.tasks],
|
|
27
|
+
return_exceptions=True,
|
|
28
|
+
)
|
|
29
|
+
signatures = [sign for sign in signatures if isinstance(sign, TaskSignature)]
|
|
30
|
+
delete_tasks = [
|
|
31
|
+
signature.remove(with_errors, with_success) for signature in signatures
|
|
32
|
+
]
|
|
33
|
+
await asyncio.gather(*delete_tasks)
|
|
34
|
+
|
|
35
|
+
async def change_status(self, status: SignatureStatus):
|
|
36
|
+
pause_chain_tasks = [
|
|
37
|
+
TaskSignature.safe_change_status(task, status) for task in self.tasks
|
|
38
|
+
]
|
|
39
|
+
pause_chain = super().change_status(status)
|
|
40
|
+
await asyncio.gather(pause_chain, *pause_chain_tasks, return_exceptions=True)
|
|
41
|
+
|
|
42
|
+
async def suspend(self):
|
|
43
|
+
await asyncio.gather(
|
|
44
|
+
*[TaskSignature.suspend_from_key(task_id) for task_id in self.tasks],
|
|
45
|
+
return_exceptions=True,
|
|
46
|
+
)
|
|
47
|
+
await super().change_status(SignatureStatus.SUSPENDED)
|
|
48
|
+
|
|
49
|
+
async def interrupt(self):
|
|
50
|
+
await asyncio.gather(
|
|
51
|
+
*[TaskSignature.interrupt_from_key(task_id) for task_id in self.tasks],
|
|
52
|
+
return_exceptions=True,
|
|
53
|
+
)
|
|
54
|
+
await super().change_status(SignatureStatus.INTERRUPTED)
|
|
55
|
+
|
|
56
|
+
async def resume(self):
|
|
57
|
+
await asyncio.gather(
|
|
58
|
+
*[TaskSignature.resume_from_key(task_key) for task_key in self.tasks],
|
|
59
|
+
return_exceptions=True,
|
|
60
|
+
)
|
|
61
|
+
await super().change_status(self.task_status.last_status)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from hatchet_sdk import Context
|
|
4
|
+
from hatchet_sdk.runnables.types import EmptyModel
|
|
5
|
+
|
|
6
|
+
from mageflow.chain.consts import CHAIN_TASK_ID_NAME
|
|
7
|
+
from mageflow.chain.messages import ChainSuccessTaskCommandMessage
|
|
8
|
+
from mageflow.chain.model import ChainTaskSignature
|
|
9
|
+
from mageflow.invokers.hatchet import HatchetInvoker
|
|
10
|
+
from mageflow.signature.consts import TASK_ID_PARAM_NAME
|
|
11
|
+
from mageflow.signature.model import TaskSignature
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def chain_end_task(msg: ChainSuccessTaskCommandMessage, ctx: Context) -> None:
|
|
15
|
+
try:
|
|
16
|
+
task_data = HatchetInvoker(msg, ctx).task_ctx
|
|
17
|
+
chain_task_id = task_data[CHAIN_TASK_ID_NAME]
|
|
18
|
+
current_task_id = task_data[TASK_ID_PARAM_NAME]
|
|
19
|
+
|
|
20
|
+
chain_task_signature, current_task = await asyncio.gather(
|
|
21
|
+
ChainTaskSignature.get_safe(chain_task_id),
|
|
22
|
+
TaskSignature.get_safe(current_task_id),
|
|
23
|
+
)
|
|
24
|
+
ctx.log(f"Chain task done {chain_task_signature.task_name}")
|
|
25
|
+
|
|
26
|
+
# Calling error callback from a chain task - This is done before deletion because a deletion error should not disturb the workflow
|
|
27
|
+
await chain_task_signature.activate_success(msg.chain_results)
|
|
28
|
+
ctx.log(f"Chain task success {chain_task_signature.task_name}")
|
|
29
|
+
|
|
30
|
+
# Remove tasks
|
|
31
|
+
await asyncio.gather(
|
|
32
|
+
chain_task_signature.remove(with_success=False), current_task.remove()
|
|
33
|
+
)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
ctx.log(f"MAJOR - infrastructure error in chain end task: {e}")
|
|
36
|
+
raise
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# This task needs to be added as a workflow
|
|
40
|
+
async def chain_error_task(msg: EmptyModel, ctx: Context) -> None:
|
|
41
|
+
try:
|
|
42
|
+
task_data = HatchetInvoker(msg, ctx).task_ctx
|
|
43
|
+
chain_task_id = task_data[CHAIN_TASK_ID_NAME]
|
|
44
|
+
current_task_id = task_data[TASK_ID_PARAM_NAME]
|
|
45
|
+
chain_packed_task, current_task = await asyncio.gather(
|
|
46
|
+
ChainTaskSignature.get_safe(chain_task_id),
|
|
47
|
+
TaskSignature.get_safe(current_task_id),
|
|
48
|
+
)
|
|
49
|
+
ctx.log(
|
|
50
|
+
f"Chain task failed {chain_packed_task.task_name} on task id - {current_task_id}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Calling error callback from chain task
|
|
54
|
+
await chain_packed_task.activate_error(msg)
|
|
55
|
+
ctx.log(f"Chain task error {chain_packed_task.task_name}")
|
|
56
|
+
|
|
57
|
+
# Remove tasks
|
|
58
|
+
await chain_packed_task.delete_chain_tasks()
|
|
59
|
+
await asyncio.gather(
|
|
60
|
+
chain_packed_task.remove(with_error=False), current_task.remove()
|
|
61
|
+
)
|
|
62
|
+
ctx.log(f"Clean redis from chain tasks {chain_packed_task.task_name}")
|
|
63
|
+
except Exception as e:
|
|
64
|
+
ctx.log(f"MAJOR - infrastructure error in chain error task: {e}")
|
|
65
|
+
raise
|
mageflow/client.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
from typing import TypeVar, Any, overload, Unpack
|
|
4
|
+
|
|
5
|
+
import redis
|
|
6
|
+
from hatchet_sdk import Hatchet, Worker
|
|
7
|
+
from hatchet_sdk.runnables.workflow import BaseWorkflow
|
|
8
|
+
from hatchet_sdk.worker.worker import LifespanFn
|
|
9
|
+
from redis.asyncio import Redis
|
|
10
|
+
|
|
11
|
+
from mageflow.callbacks import AcceptParams, register_task, handle_task_callback
|
|
12
|
+
from mageflow.chain.creator import chain
|
|
13
|
+
from mageflow.init import init_mageflow_hatchet_tasks
|
|
14
|
+
from mageflow.signature.creator import sign, TaskSignatureConvertible
|
|
15
|
+
from mageflow.signature.model import TaskSignature, TaskInputType
|
|
16
|
+
from mageflow.signature.types import HatchetTaskType
|
|
17
|
+
from mageflow.startup import (
|
|
18
|
+
lifespan_initialize,
|
|
19
|
+
mageflow_config,
|
|
20
|
+
init_mageflow,
|
|
21
|
+
teardown_mageflow,
|
|
22
|
+
)
|
|
23
|
+
from mageflow.swarm.creator import swarm, SignatureOptions
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def merge_lifespan(original_lifespan: LifespanFn):
|
|
27
|
+
await init_mageflow()
|
|
28
|
+
async for res in original_lifespan():
|
|
29
|
+
yield res
|
|
30
|
+
await teardown_mageflow()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class HatchetMageflow(Hatchet):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
hatchet: Hatchet,
|
|
37
|
+
redis_client: Redis,
|
|
38
|
+
param_config: AcceptParams = AcceptParams.NO_CTX,
|
|
39
|
+
):
|
|
40
|
+
super().__init__(client=hatchet._client)
|
|
41
|
+
self.hatchet = hatchet
|
|
42
|
+
self.redis = redis_client
|
|
43
|
+
self.param_config = param_config
|
|
44
|
+
|
|
45
|
+
def task(self, *, name: str | None = None, **kwargs):
|
|
46
|
+
hatchet_task = super().task(name=name, **kwargs)
|
|
47
|
+
|
|
48
|
+
def decorator(func):
|
|
49
|
+
handler_dec = handle_task_callback(self.param_config)
|
|
50
|
+
func = handler_dec(func)
|
|
51
|
+
wf = hatchet_task(func)
|
|
52
|
+
|
|
53
|
+
nonlocal name
|
|
54
|
+
task_name = name or func.__name__
|
|
55
|
+
register = register_task(task_name)
|
|
56
|
+
return register(wf)
|
|
57
|
+
|
|
58
|
+
return decorator
|
|
59
|
+
|
|
60
|
+
def durable_task(self, *, name: str | None = None, **kwargs):
|
|
61
|
+
hatchet_task = super().durable_task(name=name, **kwargs)
|
|
62
|
+
|
|
63
|
+
def decorator(func):
|
|
64
|
+
handler_dec = handle_task_callback(self.param_config)
|
|
65
|
+
func = handler_dec(func)
|
|
66
|
+
wf = hatchet_task(func)
|
|
67
|
+
nonlocal name
|
|
68
|
+
task_name = name or func.__name__
|
|
69
|
+
register = register_task(task_name)
|
|
70
|
+
return register(wf)
|
|
71
|
+
|
|
72
|
+
return decorator
|
|
73
|
+
|
|
74
|
+
def worker(
|
|
75
|
+
self,
|
|
76
|
+
*args,
|
|
77
|
+
workflows: list[BaseWorkflow[Any]] | None = None,
|
|
78
|
+
lifespan: LifespanFn | None = None,
|
|
79
|
+
**kwargs,
|
|
80
|
+
) -> Worker:
|
|
81
|
+
mageflow_flows = init_mageflow_hatchet_tasks(self.hatchet)
|
|
82
|
+
workflows += mageflow_flows
|
|
83
|
+
if lifespan is None:
|
|
84
|
+
lifespan = lifespan_initialize
|
|
85
|
+
else:
|
|
86
|
+
lifespan = functools.partial(merge_lifespan, lifespan)
|
|
87
|
+
|
|
88
|
+
return super().worker(*args, workflows=workflows, lifespan=lifespan, **kwargs)
|
|
89
|
+
|
|
90
|
+
async def sign(self, task: str | HatchetTaskType, **options: Any) -> TaskSignature:
|
|
91
|
+
return await sign(task, **options)
|
|
92
|
+
|
|
93
|
+
async def chain(
|
|
94
|
+
self,
|
|
95
|
+
tasks: list[TaskSignatureConvertible],
|
|
96
|
+
name: str = None,
|
|
97
|
+
error: TaskInputType = None,
|
|
98
|
+
success: TaskInputType = None,
|
|
99
|
+
):
|
|
100
|
+
return await chain(tasks, name, error, success)
|
|
101
|
+
|
|
102
|
+
async def swarm(
|
|
103
|
+
self,
|
|
104
|
+
tasks: list[TaskSignatureConvertible] = None,
|
|
105
|
+
task_name: str = None,
|
|
106
|
+
**kwargs: Unpack[SignatureOptions],
|
|
107
|
+
):
|
|
108
|
+
return await swarm(tasks, task_name, **kwargs)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
T = TypeVar("T")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def Mageflow(
|
|
116
|
+
hatchet_client: Hatchet, redis_client: Redis | str = None
|
|
117
|
+
) -> HatchetMageflow: ...
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def Mageflow(
|
|
121
|
+
hatchet_client: T = None,
|
|
122
|
+
redis_client: Redis | str = None,
|
|
123
|
+
param_config: AcceptParams = AcceptParams.NO_CTX,
|
|
124
|
+
) -> T:
|
|
125
|
+
if hatchet_client is None:
|
|
126
|
+
hatchet_client = Hatchet()
|
|
127
|
+
|
|
128
|
+
# Create a hatchet client with empty namespace for creating wf
|
|
129
|
+
config = hatchet_client._client.config.model_copy(deep=True)
|
|
130
|
+
config.namespace = ""
|
|
131
|
+
hatchet_caller = Hatchet(config=config, debug=hatchet_client._client.debug)
|
|
132
|
+
mageflow_config.hatchet_client = hatchet_caller
|
|
133
|
+
|
|
134
|
+
if redis_client is None:
|
|
135
|
+
redis_url = os.getenv("REDIS_URL")
|
|
136
|
+
redis_client = redis.asyncio.from_url(redis_url, decode_responses=True)
|
|
137
|
+
if isinstance(redis_client, str):
|
|
138
|
+
redis_client = redis.asyncio.from_url(redis_client, decode_responses=True)
|
|
139
|
+
mageflow_config.redis_client = redis_client
|
|
140
|
+
return HatchetMageflow(hatchet_client, redis_client, param_config)
|
mageflow/errors.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
class MageflowError(Exception):
|
|
2
|
+
pass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MissingSignatureError(MageflowError):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MissingSwarmItemError(MissingSignatureError):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SwarmError(MageflowError):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TooManyTasksError(SwarmError, RuntimeError):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SwarmIsCanceledError(SwarmError, RuntimeError):
|
|
22
|
+
pass
|
mageflow/init.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from hatchet_sdk import Hatchet
|
|
2
|
+
|
|
3
|
+
from mageflow.callbacks import register_task
|
|
4
|
+
from mageflow.chain.consts import ON_CHAIN_END, ON_CHAIN_ERROR
|
|
5
|
+
from mageflow.chain.messages import ChainSuccessTaskCommandMessage
|
|
6
|
+
from mageflow.chain.workflows import chain_end_task, chain_error_task
|
|
7
|
+
from mageflow.swarm.consts import ON_SWARM_ERROR, ON_SWARM_END, ON_SWARM_START
|
|
8
|
+
from mageflow.swarm.messages import SwarmResultsMessage
|
|
9
|
+
from mageflow.swarm.workflows import (
|
|
10
|
+
swarm_item_failed,
|
|
11
|
+
swarm_item_done,
|
|
12
|
+
swarm_start_tasks,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def init_mageflow_hatchet_tasks(hatchet: Hatchet):
|
|
17
|
+
# Chain tasks
|
|
18
|
+
hatchet_chain_done = hatchet.task(
|
|
19
|
+
name=ON_CHAIN_END,
|
|
20
|
+
input_validator=ChainSuccessTaskCommandMessage,
|
|
21
|
+
)
|
|
22
|
+
hatchet_chain_error = hatchet.task(name=ON_CHAIN_ERROR)
|
|
23
|
+
chain_done_task = hatchet_chain_done(chain_end_task)
|
|
24
|
+
on_chain_error_task = hatchet_chain_error(chain_error_task)
|
|
25
|
+
register_chain_done = register_task(ON_CHAIN_END)
|
|
26
|
+
register_chain_error = register_task(ON_CHAIN_ERROR)
|
|
27
|
+
chain_done_task = register_chain_done(chain_done_task)
|
|
28
|
+
on_chain_error_task = register_chain_error(on_chain_error_task)
|
|
29
|
+
|
|
30
|
+
# Swarm tasks
|
|
31
|
+
swarm_start = hatchet.task(name=ON_SWARM_START)
|
|
32
|
+
swarm_done = hatchet.task(
|
|
33
|
+
name=ON_SWARM_END,
|
|
34
|
+
input_validator=SwarmResultsMessage,
|
|
35
|
+
)
|
|
36
|
+
swarm_error = hatchet.task(name=ON_SWARM_ERROR)
|
|
37
|
+
swarm_start = swarm_start(swarm_start_tasks)
|
|
38
|
+
swarm_done = swarm_done(swarm_item_done)
|
|
39
|
+
swarm_error = swarm_error(swarm_item_failed)
|
|
40
|
+
register_swarm_start = register_task(ON_SWARM_START)
|
|
41
|
+
register_swarm_done = register_task(ON_SWARM_END)
|
|
42
|
+
register_swarm_error = register_task(ON_SWARM_ERROR)
|
|
43
|
+
swarm_start = register_swarm_start(swarm_start)
|
|
44
|
+
swarm_done = register_swarm_done(swarm_done)
|
|
45
|
+
swarm_error = register_swarm_error(swarm_error)
|
|
46
|
+
|
|
47
|
+
return [
|
|
48
|
+
on_chain_error_task,
|
|
49
|
+
chain_done_task,
|
|
50
|
+
swarm_start,
|
|
51
|
+
swarm_done,
|
|
52
|
+
swarm_error,
|
|
53
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from mageflow.signature.model import TaskSignature
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseInvoker(ABC):
|
|
9
|
+
@property
|
|
10
|
+
@abc.abstractmethod
|
|
11
|
+
def task_ctx(self) -> dict:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
async def start_task(self):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
async def run_success(self, result: Any) -> bool:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
async def run_error(self) -> bool:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abc.abstractmethod
|
|
27
|
+
async def remove_task(
|
|
28
|
+
self, with_success: bool = True, with_error: bool = True
|
|
29
|
+
) -> TaskSignature | None:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
async def should_run_task(self) -> bool:
|
|
34
|
+
pass
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from hatchet_sdk import Context
|
|
5
|
+
from hatchet_sdk.runnables.contextvars import ctx_additional_metadata
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from mageflow.invokers.base import BaseInvoker
|
|
9
|
+
from mageflow.signature.consts import TASK_ID_PARAM_NAME
|
|
10
|
+
from mageflow.signature.model import TaskSignature
|
|
11
|
+
from mageflow.signature.status import SignatureStatus
|
|
12
|
+
from mageflow.workflows import TASK_DATA_PARAM_NAME
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HatchetInvoker(BaseInvoker):
|
|
16
|
+
def __init__(self, message: BaseModel, ctx: Context):
|
|
17
|
+
self.message = message
|
|
18
|
+
self.task_data = ctx.additional_metadata.get(TASK_DATA_PARAM_NAME, {})
|
|
19
|
+
self.workflow_id = ctx.workflow_id
|
|
20
|
+
hatchet_ctx_metadata = ctx_additional_metadata.get() or {}
|
|
21
|
+
hatchet_ctx_metadata.pop(TASK_DATA_PARAM_NAME, None)
|
|
22
|
+
ctx_additional_metadata.set(hatchet_ctx_metadata)
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def task_ctx(self) -> dict:
|
|
26
|
+
return self.task_data
|
|
27
|
+
|
|
28
|
+
async def start_task(self):
|
|
29
|
+
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
|
|
30
|
+
if task_id:
|
|
31
|
+
async with TaskSignature.lock_from_key(task_id) as signature:
|
|
32
|
+
await signature.change_status(SignatureStatus.ACTIVE)
|
|
33
|
+
await signature.task_status.aupdate(worker_task_id=self.workflow_id)
|
|
34
|
+
|
|
35
|
+
async def run_success(self, result: Any) -> bool:
|
|
36
|
+
success_publish_tasks = []
|
|
37
|
+
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
|
|
38
|
+
if task_id:
|
|
39
|
+
current_task = await TaskSignature.get_safe(task_id)
|
|
40
|
+
task_success_workflows = current_task.activate_success(result)
|
|
41
|
+
success_publish_tasks.append(asyncio.create_task(task_success_workflows))
|
|
42
|
+
|
|
43
|
+
if success_publish_tasks:
|
|
44
|
+
await asyncio.gather(*success_publish_tasks)
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
async def run_error(self) -> bool:
|
|
49
|
+
error_publish_tasks = []
|
|
50
|
+
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
|
|
51
|
+
if task_id:
|
|
52
|
+
current_task = await TaskSignature.get_safe(task_id)
|
|
53
|
+
task_error_workflows = current_task.activate_error(self.message)
|
|
54
|
+
error_publish_tasks.append(asyncio.create_task(task_error_workflows))
|
|
55
|
+
|
|
56
|
+
if error_publish_tasks:
|
|
57
|
+
await asyncio.gather(*error_publish_tasks)
|
|
58
|
+
return True
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
async def remove_task(
|
|
62
|
+
self, with_success: bool = True, with_error: bool = True
|
|
63
|
+
) -> TaskSignature | None:
|
|
64
|
+
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
|
|
65
|
+
if task_id:
|
|
66
|
+
signature = await TaskSignature.get_safe(task_id)
|
|
67
|
+
if signature:
|
|
68
|
+
await signature.remove(with_error, with_success)
|
|
69
|
+
|
|
70
|
+
async def should_run_task(self) -> bool:
|
|
71
|
+
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
|
|
72
|
+
if task_id:
|
|
73
|
+
signature = await TaskSignature.get_safe(task_id)
|
|
74
|
+
if signature is None:
|
|
75
|
+
return False
|
|
76
|
+
should_task_run = await signature.should_run()
|
|
77
|
+
if should_task_run:
|
|
78
|
+
return True
|
|
79
|
+
await signature.task_status.aupdate(last_status=SignatureStatus.ACTIVE)
|
|
80
|
+
await signature.handle_inactive_task(self.message)
|
|
81
|
+
return False
|
|
82
|
+
return True
|
|
File without changes
|
|
File without changes
|