pyfilament 0.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyfilament-0.0.0/PKG-INFO +26 -0
- pyfilament-0.0.0/pyproject.toml +64 -0
- pyfilament-0.0.0/src/filament/.gitignore +2 -0
- pyfilament-0.0.0/src/filament/api/app.py +3 -0
- pyfilament-0.0.0/src/filament/api/graphql.py +63 -0
- pyfilament-0.0.0/src/filament/api/logic/task_run_dict.py +27 -0
- pyfilament-0.0.0/src/filament/api/main.py +11 -0
- pyfilament-0.0.0/src/filament/api/middleware.py +11 -0
- pyfilament-0.0.0/src/filament/api/resolvers/task.py +168 -0
- pyfilament-0.0.0/src/filament/api/routes.py +56 -0
- pyfilament-0.0.0/src/filament/api/setup_logging.py +25 -0
- pyfilament-0.0.0/src/filament/api/types/task.py +121 -0
- pyfilament-0.0.0/src/filament/cli/cli_cleanup.py +86 -0
- pyfilament-0.0.0/src/filament/db/models.py +82 -0
- pyfilament-0.0.0/src/filament/db/session.py +40 -0
- pyfilament-0.0.0/src/filament/filament.py +41 -0
- pyfilament-0.0.0/src/filament/hooks.py +31 -0
- pyfilament-0.0.0/src/filament/logic/cache_keys.py +47 -0
- pyfilament-0.0.0/src/filament/logic/cache_utils.py +32 -0
- pyfilament-0.0.0/src/filament/logic/call_stack.py +43 -0
- pyfilament-0.0.0/src/filament/logic/events.py +32 -0
- pyfilament-0.0.0/src/filament/logic/func_registry.py +62 -0
- pyfilament-0.0.0/src/filament/logic/module_type_registry.py +24 -0
- pyfilament-0.0.0/src/filament/logic/type_checking.py +22 -0
- pyfilament-0.0.0/src/filament/logic/utils.py +168 -0
- pyfilament-0.0.0/src/filament/queue/task_queue.py +158 -0
- pyfilament-0.0.0/src/filament/queue/types/__init__.py +16 -0
- pyfilament-0.0.0/src/filament/queue/types/remote_exception.py +14 -0
- pyfilament-0.0.0/src/filament/queue/types/remote_task_result.py +49 -0
- pyfilament-0.0.0/src/filament/queue/types/remote_task_run.py +53 -0
- pyfilament-0.0.0/src/filament/queue/types/remote_task_type.py +102 -0
- pyfilament-0.0.0/src/filament/redis/client.py +17 -0
- pyfilament-0.0.0/src/filament/redis/logging_handler.py +35 -0
- pyfilament-0.0.0/src/filament/redis/semaphore.py +271 -0
- pyfilament-0.0.0/src/filament/redis/token_bucket.py +67 -0
- pyfilament-0.0.0/src/filament/state/common.py +28 -0
- pyfilament-0.0.0/src/filament/state/register.py +65 -0
- pyfilament-0.0.0/src/filament/state/task_run_state.py +103 -0
- pyfilament-0.0.0/src/filament/state/task_type_state.py +113 -0
- pyfilament-0.0.0/src/filament/task/constants.py +18 -0
- pyfilament-0.0.0/src/filament/task/registry/task_type_registry.py +37 -0
- pyfilament-0.0.0/src/filament/task/types/__init__.py +23 -0
- pyfilament-0.0.0/src/filament/task/types/base.py +6 -0
- pyfilament-0.0.0/src/filament/task/types/cache_key.py +17 -0
- pyfilament-0.0.0/src/filament/task/types/exception_type.py +17 -0
- pyfilament-0.0.0/src/filament/task/types/task_config.py +54 -0
- pyfilament-0.0.0/src/filament/task/types/task_run.py +348 -0
- pyfilament-0.0.0/src/filament/task/types/task_type.py +65 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pyfilament
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary:
|
|
5
|
+
Author: James Wu
|
|
6
|
+
Author-email: james@centauri-ai.tech
|
|
7
|
+
Requires-Python: >=3.14,<4.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: aiosqlite (>=0.22.1,<0.23.0) ; extra == "dev"
|
|
12
|
+
Requires-Dist: alembic (>=1.18.4,<2.0.0)
|
|
13
|
+
Requires-Dist: anyio (>=4.13.0,<5.0.0)
|
|
14
|
+
Requires-Dist: beartype (>=0.22.9,<0.23.0)
|
|
15
|
+
Requires-Dist: dotenv (>=0.9.9,<0.10.0)
|
|
16
|
+
Requires-Dist: fastapi (>=0.136.3,<0.137.0)
|
|
17
|
+
Requires-Dist: inflection (>=0.5.1,<0.6.0)
|
|
18
|
+
Requires-Dist: pandas (>=3.0.3,<4.0.0)
|
|
19
|
+
Requires-Dist: polars (>=1.41.2,<2.0.0)
|
|
20
|
+
Requires-Dist: pydantic (>=2.13.4,<3.0.0)
|
|
21
|
+
Requires-Dist: redis (>=8.0.0,<9.0.0)
|
|
22
|
+
Requires-Dist: sqlalchemy[asyncio] (>=2.0.50,<3.0.0)
|
|
23
|
+
Requires-Dist: starlette (>=1.2.1,<2.0.0)
|
|
24
|
+
Requires-Dist: strawberry-graphql (>=0.316.0,<0.317.0)
|
|
25
|
+
Requires-Dist: uvicorn (>=0.49.0,<0.50.0)
|
|
26
|
+
Requires-Dist: werkzeug (>=3.1.8,<4.0.0)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
[tool.ruff]
|
|
2
|
+
line-length = 120
|
|
3
|
+
target-version = 'py314'
|
|
4
|
+
|
|
5
|
+
[tool.ruff.format]
|
|
6
|
+
quote-style = 'single'
|
|
7
|
+
|
|
8
|
+
[tool.mypy]
|
|
9
|
+
plugins = ['pydantic.mypy']
|
|
10
|
+
disable_error_code = ['import-untyped', 'empty-body', 'prop-decorator']
|
|
11
|
+
|
|
12
|
+
[tool.poetry]
|
|
13
|
+
name = "pyfilament"
|
|
14
|
+
version = "0.0.0"
|
|
15
|
+
description = ""
|
|
16
|
+
authors = ["James Wu <james@centauri-ai.tech>"]
|
|
17
|
+
packages = [{ include = "filament", from = "src" }]
|
|
18
|
+
|
|
19
|
+
[tool.poetry.dependencies]
|
|
20
|
+
python = "^3.14"
|
|
21
|
+
sqlalchemy = {extras = ["asyncio"], version = "^2.0.50"}
|
|
22
|
+
alembic = "^1.18.4"
|
|
23
|
+
uvicorn = "^0.49.0"
|
|
24
|
+
anyio = "^4.13.0"
|
|
25
|
+
inflection = "^0.5.1"
|
|
26
|
+
pydantic = "^2.13.4"
|
|
27
|
+
beartype = "^0.22.9"
|
|
28
|
+
pandas = "^3.0.3"
|
|
29
|
+
polars = "^1.41.2"
|
|
30
|
+
dotenv = "^0.9.9"
|
|
31
|
+
fastapi = "^0.136.3"
|
|
32
|
+
strawberry-graphql = "^0.316.0"
|
|
33
|
+
starlette = "^1.2.1"
|
|
34
|
+
werkzeug = "^3.1.8"
|
|
35
|
+
redis = "^8.0.0"
|
|
36
|
+
aiosqlite = { version = "^0.22.1", optional = true }
|
|
37
|
+
|
|
38
|
+
[tool.poetry.extras]
|
|
39
|
+
dev = ["aiosqlite"]
|
|
40
|
+
|
|
41
|
+
[tool.poetry.group.dev.dependencies]
|
|
42
|
+
pytest = "^9.0.3"
|
|
43
|
+
hupper = "^1.12.1"
|
|
44
|
+
watchdog = "^6.0.0"
|
|
45
|
+
ruff = "^0.15.16"
|
|
46
|
+
mypy = "^1.20.2"
|
|
47
|
+
pytest-asyncio = "^1.4.0"
|
|
48
|
+
pytest-cov = "^7.1.0"
|
|
49
|
+
diff-cover = "^10.3.0"
|
|
50
|
+
coverage = "^7.14.1"
|
|
51
|
+
mock = "^5.2.0"
|
|
52
|
+
|
|
53
|
+
[build-system]
|
|
54
|
+
requires = ["poetry-core"]
|
|
55
|
+
build-backend = "poetry.core.masonry.api"
|
|
56
|
+
|
|
57
|
+
[tool.pytest.ini_options]
|
|
58
|
+
asyncio_mode = "auto"
|
|
59
|
+
asyncio_default_test_loop_scope = "session"
|
|
60
|
+
asyncio_default_fixture_loop_scope = "session"
|
|
61
|
+
log_cli = true
|
|
62
|
+
log_cli_level = "INFO"
|
|
63
|
+
log_cli_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
64
|
+
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import strawberry
|
|
2
|
+
from fastapi import Depends, Request
|
|
3
|
+
from strawberry.extensions import SchemaExtension
|
|
4
|
+
from strawberry.fastapi import GraphQLRouter
|
|
5
|
+
|
|
6
|
+
import filament.api.resolvers.task as task_resolver
|
|
7
|
+
from filament.api.types.task import TaskRun, TaskType
|
|
8
|
+
from filament.db.models import Base
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_session_from_request(request: Request):
|
|
12
|
+
return request.state.session
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def get_context(
|
|
16
|
+
session=Depends(get_session_from_request),
|
|
17
|
+
):
|
|
18
|
+
return {
|
|
19
|
+
'session': session,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SessionFlusher(SchemaExtension):
|
|
24
|
+
def resolve(self, _next, root, info, *args, **kwargs):
|
|
25
|
+
if (
|
|
26
|
+
info.path is not None
|
|
27
|
+
and info.path.key is not None
|
|
28
|
+
and (info.path.key == 'id' or info.path.key.endswith('_id'))
|
|
29
|
+
):
|
|
30
|
+
if isinstance(root, Base) and getattr(root, info.path.key, None) is None:
|
|
31
|
+
if info.context is not None and 'session' in info.context:
|
|
32
|
+
info.context['session'].flush()
|
|
33
|
+
return _next(root, info, *args, **kwargs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@strawberry.type
|
|
37
|
+
class Query:
|
|
38
|
+
get_task_run: TaskRun = strawberry.field(resolver=task_resolver.get_task_run)
|
|
39
|
+
get_task_type: TaskType = strawberry.field(resolver=task_resolver.get_task_type)
|
|
40
|
+
get_task_types: list[TaskType] = strawberry.field(resolver=task_resolver.get_task_types)
|
|
41
|
+
get_task_runs: list[TaskRun] = strawberry.field(resolver=task_resolver.get_task_runs)
|
|
42
|
+
get_task_runs_by_ids: list[TaskRun] = strawberry.field(resolver=task_resolver.get_task_runs_by_ids)
|
|
43
|
+
get_task_types_by_ids: list[TaskType] = strawberry.field(resolver=task_resolver.get_task_types_by_ids)
|
|
44
|
+
get_task_type_stack_runs: list[TaskRun] = strawberry.field(resolver=task_resolver.get_task_type_stack_runs)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@strawberry.type
|
|
48
|
+
class Mutation:
|
|
49
|
+
cancel_task_run: TaskRun = strawberry.field(resolver=task_resolver.cancel_task_run)
|
|
50
|
+
run_task: TaskRun = strawberry.field(resolver=task_resolver.run_task)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
schema = strawberry.Schema(
|
|
54
|
+
query=Query,
|
|
55
|
+
mutation=Mutation,
|
|
56
|
+
extensions=[SessionFlusher],
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
graphql_app = GraphQLRouter(
|
|
61
|
+
schema,
|
|
62
|
+
context_getter=get_context,
|
|
63
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from beartype import beartype
|
|
2
|
+
|
|
3
|
+
from filament.db.models import TaskRun as TaskRunModel
|
|
4
|
+
from filament.logic.utils import avoid_nans, get_json_dict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@beartype
|
|
8
|
+
async def deep_get_task_run_dict(task_run: TaskRunModel, max_child_tasks: int = 100, child_depth: int = 0) -> dict:
|
|
9
|
+
task_run_dict = get_json_dict(task_run)
|
|
10
|
+
task_run_dict['task_type'] = get_json_dict(await task_run.awaitable_attrs.task_type)
|
|
11
|
+
if child_depth > 0:
|
|
12
|
+
sorted_child_tasks = sorted(await task_run.awaitable_attrs.child_tasks, key=lambda x: x.id)
|
|
13
|
+
task_run_dict['child_tasks'] = [
|
|
14
|
+
await deep_get_task_run_dict(child_task_run, max_child_tasks, child_depth - 1)
|
|
15
|
+
for child_task_run in sorted_child_tasks[:max_child_tasks]
|
|
16
|
+
]
|
|
17
|
+
else:
|
|
18
|
+
task_run_dict['child_tasks'] = []
|
|
19
|
+
sorted_state_transitions = sorted(await task_run.awaitable_attrs.state_transitions, key=lambda x: x.id)
|
|
20
|
+
if task_run.parameters_json is not None:
|
|
21
|
+
task_run_dict['parameters_json'] = avoid_nans(task_run.parameters_json)
|
|
22
|
+
if task_run.result_json is not None:
|
|
23
|
+
task_run_dict['result_json'] = avoid_nans(task_run.result_json)
|
|
24
|
+
task_run_dict['state_transitions'] = [
|
|
25
|
+
get_json_dict(state_transition) for state_transition in sorted_state_transitions
|
|
26
|
+
]
|
|
27
|
+
return task_run_dict
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from filament.api.setup_logging import setup_logging
|
|
2
|
+
|
|
3
|
+
setup_logging()
|
|
4
|
+
|
|
5
|
+
import filament.api.routes # noqa: F401
|
|
6
|
+
from filament.api.app import app
|
|
7
|
+
from filament.api.graphql import graphql_app
|
|
8
|
+
from filament.api.middleware import SessionMiddleware
|
|
9
|
+
|
|
10
|
+
app.add_middleware(SessionMiddleware)
|
|
11
|
+
app.include_router(graphql_app, prefix='/graphql')
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from fastapi import Request
|
|
2
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
3
|
+
|
|
4
|
+
from filament.db.session import async_session_scope
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SessionMiddleware(BaseHTTPMiddleware):
|
|
8
|
+
async def dispatch(self, request: Request, call_next):
|
|
9
|
+
async with async_session_scope() as session:
|
|
10
|
+
request.state.session = session
|
|
11
|
+
return await call_next(request)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.orm import aliased
|
|
6
|
+
from sqlalchemy.sql import func
|
|
7
|
+
from strawberry import ID, Info
|
|
8
|
+
from werkzeug.exceptions import BadRequest, NotFound
|
|
9
|
+
|
|
10
|
+
from filament.api.types.task import TaskRun, TaskType
|
|
11
|
+
from filament.db.models import TaskRun as TaskRunModel
|
|
12
|
+
from filament.db.models import TaskType as TaskTypeModel
|
|
13
|
+
from filament.state.task_run_state import cancel_task_run as logic_cancel_task_run
|
|
14
|
+
from filament.task.registry.task_type_registry import lookup
|
|
15
|
+
|
|
16
|
+
DEFAULT_MAX_DAYS = 3
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def get_task_run(self, info: Info, id: ID | None = None, task_uuid: str | None = None) -> TaskRun:
|
|
20
|
+
session = info.context['session']
|
|
21
|
+
if task_uuid is not None:
|
|
22
|
+
statement = select(TaskRunModel).where(TaskRunModel.task_uuid == task_uuid)
|
|
23
|
+
elif id is not None:
|
|
24
|
+
statement = select(TaskRunModel).where(TaskRunModel.id == int(id))
|
|
25
|
+
else:
|
|
26
|
+
raise BadRequest('Either id or task_uuid must be provided')
|
|
27
|
+
task_run = (await session.execute(statement)).scalars().one_or_none()
|
|
28
|
+
if not task_run:
|
|
29
|
+
raise NotFound(f'TaskRun with ID {id} or UUID {task_uuid} not found')
|
|
30
|
+
return task_run
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def get_task_type(self, info: Info, id: ID | None = None, func_address: str | None = None) -> TaskType:
|
|
34
|
+
session = info.context['session']
|
|
35
|
+
if func_address is not None:
|
|
36
|
+
statement = select(TaskTypeModel).where(TaskTypeModel.func_address == func_address)
|
|
37
|
+
elif id is not None:
|
|
38
|
+
statement = select(TaskTypeModel).where(TaskTypeModel.id == int(id))
|
|
39
|
+
else:
|
|
40
|
+
raise BadRequest('Either id or func_address must be provided')
|
|
41
|
+
task_type = (await session.execute(statement)).scalars().one_or_none()
|
|
42
|
+
if not task_type:
|
|
43
|
+
raise NotFound(f'TaskType with id {id} or func_address {func_address} not found')
|
|
44
|
+
return task_type
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
async def get_task_types_by_ids(
|
|
48
|
+
self, info: Info, ids: list[int] | None = None, uuids: list[str] | None = None
|
|
49
|
+
) -> list[TaskType]:
|
|
50
|
+
session = info.context['session']
|
|
51
|
+
if ids:
|
|
52
|
+
statement = select(TaskTypeModel).where(TaskTypeModel.id.in_(ids))
|
|
53
|
+
task_types = (await session.execute(statement)).scalars().all()
|
|
54
|
+
ids_to_task_types = {task_type.id: task_type for task_type in task_types}
|
|
55
|
+
task_types = [ids_to_task_types[id] for id in ids if id in ids_to_task_types]
|
|
56
|
+
elif uuids:
|
|
57
|
+
statement = select(TaskTypeModel).where(TaskTypeModel.uuid.in_(uuids))
|
|
58
|
+
task_types = (await session.execute(statement)).scalars().all()
|
|
59
|
+
uuids_to_task_types = {task_type.uuid: task_type for task_type in task_types}
|
|
60
|
+
task_types = [uuids_to_task_types[uuid] for uuid in uuids if uuid in uuids_to_task_types]
|
|
61
|
+
else:
|
|
62
|
+
raise BadRequest('Either ids or uuids must be provided')
|
|
63
|
+
return task_types
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def get_task_type_stack_runs(
|
|
67
|
+
self, info: Info, task_type_ids: list[int], states: list[str] | None = None
|
|
68
|
+
) -> list[TaskRun]:
|
|
69
|
+
session = info.context['session']
|
|
70
|
+
if len(task_type_ids) == 0:
|
|
71
|
+
raise BadRequest('task_type_ids must be provided')
|
|
72
|
+
MAX_RESULTS = 100
|
|
73
|
+
final_task_type_id = task_type_ids[-1]
|
|
74
|
+
statement = select(TaskRunModel).filter(TaskRunModel.task_type_id == final_task_type_id)
|
|
75
|
+
if states:
|
|
76
|
+
statement = statement.where(TaskRunModel.state.in_(states))
|
|
77
|
+
last_model = TaskRunModel
|
|
78
|
+
for task_type_id in reversed(task_type_ids[:-1]):
|
|
79
|
+
current_model = aliased(TaskRunModel)
|
|
80
|
+
statement = statement.join(current_model, last_model.parent_task_uuid == current_model.task_uuid)
|
|
81
|
+
statement = statement.where(current_model.task_type_id == task_type_id)
|
|
82
|
+
last_model = current_model
|
|
83
|
+
statement = (
|
|
84
|
+
statement.where(last_model.parent_task_uuid.is_(None))
|
|
85
|
+
.order_by(TaskRunModel.created_at.desc())
|
|
86
|
+
.limit(MAX_RESULTS)
|
|
87
|
+
)
|
|
88
|
+
task_runs = (await session.execute(statement)).scalars().all()
|
|
89
|
+
return task_runs
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def get_task_runs_by_ids(self, info: Info, ids: list[int]) -> list[TaskRun]:
|
|
93
|
+
session = info.context['session']
|
|
94
|
+
if len(ids) == 0:
|
|
95
|
+
raise BadRequest('ids must be provided')
|
|
96
|
+
statement = select(TaskRunModel).where(TaskRunModel.id.in_(ids))
|
|
97
|
+
task_runs = (await session.execute(statement)).scalars().all()
|
|
98
|
+
ids_to_task_runs = {task_run.id: task_run for task_run in task_runs}
|
|
99
|
+
for id in ids:
|
|
100
|
+
if id not in ids_to_task_runs:
|
|
101
|
+
raise NotFound(f'TaskRun with ID {id} not found')
|
|
102
|
+
task_runs = [ids_to_task_runs[id] for id in ids if id in ids_to_task_runs]
|
|
103
|
+
return task_runs
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
async def get_task_types(self, info: Info, days: int = DEFAULT_MAX_DAYS):
|
|
107
|
+
session = info.context['session']
|
|
108
|
+
today = datetime.datetime.now()
|
|
109
|
+
before = today - datetime.timedelta(days=days)
|
|
110
|
+
subquery = (
|
|
111
|
+
select(TaskRunModel.task_type_id, func.max(TaskRunModel.id).label('task_run_id'))
|
|
112
|
+
.filter(TaskRunModel.created_at > before)
|
|
113
|
+
.group_by(TaskRunModel.task_type_id)
|
|
114
|
+
.subquery()
|
|
115
|
+
)
|
|
116
|
+
task_types_statement = select(TaskTypeModel).join(
|
|
117
|
+
subquery,
|
|
118
|
+
TaskTypeModel.id == subquery.c.task_type_id,
|
|
119
|
+
)
|
|
120
|
+
task_types = (await session.execute(task_types_statement)).scalars().all()
|
|
121
|
+
return task_types
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def cancel_task_run(self, info: Info, id: ID | None = None, task_uuid: str | None = None) -> TaskRun:
|
|
125
|
+
session = info.context['session']
|
|
126
|
+
if task_uuid is not None:
|
|
127
|
+
statement = select(TaskRunModel).where(TaskRunModel.task_uuid == task_uuid)
|
|
128
|
+
elif id is not None:
|
|
129
|
+
statement = select(TaskRunModel).where(TaskRunModel.id == int(id))
|
|
130
|
+
else:
|
|
131
|
+
raise BadRequest('Either id or task_uuid must be provided')
|
|
132
|
+
task_run = (await session.execute(statement)).scalars().one_or_none()
|
|
133
|
+
if not task_run:
|
|
134
|
+
raise NotFound(f'TaskRun with UUID {task_uuid} not found')
|
|
135
|
+
await logic_cancel_task_run(session, task_run)
|
|
136
|
+
return task_run
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def get_task_runs(
|
|
140
|
+
self, info: Info, task_type_id: ID, states: list[str] | None = None, days: int = DEFAULT_MAX_DAYS
|
|
141
|
+
):
|
|
142
|
+
session = info.context['session']
|
|
143
|
+
today = datetime.datetime.now()
|
|
144
|
+
before = today - datetime.timedelta(days=days)
|
|
145
|
+
statement = (
|
|
146
|
+
select(TaskRunModel)
|
|
147
|
+
.where(TaskRunModel.task_type_id == int(task_type_id))
|
|
148
|
+
.where(TaskRunModel.created_at > before)
|
|
149
|
+
)
|
|
150
|
+
if states:
|
|
151
|
+
statement = statement.where(TaskRunModel.state.in_(states))
|
|
152
|
+
statement = statement.order_by(TaskRunModel.created_at.desc()).limit(99)
|
|
153
|
+
task_runs = (await session.execute(statement)).scalars().all()
|
|
154
|
+
return task_runs
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
async def run_task(self, info: Info, task_type_id: ID, parameters_json: str) -> TaskRun:
|
|
158
|
+
session = info.context['session']
|
|
159
|
+
statement = select(TaskTypeModel).where(TaskTypeModel.id == int(task_type_id))
|
|
160
|
+
task_type = (await session.execute(statement)).scalars().one()
|
|
161
|
+
func_address = task_type.func_address
|
|
162
|
+
filament_task_type = lookup(func_address)
|
|
163
|
+
parameters = json.loads(parameters_json)
|
|
164
|
+
parameters.update({'start_immediately': True})
|
|
165
|
+
filament_task_run = await filament_task_type._request(task_args=[], task_kwargs=parameters)
|
|
166
|
+
statement = select(TaskRunModel).where(TaskRunModel.task_uuid == filament_task_run.uuid)
|
|
167
|
+
task_run = (await session.execute(statement)).scalars().one()
|
|
168
|
+
return task_run
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from fastapi import Request, Response
|
|
4
|
+
from werkzeug.exceptions import NotFound
|
|
5
|
+
|
|
6
|
+
from filament.api.app import app
|
|
7
|
+
from filament.api.logic.task_run_dict import deep_get_task_run_dict
|
|
8
|
+
from filament.db.models import TaskRun as TaskRunModel
|
|
9
|
+
from filament.logic.utils import rename_keys_to_camel_case, safe_json_dumps
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@app.get('/tasks')
|
|
15
|
+
async def root():
|
|
16
|
+
return {'message': 'Hello World'}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@app.get('/api/task-run/{task_run_id}')
|
|
20
|
+
async def get_task_run(request: Request, task_run_id: int, max_child_tasks: int = 100, child_depth: int = 3):
|
|
21
|
+
session = request.state.session
|
|
22
|
+
task_run = await session.get(TaskRunModel, task_run_id)
|
|
23
|
+
if task_run is None:
|
|
24
|
+
raise NotFound(f'TaskRun with ID {task_run_id} not found')
|
|
25
|
+
task_run_dict = await deep_get_task_run_dict(task_run, max_child_tasks, child_depth)
|
|
26
|
+
return rename_keys_to_camel_case(task_run_dict)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@app.get('/api/task-runs/{task_run_ids_str}')
|
|
30
|
+
async def get_task_runs(request: Request, task_run_ids_str: str, max_child_tasks: int = 100, child_depth: int = 3):
|
|
31
|
+
task_run_ids = [int(id) for id in task_run_ids_str.split(',')]
|
|
32
|
+
session = request.state.session
|
|
33
|
+
task_runs = []
|
|
34
|
+
for task_run_id in task_run_ids:
|
|
35
|
+
task_run = await session.get(TaskRunModel, task_run_id)
|
|
36
|
+
if task_run is None:
|
|
37
|
+
raise NotFound(f'TaskRun with ID {task_run_id} not found')
|
|
38
|
+
task_runs.append(await deep_get_task_run_dict(task_run, max_child_tasks, child_depth))
|
|
39
|
+
return rename_keys_to_camel_case(task_runs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@app.get('/api/task-run/{task_run_id}/download')
|
|
43
|
+
async def download_task_run(request: Request, task_run_id: int, max_child_tasks: int = 100, child_depth: int = 3):
|
|
44
|
+
session = request.state.session
|
|
45
|
+
task_run = await session.get(TaskRunModel, task_run_id)
|
|
46
|
+
if task_run is None:
|
|
47
|
+
raise NotFound(f'TaskRun with ID {task_run_id} not found')
|
|
48
|
+
task_run_dict = await deep_get_task_run_dict(task_run, max_child_tasks, child_depth)
|
|
49
|
+
file_content = safe_json_dumps(rename_keys_to_camel_case(task_run_dict), indent=2, sort_keys=True).encode('utf-8')
|
|
50
|
+
filename = f'task_run_{task_run_id}.json'
|
|
51
|
+
headers = {'Content-Disposition': f'attachment; filename="{filename}"'}
|
|
52
|
+
return Response(
|
|
53
|
+
content=file_content,
|
|
54
|
+
media_type='application/json',
|
|
55
|
+
headers=headers,
|
|
56
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
_is_setup = False
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class UnixTimeFormatter(logging.Formatter):
|
|
7
|
+
def formatTime(self, record, datefmt=None):
|
|
8
|
+
t = round(record.created, 1)
|
|
9
|
+
return f'{t:.1f}'
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def setup_logging():
|
|
13
|
+
global _is_setup
|
|
14
|
+
if _is_setup:
|
|
15
|
+
return
|
|
16
|
+
formatter = UnixTimeFormatter(fmt='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
|
17
|
+
handler = logging.StreamHandler()
|
|
18
|
+
handler.setFormatter(formatter)
|
|
19
|
+
rootLogger = logging.getLogger()
|
|
20
|
+
rootLogger.setLevel(logging.DEBUG)
|
|
21
|
+
rootLogger.handlers = [handler] # Replace any existing handlers
|
|
22
|
+
_is_setup = True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
setup_logging()
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from strawberry import Info
|
|
7
|
+
|
|
8
|
+
from filament.db.models import TaskRun as TaskRunModel
|
|
9
|
+
from filament.redis.client import r
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
async def get_logs(
|
|
13
|
+
task_run: TaskRunModel, with_children: bool = True, max_depth: int = 3, max_num_children: int = 100
|
|
14
|
+
) -> list[dict]:
|
|
15
|
+
logs = []
|
|
16
|
+
redis_key = f'filament_log:{(await task_run.awaitable_attrs.task_type).func_address}:{task_run.task_uuid}'
|
|
17
|
+
range_results = await r.lrange(redis_key, 0, -1)
|
|
18
|
+
for range_result in range_results:
|
|
19
|
+
logs.append(json.loads(range_result))
|
|
20
|
+
if with_children and max_depth > 0:
|
|
21
|
+
for child_task in (await task_run.awaitable_attrs.child_tasks)[:max_num_children]:
|
|
22
|
+
logs.extend(await get_logs(child_task, with_children, max_depth - 1, max_num_children))
|
|
23
|
+
return sorted(logs, key=lambda x: x['timestamp'])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@strawberry.type
|
|
27
|
+
class TaskRun:
|
|
28
|
+
id: int
|
|
29
|
+
task_uuid: str
|
|
30
|
+
name: str | None
|
|
31
|
+
created_at: datetime
|
|
32
|
+
state: str
|
|
33
|
+
state_since: datetime
|
|
34
|
+
heartbeat: datetime
|
|
35
|
+
run_count: int
|
|
36
|
+
parent_task_uuid: str | None
|
|
37
|
+
parameters_json: str | None
|
|
38
|
+
result_json: str | None
|
|
39
|
+
|
|
40
|
+
@strawberry.field
|
|
41
|
+
async def task_type(self) -> 'TaskType':
|
|
42
|
+
return await self.awaitable_attrs.task_type
|
|
43
|
+
|
|
44
|
+
@strawberry.field
|
|
45
|
+
async def state_transitions(self) -> list['TaskRunStateTransition']:
|
|
46
|
+
return await self.awaitable_attrs.state_transitions
|
|
47
|
+
|
|
48
|
+
@strawberry.field
|
|
49
|
+
async def child_tasks(self) -> list['TaskRun']:
|
|
50
|
+
return await self.awaitable_attrs.child_tasks
|
|
51
|
+
|
|
52
|
+
@strawberry.field
|
|
53
|
+
async def logs(
|
|
54
|
+
self, with_children: bool = True, max_depth: int = 3, max_num_children: int = 100
|
|
55
|
+
) -> list['TaskRunLog']:
|
|
56
|
+
logs = await get_logs(self, with_children, max_depth, max_num_children)
|
|
57
|
+
return [TaskRunLog(**log) for log in logs]
|
|
58
|
+
|
|
59
|
+
@strawberry.field
|
|
60
|
+
async def task_runs_stack(self) -> list['TaskRun']:
|
|
61
|
+
current = self
|
|
62
|
+
task_runs_stack = []
|
|
63
|
+
task_runs_stack.append(current)
|
|
64
|
+
while await current.awaitable_attrs.parent_task:
|
|
65
|
+
task_runs_stack.append(await current.awaitable_attrs.parent_task)
|
|
66
|
+
current = await current.awaitable_attrs.parent_task
|
|
67
|
+
return list(reversed(task_runs_stack))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@strawberry.type
|
|
71
|
+
class TaskRunStateTransition:
|
|
72
|
+
id: int
|
|
73
|
+
task_uuid: str
|
|
74
|
+
from_state: str
|
|
75
|
+
to_state: str
|
|
76
|
+
state_since: datetime
|
|
77
|
+
|
|
78
|
+
@strawberry.field
|
|
79
|
+
async def task_run(self) -> TaskRun:
|
|
80
|
+
return await self.awaitable_attrs.task_run
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@strawberry.type
|
|
84
|
+
class TaskRunLog:
|
|
85
|
+
timestamp: float
|
|
86
|
+
level: str
|
|
87
|
+
name: str
|
|
88
|
+
message: str
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@strawberry.type
|
|
92
|
+
class TaskType:
|
|
93
|
+
id: int
|
|
94
|
+
func_address: str
|
|
95
|
+
name: str | None
|
|
96
|
+
parameters_spec: str | None
|
|
97
|
+
result_spec: str | None
|
|
98
|
+
|
|
99
|
+
@strawberry.field
|
|
100
|
+
async def task_runs(self, info: Info) -> list[TaskRun]:
|
|
101
|
+
session = info.context['session']
|
|
102
|
+
statement = (
|
|
103
|
+
select(TaskRunModel)
|
|
104
|
+
.where(TaskRunModel.task_type_id == self.id)
|
|
105
|
+
.order_by(TaskRunModel.created_at.desc())
|
|
106
|
+
.limit(99)
|
|
107
|
+
)
|
|
108
|
+
task_runs = (await session.execute(statement)).scalars().all()
|
|
109
|
+
return task_runs
|
|
110
|
+
|
|
111
|
+
@strawberry.field
|
|
112
|
+
async def latest_task_run(self, info: Info) -> TaskRun | None:
|
|
113
|
+
session = info.context['session']
|
|
114
|
+
statement = (
|
|
115
|
+
select(TaskRunModel)
|
|
116
|
+
.where(TaskRunModel.task_type_id == self.id)
|
|
117
|
+
.order_by(TaskRunModel.state_since.desc())
|
|
118
|
+
.limit(1)
|
|
119
|
+
)
|
|
120
|
+
task_run = (await session.execute(statement)).scalars().one_or_none()
|
|
121
|
+
return task_run
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
|
|
5
|
+
import anyio
|
|
6
|
+
from plasma import Plasma
|
|
7
|
+
from sqlalchemy import select
|
|
8
|
+
|
|
9
|
+
from filament.db.models import TaskRun, TaskState
|
|
10
|
+
from filament.db.session import async_session_scope
|
|
11
|
+
from filament.state.task_run_state import cancel_task_run, delete_task_run
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def setup_logging():
|
|
15
|
+
logger = logging.getLogger()
|
|
16
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
17
|
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
18
|
+
handler.setFormatter(formatter)
|
|
19
|
+
logger.addHandler(handler)
|
|
20
|
+
logger.setLevel(logging.INFO)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
setup_logging()
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def main(
|
|
29
|
+
stonith_max_heartbeat_seconds: int = 60 * 60, delete_old_task_runs_days: int = 30, skip_root_tasks: bool = True
|
|
30
|
+
):
|
|
31
|
+
await stonith(max_heartbeat_seconds=stonith_max_heartbeat_seconds)
|
|
32
|
+
await delete_old_task_runs(days=delete_old_task_runs_days, skip_root_tasks=skip_root_tasks)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def stonith(max_heartbeat_seconds: int, batch_size: int = 100):
|
|
36
|
+
while True:
|
|
37
|
+
async with async_session_scope() as session:
|
|
38
|
+
now = datetime.now(timezone.utc)
|
|
39
|
+
heartbeat_threshold = now - timedelta(seconds=max_heartbeat_seconds)
|
|
40
|
+
task_runs_statement = (
|
|
41
|
+
select(TaskRun)
|
|
42
|
+
.where(~TaskRun.state.in_(TaskState.TERMINAL))
|
|
43
|
+
.where(TaskRun.heartbeat < heartbeat_threshold)
|
|
44
|
+
.order_by(TaskRun.heartbeat.desc())
|
|
45
|
+
.limit(batch_size)
|
|
46
|
+
)
|
|
47
|
+
task_runs = (await session.execute(task_runs_statement)).scalars().all()
|
|
48
|
+
logger.info(f'Found {len(task_runs)} task runs to STONITH')
|
|
49
|
+
any_age = None
|
|
50
|
+
for task_run in task_runs:
|
|
51
|
+
heartbeat_age = now - task_run.heartbeat
|
|
52
|
+
logger.debug(f'Cancelling task run {task_run.id} heartbeat_age={heartbeat_age}')
|
|
53
|
+
if any_age is None:
|
|
54
|
+
any_age = heartbeat_age
|
|
55
|
+
await cancel_task_run(session, task_run)
|
|
56
|
+
await session.flush()
|
|
57
|
+
await session.commit()
|
|
58
|
+
logger.info(f'STONITHed {len(task_runs)} task runs, any heartbeat_age={any_age}')
|
|
59
|
+
await anyio.sleep(1)
|
|
60
|
+
if len(task_runs) < batch_size:
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def delete_old_task_runs(days: int = 30, batch_size: int = 100, skip_root_tasks: bool = True):
|
|
65
|
+
while True:
|
|
66
|
+
async with async_session_scope() as session:
|
|
67
|
+
statement = select(TaskRun).where(TaskRun.created_at < datetime.now(timezone.utc) - timedelta(days=days))
|
|
68
|
+
if skip_root_tasks:
|
|
69
|
+
statement = statement.where(TaskRun.parent_task_uuid.is_not(None))
|
|
70
|
+
statement = statement.limit(batch_size)
|
|
71
|
+
task_runs = (await session.execute(statement)).scalars().all()
|
|
72
|
+
logger.info(f'Found {len(task_runs)} task runs to delete')
|
|
73
|
+
any_age = None
|
|
74
|
+
for task_run in task_runs:
|
|
75
|
+
await delete_task_run(session, task_run)
|
|
76
|
+
if any_age is None:
|
|
77
|
+
any_age = task_run.created_at
|
|
78
|
+
await session.commit()
|
|
79
|
+
logger.info(f'Deleted {len(task_runs)} task runs, any_age={any_age}')
|
|
80
|
+
await anyio.sleep(1)
|
|
81
|
+
if len(task_runs) < batch_size:
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
if __name__ == '__main__':
|
|
86
|
+
Plasma(main)
|