ttflow 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ttflow/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ from .core.context import Context
2
+ from .core.run_context import RunContext
3
+ from .core.workflow import WorkflowRunResult
4
+ from .ttflow import Client, event_trigger, every_trigger, setup, state_trigger
5
+
6
+ __all__ = [
7
+ "Client",
8
+ "Context",
9
+ "RunContext",
10
+ "WorkflowRunResult",
11
+ "event_trigger",
12
+ "every_trigger",
13
+ "setup",
14
+ "state_trigger",
15
+ ]
ttflow/constants.py ADDED
@@ -0,0 +1,8 @@
1
+ # StateRepositoryで使用するシステムキー
2
+ STATE_KEY_EVENTS = "_events"
3
+ STATE_KEY_SYSTEM_LOCK = "_system_lock"
4
+ STATE_KEY_WORKFLOWS_HASH = "workflows_hash"
5
+
6
+ # ステートキーのプレフィックス
7
+ STATE_PREFIX_LOGS = "_logs"
8
+ STATE_PREFIX_RUN_STATE = "_run_state"
@@ -0,0 +1 @@
1
+ from .event import _enque_event # noqa
ttflow/core/context.py ADDED
@@ -0,0 +1,28 @@
1
+ import uuid
2
+
3
+
4
+ class Context:
5
+ """
6
+ ワークフローの実行時に渡されるコンテキスト
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ workflow_name: str,
12
+ run_id: str | None = None,
13
+ paused_info: dict | None = None,
14
+ ):
15
+ self.workflow_name = workflow_name
16
+ self.paused_info = (
17
+ paused_info # このrunがpauseからの再開だった場合、そのpause情報が入る
18
+ )
19
+ if run_id is None:
20
+ run_id = uuid.uuid4().hex
21
+ self.run_id = run_id
22
+ self.used_count = 0
23
+
24
+ def _use(self):
25
+ self.used_count += 1
26
+
27
+ def get_run_state_token(self):
28
+ return f"{self.run_id}:{self.used_count}"
ttflow/core/event.py ADDED
@@ -0,0 +1,71 @@
1
+ from collections.abc import Generator
2
+ from dataclasses import asdict
3
+ from typing import Any
4
+
5
+ from dacite import from_dict
6
+
7
+ from ..constants import STATE_KEY_EVENTS
8
+ from .global_env import Event, Global
9
+
10
+ # _global.eventsはオンメモリのイベントキューである
11
+ # ttflowの一回の実行は、イベントキューが空になるまで続く
12
+
13
+
14
+ def _enque_event(
15
+ g: Global,
16
+ event_name: str,
17
+ args: Any,
18
+ process_immediately: bool = True,
19
+ ) -> None:
20
+ """イベントをキューに追加する
21
+
22
+ Args:
23
+ process_immediately: Trueならこのrun中に処理される。Falseなら次回のrunで処理される。
24
+ """
25
+ e = Event(
26
+ event_name=event_name,
27
+ args=args,
28
+ )
29
+ if process_immediately:
30
+ g.events.append(e)
31
+ else:
32
+ g.events_for_next_run.append(e)
33
+
34
+
35
+ # triggerは実態としては単にイベントである
36
+ def _enque_trigger(g: Global, name: str, args: Any) -> None:
37
+ # triggerなので即時実行
38
+ _enque_event(g, f"_trigger_{name}", args, process_immediately=True)
39
+
40
+
41
+ def load_events_from_state(g: Global) -> None:
42
+ es = _read_events_from_state(g)
43
+ g.events.extend(es)
44
+
45
+
46
+ def _read_events_from_state(g: Global) -> list[Event]:
47
+ unprocessed_events = g.state.read_state(STATE_KEY_EVENTS, default=[])
48
+ return [from_dict(data_class=Event, data=e) for e in unprocessed_events]
49
+
50
+
51
+ def flush_events_for_next_run_to_state(g: Global) -> None:
52
+ es = [asdict(e) for e in g.events_for_next_run]
53
+ g.state.save_state(STATE_KEY_EVENTS, es)
54
+
55
+
56
+ def _pop_event(
57
+ g: Global,
58
+ ) -> Event | None:
59
+ if len(g.events) == 0:
60
+ return None
61
+ return g.events.pop(0)
62
+
63
+
64
+ def iterate_events(
65
+ g: Global,
66
+ ) -> Generator[Event, None, None]:
67
+ while True:
68
+ e = _pop_event(g)
69
+ if e is None:
70
+ break
71
+ yield e
@@ -0,0 +1,44 @@
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from ..state_repository.buffer_cache_proxy import BufferCacheStateRepositoryProxy
6
+ from .trigger import Trigger
7
+
8
+
9
+ @dataclass
10
+ class Event:
11
+ event_name: str
12
+ args: Any
13
+
14
+
15
+ class Workflow:
16
+ def __init__(self, trigger: Trigger, f: Callable):
17
+ self.trigger = trigger
18
+ self.f = f
19
+
20
+ @property
21
+ def name(self) -> str:
22
+ return self.f.__name__
23
+
24
+ @property
25
+ def description(self) -> str | None:
26
+ return self.f.__doc__
27
+
28
+
29
+ class Global:
30
+ def __init__(self, state: BufferCacheStateRepositoryProxy):
31
+ self.state = state
32
+
33
+ # 登録されたワークフロー
34
+ self.workflows: list[Workflow] = []
35
+
36
+ # イベントキュー
37
+ self.events: list[Event] = []
38
+
39
+ # 実行中にエンキューされたイベント。
40
+ self.events_for_next_run: list[Event] = []
41
+
42
+ def purge_events(self) -> None:
43
+ self.events = []
44
+ self.events_for_next_run = []
ttflow/core/pause.py ADDED
@@ -0,0 +1,46 @@
1
+ from ..system_states.event_log import _get_event_logs
2
+ from ..system_states.run_state import _is_already_executed, _mark_as_executed
3
+ from .context import Context
4
+ from .global_env import Global
5
+
6
+ # ワークフローの中断
7
+ # ワークフローは中断時にPauseExceptionを投げる。
8
+ # その場合、stateに中断情報が保存される。
9
+ # 中断ワークフローが再実行される場合、run_idが同じになるので
10
+ # run_idに対して処理が冪等になっていれば、何度中断しても問題ない。
11
+
12
+
13
+ class PauseException(Exception):
14
+ def __init__(self, pause_id: str):
15
+ super().__init__()
16
+ self.pause_id = pause_id
17
+
18
+
19
+ def _wait_event(g: Global, c: Context, event_name: str) -> None:
20
+ c._use()
21
+ pause_id = f"{c.run_id}:{c.used_count}"
22
+
23
+ # 初回なので中断情報を保存する
24
+ if c.paused_info is None:
25
+ raise PauseException(pause_id)
26
+
27
+ # 既にpaused、先に進むか判断する
28
+ event_log = _get_event_logs(g)
29
+ target_events = [
30
+ a
31
+ for a in event_log
32
+ if a["event_name"] == event_name and a["timestamp"] > c.paused_info["timestamp"]
33
+ ]
34
+ if len(target_events) == 0:
35
+ raise PauseException(pause_id)
36
+
37
+
38
+ def _pause_once(g: Global, c: Context) -> None:
39
+ """一度だけ中断します。次回無条件で再開します"""
40
+ cache = _is_already_executed(g, c)
41
+ if cache is not None:
42
+ return cache.value
43
+ # fを実行する前に計算しておく必要がある。変わってしまうので
44
+ token = c.get_run_state_token()
45
+ _mark_as_executed(g, c, token, None)
46
+ raise PauseException(c.get_run_state_token())
@@ -0,0 +1,48 @@
1
+ from typing import Any
2
+
3
+ from ..system_states.logs import log
4
+ from .context import Context
5
+ from .event import _enque_event
6
+ from .global_env import Global
7
+ from .pause import _pause_once, _wait_event
8
+ from .state import add_list_state, get_state, set_state
9
+
10
+
11
+ class RunContext:
12
+ def __init__(self, _global: Global, _context: Context):
13
+ self._global = _global
14
+ self._context = _context
15
+
16
+ def get_context_data(self) -> Context:
17
+ return self._context
18
+
19
+ def get_state(self, state_name: str, default: Any = None) -> Any:
20
+ return get_state(self._global, self._context, state_name, default)
21
+
22
+ def set_state(self, state_name: str, value: Any) -> None:
23
+ return set_state(self._global, self._context, state_name, value)
24
+
25
+ def add_list_state(
26
+ self,
27
+ state_name: str,
28
+ value: Any,
29
+ max_length: int | None = None,
30
+ ) -> None:
31
+ return add_list_state(
32
+ self._global, self._context, state_name, value, max_length=max_length
33
+ )
34
+
35
+ def log(self, message: str) -> None:
36
+ return log(self._global, self._context, message)
37
+
38
+ def wait_event(self, event_name: str) -> None:
39
+ """指定したイベントが発行されるまで中断します"""
40
+ _wait_event(self._global, self._context, event_name)
41
+
42
+ def pause_once(self) -> None:
43
+ """一度だけ中断します。次回無条件で再開します"""
44
+ _pause_once(self._global, self._context)
45
+
46
+ def event(self, name: str, args: Any) -> None:
47
+ """eventの発生をstateにキューイングします。次回のrun()で実行されます。"""
48
+ _enque_event(self._global, name, args)
ttflow/core/state.py ADDED
@@ -0,0 +1,66 @@
1
+ from typing import Any
2
+
3
+ from ..errors import InvalidStateError
4
+ from ..system_states.run_state import _execute_once
5
+ from .context import Context
6
+ from .event import _enque_event
7
+ from .global_env import Global
8
+
9
+
10
+ # ステートを書き込む。再実行時は何もしない
11
+ def set_state(g: Global, c: Context, state_name: str, value: Any) -> None:
12
+ @_execute_once(g, c)
13
+ def a():
14
+ # ステートを書き込み、変更があったら差分イベントを発行する
15
+ current_state = g.state.read_state(state_name)
16
+ g.state.save_state(state_name, value)
17
+ if current_state != value:
18
+ _enque_event(
19
+ g,
20
+ f"state_changed_{state_name}",
21
+ {"old": current_state, "new": value},
22
+ )
23
+
24
+ return a()
25
+
26
+
27
+ def get_state(g: Global, c: Context, state_name: str, default: Any = None) -> Any:
28
+ """ステートを取得する。再実行時はキャッシュする"""
29
+
30
+ @_execute_once(g, c)
31
+ def a():
32
+ return g.state.read_state(state_name, default=default)
33
+
34
+ return a()
35
+
36
+
37
+ def add_list_state(
38
+ g: Global,
39
+ c: Context,
40
+ state_name: str,
41
+ value: Any,
42
+ max_length: int | None = None,
43
+ ) -> None:
44
+ values = get_state(g, c, state_name, default=[])
45
+ if not isinstance(values, list):
46
+ raise InvalidStateError(f"state {state_name} is not list")
47
+ values = [a for a in values]
48
+ values.append(value)
49
+ if max_length is not None:
50
+ values = values[-max_length:]
51
+ set_state(g, c, state_name, values)
52
+
53
+
54
+ def _add_list_state_raw(
55
+ g: Global,
56
+ state_name: str,
57
+ value: Any,
58
+ max_length: int | None = None,
59
+ ) -> None:
60
+ values = g.state.read_state(state_name, [])
61
+ if not isinstance(values, list):
62
+ raise InvalidStateError(f"state {state_name} is not list")
63
+ values.append(value)
64
+ if max_length is not None:
65
+ values = values[-max_length:]
66
+ g.state.save_state(state_name, values)
@@ -0,0 +1,29 @@
1
+ from dataclasses import dataclass
2
+
3
+ from ..event import _enque_event
4
+ from ..global_env import Global
5
+
6
+ # _every イベントは、毎回一度自動的に発生するイベントです。
7
+ # 毎回実行したいような場合に使用します
8
+
9
+ SYSTEM_EVENT__EVERY = "_every"
10
+
11
+
12
+ @dataclass
13
+ class EveryEvent:
14
+ pass
15
+
16
+
17
+ def _enque_every_event(g: Global) -> None:
18
+ _enque_event(
19
+ g,
20
+ event_name=SYSTEM_EVENT__EVERY,
21
+ args=None,
22
+ process_immediately=True,
23
+ )
24
+
25
+
26
+ def try_parse_event__every(event_raw) -> EveryEvent | None:
27
+ if event_raw.event_name != SYSTEM_EVENT__EVERY:
28
+ return None
29
+ return EveryEvent()
@@ -0,0 +1,51 @@
1
+ import time
2
+ from dataclasses import asdict, dataclass
3
+ from typing import Any
4
+
5
+ from dacite import from_dict
6
+
7
+ from ..event import Event, _enque_event
8
+ from ..global_env import Global
9
+
10
+ # _pause イベントは、ワークフローの中断を表すイベントです。
11
+ # 中断されたワークフローは、即時ではなく次回の実行まで待機します。
12
+
13
+ SYSTEM_EVENT_PAUSE = "_pause"
14
+
15
+
16
+ @dataclass
17
+ class PauseEvent:
18
+ workflow_name: str # 実行してるワークフロー名
19
+ run_id: str # run_id
20
+ pause_id: str # 中断のID
21
+ args: Any # 実行時の引数
22
+ timestamp: float # 中断時刻
23
+
24
+
25
+ def _enque_pause_event(
26
+ g: Global,
27
+ workflow_name: str,
28
+ run_id: str,
29
+ pause_id: str,
30
+ args: Any,
31
+ ) -> None:
32
+ _enque_event(
33
+ g,
34
+ event_name=SYSTEM_EVENT_PAUSE,
35
+ args=asdict(
36
+ PauseEvent(
37
+ workflow_name=workflow_name,
38
+ run_id=run_id,
39
+ pause_id=pause_id,
40
+ args=args,
41
+ timestamp=time.time(),
42
+ )
43
+ ),
44
+ process_immediately=False,
45
+ )
46
+
47
+
48
+ def try_parse_pause_event(event_raw: Event) -> PauseEvent | None:
49
+ if event_raw.event_name != SYSTEM_EVENT_PAUSE:
50
+ return None
51
+ return from_dict(data_class=PauseEvent, data=event_raw.args)
ttflow/core/trigger.py ADDED
@@ -0,0 +1,9 @@
1
+ class Trigger:
2
+ def __init__(self, trigger_type: str):
3
+ self.trigger_type = trigger_type
4
+
5
+
6
+ class EventTrigger(Trigger):
7
+ def __init__(self, event_name: str):
8
+ super().__init__("event")
9
+ self.event_name = event_name
@@ -0,0 +1,183 @@
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+ import logging
5
+ from collections.abc import Callable, Generator
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ from ..errors import SideeffectUsageError, WorkflowDirectCallError
10
+ from ..system_states.completed import add_completed_runs_log, add_failed_runs_log
11
+ from ..system_states.logs import _get_logs
12
+ from ..system_states.run_state import (
13
+ _delete_run_state,
14
+ _execute_once,
15
+ _execute_once_async,
16
+ )
17
+ from .context import Context
18
+ from .global_env import Global, Workflow
19
+ from .pause import PauseException
20
+ from .run_context import RunContext
21
+ from .system_events.pause import _enque_pause_event
22
+ from .trigger import EventTrigger, Trigger
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def _find_workflow(g: Global, workflow_name: str) -> Workflow | None:
28
+ for wf in g.workflows:
29
+ if wf.f.__name__ == workflow_name:
30
+ return wf
31
+ return None
32
+
33
+
34
+ def _find_event_triggered_workflows(
35
+ g: Global, event_name: str
36
+ ) -> Generator[Workflow, None, None]:
37
+ for wf in g.workflows:
38
+ if isinstance(wf.trigger, EventTrigger) and wf.trigger.event_name == event_name:
39
+ yield wf
40
+
41
+
42
+ @dataclass
43
+ class WorkflowRunResult:
44
+ workflow_name: str
45
+ run_id: str
46
+ status: str # succeeded, failed, paused
47
+ error_message: str | None
48
+ logs: list[str]
49
+
50
+
51
+ def _call_workflow_func(wf: Workflow, g: Global, c: Context, args: Any):
52
+ """ワークフロー関数を呼び出す。asyncの場合はasyncio.run()で実行する"""
53
+ params = inspect.signature(wf.f).parameters
54
+ if inspect.iscoroutinefunction(wf.f):
55
+ # asyncワークフロー
56
+ if len(params) >= 2:
57
+ asyncio.run(wf.f(RunContext(g, c), args))
58
+ else:
59
+ asyncio.run(wf.f(RunContext(g, c)))
60
+ else:
61
+ # syncワークフロー
62
+ if len(params) >= 2:
63
+ wf.f(RunContext(g, c), args)
64
+ else:
65
+ wf.f(RunContext(g, c))
66
+
67
+
68
+ def exec_workflow(g: Global, c: Context, wf: Workflow, args: Any) -> WorkflowRunResult:
69
+ """workflowを実行する
70
+ 中断する場合、paused_workflowsステートにその状態を保存する
71
+ 完了したらcompleted_runs_logステートに記録する
72
+ """
73
+
74
+ try:
75
+ _call_workflow_func(wf, g, c, args)
76
+ except PauseException as e:
77
+ logger.info("ワークフローを中断します")
78
+ _enque_pause_event(
79
+ g,
80
+ workflow_name=wf.f.__name__,
81
+ run_id=c.run_id,
82
+ pause_id=e.pause_id,
83
+ args=args,
84
+ )
85
+ print(f"{wf.f.__name__}: 中断します")
86
+ return WorkflowRunResult(
87
+ workflow_name=wf.f.__name__,
88
+ run_id=c.run_id,
89
+ status="paused",
90
+ error_message=None,
91
+ logs=_get_logs(g, c.run_id),
92
+ )
93
+ except Exception as e:
94
+ logger.error(f"ワークフローが失敗しました: {e}")
95
+ print(f"{wf.f.__name__}: error: {e}")
96
+ add_failed_runs_log(g, c)
97
+ return WorkflowRunResult(
98
+ workflow_name=wf.f.__name__,
99
+ run_id=c.run_id,
100
+ status="failed",
101
+ error_message=repr(e),
102
+ logs=_get_logs(g, c.run_id),
103
+ )
104
+ add_completed_runs_log(g, c)
105
+ _delete_run_state(g, c.run_id)
106
+ print(f"{wf.f.__name__}: 正常終了しました")
107
+ return WorkflowRunResult(
108
+ workflow_name=wf.f.__name__,
109
+ run_id=c.run_id,
110
+ status="succeeded",
111
+ error_message=None,
112
+ logs=_get_logs(g, c.run_id),
113
+ )
114
+
115
+
116
+ def workflow(g: Global, trigger: Trigger | str | None = None) -> Callable:
117
+ def _decorator(f: Callable) -> Callable:
118
+ if trigger is None:
119
+ t = EventTrigger(f"_trigger_{f.__name__}")
120
+ elif isinstance(trigger, str):
121
+ t = EventTrigger(f"_trigger_{trigger}")
122
+ else:
123
+ t = trigger
124
+ wf = Workflow(t, f)
125
+ g.workflows.append(wf)
126
+
127
+ @functools.wraps(f)
128
+ def _wrapper(*args, **kwargs):
129
+ raise WorkflowDirectCallError("workflow can not be called directly")
130
+
131
+ return _wrapper
132
+
133
+ return _decorator
134
+
135
+
136
+ def sideeffect(g: Global) -> Callable:
137
+ def _decorator(f: Callable) -> Callable:
138
+ if inspect.iscoroutinefunction(f):
139
+ # async sideeffect: syncワークフローからの呼び出しを検出するため、
140
+ # 通常の関数でラップし、実行中イベントループの有無で判定する
141
+ @functools.wraps(f)
142
+ def _async_guard(*args, **kwargs):
143
+ try:
144
+ asyncio.get_running_loop()
145
+ except RuntimeError:
146
+ raise SideeffectUsageError(
147
+ "async sideeffectはasyncワークフローからのみ呼び出せます"
148
+ )
149
+ return _async_impl(*args, **kwargs)
150
+
151
+ async def _async_impl(*args, **kwargs):
152
+ if len(args) == 0 or not isinstance(args[0], RunContext):
153
+ raise SideeffectUsageError(
154
+ "sideeffectはRunContextを第1引数に取る必要があります"
155
+ )
156
+ c = args[0]
157
+
158
+ @_execute_once_async(g, c.get_context_data())
159
+ async def a():
160
+ return await f(*args, **kwargs)
161
+
162
+ return await a()
163
+
164
+ return _async_guard
165
+ else:
166
+ # sync sideeffect
167
+ @functools.wraps(f)
168
+ def _wrapper(*args, **kwargs):
169
+ if len(args) == 0 or not isinstance(args[0], RunContext):
170
+ raise SideeffectUsageError(
171
+ "sideeffectはRunContextを第1引数に取る必要があります"
172
+ )
173
+ c = args[0]
174
+
175
+ @_execute_once(g, c.get_context_data())
176
+ def a():
177
+ return f(*args, **kwargs)
178
+
179
+ return a()
180
+
181
+ return _wrapper
182
+
183
+ return _decorator
ttflow/errors.py ADDED
@@ -0,0 +1,22 @@
1
+ class TtflowError(Exception):
2
+ """ttflowの基底例外"""
3
+
4
+
5
+ class StateLockedError(TtflowError):
6
+ """状態がロックされている"""
7
+
8
+
9
+ class UnknownRepositoryError(TtflowError):
10
+ """不明なStateRepository指定"""
11
+
12
+
13
+ class InvalidStateError(TtflowError):
14
+ """ステートの値が不正"""
15
+
16
+
17
+ class WorkflowDirectCallError(TtflowError):
18
+ """ワークフロー関数を直接呼び出した"""
19
+
20
+
21
+ class SideeffectUsageError(TtflowError):
22
+ """sideeffect関数の使い方が不正"""
File without changes
@@ -0,0 +1,43 @@
1
+ from typing import Any, Optional
2
+
3
+ import fire
4
+
5
+ from ttflow import Client, WorkflowRunResult
6
+
7
+
8
+ def run(client: Client):
9
+ def _internal(trigger_name: Optional[str] = None, args: Any = None):
10
+ results = client.run(trigger_name, args)
11
+ _print_workflow_results(results)
12
+
13
+ return _internal
14
+
15
+
16
+ def _print_workflow_results(results: list[WorkflowRunResult]):
17
+ print()
18
+ print("---------RUN SUMMARY---------")
19
+ print(f"{len(results)}件のワークフローが実行されました")
20
+ for i, result in enumerate(results):
21
+ print(f"\t{i + 1}件目")
22
+ print(f"\t ワークフロー名: {result.workflow_name}")
23
+ print(f"\t run_id: {result.run_id}")
24
+ print(f"\t 状態: {result.status}")
25
+ print("\t ログ:")
26
+ for log in result.logs:
27
+ print(f"\t - {log}")
28
+
29
+
30
+ def clear_state(client: Client):
31
+ def _internal():
32
+ client._global.state.clear_state()
33
+
34
+ return _internal
35
+
36
+
37
+ def run_by_cli(client: Client, *, enabled_dangerous_clear_state_command=False):
38
+ opts: dict[str, Any] = {
39
+ "run": run(client),
40
+ }
41
+ if enabled_dangerous_clear_state_command:
42
+ opts["clear_state"] = clear_state(client)
43
+ fire.Fire(opts)