mageflow 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. mageflow/__init__.py +30 -0
  2. mageflow/callbacks.py +72 -0
  3. mageflow/chain/__init__.py +0 -0
  4. mageflow/chain/consts.py +8 -0
  5. mageflow/chain/creator.py +73 -0
  6. mageflow/chain/messages.py +9 -0
  7. mageflow/chain/model.py +61 -0
  8. mageflow/chain/workflows.py +65 -0
  9. mageflow/client.py +140 -0
  10. mageflow/errors.py +22 -0
  11. mageflow/init.py +53 -0
  12. mageflow/invokers/__init__.py +0 -0
  13. mageflow/invokers/base.py +34 -0
  14. mageflow/invokers/hatchet.py +82 -0
  15. mageflow/models/__init__.py +0 -0
  16. mageflow/models/message.py +6 -0
  17. mageflow/signature/__init__.py +0 -0
  18. mageflow/signature/consts.py +3 -0
  19. mageflow/signature/creator.py +69 -0
  20. mageflow/signature/model.py +319 -0
  21. mageflow/signature/status.py +25 -0
  22. mageflow/signature/types.py +6 -0
  23. mageflow/startup.py +65 -0
  24. mageflow/swarm/__init__.py +0 -0
  25. mageflow/swarm/consts.py +12 -0
  26. mageflow/swarm/creator.py +34 -0
  27. mageflow/swarm/messages.py +7 -0
  28. mageflow/swarm/model.py +260 -0
  29. mageflow/swarm/workflows.py +120 -0
  30. mageflow/task/__init__.py +0 -0
  31. mageflow/task/model.py +19 -0
  32. mageflow/typing_support.py +8 -0
  33. mageflow/utils/__init__.py +0 -0
  34. mageflow/utils/models.py +19 -0
  35. mageflow/utils/pythonic.py +21 -0
  36. mageflow/visualizer/__init__.py +0 -0
  37. mageflow/visualizer/app.py +221 -0
  38. mageflow/visualizer/assets/cytoscape_styles.py +63 -0
  39. mageflow/visualizer/assets/styles.css +143 -0
  40. mageflow/visualizer/builder.py +497 -0
  41. mageflow/visualizer/data.py +65 -0
  42. mageflow/visualizer/utils.py +72 -0
  43. mageflow/workflows.py +128 -0
  44. mageflow-0.0.1.dist-info/METADATA +164 -0
  45. mageflow-0.0.1.dist-info/RECORD +48 -0
  46. mageflow-0.0.1.dist-info/WHEEL +4 -0
  47. mageflow-0.0.1.dist-info/entry_points.txt +3 -0
  48. mageflow-0.0.1.dist-info/licenses/LICENSE +21 -0
mageflow/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ from mageflow.callbacks import register_task, handle_task_callback
2
+ from mageflow.chain.creator import chain
3
+ from mageflow.client import Mageflow
4
+ from mageflow.init import init_mageflow_hatchet_tasks
5
+ from mageflow.signature.creator import (
6
+ sign,
7
+ load_signature,
8
+ resume_task,
9
+ lock_task,
10
+ resume,
11
+ pause,
12
+ )
13
+ from mageflow.signature.status import TaskStatus
14
+ from mageflow.swarm.creator import swarm
15
+
16
+
17
+ __all__ = [
18
+ "load_signature",
19
+ "resume_task",
20
+ "lock_task",
21
+ "resume",
22
+ "pause",
23
+ "sign",
24
+ "init_mageflow_hatchet_tasks",
25
+ "register_task",
26
+ "handle_task_callback",
27
+ "Mageflow",
28
+ "chain",
29
+ "swarm",
30
+ ]
mageflow/callbacks.py ADDED
@@ -0,0 +1,72 @@
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+ from enum import Enum
5
+ from typing import Any
6
+
7
+ from hatchet_sdk import Context
8
+ from hatchet_sdk.runnables.types import EmptyModel
9
+ from pydantic import BaseModel
10
+
11
+ from mageflow.invokers.hatchet import HatchetInvoker
12
+ from mageflow.utils.pythonic import flexible_call
13
+
14
+
15
+ class AcceptParams(Enum):
16
+ JUST_MESSAGE = 1
17
+ NO_CTX = 2
18
+ ALL = 3
19
+
20
+
21
+ class HatchetResult(BaseModel):
22
+ hatchet_results: Any
23
+
24
+
25
+ def handle_task_callback(
26
+ expected_params: AcceptParams = AcceptParams.NO_CTX, wrap_res: bool = True
27
+ ):
28
+ def task_decorator(func):
29
+ @functools.wraps(func)
30
+ async def wrapper(message: EmptyModel, ctx: Context, *args, **kwargs):
31
+ invoker = HatchetInvoker(message, ctx)
32
+ if not await invoker.should_run_task():
33
+ await ctx.aio_cancel()
34
+ await asyncio.sleep(10)
35
+ # NOTE: This should not run, the task should cancel, but just in case
36
+ return {"Error": "Task should have been canceled"}
37
+ try:
38
+ await invoker.start_task()
39
+ if expected_params == AcceptParams.JUST_MESSAGE:
40
+ result = await flexible_call(func, message)
41
+ elif expected_params == AcceptParams.NO_CTX:
42
+ result = await flexible_call(func, message, *args, **kwargs)
43
+ else:
44
+ result = await flexible_call(func, message, ctx, *args, **kwargs)
45
+ except (Exception, asyncio.CancelledError) as e:
46
+ await invoker.run_error()
47
+ await invoker.remove_task(with_error=False)
48
+ raise
49
+ else:
50
+ task_results = HatchetResult(hatchet_results=result)
51
+ dumped_results = task_results.model_dump(mode="json")
52
+ await invoker.run_success(dumped_results["hatchet_results"])
53
+ await invoker.remove_task(with_success=False)
54
+ if wrap_res:
55
+ return task_results
56
+ else:
57
+ return result
58
+
59
+ wrapper.__signature__ = inspect.signature(func)
60
+ return wrapper
61
+
62
+ return task_decorator
63
+
64
+
65
+ def register_task(register_name: str):
66
+ from mageflow.startup import REGISTERED_TASKS
67
+
68
+ def decorator(func):
69
+ REGISTERED_TASKS.append((func, register_name))
70
+ return func
71
+
72
+ return decorator
File without changes
@@ -0,0 +1,8 @@
1
+ # Params
2
+ from mageflow.signature.consts import MAGEFLOW_TASK_INITIALS
3
+
4
+ CHAIN_TASK_ID_NAME = "chain_task_id"
5
+
6
+ # Task names
7
+ ON_CHAIN_ERROR = f"{MAGEFLOW_TASK_INITIALS}on_chain_error"
8
+ ON_CHAIN_END = f"{MAGEFLOW_TASK_INITIALS}on_chain_done"
@@ -0,0 +1,73 @@
1
+ import asyncio
2
+
3
+ from mageflow.chain.consts import ON_CHAIN_END, ON_CHAIN_ERROR
4
+ from mageflow.chain.messages import ChainSuccessTaskCommandMessage
5
+ from mageflow.chain.model import ChainTaskSignature
6
+ from mageflow.signature.creator import (
7
+ TaskSignatureConvertible,
8
+ resolve_signature_key,
9
+ )
10
+ from mageflow.signature.model import (
11
+ TaskIdentifierType,
12
+ TaskSignature,
13
+ TaskInputType,
14
+ )
15
+
16
+
17
+ async def chain(
18
+ tasks: list[TaskSignatureConvertible],
19
+ name: str = None,
20
+ error: TaskInputType = None,
21
+ success: TaskInputType = None,
22
+ ) -> ChainTaskSignature:
23
+ tasks = [await resolve_signature_key(task) for task in tasks]
24
+
25
+ # Create a chain task that will be deleted only at the end of the chain
26
+ first_task = tasks[0]
27
+ chain_task_signature = ChainTaskSignature(
28
+ task_name=f"chain-task:{name or first_task.task_name}",
29
+ success_callbacks=[success] if success else [],
30
+ error_callbacks=[error] if error else [],
31
+ tasks=tasks,
32
+ )
33
+ await chain_task_signature.save()
34
+
35
+ callback_kwargs = dict(chain_task_id=chain_task_signature.key)
36
+ on_chain_error = TaskSignature(
37
+ task_name=ON_CHAIN_ERROR,
38
+ task_identifiers=callback_kwargs,
39
+ model_validators=ChainSuccessTaskCommandMessage,
40
+ )
41
+ on_chain_success = TaskSignature(
42
+ task_name=ON_CHAIN_END,
43
+ task_identifiers=callback_kwargs,
44
+ model_validators=ChainSuccessTaskCommandMessage,
45
+ )
46
+ await _chain_task_to_previous_success(tasks, on_chain_error, on_chain_success)
47
+ return chain_task_signature
48
+
49
+
50
+ async def _chain_task_to_previous_success(
51
+ tasks: list[TaskSignature], error: TaskSignature, success: TaskSignature
52
+ ) -> TaskIdentifierType:
53
+ """
54
+ Take a list of tasks and connect each one to the previous one.
55
+ """
56
+ if len(tasks) < 2:
57
+ raise ValueError(
58
+ "Chained tasks must contain at least two tasks. "
59
+ "If you want to run a single task, use `create_workflow` instead."
60
+ )
61
+
62
+ total_tasks = tasks + [success]
63
+ error_tasks = await error.duplicate_many(len(tasks))
64
+ store_errors = [error.save() for error in error_tasks]
65
+
66
+ # Store tasks
67
+ await asyncio.gather(success.save(), *store_errors)
68
+ update_tasks = [
69
+ task.add_callbacks(success=[total_tasks[i + 1]], errors=[error_tasks[i]])
70
+ for i, task in enumerate(tasks)
71
+ ]
72
+ chained_tasks = await asyncio.gather(*update_tasks)
73
+ return chained_tasks[0]
@@ -0,0 +1,9 @@
1
+ from typing import Any, Annotated
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from mageflow.models.message import ReturnValue
6
+
7
+
8
+ class ChainSuccessTaskCommandMessage(BaseModel):
9
+ chain_results: Annotated[Any, ReturnValue()]
@@ -0,0 +1,61 @@
1
+ import asyncio
2
+
3
+ from pydantic import field_validator, Field
4
+
5
+ from mageflow.errors import MissingSignatureError
6
+ from mageflow.signature.model import TaskSignature, TaskIdentifierType
7
+ from mageflow.signature.status import SignatureStatus
8
+
9
+
10
+ class ChainTaskSignature(TaskSignature):
11
+ tasks: list[TaskIdentifierType] = Field(default_factory=list)
12
+
13
+ @field_validator("tasks", mode="before")
14
+ @classmethod
15
+ def validate_tasks(cls, v: list[TaskSignature]):
16
+ return [cls.validate_task_key(item) for item in v]
17
+
18
+ async def workflow(self, **task_additional_params):
19
+ first_task = await TaskSignature.get_safe(self.tasks[0])
20
+ if first_task is None:
21
+ raise MissingSignatureError(f"First task from chain {self.key} not found")
22
+ return await first_task.workflow(**task_additional_params)
23
+
24
+ async def delete_chain_tasks(self, with_errors=True, with_success=True):
25
+ signatures = await asyncio.gather(
26
+ *[TaskSignature.get_safe(signature_id) for signature_id in self.tasks],
27
+ return_exceptions=True,
28
+ )
29
+ signatures = [sign for sign in signatures if isinstance(sign, TaskSignature)]
30
+ delete_tasks = [
31
+ signature.remove(with_errors, with_success) for signature in signatures
32
+ ]
33
+ await asyncio.gather(*delete_tasks)
34
+
35
+ async def change_status(self, status: SignatureStatus):
36
+ pause_chain_tasks = [
37
+ TaskSignature.safe_change_status(task, status) for task in self.tasks
38
+ ]
39
+ pause_chain = super().change_status(status)
40
+ await asyncio.gather(pause_chain, *pause_chain_tasks, return_exceptions=True)
41
+
42
+ async def suspend(self):
43
+ await asyncio.gather(
44
+ *[TaskSignature.suspend_from_key(task_id) for task_id in self.tasks],
45
+ return_exceptions=True,
46
+ )
47
+ await super().change_status(SignatureStatus.SUSPENDED)
48
+
49
+ async def interrupt(self):
50
+ await asyncio.gather(
51
+ *[TaskSignature.interrupt_from_key(task_id) for task_id in self.tasks],
52
+ return_exceptions=True,
53
+ )
54
+ await super().change_status(SignatureStatus.INTERRUPTED)
55
+
56
+ async def resume(self):
57
+ await asyncio.gather(
58
+ *[TaskSignature.resume_from_key(task_key) for task_key in self.tasks],
59
+ return_exceptions=True,
60
+ )
61
+ await super().change_status(self.task_status.last_status)
@@ -0,0 +1,65 @@
1
+ import asyncio
2
+
3
+ from hatchet_sdk import Context
4
+ from hatchet_sdk.runnables.types import EmptyModel
5
+
6
+ from mageflow.chain.consts import CHAIN_TASK_ID_NAME
7
+ from mageflow.chain.messages import ChainSuccessTaskCommandMessage
8
+ from mageflow.chain.model import ChainTaskSignature
9
+ from mageflow.invokers.hatchet import HatchetInvoker
10
+ from mageflow.signature.consts import TASK_ID_PARAM_NAME
11
+ from mageflow.signature.model import TaskSignature
12
+
13
+
14
+ async def chain_end_task(msg: ChainSuccessTaskCommandMessage, ctx: Context) -> None:
15
+ try:
16
+ task_data = HatchetInvoker(msg, ctx).task_ctx
17
+ chain_task_id = task_data[CHAIN_TASK_ID_NAME]
18
+ current_task_id = task_data[TASK_ID_PARAM_NAME]
19
+
20
+ chain_task_signature, current_task = await asyncio.gather(
21
+ ChainTaskSignature.get_safe(chain_task_id),
22
+ TaskSignature.get_safe(current_task_id),
23
+ )
24
+ ctx.log(f"Chain task done {chain_task_signature.task_name}")
25
+
26
+ # Calling error callback from a chain task - This is done before deletion because a deletion error should not disturb the workflow
27
+ await chain_task_signature.activate_success(msg.chain_results)
28
+ ctx.log(f"Chain task success {chain_task_signature.task_name}")
29
+
30
+ # Remove tasks
31
+ await asyncio.gather(
32
+ chain_task_signature.remove(with_success=False), current_task.remove()
33
+ )
34
+ except Exception as e:
35
+ ctx.log(f"MAJOR - infrastructure error in chain end task: {e}")
36
+ raise
37
+
38
+
39
+ # This task needs to be added as a workflow
40
+ async def chain_error_task(msg: EmptyModel, ctx: Context) -> None:
41
+ try:
42
+ task_data = HatchetInvoker(msg, ctx).task_ctx
43
+ chain_task_id = task_data[CHAIN_TASK_ID_NAME]
44
+ current_task_id = task_data[TASK_ID_PARAM_NAME]
45
+ chain_packed_task, current_task = await asyncio.gather(
46
+ ChainTaskSignature.get_safe(chain_task_id),
47
+ TaskSignature.get_safe(current_task_id),
48
+ )
49
+ ctx.log(
50
+ f"Chain task failed {chain_packed_task.task_name} on task id - {current_task_id}"
51
+ )
52
+
53
+ # Calling error callback from chain task
54
+ await chain_packed_task.activate_error(msg)
55
+ ctx.log(f"Chain task error {chain_packed_task.task_name}")
56
+
57
+ # Remove tasks
58
+ await chain_packed_task.delete_chain_tasks()
59
+ await asyncio.gather(
60
+ chain_packed_task.remove(with_error=False), current_task.remove()
61
+ )
62
+ ctx.log(f"Clean redis from chain tasks {chain_packed_task.task_name}")
63
+ except Exception as e:
64
+ ctx.log(f"MAJOR - infrastructure error in chain error task: {e}")
65
+ raise
mageflow/client.py ADDED
@@ -0,0 +1,140 @@
1
+ import functools
2
+ import os
3
+ from typing import TypeVar, Any, overload, Unpack
4
+
5
+ import redis
6
+ from hatchet_sdk import Hatchet, Worker
7
+ from hatchet_sdk.runnables.workflow import BaseWorkflow
8
+ from hatchet_sdk.worker.worker import LifespanFn
9
+ from redis.asyncio import Redis
10
+
11
+ from mageflow.callbacks import AcceptParams, register_task, handle_task_callback
12
+ from mageflow.chain.creator import chain
13
+ from mageflow.init import init_mageflow_hatchet_tasks
14
+ from mageflow.signature.creator import sign, TaskSignatureConvertible
15
+ from mageflow.signature.model import TaskSignature, TaskInputType
16
+ from mageflow.signature.types import HatchetTaskType
17
+ from mageflow.startup import (
18
+ lifespan_initialize,
19
+ mageflow_config,
20
+ init_mageflow,
21
+ teardown_mageflow,
22
+ )
23
+ from mageflow.swarm.creator import swarm, SignatureOptions
24
+
25
+
26
+ async def merge_lifespan(original_lifespan: LifespanFn):
27
+ await init_mageflow()
28
+ async for res in original_lifespan():
29
+ yield res
30
+ await teardown_mageflow()
31
+
32
+
33
+ class HatchetMageflow(Hatchet):
34
+ def __init__(
35
+ self,
36
+ hatchet: Hatchet,
37
+ redis_client: Redis,
38
+ param_config: AcceptParams = AcceptParams.NO_CTX,
39
+ ):
40
+ super().__init__(client=hatchet._client)
41
+ self.hatchet = hatchet
42
+ self.redis = redis_client
43
+ self.param_config = param_config
44
+
45
+ def task(self, *, name: str | None = None, **kwargs):
46
+ hatchet_task = super().task(name=name, **kwargs)
47
+
48
+ def decorator(func):
49
+ handler_dec = handle_task_callback(self.param_config)
50
+ func = handler_dec(func)
51
+ wf = hatchet_task(func)
52
+
53
+ nonlocal name
54
+ task_name = name or func.__name__
55
+ register = register_task(task_name)
56
+ return register(wf)
57
+
58
+ return decorator
59
+
60
+ def durable_task(self, *, name: str | None = None, **kwargs):
61
+ hatchet_task = super().durable_task(name=name, **kwargs)
62
+
63
+ def decorator(func):
64
+ handler_dec = handle_task_callback(self.param_config)
65
+ func = handler_dec(func)
66
+ wf = hatchet_task(func)
67
+ nonlocal name
68
+ task_name = name or func.__name__
69
+ register = register_task(task_name)
70
+ return register(wf)
71
+
72
+ return decorator
73
+
74
+ def worker(
75
+ self,
76
+ *args,
77
+ workflows: list[BaseWorkflow[Any]] | None = None,
78
+ lifespan: LifespanFn | None = None,
79
+ **kwargs,
80
+ ) -> Worker:
81
+ mageflow_flows = init_mageflow_hatchet_tasks(self.hatchet)
82
+ workflows += mageflow_flows
83
+ if lifespan is None:
84
+ lifespan = lifespan_initialize
85
+ else:
86
+ lifespan = functools.partial(merge_lifespan, lifespan)
87
+
88
+ return super().worker(*args, workflows=workflows, lifespan=lifespan, **kwargs)
89
+
90
+ async def sign(self, task: str | HatchetTaskType, **options: Any) -> TaskSignature:
91
+ return await sign(task, **options)
92
+
93
+ async def chain(
94
+ self,
95
+ tasks: list[TaskSignatureConvertible],
96
+ name: str = None,
97
+ error: TaskInputType = None,
98
+ success: TaskInputType = None,
99
+ ):
100
+ return await chain(tasks, name, error, success)
101
+
102
+ async def swarm(
103
+ self,
104
+ tasks: list[TaskSignatureConvertible] = None,
105
+ task_name: str = None,
106
+ **kwargs: Unpack[SignatureOptions],
107
+ ):
108
+ return await swarm(tasks, task_name, **kwargs)
109
+
110
+
111
+ T = TypeVar("T")
112
+
113
+
114
+ @overload
115
+ def Mageflow(
116
+ hatchet_client: Hatchet, redis_client: Redis | str = None
117
+ ) -> HatchetMageflow: ...
118
+
119
+
120
+ def Mageflow(
121
+ hatchet_client: T = None,
122
+ redis_client: Redis | str = None,
123
+ param_config: AcceptParams = AcceptParams.NO_CTX,
124
+ ) -> T:
125
+ if hatchet_client is None:
126
+ hatchet_client = Hatchet()
127
+
128
+ # Create a hatchet client with empty namespace for creating wf
129
+ config = hatchet_client._client.config.model_copy(deep=True)
130
+ config.namespace = ""
131
+ hatchet_caller = Hatchet(config=config, debug=hatchet_client._client.debug)
132
+ mageflow_config.hatchet_client = hatchet_caller
133
+
134
+ if redis_client is None:
135
+ redis_url = os.getenv("REDIS_URL")
136
+ redis_client = redis.asyncio.from_url(redis_url, decode_responses=True)
137
+ if isinstance(redis_client, str):
138
+ redis_client = redis.asyncio.from_url(redis_client, decode_responses=True)
139
+ mageflow_config.redis_client = redis_client
140
+ return HatchetMageflow(hatchet_client, redis_client, param_config)
mageflow/errors.py ADDED
@@ -0,0 +1,22 @@
1
+ class MageflowError(Exception):
2
+ pass
3
+
4
+
5
+ class MissingSignatureError(MageflowError):
6
+ pass
7
+
8
+
9
+ class MissingSwarmItemError(MissingSignatureError):
10
+ pass
11
+
12
+
13
+ class SwarmError(MageflowError):
14
+ pass
15
+
16
+
17
+ class TooManyTasksError(SwarmError, RuntimeError):
18
+ pass
19
+
20
+
21
+ class SwarmIsCanceledError(SwarmError, RuntimeError):
22
+ pass
mageflow/init.py ADDED
@@ -0,0 +1,53 @@
1
+ from hatchet_sdk import Hatchet
2
+
3
+ from mageflow.callbacks import register_task
4
+ from mageflow.chain.consts import ON_CHAIN_END, ON_CHAIN_ERROR
5
+ from mageflow.chain.messages import ChainSuccessTaskCommandMessage
6
+ from mageflow.chain.workflows import chain_end_task, chain_error_task
7
+ from mageflow.swarm.consts import ON_SWARM_ERROR, ON_SWARM_END, ON_SWARM_START
8
+ from mageflow.swarm.messages import SwarmResultsMessage
9
+ from mageflow.swarm.workflows import (
10
+ swarm_item_failed,
11
+ swarm_item_done,
12
+ swarm_start_tasks,
13
+ )
14
+
15
+
16
+ def init_mageflow_hatchet_tasks(hatchet: Hatchet):
17
+ # Chain tasks
18
+ hatchet_chain_done = hatchet.task(
19
+ name=ON_CHAIN_END,
20
+ input_validator=ChainSuccessTaskCommandMessage,
21
+ )
22
+ hatchet_chain_error = hatchet.task(name=ON_CHAIN_ERROR)
23
+ chain_done_task = hatchet_chain_done(chain_end_task)
24
+ on_chain_error_task = hatchet_chain_error(chain_error_task)
25
+ register_chain_done = register_task(ON_CHAIN_END)
26
+ register_chain_error = register_task(ON_CHAIN_ERROR)
27
+ chain_done_task = register_chain_done(chain_done_task)
28
+ on_chain_error_task = register_chain_error(on_chain_error_task)
29
+
30
+ # Swarm tasks
31
+ swarm_start = hatchet.task(name=ON_SWARM_START)
32
+ swarm_done = hatchet.task(
33
+ name=ON_SWARM_END,
34
+ input_validator=SwarmResultsMessage,
35
+ )
36
+ swarm_error = hatchet.task(name=ON_SWARM_ERROR)
37
+ swarm_start = swarm_start(swarm_start_tasks)
38
+ swarm_done = swarm_done(swarm_item_done)
39
+ swarm_error = swarm_error(swarm_item_failed)
40
+ register_swarm_start = register_task(ON_SWARM_START)
41
+ register_swarm_done = register_task(ON_SWARM_END)
42
+ register_swarm_error = register_task(ON_SWARM_ERROR)
43
+ swarm_start = register_swarm_start(swarm_start)
44
+ swarm_done = register_swarm_done(swarm_done)
45
+ swarm_error = register_swarm_error(swarm_error)
46
+
47
+ return [
48
+ on_chain_error_task,
49
+ chain_done_task,
50
+ swarm_start,
51
+ swarm_done,
52
+ swarm_error,
53
+ ]
File without changes
@@ -0,0 +1,34 @@
1
+ import abc
2
+ from abc import ABC
3
+ from typing import Any
4
+
5
+ from mageflow.signature.model import TaskSignature
6
+
7
+
8
+ class BaseInvoker(ABC):
9
+ @property
10
+ @abc.abstractmethod
11
+ def task_ctx(self) -> dict:
12
+ pass
13
+
14
+ @abc.abstractmethod
15
+ async def start_task(self):
16
+ pass
17
+
18
+ @abc.abstractmethod
19
+ async def run_success(self, result: Any) -> bool:
20
+ pass
21
+
22
+ @abc.abstractmethod
23
+ async def run_error(self) -> bool:
24
+ pass
25
+
26
+ @abc.abstractmethod
27
+ async def remove_task(
28
+ self, with_success: bool = True, with_error: bool = True
29
+ ) -> TaskSignature | None:
30
+ pass
31
+
32
+ @abc.abstractmethod
33
+ async def should_run_task(self) -> bool:
34
+ pass
@@ -0,0 +1,82 @@
1
+ import asyncio
2
+ from typing import Any
3
+
4
+ from hatchet_sdk import Context
5
+ from hatchet_sdk.runnables.contextvars import ctx_additional_metadata
6
+ from pydantic import BaseModel
7
+
8
+ from mageflow.invokers.base import BaseInvoker
9
+ from mageflow.signature.consts import TASK_ID_PARAM_NAME
10
+ from mageflow.signature.model import TaskSignature
11
+ from mageflow.signature.status import SignatureStatus
12
+ from mageflow.workflows import TASK_DATA_PARAM_NAME
13
+
14
+
15
+ class HatchetInvoker(BaseInvoker):
16
+ def __init__(self, message: BaseModel, ctx: Context):
17
+ self.message = message
18
+ self.task_data = ctx.additional_metadata.get(TASK_DATA_PARAM_NAME, {})
19
+ self.workflow_id = ctx.workflow_id
20
+ hatchet_ctx_metadata = ctx_additional_metadata.get() or {}
21
+ hatchet_ctx_metadata.pop(TASK_DATA_PARAM_NAME, None)
22
+ ctx_additional_metadata.set(hatchet_ctx_metadata)
23
+
24
+ @property
25
+ def task_ctx(self) -> dict:
26
+ return self.task_data
27
+
28
+ async def start_task(self):
29
+ task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
30
+ if task_id:
31
+ async with TaskSignature.lock_from_key(task_id) as signature:
32
+ await signature.change_status(SignatureStatus.ACTIVE)
33
+ await signature.task_status.aupdate(worker_task_id=self.workflow_id)
34
+
35
+ async def run_success(self, result: Any) -> bool:
36
+ success_publish_tasks = []
37
+ task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
38
+ if task_id:
39
+ current_task = await TaskSignature.get_safe(task_id)
40
+ task_success_workflows = current_task.activate_success(result)
41
+ success_publish_tasks.append(asyncio.create_task(task_success_workflows))
42
+
43
+ if success_publish_tasks:
44
+ await asyncio.gather(*success_publish_tasks)
45
+ return True
46
+ return False
47
+
48
+ async def run_error(self) -> bool:
49
+ error_publish_tasks = []
50
+ task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
51
+ if task_id:
52
+ current_task = await TaskSignature.get_safe(task_id)
53
+ task_error_workflows = current_task.activate_error(self.message)
54
+ error_publish_tasks.append(asyncio.create_task(task_error_workflows))
55
+
56
+ if error_publish_tasks:
57
+ await asyncio.gather(*error_publish_tasks)
58
+ return True
59
+ return False
60
+
61
+ async def remove_task(
62
+ self, with_success: bool = True, with_error: bool = True
63
+ ) -> TaskSignature | None:
64
+ task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
65
+ if task_id:
66
+ signature = await TaskSignature.get_safe(task_id)
67
+ if signature:
68
+ await signature.remove(with_error, with_success)
69
+
70
+ async def should_run_task(self) -> bool:
71
+ task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
72
+ if task_id:
73
+ signature = await TaskSignature.get_safe(task_id)
74
+ if signature is None:
75
+ return False
76
+ should_task_run = await signature.should_run()
77
+ if should_task_run:
78
+ return True
79
+ await signature.task_status.aupdate(last_status=SignatureStatus.ACTIVE)
80
+ await signature.handle_inactive_task(self.message)
81
+ return False
82
+ return True
File without changes
@@ -0,0 +1,6 @@
1
+ import dataclasses
2
+
3
+
4
+ @dataclasses.dataclass
5
+ class ReturnValue:
6
+ pass
File without changes
@@ -0,0 +1,3 @@
1
+ TASK_ID_PARAM_NAME = "task_id"
2
+
3
+ MAGEFLOW_TASK_INITIALS = "mageflow_"