fractal-server 1.3.0a2__py3-none-any.whl → 1.3.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/app/api/v1/dataset.py +21 -0
  3. fractal_server/app/api/v1/task.py +78 -16
  4. fractal_server/app/api/v1/workflow.py +36 -25
  5. fractal_server/app/models/job.py +3 -2
  6. fractal_server/app/models/project.py +4 -3
  7. fractal_server/app/models/security.py +2 -0
  8. fractal_server/app/models/state.py +2 -1
  9. fractal_server/app/models/task.py +20 -9
  10. fractal_server/app/models/workflow.py +34 -29
  11. fractal_server/app/runner/_common.py +13 -12
  12. fractal_server/app/runner/_slurm/executor.py +2 -1
  13. fractal_server/common/requirements.txt +1 -1
  14. fractal_server/common/schemas/__init__.py +2 -0
  15. fractal_server/common/schemas/applyworkflow.py +6 -7
  16. fractal_server/common/schemas/manifest.py +32 -15
  17. fractal_server/common/schemas/project.py +8 -10
  18. fractal_server/common/schemas/state.py +3 -4
  19. fractal_server/common/schemas/task.py +28 -97
  20. fractal_server/common/schemas/task_collection.py +101 -0
  21. fractal_server/common/schemas/user.py +5 -0
  22. fractal_server/common/schemas/workflow.py +9 -11
  23. fractal_server/common/tests/test_manifest.py +36 -4
  24. fractal_server/common/tests/test_task.py +16 -0
  25. fractal_server/common/tests/test_task_collection.py +24 -0
  26. fractal_server/common/tests/test_user.py +12 -0
  27. fractal_server/main.py +3 -0
  28. fractal_server/migrations/versions/4c308bcaea2b_add_task_args_schema_and_task_args_.py +38 -0
  29. fractal_server/migrations/versions/{e8f4051440be_new_initial_schema.py → 50a13d6138fd_initial_schema.py} +18 -10
  30. fractal_server/migrations/versions/{fda995215ae9_drop_applyworkflow_overwrite_input.py → f384e1c0cf5d_drop_task_default_args_columns.py} +9 -10
  31. fractal_server/tasks/collection.py +180 -115
  32. {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/METADATA +2 -1
  33. {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/RECORD +36 -34
  34. {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/WHEEL +1 -1
  35. fractal_server/migrations/versions/bb1cca2acc40_add_applyworkflow_end_timestamp.py +0 -31
  36. {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/LICENSE +0 -0
  37. {fractal_server-1.3.0a2.dist-info → fractal_server-1.3.0a3.dist-info}/entry_points.txt +0 -0
@@ -1 +1 @@
1
- __VERSION__ = "1.3.0a2"
1
+ __VERSION__ = "1.3.0a3"
@@ -5,10 +5,12 @@ from fastapi import Depends
5
5
  from fastapi import HTTPException
6
6
  from fastapi import Response
7
7
  from fastapi import status
8
+ from sqlmodel import or_
8
9
  from sqlmodel import select
9
10
 
10
11
  from ...db import AsyncSession
11
12
  from ...db import get_db
13
+ from ...models import ApplyWorkflow
12
14
  from ...models import Dataset
13
15
  from ...models import DatasetCreate
14
16
  from ...models import DatasetRead
@@ -124,6 +126,25 @@ async def delete_dataset(
124
126
  db=db,
125
127
  )
126
128
  dataset = output["dataset"]
129
+
130
+ # Check that no ApplyWorkflow is in relationship with the current Dataset
131
+ stm = select(ApplyWorkflow).filter(
132
+ or_(
133
+ ApplyWorkflow.input_dataset_id == dataset_id,
134
+ ApplyWorkflow.output_dataset_id == dataset_id,
135
+ )
136
+ )
137
+ res = await db.execute(stm)
138
+ job = res.scalars().first()
139
+ if job:
140
+ raise HTTPException(
141
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
142
+ detail=(
143
+ f"Cannot remove dataset {dataset_id}: "
144
+ f"it's still linked to job {job.id}."
145
+ ),
146
+ )
147
+
127
148
  await db.delete(dataset)
128
149
  await db.commit()
129
150
  await db.close()
@@ -41,12 +41,13 @@ from ...db import get_db
41
41
  from ...db import get_sync_db
42
42
  from ...models import State
43
43
  from ...models import Task
44
- from ...security import current_active_superuser
45
44
  from ...security import current_active_user
46
45
  from ...security import User
47
46
 
48
47
  router = APIRouter()
49
48
 
49
+ logger = set_logger(__name__)
50
+
50
51
 
51
52
  async def _background_collect_pip(
52
53
  state_id: int,
@@ -178,7 +179,6 @@ async def collect_tasks_pip(
178
179
  response: Response,
179
180
  user: User = Depends(current_active_user),
180
181
  db: AsyncSession = Depends(get_db),
181
- public: bool = True,
182
182
  ) -> StateRead: # State[TaskCollectStatus]
183
183
  """
184
184
  Task collection endpoint
@@ -192,7 +192,7 @@ async def collect_tasks_pip(
192
192
  # Validate payload as _TaskCollectPip, which has more strict checks than
193
193
  # TaskCollectPip
194
194
  try:
195
- task_pkg = _TaskCollectPip(**task_collect.dict())
195
+ task_pkg = _TaskCollectPip(**task_collect.dict(exclude_unset=True))
196
196
  except ValidationError as e:
197
197
  raise HTTPException(
198
198
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -201,6 +201,7 @@ async def collect_tasks_pip(
201
201
 
202
202
  with TemporaryDirectory() as tmpdir:
203
203
  try:
204
+ # Copy or download the package wheel file to tmpdir
204
205
  if task_pkg.is_local_package:
205
206
  shell_copy(task_pkg.package_path.as_posix(), tmpdir)
206
207
  pkg_path = Path(tmpdir) / task_pkg.package_path.name
@@ -208,10 +209,12 @@ async def collect_tasks_pip(
208
209
  pkg_path = await download_package(
209
210
  task_pkg=task_pkg, dest=tmpdir
210
211
  )
211
-
212
- version_manifest = inspect_package(pkg_path)
213
-
214
- task_pkg.version = version_manifest["version"]
212
+ # Read package info from wheel file, and override the ones coming
213
+ # from the request body
214
+ pkg_info = inspect_package(pkg_path)
215
+ task_pkg.package_name = pkg_info["pkg_name"]
216
+ task_pkg.package_version = pkg_info["pkg_version"]
217
+ task_pkg.package_manifest = pkg_info["pkg_manifest"]
215
218
  task_pkg.check()
216
219
  except Exception as e:
217
220
  raise HTTPException(
@@ -220,14 +223,28 @@ async def collect_tasks_pip(
220
223
  )
221
224
 
222
225
  try:
223
- pkg_user = None if public else user.slurm_user
224
- venv_path = create_package_dir_pip(task_pkg=task_pkg, user=pkg_user)
226
+ venv_path = create_package_dir_pip(task_pkg=task_pkg)
225
227
  except FileExistsError:
226
- venv_path = create_package_dir_pip(
227
- task_pkg=task_pkg, user=pkg_user, create=False
228
- )
228
+ venv_path = create_package_dir_pip(task_pkg=task_pkg, create=False)
229
229
  try:
230
230
  task_collect_status = get_collection_data(venv_path)
231
+ for task in task_collect_status.task_list:
232
+ db_task = await db.get(Task, task.id)
233
+ if (
234
+ (not db_task)
235
+ or db_task.source != task.source
236
+ or db_task.name != task.name
237
+ ):
238
+ await db.close()
239
+ raise HTTPException(
240
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
241
+ detail=(
242
+ "Cannot collect package. Folder already exists, "
243
+ f"but task {task.id} does not exists or it does "
244
+ f"not have the expected source ({task.source}) or "
245
+ f"name ({task.name})."
246
+ ),
247
+ )
231
248
  except FileNotFoundError as e:
232
249
  await db.close()
233
250
  raise HTTPException(
@@ -344,19 +361,42 @@ async def get_task(
344
361
  async def patch_task(
345
362
  task_id: int,
346
363
  task_update: TaskUpdate,
347
- user: User = Depends(current_active_superuser),
364
+ user: User = Depends(current_active_user),
348
365
  db: AsyncSession = Depends(get_db),
349
366
  ) -> Optional[TaskRead]:
350
367
  """
351
- Edit a specific task
368
+ Edit a specific task (restricted to superusers and task owner)
352
369
  """
370
+
353
371
  if task_update.source:
354
372
  raise HTTPException(
355
373
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
356
374
  detail="patch_task endpoint cannot set `source`",
357
375
  )
358
376
 
377
+ # Retrieve task from database
359
378
  db_task = await db.get(Task, task_id)
379
+
380
+ # This check constitutes a preliminary version of access control:
381
+ # if the current user is not a superuser and differs from the task owner
382
+ # (including when `owner is None`), we raise an 403 HTTP Exception.
383
+ if not user.is_superuser:
384
+ if db_task.owner is None:
385
+ raise HTTPException(
386
+ status_code=status.HTTP_403_FORBIDDEN,
387
+ detail=("Only a superuser can edit a task with `owner=None`."),
388
+ )
389
+ else:
390
+ owner = user.username or user.slurm_user
391
+ if owner != db_task.owner:
392
+ raise HTTPException(
393
+ status_code=status.HTTP_403_FORBIDDEN,
394
+ detail=(
395
+ f"Current user ({owner}) cannot modify task "
396
+ f"({task_id}) with different owner ({db_task.owner})."
397
+ ),
398
+ )
399
+
360
400
  update = task_update.dict(exclude_unset=True)
361
401
  for key, value in update.items():
362
402
  if isinstance(value, str):
@@ -386,14 +426,36 @@ async def create_task(
386
426
  """
387
427
  Create a new task
388
428
  """
429
+ # Set task.owner attribute
430
+ if user.username:
431
+ owner = user.username
432
+ elif user.slurm_user:
433
+ owner = user.slurm_user
434
+ else:
435
+ raise HTTPException(
436
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
437
+ detail=(
438
+ "Cannot add a new task because current user does not "
439
+ "have `username` or `slurm_user` attributes."
440
+ ),
441
+ )
442
+
443
+ # Prepend owner to task.source
444
+ task.source = f"{owner}:{task.source}"
445
+
446
+ # Verify that source is not already in use (note: this check is only useful
447
+ # to provide a user-friendly error message, but `task.source` uniqueness is
448
+ # already guaranteed by a constraint in the table definition).
389
449
  stm = select(Task).where(Task.source == task.source)
390
450
  res = await db.execute(stm)
391
451
  if res.scalars().all():
392
452
  raise HTTPException(
393
453
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
394
- detail=f"Task with source={task.source} already in use",
454
+ detail=f'Task source "{task.source}" already in use',
395
455
  )
396
- db_task = Task.from_orm(task)
456
+
457
+ # Add task
458
+ db_task = Task(**task.dict(), owner=owner)
397
459
  db.add(db_task)
398
460
  await db.commit()
399
461
  await db.refresh(db_task)
@@ -20,8 +20,11 @@ from fastapi import Response
20
20
  from fastapi import status
21
21
  from sqlmodel import select
22
22
 
23
+ from ....logger import close_logger
24
+ from ....logger import set_logger
23
25
  from ...db import AsyncSession
24
26
  from ...db import get_db
27
+ from ...models import ApplyWorkflow
25
28
  from ...models import Task
26
29
  from ...models import Workflow
27
30
  from ...models import WorkflowCreate
@@ -181,6 +184,19 @@ async def delete_workflow(
181
184
  project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
182
185
  )
183
186
 
187
+ # Check that no ApplyWorkflow is in relationship with the current Workflow
188
+ stm = select(ApplyWorkflow).where(ApplyWorkflow.workflow_id == workflow_id)
189
+ res = await db.execute(stm)
190
+ job = res.scalars().first()
191
+ if job:
192
+ raise HTTPException(
193
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
194
+ detail=(
195
+ f"Cannot remove workflow {workflow_id}: "
196
+ f"it's still linked to job {job.id}."
197
+ ),
198
+ )
199
+
184
200
  await db.delete(workflow)
185
201
  await db.commit()
186
202
 
@@ -203,6 +219,18 @@ async def export_worfklow(
203
219
  workflow = await _get_workflow_check_owner(
204
220
  project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
205
221
  )
222
+ # Emit a warning when exporting a workflow with custom tasks
223
+ logger = set_logger(None)
224
+ for wftask in workflow.task_list:
225
+ if wftask.task.owner is not None:
226
+ logger.warning(
227
+ f"Custom tasks (like the one with id={wftask.task.id} and "
228
+ f'source="{wftask.task.source}") are not meant to be '
229
+ "portable; re-importing this workflow may not work as "
230
+ "expected."
231
+ )
232
+ close_logger(logger)
233
+
206
234
  await db.close()
207
235
  return workflow
208
236
 
@@ -237,37 +265,21 @@ async def import_workflow(
237
265
  )
238
266
 
239
267
  # Check that all required tasks are available
240
- # NOTE: by now we go through the pair (source, name), but later on we may
241
- # combine them into source -- see issue #293.
242
268
  tasks = [wf_task.task for wf_task in workflow.task_list]
243
- sourcename_to_id = {}
269
+ source_to_id = {}
244
270
  for task in tasks:
245
271
  source = task.source
246
- name = task.name
247
- if not (source, name) in sourcename_to_id.keys():
272
+ if source not in source_to_id.keys():
248
273
  stm = select(Task).where(Task.source == source)
249
274
  tasks_by_source = (await db.execute(stm)).scalars().all()
250
- if not tasks_by_source:
275
+ if len(tasks_by_source) != 1:
251
276
  raise HTTPException(
252
277
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
253
- detail=(f"Found 0 tasks with {source=}."),
254
- )
255
- else:
256
- stm = (
257
- select(Task)
258
- .where(Task.source == source)
259
- .where(Task.name == name)
278
+ detail=(
279
+ f"Found {len(tasks_by_source)} tasks with {source=}."
280
+ ),
260
281
  )
261
- current_task = (await db.execute(stm)).scalars().all()
262
- if len(current_task) != 1:
263
- raise HTTPException(
264
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
265
- detail=(
266
- f"Found {len(current_task)} tasks with "
267
- f"{name =} and {source=}."
268
- ),
269
- )
270
- sourcename_to_id[(source, name)] = current_task[0].id
282
+ source_to_id[source] = tasks_by_source[0].id
271
283
 
272
284
  # Create new Workflow (with empty task_list)
273
285
  db_workflow = Workflow(
@@ -283,8 +295,7 @@ async def import_workflow(
283
295
  for _, wf_task in enumerate(workflow.task_list):
284
296
  # Identify task_id
285
297
  source = wf_task.task.source
286
- name = wf_task.task.name
287
- task_id = sourcename_to_id[(source, name)]
298
+ task_id = source_to_id[source]
288
299
  # Prepare new_wf_task
289
300
  new_wf_task = WorkflowTaskCreate(
290
301
  **wf_task.dict(exclude_none=True),
@@ -6,8 +6,9 @@ from sqlalchemy import Column
6
6
  from sqlalchemy.types import DateTime
7
7
  from sqlmodel import Field
8
8
  from sqlmodel import Relationship
9
+ from sqlmodel import SQLModel
9
10
 
10
- from ...common.schemas import ApplyWorkflowBase
11
+ from ...common.schemas import _ApplyWorkflowBase
11
12
  from ...utils import get_timestamp
12
13
  from .project import Dataset
13
14
  from .workflow import Workflow
@@ -38,7 +39,7 @@ class JobStatusType(str, Enum):
38
39
  FAILED = "failed"
39
40
 
40
41
 
41
- class ApplyWorkflow(ApplyWorkflowBase, table=True):
42
+ class ApplyWorkflow(_ApplyWorkflowBase, SQLModel, table=True):
42
43
  """
43
44
  Represent a workflow run
44
45
 
@@ -5,6 +5,7 @@ from sqlalchemy import Column
5
5
  from sqlalchemy.types import JSON
6
6
  from sqlmodel import Field
7
7
  from sqlmodel import Relationship
8
+ from sqlmodel import SQLModel
8
9
 
9
10
  from ...common.schemas.project import _DatasetBase
10
11
  from ...common.schemas.project import _ProjectBase
@@ -14,7 +15,7 @@ from .security import UserOAuth as User
14
15
  from .workflow import Workflow
15
16
 
16
17
 
17
- class Dataset(_DatasetBase, table=True):
18
+ class Dataset(_DatasetBase, SQLModel, table=True):
18
19
  """
19
20
  Represent a dataset
20
21
 
@@ -49,7 +50,7 @@ class Dataset(_DatasetBase, table=True):
49
50
  return [r.path for r in self.resource_list]
50
51
 
51
52
 
52
- class Project(_ProjectBase, table=True):
53
+ class Project(_ProjectBase, SQLModel, table=True):
53
54
  id: Optional[int] = Field(default=None, primary_key=True)
54
55
 
55
56
  user_list: list[User] = Relationship(
@@ -82,6 +83,6 @@ class Project(_ProjectBase, table=True):
82
83
  )
83
84
 
84
85
 
85
- class Resource(_ResourceBase, table=True):
86
+ class Resource(_ResourceBase, SQLModel, table=True):
86
87
  id: Optional[int] = Field(default=None, primary_key=True)
87
88
  dataset_id: int = Field(foreign_key="dataset.id")
@@ -62,6 +62,8 @@ class UserOAuth(SQLModel, table=True):
62
62
 
63
63
  slurm_user: Optional[str]
64
64
  cache_dir: Optional[str]
65
+ username: Optional[str]
66
+
65
67
  oauth_accounts: list["OAuthAccount"] = Relationship(
66
68
  back_populates="user",
67
69
  sa_relationship_kwargs={"lazy": "selectin", "cascade": "all, delete"},
@@ -6,12 +6,13 @@ from sqlalchemy import Column
6
6
  from sqlalchemy.types import DateTime
7
7
  from sqlalchemy.types import JSON
8
8
  from sqlmodel import Field
9
+ from sqlmodel import SQLModel
9
10
 
10
11
  from ...common.schemas import _StateBase
11
12
  from ...utils import get_timestamp
12
13
 
13
14
 
14
- class State(_StateBase, table=True):
15
+ class State(_StateBase, SQLModel, table=True):
15
16
  """
16
17
  Store arbitrary data in the database
17
18
 
@@ -4,33 +4,44 @@ from typing import Optional
4
4
  from sqlalchemy import Column
5
5
  from sqlalchemy.types import JSON
6
6
  from sqlmodel import Field
7
+ from sqlmodel import SQLModel
7
8
 
8
9
  from ...common.schemas.task import _TaskBase
9
10
 
10
11
 
11
- class Task(_TaskBase, table=True):
12
+ class Task(_TaskBase, SQLModel, table=True):
12
13
  """
13
14
  Task model
14
15
 
15
16
  Attributes:
16
17
  id: Primary key
17
- command: TBD
18
- input_type: TBD
19
- output_type: TBD
20
- default_args: TBD
21
- meta: TBD
18
+ command: Executable command
19
+ input_type: Expected type of input `Dataset`
20
+ output_type: Expected type of output `Dataset`
21
+ meta:
22
+ Additional metadata related to execution (e.g. computational
23
+ resources)
22
24
  source: inherited from `_TaskBase`
23
25
  name: inherited from `_TaskBase`
26
+ args_schema: JSON schema of task arguments
27
+ args_schema_version:
28
+ label pointing at how the JSON schema of task arguments was
29
+ generated
24
30
  """
25
31
 
26
32
  id: Optional[int] = Field(default=None, primary_key=True)
33
+ name: str
27
34
  command: str
35
+ source: str = Field(unique=True)
28
36
  input_type: str
29
37
  output_type: str
30
- default_args: Optional[dict[str, Any]] = Field(
31
- sa_column=Column(JSON), default={}
32
- )
33
38
  meta: Optional[dict[str, Any]] = Field(sa_column=Column(JSON), default={})
39
+ owner: Optional[str] = None
40
+ version: Optional[str] = None
41
+ args_schema: Optional[dict[str, Any]] = Field(
42
+ sa_column=Column(JSON), default=None
43
+ )
44
+ args_schema_version: Optional[str]
34
45
 
35
46
  @property
36
47
  def parallelization_level(self) -> Optional[str]:
@@ -1,3 +1,5 @@
1
+ import json
2
+ import logging
1
3
  from typing import Any
2
4
  from typing import Optional
3
5
  from typing import Union
@@ -8,6 +10,7 @@ from sqlalchemy.ext.orderinglist import ordering_list
8
10
  from sqlalchemy.types import JSON
9
11
  from sqlmodel import Field
10
12
  from sqlmodel import Relationship
13
+ from sqlmodel import SQLModel
11
14
 
12
15
  from ...common.schemas.workflow import _WorkflowBase
13
16
  from ...common.schemas.workflow import _WorkflowTaskBase
@@ -15,7 +18,7 @@ from ..db import AsyncSession
15
18
  from .task import Task
16
19
 
17
20
 
18
- class WorkflowTask(_WorkflowTaskBase, table=True):
21
+ class WorkflowTask(_WorkflowTaskBase, SQLModel, table=True):
19
22
  """
20
23
  A Task as part of a Workflow
21
24
 
@@ -36,8 +39,7 @@ class WorkflowTask(_WorkflowTaskBase, table=True):
36
39
  meta:
37
40
  Additional parameters useful for execution
38
41
  args:
39
- Additional task arguments, overriding the ones in
40
- `WorkflowTask.task.args`
42
+ Task arguments
41
43
  task:
42
44
  `Task` object associated with the current `WorkflowTask`
43
45
 
@@ -81,15 +83,6 @@ class WorkflowTask(_WorkflowTaskBase, table=True):
81
83
  )
82
84
  return value
83
85
 
84
- @property
85
- def arguments(self):
86
- """
87
- Override default arguments
88
- """
89
- out = self.task.default_args.copy()
90
- out.update(self.args or {})
91
- return out
92
-
93
86
  @property
94
87
  def is_parallel(self) -> bool:
95
88
  return self.task.is_parallel
@@ -108,23 +101,8 @@ class WorkflowTask(_WorkflowTaskBase, table=True):
108
101
  res.update(self.meta or {})
109
102
  return res
110
103
 
111
- def assemble_args(self, extra: dict[str, Any] = None) -> dict:
112
- """
113
- Merge of `extra` arguments and `self.arguments`.
114
104
 
115
- Returns
116
- full_args:
117
- A dictionary consisting of the merge of `extra` and
118
- `self.arguments`.
119
- """
120
- full_args = {}
121
- if extra:
122
- full_args.update(extra)
123
- full_args.update(self.arguments)
124
- return full_args
125
-
126
-
127
- class Workflow(_WorkflowBase, table=True):
105
+ class Workflow(_WorkflowBase, SQLModel, table=True):
128
106
  """
129
107
  Workflow
130
108
 
@@ -172,7 +150,34 @@ class Workflow(_WorkflowBase, table=True):
172
150
  """
173
151
  if order is None:
174
152
  order = len(self.task_list)
175
- wf_task = WorkflowTask(task_id=task_id, args=args, meta=meta)
153
+
154
+ # Get task from db, extract the JSON Schema for its arguments (if any),
155
+ # read default values and set them in default_args
156
+ db_task = await db.get(Task, task_id)
157
+ default_args = {}
158
+ if db_task.args_schema is not None:
159
+ try:
160
+ properties = db_task.args_schema["properties"]
161
+ for prop_name, prop_schema in properties.items():
162
+ default_value = prop_schema.get("default", None)
163
+ if default_value:
164
+ default_args[prop_name] = default_value
165
+ except KeyError as e:
166
+ logging.warning(
167
+ "Cannot set default_args from args_schema="
168
+ f"{json.dumps(db_task.args_schema)}\n"
169
+ f"Original KeyError: {str(e)}"
170
+ )
171
+ # Override default_args with args
172
+ actual_args = default_args.copy()
173
+ if args is not None:
174
+ for k, v in args.items():
175
+ actual_args[k] = v
176
+ if not actual_args:
177
+ actual_args = None
178
+
179
+ # Create DB entry
180
+ wf_task = WorkflowTask(task_id=task_id, args=actual_args, meta=meta)
176
181
  db.add(wf_task)
177
182
  self.task_list.insert(order, wf_task)
178
183
  self.task_list.reorder() # type: ignore
@@ -187,10 +187,10 @@ def call_single_task(
187
187
  Call a single task
188
188
 
189
189
  This assembles the runner arguments (input_paths, output_path, ...) and
190
- task arguments (i.e., arguments that are specific to the task, such as
191
- message or index in the dummy task), writes them to file, call the task
192
- executable command passing the arguments file as an input and assembles
193
- the output.
190
+ wftask arguments (i.e., arguments that are specific to the WorkflowTask,
191
+ such as message or index in the dummy task), writes them to file, call the
192
+ task executable command passing the arguments file as an input and
193
+ assembles the output.
194
194
 
195
195
  **Note**: This function is directly submitted to a
196
196
  `concurrent.futures`-compatible executor, as in
@@ -205,7 +205,7 @@ def call_single_task(
205
205
  Args:
206
206
  wftask:
207
207
  The workflow task to be called. This includes task specific
208
- arguments via the task.task.arguments attribute.
208
+ arguments via the wftask.args attribute.
209
209
  task_pars:
210
210
  The parameters required to run the task which are not specific to
211
211
  the task, e.g., I/O paths.
@@ -238,11 +238,12 @@ def call_single_task(
238
238
  task_order=wftask.order,
239
239
  )
240
240
 
241
- # assemble full args
242
- args_dict = wftask.assemble_args(extra=task_pars.dict())
243
-
244
- # write args file
245
- write_args_file(args_dict, path=task_files.args)
241
+ # write args file (by assembling task_pars and wftask.args)
242
+ write_args_file(
243
+ task_pars.dict(),
244
+ wftask.args or {},
245
+ path=task_files.args,
246
+ )
246
247
 
247
248
  # assemble full command
248
249
  cmd = (
@@ -341,10 +342,10 @@ def call_single_parallel_task(
341
342
  component=component,
342
343
  )
343
344
 
344
- # assemble full args
345
+ # write args file (by assembling task_pars, wftask.args and component)
345
346
  write_args_file(
346
347
  task_pars.dict(),
347
- wftask.arguments,
348
+ wftask.args or {},
348
349
  dict(component=component),
349
350
  path=task_files.args,
350
351
  )
@@ -743,7 +743,8 @@ class FractalSlurmExecutor(SlurmExecutor):
743
743
  "the job started running, the SLURM out/err files "
744
744
  "will be empty.\n"
745
745
  "2. Some error occurred upon writing the file to disk "
746
- "(e.g. due to an overloaded NFS filesystem). "
746
+ "(e.g. because there is not enough space on disk, or "
747
+ "due to an overloaded NFS filesystem). "
747
748
  "Note that the server configuration has "
748
749
  "FRACTAL_SLURM_OUTPUT_FILE_GRACE_TIME="
749
750
  f"{settings.FRACTAL_SLURM_OUTPUT_FILE_GRACE_TIME} "
@@ -1,5 +1,5 @@
1
1
  pydantic
2
- sqlmodel
3
2
  fastapi-users
4
3
  devtools
5
4
  pytest
5
+ typing-extensions
@@ -3,6 +3,7 @@ from .manifest import * # noqa: F403
3
3
  from .project import * # noqa: F403
4
4
  from .state import * # noqa: F403
5
5
  from .task import * # noqa: F403
6
+ from .task_collection import * # noqa: F403
6
7
  from .user import * # noqa: F403
7
8
  from .workflow import * # noqa: F403
8
9
 
@@ -10,6 +11,7 @@ from .workflow import * # noqa: F403
10
11
  __all__ = (
11
12
  project.__all__ # noqa: F405
12
13
  + task.__all__ # noqa: F405
14
+ + task_collection.__all__ # noqa: F405
13
15
  + workflow.__all__ # noqa: F405
14
16
  + applyworkflow.__all__ # noqa: F405
15
17
  + manifest.__all__ # noqa: F405