fractal-server 1.3.0a2__py3-none-any.whl → 1.3.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fractal_server/__init__.py +1 -1
- fractal_server/app/api/v1/dataset.py +21 -0
- fractal_server/app/api/v1/task.py +78 -16
- fractal_server/app/api/v1/workflow.py +36 -25
- fractal_server/app/models/job.py +3 -2
- fractal_server/app/models/project.py +4 -3
- fractal_server/app/models/security.py +2 -0
- fractal_server/app/models/state.py +2 -1
- fractal_server/app/models/task.py +20 -9
- fractal_server/app/models/workflow.py +34 -29
- fractal_server/app/runner/_common.py +13 -12
- fractal_server/app/runner/_slurm/executor.py +2 -1
- fractal_server/common/requirements.txt +1 -1
- fractal_server/common/schemas/__init__.py +2 -0
- fractal_server/common/schemas/applyworkflow.py +6 -7
- fractal_server/common/schemas/manifest.py +32 -15
- fractal_server/common/schemas/project.py +8 -10
- fractal_server/common/schemas/state.py +3 -4
- fractal_server/common/schemas/task.py +28 -97
- fractal_server/common/schemas/task_collection.py +101 -0
- fractal_server/common/schemas/user.py +5 -0
- fractal_server/common/schemas/workflow.py +9 -11
- fractal_server/common/tests/test_manifest.py +36 -4
- fractal_server/common/tests/test_task.py +16 -0
- fractal_server/common/tests/test_task_collection.py +24 -0
- fractal_server/common/tests/test_user.py +12 -0
- fractal_server/main.py +3 -0
- fractal_server/migrations/versions/4c308bcaea2b_add_task_args_schema_and_task_args_.py +38 -0
- fractal_server/migrations/versions/{e8f4051440be_new_initial_schema.py → 50a13d6138fd_initial_schema.py} +18 -10
- fractal_server/migrations/versions/{fda995215ae9_drop_applyworkflow_overwrite_input.py → f384e1c0cf5d_drop_task_default_args_columns.py} +9 -10
- fractal_server/tasks/collection.py +180 -115
- {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/METADATA +2 -1
- {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/RECORD +36 -34
- {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/WHEEL +1 -1
- fractal_server/migrations/versions/bb1cca2acc40_add_applyworkflow_end_timestamp.py +0 -31
- {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/LICENSE +0 -0
- {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/entry_points.txt +0 -0
fractal_server/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__VERSION__ = "1.3.
|
1
|
+
__VERSION__ = "1.3.0a3"
|
@@ -5,10 +5,12 @@ from fastapi import Depends
|
|
5
5
|
from fastapi import HTTPException
|
6
6
|
from fastapi import Response
|
7
7
|
from fastapi import status
|
8
|
+
from sqlmodel import or_
|
8
9
|
from sqlmodel import select
|
9
10
|
|
10
11
|
from ...db import AsyncSession
|
11
12
|
from ...db import get_db
|
13
|
+
from ...models import ApplyWorkflow
|
12
14
|
from ...models import Dataset
|
13
15
|
from ...models import DatasetCreate
|
14
16
|
from ...models import DatasetRead
|
@@ -124,6 +126,25 @@ async def delete_dataset(
|
|
124
126
|
db=db,
|
125
127
|
)
|
126
128
|
dataset = output["dataset"]
|
129
|
+
|
130
|
+
# Check that no ApplyWorkflow is in relationship with the current Dataset
|
131
|
+
stm = select(ApplyWorkflow).filter(
|
132
|
+
or_(
|
133
|
+
ApplyWorkflow.input_dataset_id == dataset_id,
|
134
|
+
ApplyWorkflow.output_dataset_id == dataset_id,
|
135
|
+
)
|
136
|
+
)
|
137
|
+
res = await db.execute(stm)
|
138
|
+
job = res.scalars().first()
|
139
|
+
if job:
|
140
|
+
raise HTTPException(
|
141
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
142
|
+
detail=(
|
143
|
+
f"Cannot remove dataset {dataset_id}: "
|
144
|
+
f"it's still linked to job {job.id}."
|
145
|
+
),
|
146
|
+
)
|
147
|
+
|
127
148
|
await db.delete(dataset)
|
128
149
|
await db.commit()
|
129
150
|
await db.close()
|
@@ -41,12 +41,13 @@ from ...db import get_db
|
|
41
41
|
from ...db import get_sync_db
|
42
42
|
from ...models import State
|
43
43
|
from ...models import Task
|
44
|
-
from ...security import current_active_superuser
|
45
44
|
from ...security import current_active_user
|
46
45
|
from ...security import User
|
47
46
|
|
48
47
|
router = APIRouter()
|
49
48
|
|
49
|
+
logger = set_logger(__name__)
|
50
|
+
|
50
51
|
|
51
52
|
async def _background_collect_pip(
|
52
53
|
state_id: int,
|
@@ -178,7 +179,6 @@ async def collect_tasks_pip(
|
|
178
179
|
response: Response,
|
179
180
|
user: User = Depends(current_active_user),
|
180
181
|
db: AsyncSession = Depends(get_db),
|
181
|
-
public: bool = True,
|
182
182
|
) -> StateRead: # State[TaskCollectStatus]
|
183
183
|
"""
|
184
184
|
Task collection endpoint
|
@@ -192,7 +192,7 @@ async def collect_tasks_pip(
|
|
192
192
|
# Validate payload as _TaskCollectPip, which has more strict checks than
|
193
193
|
# TaskCollectPip
|
194
194
|
try:
|
195
|
-
task_pkg = _TaskCollectPip(**task_collect.dict())
|
195
|
+
task_pkg = _TaskCollectPip(**task_collect.dict(exclude_unset=True))
|
196
196
|
except ValidationError as e:
|
197
197
|
raise HTTPException(
|
198
198
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
@@ -201,6 +201,7 @@ async def collect_tasks_pip(
|
|
201
201
|
|
202
202
|
with TemporaryDirectory() as tmpdir:
|
203
203
|
try:
|
204
|
+
# Copy or download the package wheel file to tmpdir
|
204
205
|
if task_pkg.is_local_package:
|
205
206
|
shell_copy(task_pkg.package_path.as_posix(), tmpdir)
|
206
207
|
pkg_path = Path(tmpdir) / task_pkg.package_path.name
|
@@ -208,10 +209,12 @@ async def collect_tasks_pip(
|
|
208
209
|
pkg_path = await download_package(
|
209
210
|
task_pkg=task_pkg, dest=tmpdir
|
210
211
|
)
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
task_pkg.
|
212
|
+
# Read package info from wheel file, and override the ones coming
|
213
|
+
# from the request body
|
214
|
+
pkg_info = inspect_package(pkg_path)
|
215
|
+
task_pkg.package_name = pkg_info["pkg_name"]
|
216
|
+
task_pkg.package_version = pkg_info["pkg_version"]
|
217
|
+
task_pkg.package_manifest = pkg_info["pkg_manifest"]
|
215
218
|
task_pkg.check()
|
216
219
|
except Exception as e:
|
217
220
|
raise HTTPException(
|
@@ -220,14 +223,28 @@ async def collect_tasks_pip(
|
|
220
223
|
)
|
221
224
|
|
222
225
|
try:
|
223
|
-
|
224
|
-
venv_path = create_package_dir_pip(task_pkg=task_pkg, user=pkg_user)
|
226
|
+
venv_path = create_package_dir_pip(task_pkg=task_pkg)
|
225
227
|
except FileExistsError:
|
226
|
-
venv_path = create_package_dir_pip(
|
227
|
-
task_pkg=task_pkg, user=pkg_user, create=False
|
228
|
-
)
|
228
|
+
venv_path = create_package_dir_pip(task_pkg=task_pkg, create=False)
|
229
229
|
try:
|
230
230
|
task_collect_status = get_collection_data(venv_path)
|
231
|
+
for task in task_collect_status.task_list:
|
232
|
+
db_task = await db.get(Task, task.id)
|
233
|
+
if (
|
234
|
+
(not db_task)
|
235
|
+
or db_task.source != task.source
|
236
|
+
or db_task.name != task.name
|
237
|
+
):
|
238
|
+
await db.close()
|
239
|
+
raise HTTPException(
|
240
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
241
|
+
detail=(
|
242
|
+
"Cannot collect package. Folder already exists, "
|
243
|
+
f"but task {task.id} does not exists or it does "
|
244
|
+
f"not have the expected source ({task.source}) or "
|
245
|
+
f"name ({task.name})."
|
246
|
+
),
|
247
|
+
)
|
231
248
|
except FileNotFoundError as e:
|
232
249
|
await db.close()
|
233
250
|
raise HTTPException(
|
@@ -344,19 +361,42 @@ async def get_task(
|
|
344
361
|
async def patch_task(
|
345
362
|
task_id: int,
|
346
363
|
task_update: TaskUpdate,
|
347
|
-
user: User = Depends(
|
364
|
+
user: User = Depends(current_active_user),
|
348
365
|
db: AsyncSession = Depends(get_db),
|
349
366
|
) -> Optional[TaskRead]:
|
350
367
|
"""
|
351
|
-
Edit a specific task
|
368
|
+
Edit a specific task (restricted to superusers and task owner)
|
352
369
|
"""
|
370
|
+
|
353
371
|
if task_update.source:
|
354
372
|
raise HTTPException(
|
355
373
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
356
374
|
detail="patch_task endpoint cannot set `source`",
|
357
375
|
)
|
358
376
|
|
377
|
+
# Retrieve task from database
|
359
378
|
db_task = await db.get(Task, task_id)
|
379
|
+
|
380
|
+
# This check constitutes a preliminary version of access control:
|
381
|
+
# if the current user is not a superuser and differs from the task owner
|
382
|
+
# (including when `owner is None`), we raise an 403 HTTP Exception.
|
383
|
+
if not user.is_superuser:
|
384
|
+
if db_task.owner is None:
|
385
|
+
raise HTTPException(
|
386
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
387
|
+
detail=("Only a superuser can edit a task with `owner=None`."),
|
388
|
+
)
|
389
|
+
else:
|
390
|
+
owner = user.username or user.slurm_user
|
391
|
+
if owner != db_task.owner:
|
392
|
+
raise HTTPException(
|
393
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
394
|
+
detail=(
|
395
|
+
f"Current user ({owner}) cannot modify task "
|
396
|
+
f"({task_id}) with different owner ({db_task.owner})."
|
397
|
+
),
|
398
|
+
)
|
399
|
+
|
360
400
|
update = task_update.dict(exclude_unset=True)
|
361
401
|
for key, value in update.items():
|
362
402
|
if isinstance(value, str):
|
@@ -386,14 +426,36 @@ async def create_task(
|
|
386
426
|
"""
|
387
427
|
Create a new task
|
388
428
|
"""
|
429
|
+
# Set task.owner attribute
|
430
|
+
if user.username:
|
431
|
+
owner = user.username
|
432
|
+
elif user.slurm_user:
|
433
|
+
owner = user.slurm_user
|
434
|
+
else:
|
435
|
+
raise HTTPException(
|
436
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
437
|
+
detail=(
|
438
|
+
"Cannot add a new task because current user does not "
|
439
|
+
"have `username` or `slurm_user` attributes."
|
440
|
+
),
|
441
|
+
)
|
442
|
+
|
443
|
+
# Prepend owner to task.source
|
444
|
+
task.source = f"{owner}:{task.source}"
|
445
|
+
|
446
|
+
# Verify that source is not already in use (note: this check is only useful
|
447
|
+
# to provide a user-friendly error message, but `task.source` uniqueness is
|
448
|
+
# already guaranteed by a constraint in the table definition).
|
389
449
|
stm = select(Task).where(Task.source == task.source)
|
390
450
|
res = await db.execute(stm)
|
391
451
|
if res.scalars().all():
|
392
452
|
raise HTTPException(
|
393
453
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
394
|
-
detail=f
|
454
|
+
detail=f'Task source "{task.source}" already in use',
|
395
455
|
)
|
396
|
-
|
456
|
+
|
457
|
+
# Add task
|
458
|
+
db_task = Task(**task.dict(), owner=owner)
|
397
459
|
db.add(db_task)
|
398
460
|
await db.commit()
|
399
461
|
await db.refresh(db_task)
|
@@ -20,8 +20,11 @@ from fastapi import Response
|
|
20
20
|
from fastapi import status
|
21
21
|
from sqlmodel import select
|
22
22
|
|
23
|
+
from ....logger import close_logger
|
24
|
+
from ....logger import set_logger
|
23
25
|
from ...db import AsyncSession
|
24
26
|
from ...db import get_db
|
27
|
+
from ...models import ApplyWorkflow
|
25
28
|
from ...models import Task
|
26
29
|
from ...models import Workflow
|
27
30
|
from ...models import WorkflowCreate
|
@@ -181,6 +184,19 @@ async def delete_workflow(
|
|
181
184
|
project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
|
182
185
|
)
|
183
186
|
|
187
|
+
# Check that no ApplyWorkflow is in relationship with the current Workflow
|
188
|
+
stm = select(ApplyWorkflow).where(ApplyWorkflow.workflow_id == workflow_id)
|
189
|
+
res = await db.execute(stm)
|
190
|
+
job = res.scalars().first()
|
191
|
+
if job:
|
192
|
+
raise HTTPException(
|
193
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
194
|
+
detail=(
|
195
|
+
f"Cannot remove workflow {workflow_id}: "
|
196
|
+
f"it's still linked to job {job.id}."
|
197
|
+
),
|
198
|
+
)
|
199
|
+
|
184
200
|
await db.delete(workflow)
|
185
201
|
await db.commit()
|
186
202
|
|
@@ -203,6 +219,18 @@ async def export_worfklow(
|
|
203
219
|
workflow = await _get_workflow_check_owner(
|
204
220
|
project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
|
205
221
|
)
|
222
|
+
# Emit a warning when exporting a workflow with custom tasks
|
223
|
+
logger = set_logger(None)
|
224
|
+
for wftask in workflow.task_list:
|
225
|
+
if wftask.task.owner is not None:
|
226
|
+
logger.warning(
|
227
|
+
f"Custom tasks (like the one with id={wftask.task.id} and "
|
228
|
+
f'source="{wftask.task.source}") are not meant to be '
|
229
|
+
"portable; re-importing this workflow may not work as "
|
230
|
+
"expected."
|
231
|
+
)
|
232
|
+
close_logger(logger)
|
233
|
+
|
206
234
|
await db.close()
|
207
235
|
return workflow
|
208
236
|
|
@@ -237,37 +265,21 @@ async def import_workflow(
|
|
237
265
|
)
|
238
266
|
|
239
267
|
# Check that all required tasks are available
|
240
|
-
# NOTE: by now we go through the pair (source, name), but later on we may
|
241
|
-
# combine them into source -- see issue #293.
|
242
268
|
tasks = [wf_task.task for wf_task in workflow.task_list]
|
243
|
-
|
269
|
+
source_to_id = {}
|
244
270
|
for task in tasks:
|
245
271
|
source = task.source
|
246
|
-
|
247
|
-
if not (source, name) in sourcename_to_id.keys():
|
272
|
+
if source not in source_to_id.keys():
|
248
273
|
stm = select(Task).where(Task.source == source)
|
249
274
|
tasks_by_source = (await db.execute(stm)).scalars().all()
|
250
|
-
if
|
275
|
+
if len(tasks_by_source) != 1:
|
251
276
|
raise HTTPException(
|
252
277
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
253
|
-
detail=(
|
254
|
-
|
255
|
-
|
256
|
-
stm = (
|
257
|
-
select(Task)
|
258
|
-
.where(Task.source == source)
|
259
|
-
.where(Task.name == name)
|
278
|
+
detail=(
|
279
|
+
f"Found {len(tasks_by_source)} tasks with {source=}."
|
280
|
+
),
|
260
281
|
)
|
261
|
-
|
262
|
-
if len(current_task) != 1:
|
263
|
-
raise HTTPException(
|
264
|
-
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
265
|
-
detail=(
|
266
|
-
f"Found {len(current_task)} tasks with "
|
267
|
-
f"{name =} and {source=}."
|
268
|
-
),
|
269
|
-
)
|
270
|
-
sourcename_to_id[(source, name)] = current_task[0].id
|
282
|
+
source_to_id[source] = tasks_by_source[0].id
|
271
283
|
|
272
284
|
# Create new Workflow (with empty task_list)
|
273
285
|
db_workflow = Workflow(
|
@@ -283,8 +295,7 @@ async def import_workflow(
|
|
283
295
|
for _, wf_task in enumerate(workflow.task_list):
|
284
296
|
# Identify task_id
|
285
297
|
source = wf_task.task.source
|
286
|
-
|
287
|
-
task_id = sourcename_to_id[(source, name)]
|
298
|
+
task_id = source_to_id[source]
|
288
299
|
# Prepare new_wf_task
|
289
300
|
new_wf_task = WorkflowTaskCreate(
|
290
301
|
**wf_task.dict(exclude_none=True),
|
fractal_server/app/models/job.py
CHANGED
@@ -6,8 +6,9 @@ from sqlalchemy import Column
|
|
6
6
|
from sqlalchemy.types import DateTime
|
7
7
|
from sqlmodel import Field
|
8
8
|
from sqlmodel import Relationship
|
9
|
+
from sqlmodel import SQLModel
|
9
10
|
|
10
|
-
from ...common.schemas import
|
11
|
+
from ...common.schemas import _ApplyWorkflowBase
|
11
12
|
from ...utils import get_timestamp
|
12
13
|
from .project import Dataset
|
13
14
|
from .workflow import Workflow
|
@@ -38,7 +39,7 @@ class JobStatusType(str, Enum):
|
|
38
39
|
FAILED = "failed"
|
39
40
|
|
40
41
|
|
41
|
-
class ApplyWorkflow(
|
42
|
+
class ApplyWorkflow(_ApplyWorkflowBase, SQLModel, table=True):
|
42
43
|
"""
|
43
44
|
Represent a workflow run
|
44
45
|
|
@@ -5,6 +5,7 @@ from sqlalchemy import Column
|
|
5
5
|
from sqlalchemy.types import JSON
|
6
6
|
from sqlmodel import Field
|
7
7
|
from sqlmodel import Relationship
|
8
|
+
from sqlmodel import SQLModel
|
8
9
|
|
9
10
|
from ...common.schemas.project import _DatasetBase
|
10
11
|
from ...common.schemas.project import _ProjectBase
|
@@ -14,7 +15,7 @@ from .security import UserOAuth as User
|
|
14
15
|
from .workflow import Workflow
|
15
16
|
|
16
17
|
|
17
|
-
class Dataset(_DatasetBase, table=True):
|
18
|
+
class Dataset(_DatasetBase, SQLModel, table=True):
|
18
19
|
"""
|
19
20
|
Represent a dataset
|
20
21
|
|
@@ -49,7 +50,7 @@ class Dataset(_DatasetBase, table=True):
|
|
49
50
|
return [r.path for r in self.resource_list]
|
50
51
|
|
51
52
|
|
52
|
-
class Project(_ProjectBase, table=True):
|
53
|
+
class Project(_ProjectBase, SQLModel, table=True):
|
53
54
|
id: Optional[int] = Field(default=None, primary_key=True)
|
54
55
|
|
55
56
|
user_list: list[User] = Relationship(
|
@@ -82,6 +83,6 @@ class Project(_ProjectBase, table=True):
|
|
82
83
|
)
|
83
84
|
|
84
85
|
|
85
|
-
class Resource(_ResourceBase, table=True):
|
86
|
+
class Resource(_ResourceBase, SQLModel, table=True):
|
86
87
|
id: Optional[int] = Field(default=None, primary_key=True)
|
87
88
|
dataset_id: int = Field(foreign_key="dataset.id")
|
@@ -62,6 +62,8 @@ class UserOAuth(SQLModel, table=True):
|
|
62
62
|
|
63
63
|
slurm_user: Optional[str]
|
64
64
|
cache_dir: Optional[str]
|
65
|
+
username: Optional[str]
|
66
|
+
|
65
67
|
oauth_accounts: list["OAuthAccount"] = Relationship(
|
66
68
|
back_populates="user",
|
67
69
|
sa_relationship_kwargs={"lazy": "selectin", "cascade": "all, delete"},
|
@@ -6,12 +6,13 @@ from sqlalchemy import Column
|
|
6
6
|
from sqlalchemy.types import DateTime
|
7
7
|
from sqlalchemy.types import JSON
|
8
8
|
from sqlmodel import Field
|
9
|
+
from sqlmodel import SQLModel
|
9
10
|
|
10
11
|
from ...common.schemas import _StateBase
|
11
12
|
from ...utils import get_timestamp
|
12
13
|
|
13
14
|
|
14
|
-
class State(_StateBase, table=True):
|
15
|
+
class State(_StateBase, SQLModel, table=True):
|
15
16
|
"""
|
16
17
|
Store arbitrary data in the database
|
17
18
|
|
@@ -4,33 +4,44 @@ from typing import Optional
|
|
4
4
|
from sqlalchemy import Column
|
5
5
|
from sqlalchemy.types import JSON
|
6
6
|
from sqlmodel import Field
|
7
|
+
from sqlmodel import SQLModel
|
7
8
|
|
8
9
|
from ...common.schemas.task import _TaskBase
|
9
10
|
|
10
11
|
|
11
|
-
class Task(_TaskBase, table=True):
|
12
|
+
class Task(_TaskBase, SQLModel, table=True):
|
12
13
|
"""
|
13
14
|
Task model
|
14
15
|
|
15
16
|
Attributes:
|
16
17
|
id: Primary key
|
17
|
-
command:
|
18
|
-
input_type:
|
19
|
-
output_type:
|
20
|
-
|
21
|
-
|
18
|
+
command: Executable command
|
19
|
+
input_type: Expected type of input `Dataset`
|
20
|
+
output_type: Expected type of output `Dataset`
|
21
|
+
meta:
|
22
|
+
Additional metadata related to execution (e.g. computational
|
23
|
+
resources)
|
22
24
|
source: inherited from `_TaskBase`
|
23
25
|
name: inherited from `_TaskBase`
|
26
|
+
args_schema: JSON schema of task arguments
|
27
|
+
args_schema_version:
|
28
|
+
label pointing at how the JSON schema of task arguments was
|
29
|
+
generated
|
24
30
|
"""
|
25
31
|
|
26
32
|
id: Optional[int] = Field(default=None, primary_key=True)
|
33
|
+
name: str
|
27
34
|
command: str
|
35
|
+
source: str = Field(unique=True)
|
28
36
|
input_type: str
|
29
37
|
output_type: str
|
30
|
-
default_args: Optional[dict[str, Any]] = Field(
|
31
|
-
sa_column=Column(JSON), default={}
|
32
|
-
)
|
33
38
|
meta: Optional[dict[str, Any]] = Field(sa_column=Column(JSON), default={})
|
39
|
+
owner: Optional[str] = None
|
40
|
+
version: Optional[str] = None
|
41
|
+
args_schema: Optional[dict[str, Any]] = Field(
|
42
|
+
sa_column=Column(JSON), default=None
|
43
|
+
)
|
44
|
+
args_schema_version: Optional[str]
|
34
45
|
|
35
46
|
@property
|
36
47
|
def parallelization_level(self) -> Optional[str]:
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
1
3
|
from typing import Any
|
2
4
|
from typing import Optional
|
3
5
|
from typing import Union
|
@@ -8,6 +10,7 @@ from sqlalchemy.ext.orderinglist import ordering_list
|
|
8
10
|
from sqlalchemy.types import JSON
|
9
11
|
from sqlmodel import Field
|
10
12
|
from sqlmodel import Relationship
|
13
|
+
from sqlmodel import SQLModel
|
11
14
|
|
12
15
|
from ...common.schemas.workflow import _WorkflowBase
|
13
16
|
from ...common.schemas.workflow import _WorkflowTaskBase
|
@@ -15,7 +18,7 @@ from ..db import AsyncSession
|
|
15
18
|
from .task import Task
|
16
19
|
|
17
20
|
|
18
|
-
class WorkflowTask(_WorkflowTaskBase, table=True):
|
21
|
+
class WorkflowTask(_WorkflowTaskBase, SQLModel, table=True):
|
19
22
|
"""
|
20
23
|
A Task as part of a Workflow
|
21
24
|
|
@@ -36,8 +39,7 @@ class WorkflowTask(_WorkflowTaskBase, table=True):
|
|
36
39
|
meta:
|
37
40
|
Additional parameters useful for execution
|
38
41
|
args:
|
39
|
-
|
40
|
-
`WorkflowTask.task.args`
|
42
|
+
Task arguments
|
41
43
|
task:
|
42
44
|
`Task` object associated with the current `WorkflowTask`
|
43
45
|
|
@@ -81,15 +83,6 @@ class WorkflowTask(_WorkflowTaskBase, table=True):
|
|
81
83
|
)
|
82
84
|
return value
|
83
85
|
|
84
|
-
@property
|
85
|
-
def arguments(self):
|
86
|
-
"""
|
87
|
-
Override default arguments
|
88
|
-
"""
|
89
|
-
out = self.task.default_args.copy()
|
90
|
-
out.update(self.args or {})
|
91
|
-
return out
|
92
|
-
|
93
86
|
@property
|
94
87
|
def is_parallel(self) -> bool:
|
95
88
|
return self.task.is_parallel
|
@@ -108,23 +101,8 @@ class WorkflowTask(_WorkflowTaskBase, table=True):
|
|
108
101
|
res.update(self.meta or {})
|
109
102
|
return res
|
110
103
|
|
111
|
-
def assemble_args(self, extra: dict[str, Any] = None) -> dict:
|
112
|
-
"""
|
113
|
-
Merge of `extra` arguments and `self.arguments`.
|
114
104
|
|
115
|
-
|
116
|
-
full_args:
|
117
|
-
A dictionary consisting of the merge of `extra` and
|
118
|
-
`self.arguments`.
|
119
|
-
"""
|
120
|
-
full_args = {}
|
121
|
-
if extra:
|
122
|
-
full_args.update(extra)
|
123
|
-
full_args.update(self.arguments)
|
124
|
-
return full_args
|
125
|
-
|
126
|
-
|
127
|
-
class Workflow(_WorkflowBase, table=True):
|
105
|
+
class Workflow(_WorkflowBase, SQLModel, table=True):
|
128
106
|
"""
|
129
107
|
Workflow
|
130
108
|
|
@@ -172,7 +150,34 @@ class Workflow(_WorkflowBase, table=True):
|
|
172
150
|
"""
|
173
151
|
if order is None:
|
174
152
|
order = len(self.task_list)
|
175
|
-
|
153
|
+
|
154
|
+
# Get task from db, extract the JSON Schema for its arguments (if any),
|
155
|
+
# read default values and set them in default_args
|
156
|
+
db_task = await db.get(Task, task_id)
|
157
|
+
default_args = {}
|
158
|
+
if db_task.args_schema is not None:
|
159
|
+
try:
|
160
|
+
properties = db_task.args_schema["properties"]
|
161
|
+
for prop_name, prop_schema in properties.items():
|
162
|
+
default_value = prop_schema.get("default", None)
|
163
|
+
if default_value:
|
164
|
+
default_args[prop_name] = default_value
|
165
|
+
except KeyError as e:
|
166
|
+
logging.warning(
|
167
|
+
"Cannot set default_args from args_schema="
|
168
|
+
f"{json.dumps(db_task.args_schema)}\n"
|
169
|
+
f"Original KeyError: {str(e)}"
|
170
|
+
)
|
171
|
+
# Override default_args with args
|
172
|
+
actual_args = default_args.copy()
|
173
|
+
if args is not None:
|
174
|
+
for k, v in args.items():
|
175
|
+
actual_args[k] = v
|
176
|
+
if not actual_args:
|
177
|
+
actual_args = None
|
178
|
+
|
179
|
+
# Create DB entry
|
180
|
+
wf_task = WorkflowTask(task_id=task_id, args=actual_args, meta=meta)
|
176
181
|
db.add(wf_task)
|
177
182
|
self.task_list.insert(order, wf_task)
|
178
183
|
self.task_list.reorder() # type: ignore
|
@@ -187,10 +187,10 @@ def call_single_task(
|
|
187
187
|
Call a single task
|
188
188
|
|
189
189
|
This assembles the runner arguments (input_paths, output_path, ...) and
|
190
|
-
|
191
|
-
message or index in the dummy task), writes them to file, call the
|
192
|
-
executable command passing the arguments file as an input and
|
193
|
-
the output.
|
190
|
+
wftask arguments (i.e., arguments that are specific to the WorkflowTask,
|
191
|
+
such as message or index in the dummy task), writes them to file, call the
|
192
|
+
task executable command passing the arguments file as an input and
|
193
|
+
assembles the output.
|
194
194
|
|
195
195
|
**Note**: This function is directly submitted to a
|
196
196
|
`concurrent.futures`-compatible executor, as in
|
@@ -205,7 +205,7 @@ def call_single_task(
|
|
205
205
|
Args:
|
206
206
|
wftask:
|
207
207
|
The workflow task to be called. This includes task specific
|
208
|
-
arguments via the
|
208
|
+
arguments via the wftask.args attribute.
|
209
209
|
task_pars:
|
210
210
|
The parameters required to run the task which are not specific to
|
211
211
|
the task, e.g., I/O paths.
|
@@ -238,11 +238,12 @@ def call_single_task(
|
|
238
238
|
task_order=wftask.order,
|
239
239
|
)
|
240
240
|
|
241
|
-
#
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
241
|
+
# write args file (by assembling task_pars and wftask.args)
|
242
|
+
write_args_file(
|
243
|
+
task_pars.dict(),
|
244
|
+
wftask.args or {},
|
245
|
+
path=task_files.args,
|
246
|
+
)
|
246
247
|
|
247
248
|
# assemble full command
|
248
249
|
cmd = (
|
@@ -341,10 +342,10 @@ def call_single_parallel_task(
|
|
341
342
|
component=component,
|
342
343
|
)
|
343
344
|
|
344
|
-
#
|
345
|
+
# write args file (by assembling task_pars, wftask.args and component)
|
345
346
|
write_args_file(
|
346
347
|
task_pars.dict(),
|
347
|
-
wftask.
|
348
|
+
wftask.args or {},
|
348
349
|
dict(component=component),
|
349
350
|
path=task_files.args,
|
350
351
|
)
|
@@ -743,7 +743,8 @@ class FractalSlurmExecutor(SlurmExecutor):
|
|
743
743
|
"the job started running, the SLURM out/err files "
|
744
744
|
"will be empty.\n"
|
745
745
|
"2. Some error occurred upon writing the file to disk "
|
746
|
-
"(e.g.
|
746
|
+
"(e.g. because there is not enough space on disk, or "
|
747
|
+
"due to an overloaded NFS filesystem). "
|
747
748
|
"Note that the server configuration has "
|
748
749
|
"FRACTAL_SLURM_OUTPUT_FILE_GRACE_TIME="
|
749
750
|
f"{settings.FRACTAL_SLURM_OUTPUT_FILE_GRACE_TIME} "
|
@@ -3,6 +3,7 @@ from .manifest import * # noqa: F403
|
|
3
3
|
from .project import * # noqa: F403
|
4
4
|
from .state import * # noqa: F403
|
5
5
|
from .task import * # noqa: F403
|
6
|
+
from .task_collection import * # noqa: F403
|
6
7
|
from .user import * # noqa: F403
|
7
8
|
from .workflow import * # noqa: F403
|
8
9
|
|
@@ -10,6 +11,7 @@ from .workflow import * # noqa: F403
|
|
10
11
|
__all__ = (
|
11
12
|
project.__all__ # noqa: F405
|
12
13
|
+ task.__all__ # noqa: F405
|
14
|
+
+ task_collection.__all__ # noqa: F405
|
13
15
|
+ workflow.__all__ # noqa: F405
|
14
16
|
+ applyworkflow.__all__ # noqa: F405
|
15
17
|
+ manifest.__all__ # noqa: F405
|