pyfilament 0.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. pyfilament-0.0.0/PKG-INFO +26 -0
  2. pyfilament-0.0.0/pyproject.toml +64 -0
  3. pyfilament-0.0.0/src/filament/.gitignore +2 -0
  4. pyfilament-0.0.0/src/filament/api/app.py +3 -0
  5. pyfilament-0.0.0/src/filament/api/graphql.py +63 -0
  6. pyfilament-0.0.0/src/filament/api/logic/task_run_dict.py +27 -0
  7. pyfilament-0.0.0/src/filament/api/main.py +11 -0
  8. pyfilament-0.0.0/src/filament/api/middleware.py +11 -0
  9. pyfilament-0.0.0/src/filament/api/resolvers/task.py +168 -0
  10. pyfilament-0.0.0/src/filament/api/routes.py +56 -0
  11. pyfilament-0.0.0/src/filament/api/setup_logging.py +25 -0
  12. pyfilament-0.0.0/src/filament/api/types/task.py +121 -0
  13. pyfilament-0.0.0/src/filament/cli/cli_cleanup.py +86 -0
  14. pyfilament-0.0.0/src/filament/db/models.py +82 -0
  15. pyfilament-0.0.0/src/filament/db/session.py +40 -0
  16. pyfilament-0.0.0/src/filament/filament.py +41 -0
  17. pyfilament-0.0.0/src/filament/hooks.py +31 -0
  18. pyfilament-0.0.0/src/filament/logic/cache_keys.py +47 -0
  19. pyfilament-0.0.0/src/filament/logic/cache_utils.py +32 -0
  20. pyfilament-0.0.0/src/filament/logic/call_stack.py +43 -0
  21. pyfilament-0.0.0/src/filament/logic/events.py +32 -0
  22. pyfilament-0.0.0/src/filament/logic/func_registry.py +62 -0
  23. pyfilament-0.0.0/src/filament/logic/module_type_registry.py +24 -0
  24. pyfilament-0.0.0/src/filament/logic/type_checking.py +22 -0
  25. pyfilament-0.0.0/src/filament/logic/utils.py +168 -0
  26. pyfilament-0.0.0/src/filament/queue/task_queue.py +158 -0
  27. pyfilament-0.0.0/src/filament/queue/types/__init__.py +16 -0
  28. pyfilament-0.0.0/src/filament/queue/types/remote_exception.py +14 -0
  29. pyfilament-0.0.0/src/filament/queue/types/remote_task_result.py +49 -0
  30. pyfilament-0.0.0/src/filament/queue/types/remote_task_run.py +53 -0
  31. pyfilament-0.0.0/src/filament/queue/types/remote_task_type.py +102 -0
  32. pyfilament-0.0.0/src/filament/redis/client.py +17 -0
  33. pyfilament-0.0.0/src/filament/redis/logging_handler.py +35 -0
  34. pyfilament-0.0.0/src/filament/redis/semaphore.py +271 -0
  35. pyfilament-0.0.0/src/filament/redis/token_bucket.py +67 -0
  36. pyfilament-0.0.0/src/filament/state/common.py +28 -0
  37. pyfilament-0.0.0/src/filament/state/register.py +65 -0
  38. pyfilament-0.0.0/src/filament/state/task_run_state.py +103 -0
  39. pyfilament-0.0.0/src/filament/state/task_type_state.py +113 -0
  40. pyfilament-0.0.0/src/filament/task/constants.py +18 -0
  41. pyfilament-0.0.0/src/filament/task/registry/task_type_registry.py +37 -0
  42. pyfilament-0.0.0/src/filament/task/types/__init__.py +23 -0
  43. pyfilament-0.0.0/src/filament/task/types/base.py +6 -0
  44. pyfilament-0.0.0/src/filament/task/types/cache_key.py +17 -0
  45. pyfilament-0.0.0/src/filament/task/types/exception_type.py +17 -0
  46. pyfilament-0.0.0/src/filament/task/types/task_config.py +54 -0
  47. pyfilament-0.0.0/src/filament/task/types/task_run.py +348 -0
  48. pyfilament-0.0.0/src/filament/task/types/task_type.py +65 -0
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.4
2
+ Name: pyfilament
3
+ Version: 0.0.0
4
+ Summary:
5
+ Author: James Wu
6
+ Author-email: james@centauri-ai.tech
7
+ Requires-Python: >=3.14,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.14
10
+ Provides-Extra: dev
11
+ Requires-Dist: aiosqlite (>=0.22.1,<0.23.0) ; extra == "dev"
12
+ Requires-Dist: alembic (>=1.18.4,<2.0.0)
13
+ Requires-Dist: anyio (>=4.13.0,<5.0.0)
14
+ Requires-Dist: beartype (>=0.22.9,<0.23.0)
15
+ Requires-Dist: dotenv (>=0.9.9,<0.10.0)
16
+ Requires-Dist: fastapi (>=0.136.3,<0.137.0)
17
+ Requires-Dist: inflection (>=0.5.1,<0.6.0)
18
+ Requires-Dist: pandas (>=3.0.3,<4.0.0)
19
+ Requires-Dist: polars (>=1.41.2,<2.0.0)
20
+ Requires-Dist: pydantic (>=2.13.4,<3.0.0)
21
+ Requires-Dist: redis (>=8.0.0,<9.0.0)
22
+ Requires-Dist: sqlalchemy[asyncio] (>=2.0.50,<3.0.0)
23
+ Requires-Dist: starlette (>=1.2.1,<2.0.0)
24
+ Requires-Dist: strawberry-graphql (>=0.316.0,<0.317.0)
25
+ Requires-Dist: uvicorn (>=0.49.0,<0.50.0)
26
+ Requires-Dist: werkzeug (>=3.1.8,<4.0.0)
@@ -0,0 +1,64 @@
1
+ [tool.ruff]
2
+ line-length = 120
3
+ target-version = 'py314'
4
+
5
+ [tool.ruff.format]
6
+ quote-style = 'single'
7
+
8
+ [tool.mypy]
9
+ plugins = ['pydantic.mypy']
10
+ disable_error_code = ['import-untyped', 'empty-body', 'prop-decorator']
11
+
12
+ [tool.poetry]
13
+ name = "pyfilament"
14
+ version = "0.0.0"
15
+ description = ""
16
+ authors = ["James Wu <james@centauri-ai.tech>"]
17
+ packages = [{ include = "filament", from = "src" }]
18
+
19
+ [tool.poetry.dependencies]
20
+ python = "^3.14"
21
+ sqlalchemy = {extras = ["asyncio"], version = "^2.0.50"}
22
+ alembic = "^1.18.4"
23
+ uvicorn = "^0.49.0"
24
+ anyio = "^4.13.0"
25
+ inflection = "^0.5.1"
26
+ pydantic = "^2.13.4"
27
+ beartype = "^0.22.9"
28
+ pandas = "^3.0.3"
29
+ polars = "^1.41.2"
30
+ dotenv = "^0.9.9"
31
+ fastapi = "^0.136.3"
32
+ strawberry-graphql = "^0.316.0"
33
+ starlette = "^1.2.1"
34
+ werkzeug = "^3.1.8"
35
+ redis = "^8.0.0"
36
+ aiosqlite = { version = "^0.22.1", optional = true }
37
+
38
+ [tool.poetry.extras]
39
+ dev = ["aiosqlite"]
40
+
41
+ [tool.poetry.group.dev.dependencies]
42
+ pytest = "^9.0.3"
43
+ hupper = "^1.12.1"
44
+ watchdog = "^6.0.0"
45
+ ruff = "^0.15.16"
46
+ mypy = "^1.20.2"
47
+ pytest-asyncio = "^1.4.0"
48
+ pytest-cov = "^7.1.0"
49
+ diff-cover = "^10.3.0"
50
+ coverage = "^7.14.1"
51
+ mock = "^5.2.0"
52
+
53
+ [build-system]
54
+ requires = ["poetry-core"]
55
+ build-backend = "poetry.core.masonry.api"
56
+
57
+ [tool.pytest.ini_options]
58
+ asyncio_mode = "auto"
59
+ asyncio_default_test_loop_scope = "session"
60
+ asyncio_default_fixture_loop_scope = "session"
61
+ log_cli = true
62
+ log_cli_level = "INFO"
63
+ log_cli_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
64
+ log_cli_date_format = "%Y-%m-%d %H:%M:%S"
@@ -0,0 +1,2 @@
1
+ __pycache__
2
+ .venv/
@@ -0,0 +1,3 @@
1
+ from fastapi import FastAPI
2
+
3
+ app = FastAPI()
@@ -0,0 +1,63 @@
1
+ import strawberry
2
+ from fastapi import Depends, Request
3
+ from strawberry.extensions import SchemaExtension
4
+ from strawberry.fastapi import GraphQLRouter
5
+
6
+ import filament.api.resolvers.task as task_resolver
7
+ from filament.api.types.task import TaskRun, TaskType
8
+ from filament.db.models import Base
9
+
10
+
11
+ def get_session_from_request(request: Request):
12
+ return request.state.session
13
+
14
+
15
+ async def get_context(
16
+ session=Depends(get_session_from_request),
17
+ ):
18
+ return {
19
+ 'session': session,
20
+ }
21
+
22
+
23
+ class SessionFlusher(SchemaExtension):
24
+ def resolve(self, _next, root, info, *args, **kwargs):
25
+ if (
26
+ info.path is not None
27
+ and info.path.key is not None
28
+ and (info.path.key == 'id' or info.path.key.endswith('_id'))
29
+ ):
30
+ if isinstance(root, Base) and getattr(root, info.path.key, None) is None:
31
+ if info.context is not None and 'session' in info.context:
32
+ info.context['session'].flush()
33
+ return _next(root, info, *args, **kwargs)
34
+
35
+
36
+ @strawberry.type
37
+ class Query:
38
+ get_task_run: TaskRun = strawberry.field(resolver=task_resolver.get_task_run)
39
+ get_task_type: TaskType = strawberry.field(resolver=task_resolver.get_task_type)
40
+ get_task_types: list[TaskType] = strawberry.field(resolver=task_resolver.get_task_types)
41
+ get_task_runs: list[TaskRun] = strawberry.field(resolver=task_resolver.get_task_runs)
42
+ get_task_runs_by_ids: list[TaskRun] = strawberry.field(resolver=task_resolver.get_task_runs_by_ids)
43
+ get_task_types_by_ids: list[TaskType] = strawberry.field(resolver=task_resolver.get_task_types_by_ids)
44
+ get_task_type_stack_runs: list[TaskRun] = strawberry.field(resolver=task_resolver.get_task_type_stack_runs)
45
+
46
+
47
+ @strawberry.type
48
+ class Mutation:
49
+ cancel_task_run: TaskRun = strawberry.field(resolver=task_resolver.cancel_task_run)
50
+ run_task: TaskRun = strawberry.field(resolver=task_resolver.run_task)
51
+
52
+
53
+ schema = strawberry.Schema(
54
+ query=Query,
55
+ mutation=Mutation,
56
+ extensions=[SessionFlusher],
57
+ )
58
+
59
+
60
+ graphql_app = GraphQLRouter(
61
+ schema,
62
+ context_getter=get_context,
63
+ )
@@ -0,0 +1,27 @@
1
+ from beartype import beartype
2
+
3
+ from filament.db.models import TaskRun as TaskRunModel
4
+ from filament.logic.utils import avoid_nans, get_json_dict
5
+
6
+
7
+ @beartype
8
+ async def deep_get_task_run_dict(task_run: TaskRunModel, max_child_tasks: int = 100, child_depth: int = 0) -> dict:
9
+ task_run_dict = get_json_dict(task_run)
10
+ task_run_dict['task_type'] = get_json_dict(await task_run.awaitable_attrs.task_type)
11
+ if child_depth > 0:
12
+ sorted_child_tasks = sorted(await task_run.awaitable_attrs.child_tasks, key=lambda x: x.id)
13
+ task_run_dict['child_tasks'] = [
14
+ await deep_get_task_run_dict(child_task_run, max_child_tasks, child_depth - 1)
15
+ for child_task_run in sorted_child_tasks[:max_child_tasks]
16
+ ]
17
+ else:
18
+ task_run_dict['child_tasks'] = []
19
+ sorted_state_transitions = sorted(await task_run.awaitable_attrs.state_transitions, key=lambda x: x.id)
20
+ if task_run.parameters_json is not None:
21
+ task_run_dict['parameters_json'] = avoid_nans(task_run.parameters_json)
22
+ if task_run.result_json is not None:
23
+ task_run_dict['result_json'] = avoid_nans(task_run.result_json)
24
+ task_run_dict['state_transitions'] = [
25
+ get_json_dict(state_transition) for state_transition in sorted_state_transitions
26
+ ]
27
+ return task_run_dict
@@ -0,0 +1,11 @@
1
+ from filament.api.setup_logging import setup_logging
2
+
3
+ setup_logging()
4
+
5
+ import filament.api.routes # noqa: F401
6
+ from filament.api.app import app
7
+ from filament.api.graphql import graphql_app
8
+ from filament.api.middleware import SessionMiddleware
9
+
10
+ app.add_middleware(SessionMiddleware)
11
+ app.include_router(graphql_app, prefix='/graphql')
@@ -0,0 +1,11 @@
1
+ from fastapi import Request
2
+ from starlette.middleware.base import BaseHTTPMiddleware
3
+
4
+ from filament.db.session import async_session_scope
5
+
6
+
7
+ class SessionMiddleware(BaseHTTPMiddleware):
8
+ async def dispatch(self, request: Request, call_next):
9
+ async with async_session_scope() as session:
10
+ request.state.session = session
11
+ return await call_next(request)
@@ -0,0 +1,168 @@
1
+ import datetime
2
+ import json
3
+
4
+ from sqlalchemy import select
5
+ from sqlalchemy.orm import aliased
6
+ from sqlalchemy.sql import func
7
+ from strawberry import ID, Info
8
+ from werkzeug.exceptions import BadRequest, NotFound
9
+
10
+ from filament.api.types.task import TaskRun, TaskType
11
+ from filament.db.models import TaskRun as TaskRunModel
12
+ from filament.db.models import TaskType as TaskTypeModel
13
+ from filament.state.task_run_state import cancel_task_run as logic_cancel_task_run
14
+ from filament.task.registry.task_type_registry import lookup
15
+
16
+ DEFAULT_MAX_DAYS = 3
17
+
18
+
19
+ async def get_task_run(self, info: Info, id: ID | None = None, task_uuid: str | None = None) -> TaskRun:
20
+ session = info.context['session']
21
+ if task_uuid is not None:
22
+ statement = select(TaskRunModel).where(TaskRunModel.task_uuid == task_uuid)
23
+ elif id is not None:
24
+ statement = select(TaskRunModel).where(TaskRunModel.id == int(id))
25
+ else:
26
+ raise BadRequest('Either id or task_uuid must be provided')
27
+ task_run = (await session.execute(statement)).scalars().one_or_none()
28
+ if not task_run:
29
+ raise NotFound(f'TaskRun with ID {id} or UUID {task_uuid} not found')
30
+ return task_run
31
+
32
+
33
+ async def get_task_type(self, info: Info, id: ID | None = None, func_address: str | None = None) -> TaskType:
34
+ session = info.context['session']
35
+ if func_address is not None:
36
+ statement = select(TaskTypeModel).where(TaskTypeModel.func_address == func_address)
37
+ elif id is not None:
38
+ statement = select(TaskTypeModel).where(TaskTypeModel.id == int(id))
39
+ else:
40
+ raise BadRequest('Either id or func_address must be provided')
41
+ task_type = (await session.execute(statement)).scalars().one_or_none()
42
+ if not task_type:
43
+ raise NotFound(f'TaskType with id {id} or func_address {func_address} not found')
44
+ return task_type
45
+
46
+
47
+ async def get_task_types_by_ids(
48
+ self, info: Info, ids: list[int] | None = None, uuids: list[str] | None = None
49
+ ) -> list[TaskType]:
50
+ session = info.context['session']
51
+ if ids:
52
+ statement = select(TaskTypeModel).where(TaskTypeModel.id.in_(ids))
53
+ task_types = (await session.execute(statement)).scalars().all()
54
+ ids_to_task_types = {task_type.id: task_type for task_type in task_types}
55
+ task_types = [ids_to_task_types[id] for id in ids if id in ids_to_task_types]
56
+ elif uuids:
57
+ statement = select(TaskTypeModel).where(TaskTypeModel.uuid.in_(uuids))
58
+ task_types = (await session.execute(statement)).scalars().all()
59
+ uuids_to_task_types = {task_type.uuid: task_type for task_type in task_types}
60
+ task_types = [uuids_to_task_types[uuid] for uuid in uuids if uuid in uuids_to_task_types]
61
+ else:
62
+ raise BadRequest('Either ids or uuids must be provided')
63
+ return task_types
64
+
65
+
66
+ async def get_task_type_stack_runs(
67
+ self, info: Info, task_type_ids: list[int], states: list[str] | None = None
68
+ ) -> list[TaskRun]:
69
+ session = info.context['session']
70
+ if len(task_type_ids) == 0:
71
+ raise BadRequest('task_type_ids must be provided')
72
+ MAX_RESULTS = 100
73
+ final_task_type_id = task_type_ids[-1]
74
+ statement = select(TaskRunModel).filter(TaskRunModel.task_type_id == final_task_type_id)
75
+ if states:
76
+ statement = statement.where(TaskRunModel.state.in_(states))
77
+ last_model = TaskRunModel
78
+ for task_type_id in reversed(task_type_ids[:-1]):
79
+ current_model = aliased(TaskRunModel)
80
+ statement = statement.join(current_model, last_model.parent_task_uuid == current_model.task_uuid)
81
+ statement = statement.where(current_model.task_type_id == task_type_id)
82
+ last_model = current_model
83
+ statement = (
84
+ statement.where(last_model.parent_task_uuid.is_(None))
85
+ .order_by(TaskRunModel.created_at.desc())
86
+ .limit(MAX_RESULTS)
87
+ )
88
+ task_runs = (await session.execute(statement)).scalars().all()
89
+ return task_runs
90
+
91
+
92
+ async def get_task_runs_by_ids(self, info: Info, ids: list[int]) -> list[TaskRun]:
93
+ session = info.context['session']
94
+ if len(ids) == 0:
95
+ raise BadRequest('ids must be provided')
96
+ statement = select(TaskRunModel).where(TaskRunModel.id.in_(ids))
97
+ task_runs = (await session.execute(statement)).scalars().all()
98
+ ids_to_task_runs = {task_run.id: task_run for task_run in task_runs}
99
+ for id in ids:
100
+ if id not in ids_to_task_runs:
101
+ raise NotFound(f'TaskRun with ID {id} not found')
102
+ task_runs = [ids_to_task_runs[id] for id in ids if id in ids_to_task_runs]
103
+ return task_runs
104
+
105
+
106
+ async def get_task_types(self, info: Info, days: int = DEFAULT_MAX_DAYS):
107
+ session = info.context['session']
108
+ today = datetime.datetime.now()
109
+ before = today - datetime.timedelta(days=days)
110
+ subquery = (
111
+ select(TaskRunModel.task_type_id, func.max(TaskRunModel.id).label('task_run_id'))
112
+ .filter(TaskRunModel.created_at > before)
113
+ .group_by(TaskRunModel.task_type_id)
114
+ .subquery()
115
+ )
116
+ task_types_statement = select(TaskTypeModel).join(
117
+ subquery,
118
+ TaskTypeModel.id == subquery.c.task_type_id,
119
+ )
120
+ task_types = (await session.execute(task_types_statement)).scalars().all()
121
+ return task_types
122
+
123
+
124
+ async def cancel_task_run(self, info: Info, id: ID | None = None, task_uuid: str | None = None) -> TaskRun:
125
+ session = info.context['session']
126
+ if task_uuid is not None:
127
+ statement = select(TaskRunModel).where(TaskRunModel.task_uuid == task_uuid)
128
+ elif id is not None:
129
+ statement = select(TaskRunModel).where(TaskRunModel.id == int(id))
130
+ else:
131
+ raise BadRequest('Either id or task_uuid must be provided')
132
+ task_run = (await session.execute(statement)).scalars().one_or_none()
133
+ if not task_run:
134
+ raise NotFound(f'TaskRun with UUID {task_uuid} not found')
135
+ await logic_cancel_task_run(session, task_run)
136
+ return task_run
137
+
138
+
139
+ async def get_task_runs(
140
+ self, info: Info, task_type_id: ID, states: list[str] | None = None, days: int = DEFAULT_MAX_DAYS
141
+ ):
142
+ session = info.context['session']
143
+ today = datetime.datetime.now()
144
+ before = today - datetime.timedelta(days=days)
145
+ statement = (
146
+ select(TaskRunModel)
147
+ .where(TaskRunModel.task_type_id == int(task_type_id))
148
+ .where(TaskRunModel.created_at > before)
149
+ )
150
+ if states:
151
+ statement = statement.where(TaskRunModel.state.in_(states))
152
+ statement = statement.order_by(TaskRunModel.created_at.desc()).limit(99)
153
+ task_runs = (await session.execute(statement)).scalars().all()
154
+ return task_runs
155
+
156
+
157
+ async def run_task(self, info: Info, task_type_id: ID, parameters_json: str) -> TaskRun:
158
+ session = info.context['session']
159
+ statement = select(TaskTypeModel).where(TaskTypeModel.id == int(task_type_id))
160
+ task_type = (await session.execute(statement)).scalars().one()
161
+ func_address = task_type.func_address
162
+ filament_task_type = lookup(func_address)
163
+ parameters = json.loads(parameters_json)
164
+ parameters.update({'start_immediately': True})
165
+ filament_task_run = await filament_task_type._request(task_args=[], task_kwargs=parameters)
166
+ statement = select(TaskRunModel).where(TaskRunModel.task_uuid == filament_task_run.uuid)
167
+ task_run = (await session.execute(statement)).scalars().one()
168
+ return task_run
@@ -0,0 +1,56 @@
1
+ import logging
2
+
3
+ from fastapi import Request, Response
4
+ from werkzeug.exceptions import NotFound
5
+
6
+ from filament.api.app import app
7
+ from filament.api.logic.task_run_dict import deep_get_task_run_dict
8
+ from filament.db.models import TaskRun as TaskRunModel
9
+ from filament.logic.utils import rename_keys_to_camel_case, safe_json_dumps
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @app.get('/tasks')
15
+ async def root():
16
+ return {'message': 'Hello World'}
17
+
18
+
19
+ @app.get('/api/task-run/{task_run_id}')
20
+ async def get_task_run(request: Request, task_run_id: int, max_child_tasks: int = 100, child_depth: int = 3):
21
+ session = request.state.session
22
+ task_run = await session.get(TaskRunModel, task_run_id)
23
+ if task_run is None:
24
+ raise NotFound(f'TaskRun with ID {task_run_id} not found')
25
+ task_run_dict = await deep_get_task_run_dict(task_run, max_child_tasks, child_depth)
26
+ return rename_keys_to_camel_case(task_run_dict)
27
+
28
+
29
+ @app.get('/api/task-runs/{task_run_ids_str}')
30
+ async def get_task_runs(request: Request, task_run_ids_str: str, max_child_tasks: int = 100, child_depth: int = 3):
31
+ task_run_ids = [int(id) for id in task_run_ids_str.split(',')]
32
+ session = request.state.session
33
+ task_runs = []
34
+ for task_run_id in task_run_ids:
35
+ task_run = await session.get(TaskRunModel, task_run_id)
36
+ if task_run is None:
37
+ raise NotFound(f'TaskRun with ID {task_run_id} not found')
38
+ task_runs.append(await deep_get_task_run_dict(task_run, max_child_tasks, child_depth))
39
+ return rename_keys_to_camel_case(task_runs)
40
+
41
+
42
+ @app.get('/api/task-run/{task_run_id}/download')
43
+ async def download_task_run(request: Request, task_run_id: int, max_child_tasks: int = 100, child_depth: int = 3):
44
+ session = request.state.session
45
+ task_run = await session.get(TaskRunModel, task_run_id)
46
+ if task_run is None:
47
+ raise NotFound(f'TaskRun with ID {task_run_id} not found')
48
+ task_run_dict = await deep_get_task_run_dict(task_run, max_child_tasks, child_depth)
49
+ file_content = safe_json_dumps(rename_keys_to_camel_case(task_run_dict), indent=2, sort_keys=True).encode('utf-8')
50
+ filename = f'task_run_{task_run_id}.json'
51
+ headers = {'Content-Disposition': f'attachment; filename="{filename}"'}
52
+ return Response(
53
+ content=file_content,
54
+ media_type='application/json',
55
+ headers=headers,
56
+ )
@@ -0,0 +1,25 @@
1
+ import logging
2
+
3
+ _is_setup = False
4
+
5
+
6
+ class UnixTimeFormatter(logging.Formatter):
7
+ def formatTime(self, record, datefmt=None):
8
+ t = round(record.created, 1)
9
+ return f'{t:.1f}'
10
+
11
+
12
+ def setup_logging():
13
+ global _is_setup
14
+ if _is_setup:
15
+ return
16
+ formatter = UnixTimeFormatter(fmt='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
17
+ handler = logging.StreamHandler()
18
+ handler.setFormatter(formatter)
19
+ rootLogger = logging.getLogger()
20
+ rootLogger.setLevel(logging.DEBUG)
21
+ rootLogger.handlers = [handler] # Replace any existing handlers
22
+ _is_setup = True
23
+
24
+
25
+ setup_logging()
@@ -0,0 +1,121 @@
1
+ import json
2
+ from datetime import datetime
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from strawberry import Info
7
+
8
+ from filament.db.models import TaskRun as TaskRunModel
9
+ from filament.redis.client import r
10
+
11
+
12
+ async def get_logs(
13
+ task_run: TaskRunModel, with_children: bool = True, max_depth: int = 3, max_num_children: int = 100
14
+ ) -> list[dict]:
15
+ logs = []
16
+ redis_key = f'filament_log:{(await task_run.awaitable_attrs.task_type).func_address}:{task_run.task_uuid}'
17
+ range_results = await r.lrange(redis_key, 0, -1)
18
+ for range_result in range_results:
19
+ logs.append(json.loads(range_result))
20
+ if with_children and max_depth > 0:
21
+ for child_task in (await task_run.awaitable_attrs.child_tasks)[:max_num_children]:
22
+ logs.extend(await get_logs(child_task, with_children, max_depth - 1, max_num_children))
23
+ return sorted(logs, key=lambda x: x['timestamp'])
24
+
25
+
26
+ @strawberry.type
27
+ class TaskRun:
28
+ id: int
29
+ task_uuid: str
30
+ name: str | None
31
+ created_at: datetime
32
+ state: str
33
+ state_since: datetime
34
+ heartbeat: datetime
35
+ run_count: int
36
+ parent_task_uuid: str | None
37
+ parameters_json: str | None
38
+ result_json: str | None
39
+
40
+ @strawberry.field
41
+ async def task_type(self) -> 'TaskType':
42
+ return await self.awaitable_attrs.task_type
43
+
44
+ @strawberry.field
45
+ async def state_transitions(self) -> list['TaskRunStateTransition']:
46
+ return await self.awaitable_attrs.state_transitions
47
+
48
+ @strawberry.field
49
+ async def child_tasks(self) -> list['TaskRun']:
50
+ return await self.awaitable_attrs.child_tasks
51
+
52
+ @strawberry.field
53
+ async def logs(
54
+ self, with_children: bool = True, max_depth: int = 3, max_num_children: int = 100
55
+ ) -> list['TaskRunLog']:
56
+ logs = await get_logs(self, with_children, max_depth, max_num_children)
57
+ return [TaskRunLog(**log) for log in logs]
58
+
59
+ @strawberry.field
60
+ async def task_runs_stack(self) -> list['TaskRun']:
61
+ current = self
62
+ task_runs_stack = []
63
+ task_runs_stack.append(current)
64
+ while await current.awaitable_attrs.parent_task:
65
+ task_runs_stack.append(await current.awaitable_attrs.parent_task)
66
+ current = await current.awaitable_attrs.parent_task
67
+ return list(reversed(task_runs_stack))
68
+
69
+
70
+ @strawberry.type
71
+ class TaskRunStateTransition:
72
+ id: int
73
+ task_uuid: str
74
+ from_state: str
75
+ to_state: str
76
+ state_since: datetime
77
+
78
+ @strawberry.field
79
+ async def task_run(self) -> TaskRun:
80
+ return await self.awaitable_attrs.task_run
81
+
82
+
83
+ @strawberry.type
84
+ class TaskRunLog:
85
+ timestamp: float
86
+ level: str
87
+ name: str
88
+ message: str
89
+
90
+
91
+ @strawberry.type
92
+ class TaskType:
93
+ id: int
94
+ func_address: str
95
+ name: str | None
96
+ parameters_spec: str | None
97
+ result_spec: str | None
98
+
99
+ @strawberry.field
100
+ async def task_runs(self, info: Info) -> list[TaskRun]:
101
+ session = info.context['session']
102
+ statement = (
103
+ select(TaskRunModel)
104
+ .where(TaskRunModel.task_type_id == self.id)
105
+ .order_by(TaskRunModel.created_at.desc())
106
+ .limit(99)
107
+ )
108
+ task_runs = (await session.execute(statement)).scalars().all()
109
+ return task_runs
110
+
111
+ @strawberry.field
112
+ async def latest_task_run(self, info: Info) -> TaskRun | None:
113
+ session = info.context['session']
114
+ statement = (
115
+ select(TaskRunModel)
116
+ .where(TaskRunModel.task_type_id == self.id)
117
+ .order_by(TaskRunModel.state_since.desc())
118
+ .limit(1)
119
+ )
120
+ task_run = (await session.execute(statement)).scalars().one_or_none()
121
+ return task_run
@@ -0,0 +1,86 @@
1
+ import logging
2
+ import sys
3
+ from datetime import datetime, timedelta, timezone
4
+
5
+ import anyio
6
+ from plasma import Plasma
7
+ from sqlalchemy import select
8
+
9
+ from filament.db.models import TaskRun, TaskState
10
+ from filament.db.session import async_session_scope
11
+ from filament.state.task_run_state import cancel_task_run, delete_task_run
12
+
13
+
14
+ def setup_logging():
15
+ logger = logging.getLogger()
16
+ handler = logging.StreamHandler(sys.stdout)
17
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
+ handler.setFormatter(formatter)
19
+ logger.addHandler(handler)
20
+ logger.setLevel(logging.INFO)
21
+
22
+
23
+ setup_logging()
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ async def main(
29
+ stonith_max_heartbeat_seconds: int = 60 * 60, delete_old_task_runs_days: int = 30, skip_root_tasks: bool = True
30
+ ):
31
+ await stonith(max_heartbeat_seconds=stonith_max_heartbeat_seconds)
32
+ await delete_old_task_runs(days=delete_old_task_runs_days, skip_root_tasks=skip_root_tasks)
33
+
34
+
35
+ async def stonith(max_heartbeat_seconds: int, batch_size: int = 100):
36
+ while True:
37
+ async with async_session_scope() as session:
38
+ now = datetime.now(timezone.utc)
39
+ heartbeat_threshold = now - timedelta(seconds=max_heartbeat_seconds)
40
+ task_runs_statement = (
41
+ select(TaskRun)
42
+ .where(~TaskRun.state.in_(TaskState.TERMINAL))
43
+ .where(TaskRun.heartbeat < heartbeat_threshold)
44
+ .order_by(TaskRun.heartbeat.desc())
45
+ .limit(batch_size)
46
+ )
47
+ task_runs = (await session.execute(task_runs_statement)).scalars().all()
48
+ logger.info(f'Found {len(task_runs)} task runs to STONITH')
49
+ any_age = None
50
+ for task_run in task_runs:
51
+ heartbeat_age = now - task_run.heartbeat
52
+ logger.debug(f'Cancelling task run {task_run.id} heartbeat_age={heartbeat_age}')
53
+ if any_age is None:
54
+ any_age = heartbeat_age
55
+ await cancel_task_run(session, task_run)
56
+ await session.flush()
57
+ await session.commit()
58
+ logger.info(f'STONITHed {len(task_runs)} task runs, any heartbeat_age={any_age}')
59
+ await anyio.sleep(1)
60
+ if len(task_runs) < batch_size:
61
+ break
62
+
63
+
64
+ async def delete_old_task_runs(days: int = 30, batch_size: int = 100, skip_root_tasks: bool = True):
65
+ while True:
66
+ async with async_session_scope() as session:
67
+ statement = select(TaskRun).where(TaskRun.created_at < datetime.now(timezone.utc) - timedelta(days=days))
68
+ if skip_root_tasks:
69
+ statement = statement.where(TaskRun.parent_task_uuid.is_not(None))
70
+ statement = statement.limit(batch_size)
71
+ task_runs = (await session.execute(statement)).scalars().all()
72
+ logger.info(f'Found {len(task_runs)} task runs to delete')
73
+ any_age = None
74
+ for task_run in task_runs:
75
+ await delete_task_run(session, task_run)
76
+ if any_age is None:
77
+ any_age = task_run.created_at
78
+ await session.commit()
79
+ logger.info(f'Deleted {len(task_runs)} task runs, any_age={any_age}')
80
+ await anyio.sleep(1)
81
+ if len(task_runs) < batch_size:
82
+ break
83
+
84
+
85
+ if __name__ == '__main__':
86
+ Plasma(main)