fractal-server 2.7.0a1__py3-none-any.whl → 2.7.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from fastapi import Depends
6
6
  from fastapi import HTTPException
7
7
  from fastapi import Response
8
8
  from fastapi import status
9
+ from sqlmodel import func
9
10
  from sqlmodel import or_
10
11
  from sqlmodel import select
11
12
 
@@ -36,6 +37,9 @@ logger = set_logger(__name__)
36
37
  async def get_list_task(
37
38
  args_schema_parallel: bool = True,
38
39
  args_schema_non_parallel: bool = True,
40
+ category: Optional[str] = None,
41
+ modality: Optional[str] = None,
42
+ author: Optional[str] = None,
39
43
  user: UserOAuth = Depends(current_active_user),
40
44
  db: AsyncSession = Depends(get_async_db),
41
45
  ) -> list[TaskReadV2]:
@@ -57,6 +61,13 @@ async def get_list_task(
57
61
  )
58
62
  )
59
63
  )
64
+ if category is not None:
65
+ stm = stm.where(func.lower(TaskV2.category) == category.lower())
66
+ if modality is not None:
67
+ stm = stm.where(func.lower(TaskV2.modality) == modality.lower())
68
+ if author is not None:
69
+ stm = stm.where(TaskV2.authors.icontains(author))
70
+
60
71
  res = await db.execute(stm)
61
72
  task_list = res.scalars().all()
62
73
  await db.close()
@@ -216,6 +227,8 @@ async def create_task(
216
227
  user_group_id=user_group_id,
217
228
  active=True,
218
229
  task_list=[db_task],
230
+ origin="other",
231
+ pkg_name=task.name,
219
232
  )
220
233
  db.add(db_task_group)
221
234
  await db.commit()
@@ -186,9 +186,15 @@ async def collect_task_custom(
186
186
  detail="\n".join(overlapping_tasks_v1_source_and_id),
187
187
  )
188
188
 
189
+ # Prepare task-group attributes
190
+ task_group_attrs = dict(
191
+ origin="other",
192
+ pkg_name=task_collect.source, # FIXME
193
+ )
194
+
189
195
  task_group = create_db_task_group_and_tasks(
190
196
  task_list=task_list,
191
- task_group_obj=TaskGroupCreateV2(),
197
+ task_group_obj=TaskGroupCreateV2(**task_group_attrs),
192
198
  user_id=user.id,
193
199
  user_group_id=user_group_id,
194
200
  db=db_sync,
@@ -134,7 +134,7 @@ async def read_workflow(
134
134
 
135
135
  @router.patch(
136
136
  "/project/{project_id}/workflow/{workflow_id}/",
137
- response_model=WorkflowReadV2,
137
+ response_model=WorkflowReadV2WithWarnings,
138
138
  )
139
139
  async def update_workflow(
140
140
  project_id: int,
@@ -142,7 +142,7 @@ async def update_workflow(
142
142
  patch: WorkflowUpdateV2,
143
143
  user: UserOAuth = Depends(current_active_user),
144
144
  db: AsyncSession = Depends(get_async_db),
145
- ) -> Optional[WorkflowReadV2]:
145
+ ) -> Optional[WorkflowReadV2WithWarnings]:
146
146
  """
147
147
  Edit a workflow
148
148
  """
@@ -184,7 +184,26 @@ async def update_workflow(
184
184
  await db.refresh(workflow)
185
185
  await db.close()
186
186
 
187
- return workflow
187
+ workflow_data = dict(
188
+ **workflow.model_dump(),
189
+ project=workflow.project,
190
+ )
191
+ task_list_with_warnings = []
192
+ for wftask in workflow.task_list:
193
+ wftask_data = dict(wftask.model_dump(), task=wftask.task)
194
+ try:
195
+ task_group = await _get_task_group_read_access(
196
+ task_group_id=wftask.task.taskgroupv2_id,
197
+ user_id=user.id,
198
+ db=db,
199
+ )
200
+ if not task_group.active:
201
+ wftask_data["warning"] = "Task is not active."
202
+ except HTTPException:
203
+ wftask_data["warning"] = "Current user has no access to this task."
204
+ task_list_with_warnings.append(wftask_data)
205
+ workflow_data["task_list"] = task_list_with_warnings
206
+ return workflow_data
188
207
 
189
208
 
190
209
  @router.delete(
@@ -7,6 +7,8 @@ from pydantic import HttpUrl
7
7
  from pydantic import root_validator
8
8
  from pydantic import validator
9
9
 
10
+ from .._validators import valstr
11
+
10
12
 
11
13
  class TaskManifestV2(BaseModel):
12
14
  """
@@ -50,6 +52,10 @@ class TaskManifestV2(BaseModel):
50
52
  docs_info: Optional[str] = None
51
53
  docs_link: Optional[HttpUrl] = None
52
54
 
55
+ category: Optional[str] = None
56
+ modality: Optional[str] = None
57
+ tags: list[str] = Field(default_factory=list)
58
+
53
59
  @root_validator
54
60
  def validate_executable_args_meta(cls, values):
55
61
 
@@ -128,7 +134,8 @@ class ManifestV2(BaseModel):
128
134
  manifest_version: str
129
135
  task_list: list[TaskManifestV2]
130
136
  has_args_schemas: bool = False
131
- args_schema_version: Optional[str]
137
+ args_schema_version: Optional[str] = None
138
+ authors: Optional[str] = None
132
139
 
133
140
  @root_validator()
134
141
  def _check_args_schemas_are_present(cls, values):
@@ -157,3 +164,7 @@ class ManifestV2(BaseModel):
157
164
  if value != "2":
158
165
  raise ValueError(f"Wrong manifest version (given {value})")
159
166
  return value
167
+
168
+ _authors = validator("authors", allow_reuse=True)(
169
+ valstr("authors", accept_none=True)
170
+ )
@@ -9,8 +9,9 @@ from pydantic import HttpUrl
9
9
  from pydantic import root_validator
10
10
  from pydantic import validator
11
11
 
12
- from .._validators import valdictkeys
13
- from .._validators import valstr
12
+ from fractal_server.app.schemas._validators import val_unique_list
13
+ from fractal_server.app.schemas._validators import valdictkeys
14
+ from fractal_server.app.schemas._validators import valstr
14
15
  from fractal_server.string_tools import validate_cmd
15
16
 
16
17
 
@@ -18,22 +19,27 @@ class TaskCreateV2(BaseModel, extra=Extra.forbid):
18
19
 
19
20
  name: str
20
21
 
21
- command_non_parallel: Optional[str]
22
- command_parallel: Optional[str]
22
+ command_non_parallel: Optional[str] = None
23
+ command_parallel: Optional[str] = None
23
24
  source: str
24
25
 
25
- meta_non_parallel: Optional[dict[str, Any]]
26
- meta_parallel: Optional[dict[str, Any]]
27
- version: Optional[str]
28
- args_schema_non_parallel: Optional[dict[str, Any]]
29
- args_schema_parallel: Optional[dict[str, Any]]
30
- args_schema_version: Optional[str]
31
- docs_info: Optional[str]
32
- docs_link: Optional[HttpUrl]
26
+ meta_non_parallel: Optional[dict[str, Any]] = None
27
+ meta_parallel: Optional[dict[str, Any]] = None
28
+ version: Optional[str] = None
29
+ args_schema_non_parallel: Optional[dict[str, Any]] = None
30
+ args_schema_parallel: Optional[dict[str, Any]] = None
31
+ args_schema_version: Optional[str] = None
32
+ docs_info: Optional[str] = None
33
+ docs_link: Optional[HttpUrl] = None
33
34
 
34
35
  input_types: dict[str, bool] = Field(default={})
35
36
  output_types: dict[str, bool] = Field(default={})
36
37
 
38
+ category: Optional[str] = None
39
+ modality: Optional[str] = None
40
+ tags: list[str] = Field(default_factory=list)
41
+ authors: Optional[str] = None
42
+
37
43
  # Validators
38
44
  @root_validator
39
45
  def validate_commands(cls, values):
@@ -83,6 +89,22 @@ class TaskCreateV2(BaseModel, extra=Extra.forbid):
83
89
  valdictkeys("output_types")
84
90
  )
85
91
 
92
+ _category = validator("category", allow_reuse=True)(
93
+ valstr("category", accept_none=True)
94
+ )
95
+ _modality = validator("modality", allow_reuse=True)(
96
+ valstr("modality", accept_none=True)
97
+ )
98
+ _authors = validator("authors", allow_reuse=True)(
99
+ valstr("authors", accept_none=True)
100
+ )
101
+
102
+ @validator("tags")
103
+ def validate_list_of_strings(cls, value):
104
+ for i, tag in enumerate(value):
105
+ value[i] = valstr(f"tags[{i}]")(tag)
106
+ return val_unique_list("tags")(value)
107
+
86
108
 
87
109
  class TaskReadV2(BaseModel):
88
110
 
@@ -90,31 +112,41 @@ class TaskReadV2(BaseModel):
90
112
  name: str
91
113
  type: Literal["parallel", "non_parallel", "compound"]
92
114
  source: str
93
- version: Optional[str]
115
+ version: Optional[str] = None
94
116
 
95
- command_non_parallel: Optional[str]
96
- command_parallel: Optional[str]
117
+ command_non_parallel: Optional[str] = None
118
+ command_parallel: Optional[str] = None
97
119
  meta_parallel: dict[str, Any]
98
120
  meta_non_parallel: dict[str, Any]
99
121
  args_schema_non_parallel: Optional[dict[str, Any]] = None
100
122
  args_schema_parallel: Optional[dict[str, Any]] = None
101
- args_schema_version: Optional[str]
102
- docs_info: Optional[str]
103
- docs_link: Optional[HttpUrl]
123
+ args_schema_version: Optional[str] = None
124
+ docs_info: Optional[str] = None
125
+ docs_link: Optional[HttpUrl] = None
104
126
  input_types: dict[str, bool]
105
127
  output_types: dict[str, bool]
106
128
 
107
- taskgroupv2_id: Optional[int]
129
+ taskgroupv2_id: Optional[int] = None
130
+
131
+ category: Optional[str] = None
132
+ modality: Optional[str] = None
133
+ authors: Optional[str] = None
134
+ tags: list[str]
108
135
 
109
136
 
110
137
  class TaskUpdateV2(BaseModel):
111
138
 
112
- name: Optional[str]
113
- version: Optional[str]
114
- command_parallel: Optional[str]
115
- command_non_parallel: Optional[str]
116
- input_types: Optional[dict[str, bool]]
117
- output_types: Optional[dict[str, bool]]
139
+ name: Optional[str] = None
140
+ version: Optional[str] = None
141
+ command_parallel: Optional[str] = None
142
+ command_non_parallel: Optional[str] = None
143
+ input_types: Optional[dict[str, bool]] = None
144
+ output_types: Optional[dict[str, bool]] = None
145
+
146
+ category: Optional[str] = None
147
+ modality: Optional[str] = None
148
+ authors: Optional[str] = None
149
+ tags: Optional[list[str]] = None
118
150
 
119
151
  # Validators
120
152
  @validator("input_types", "output_types")
@@ -140,6 +172,22 @@ class TaskUpdateV2(BaseModel):
140
172
  valdictkeys("output_types")
141
173
  )
142
174
 
175
+ _category = validator("category", allow_reuse=True)(
176
+ valstr("category", accept_none=True)
177
+ )
178
+ _modality = validator("modality", allow_reuse=True)(
179
+ valstr("modality", accept_none=True)
180
+ )
181
+ _authors = validator("authors", allow_reuse=True)(
182
+ valstr("authors", accept_none=True)
183
+ )
184
+
185
+ @validator("tags")
186
+ def validate_tags(cls, value):
187
+ for i, tag in enumerate(value):
188
+ value[i] = valstr(f"tags[{i}]")(tag)
189
+ return val_unique_list("tags")(value)
190
+
143
191
 
144
192
  class TaskImportV2(BaseModel):
145
193
 
@@ -1,23 +1,50 @@
1
+ from datetime import datetime
2
+ from typing import Literal
1
3
  from typing import Optional
2
4
 
3
5
  from pydantic import BaseModel
6
+ from pydantic import validator
4
7
 
5
8
  from .task import TaskReadV2
6
9
 
7
10
 
8
11
  class TaskGroupCreateV2(BaseModel):
9
12
  active: bool = True
13
+ origin: Literal["pypi", "wheel-file", "other"]
14
+ pkg_name: str
15
+ version: Optional[str] = None
16
+ python_version: Optional[str] = None
17
+ path: Optional[str] = None
18
+ venv_path: Optional[str] = None
19
+ pip_extras: Optional[str] = None
10
20
 
11
21
 
12
22
  class TaskGroupReadV2(BaseModel):
13
23
 
14
24
  id: int
25
+ task_list: list[TaskReadV2]
26
+
15
27
  user_id: int
16
28
  user_group_id: Optional[int] = None
29
+
30
+ origin: Literal["pypi", "wheel-file", "other"]
31
+ pkg_name: str
32
+ version: Optional[str] = None
33
+ python_version: Optional[str] = None
34
+ path: Optional[str] = None
35
+ venv_path: Optional[str] = None
36
+ pip_extras: Optional[str] = None
37
+
17
38
  active: bool
18
- task_list: list[TaskReadV2]
39
+ timestamp_created: datetime
19
40
 
20
41
 
21
42
  class TaskGroupUpdateV2(BaseModel):
22
43
  user_group_id: Optional[int] = None
23
44
  active: Optional[bool] = None
45
+
46
+ @validator("active")
47
+ def active_cannot_be_None(cls, value):
48
+ if value is None:
49
+ raise ValueError("`active` cannot be set to None")
50
+ return value
@@ -0,0 +1,274 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from sqlalchemy import select
7
+ from sqlalchemy.orm import Session
8
+
9
+ from fractal_server.app.db import get_sync_db
10
+ from fractal_server.app.models import TaskGroupV2
11
+ from fractal_server.app.models import TaskV2
12
+ from fractal_server.app.models import UserGroup
13
+ from fractal_server.app.models import UserOAuth
14
+ from fractal_server.app.models import UserSettings
15
+ from fractal_server.app.security import FRACTAL_DEFAULT_GROUP_NAME
16
+ from fractal_server.data_migrations.tools import _check_current_version
17
+ from fractal_server.utils import get_timestamp
18
+
19
+ logger = logging.getLogger("fix_db")
20
+
21
+
22
+ def get_unique_value(list_of_objects: list[dict[str, Any]], key: str):
23
+ """
24
+ Loop over `list_of_objects` and extract (unique) value for `key`.
25
+ """
26
+ unique_values = set()
27
+ for this_obj in list_of_objects:
28
+ this_value = this_obj.get(key, None)
29
+ unique_values.add(this_value)
30
+ if len(unique_values) != 1:
31
+ raise RuntimeError(
32
+ f"There must be a single taskgroup `{key}`, "
33
+ f"but {unique_values=}"
34
+ )
35
+ return unique_values.pop()
36
+
37
+
38
+ def get_users_mapping(db) -> dict[str, int]:
39
+ logger.warning("START _check_users")
40
+ print()
41
+
42
+ stm_users = select(UserOAuth).order_by(UserOAuth.id)
43
+ users = db.execute(stm_users).scalars().unique().all()
44
+ name_to_user_id = {}
45
+ for user in users:
46
+ logger.warning(f"START handling user {user.id}: '{user.email}'")
47
+ # Compute "name" attribute
48
+ user_settings = db.get(UserSettings, user.user_settings_id)
49
+ name = user.username or user_settings.slurm_user
50
+ logger.warning(f"{name=}")
51
+ # Fail for missing values
52
+ if name is None:
53
+ raise ValueError(
54
+ f"User with {user.id=} and {user.email=} has no "
55
+ "`username` or `slurm_user` set."
56
+ "Please fix this issue manually."
57
+ )
58
+ # Fail for non-unique values
59
+ existing_user = name_to_user_id.get(name, None)
60
+ if existing_user is not None:
61
+ raise ValueError(
62
+ f"User with {user.id=} and {user.email=} has same "
63
+ f"`(username or slurm_user)={name}` as another user. "
64
+ "Please fix this issue manually."
65
+ )
66
+ # Update dictionary
67
+ name_to_user_id[name] = user.id
68
+ logger.warning(f"END handling user {user.id}: '{user.email}'")
69
+ print()
70
+ logger.warning("END _check_users")
71
+ print()
72
+ return name_to_user_id
73
+
74
+
75
+ def get_default_user_group_id(db):
76
+ stm = select(UserGroup.id).where(
77
+ UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
78
+ )
79
+ res = db.execute(stm)
80
+ default_group_id = res.scalars().one_or_none()
81
+ if default_group_id is None:
82
+ raise RuntimeError("Default user group is missing.")
83
+ else:
84
+ return default_group_id
85
+
86
+
87
+ def get_default_user_id(db):
88
+
89
+ DEFAULT_USER_EMAIL = os.getenv("FRACTAL_V27_DEFAULT_USER_EMAIL")
90
+ if DEFAULT_USER_EMAIL is None:
91
+ raise ValueError(
92
+ "FRACTAL_V27_DEFAULT_USER_EMAIL env variable is not set. "
93
+ "Please set it to be the email of the user who will own "
94
+ "all previously-global tasks."
95
+ )
96
+
97
+ stm = select(UserOAuth.id).where(UserOAuth.email == DEFAULT_USER_EMAIL)
98
+ res = db.execute(stm)
99
+ default_user_id = res.scalars().one_or_none()
100
+ if default_user_id is None:
101
+ raise RuntimeError(
102
+ f"Default user with email {DEFAULT_USER_EMAIL} is missing."
103
+ )
104
+ else:
105
+ return default_user_id
106
+
107
+
108
+ def prepare_task_groups(
109
+ *,
110
+ user_mapping: dict[str, int],
111
+ default_user_group_id: int,
112
+ default_user_id: int,
113
+ dry_run: bool,
114
+ db: Session,
115
+ ):
116
+ stm_tasks = select(TaskV2).order_by(TaskV2.id)
117
+ res = db.execute(stm_tasks).scalars().all()
118
+ task_groups = {}
119
+ for task in res:
120
+ if (
121
+ task.source.startswith(("pip_remote", "pip_local"))
122
+ and task.source.count(":") == 5
123
+ ):
124
+ source_fields = task.source.split(":")
125
+ (
126
+ collection_mode,
127
+ pkg_name,
128
+ version,
129
+ extras,
130
+ python_version,
131
+ name,
132
+ ) = source_fields
133
+ task_group_key = ":".join(
134
+ [pkg_name, version, extras, python_version]
135
+ )
136
+ if collection_mode == "pip_remote":
137
+ origin = "pypi"
138
+ elif collection_mode == "pip_local":
139
+ origin = "wheel-file"
140
+ else:
141
+ raise RuntimeError(
142
+ f"Invalid {collection_mode=} for {task.source=}."
143
+ )
144
+ new_obj = dict(
145
+ task=task,
146
+ user_id=default_user_id,
147
+ origin=origin,
148
+ pkg_name=pkg_name,
149
+ version=version,
150
+ pip_extras=extras,
151
+ python_version=python_version,
152
+ )
153
+
154
+ if task_group_key in task_groups:
155
+ task_groups[task_group_key].append(new_obj)
156
+ else:
157
+ task_groups[task_group_key] = [new_obj]
158
+ else:
159
+ owner = task.owner
160
+ if owner is None:
161
+ raise RuntimeError(
162
+ "Error: `owner` is `None` for "
163
+ f"{task.id=}, {task.source=}, {task.owner=}."
164
+ )
165
+ user_id = user_mapping.get(owner, None)
166
+ if user_id is None:
167
+ raise RuntimeError(
168
+ "Error: `user_id` is `None` for "
169
+ f"{task.id=}, {task.source=}, {task.owner=}"
170
+ )
171
+ task_group_key = "-".join(
172
+ [
173
+ "NOT_PIP",
174
+ str(task.id),
175
+ str(task.version),
176
+ task.source,
177
+ str(task.owner),
178
+ ]
179
+ )
180
+ if task_group_key in task_groups:
181
+ raise RuntimeError(
182
+ f"ERROR: Duplicated {task_group_key=} for "
183
+ f"{task.id=}, {task.source=}, {task.owner=}"
184
+ )
185
+ else:
186
+ task_groups[task_group_key] = [
187
+ dict(
188
+ task=task,
189
+ user_id=user_id,
190
+ origin="other",
191
+ pkg_name=task.source,
192
+ version=task.version,
193
+ )
194
+ ]
195
+
196
+ for task_group_key, task_group_objects in task_groups.items():
197
+ print("-" * 80)
198
+ print(f"Start handling task group with key '{task_group_key}")
199
+ task_group_task_list = [item["task"] for item in task_group_objects]
200
+ print("List of tasks to be included")
201
+ for task in task_group_task_list:
202
+ print(f" {task.id=}, {task.source=}")
203
+
204
+ task_group_attributes = dict(
205
+ pkg_name=get_unique_value(task_group_objects, "pkg_name"),
206
+ version=get_unique_value(task_group_objects, "version"),
207
+ origin=get_unique_value(task_group_objects, "origin"),
208
+ user_id=get_unique_value(task_group_objects, "user_id"),
209
+ user_group_id=default_user_group_id,
210
+ python_version=get_unique_value(
211
+ task_group_objects, "python_version"
212
+ ),
213
+ pip_extras=get_unique_value(task_group_objects, "pip_extras"),
214
+ task_list=task_group_task_list,
215
+ active=True,
216
+ timestamp_created=get_timestamp(),
217
+ )
218
+
219
+ if not task_group_key.startswith("NOT_PIP"):
220
+ cmd = next(
221
+ getattr(task_group_task_list[0], attr_name)
222
+ for attr_name in ["command_non_parallel", "command_parallel"]
223
+ if getattr(task_group_task_list[0], attr_name) is not None
224
+ )
225
+ python_bin = cmd.split()[0]
226
+ venv_path = Path(python_bin).parents[1].as_posix()
227
+ path = Path(python_bin).parents[2].as_posix()
228
+ task_group_attributes["venv_path"] = venv_path
229
+ task_group_attributes["path"] = path
230
+
231
+ print()
232
+ print("List of task-group attributes")
233
+ for key, value in task_group_attributes.items():
234
+ if key != "task_list":
235
+ print(f" {key}: {value}")
236
+
237
+ print()
238
+
239
+ if dry_run:
240
+ print(
241
+ "End dry-run of handling task group with key "
242
+ f"'{task_group_key}"
243
+ )
244
+ print("-" * 80)
245
+ continue
246
+
247
+ task_group = TaskGroupV2(**task_group_attributes)
248
+ db.add(task_group)
249
+ db.commit()
250
+ db.refresh(task_group)
251
+ logger.warning(f"Created task group {task_group.id=}")
252
+ print()
253
+
254
+ return
255
+
256
+
257
+ def fix_db(dry_run: bool = False):
258
+ logger.warning("START execution of fix_db function")
259
+ _check_current_version("2.7.0")
260
+
261
+ with next(get_sync_db()) as db:
262
+ user_mapping = get_users_mapping(db)
263
+ default_user_id = get_default_user_id(db)
264
+ default_user_group_id = get_default_user_group_id(db)
265
+
266
+ prepare_task_groups(
267
+ user_mapping=user_mapping,
268
+ default_user_id=default_user_id,
269
+ default_user_group_id=default_user_group_id,
270
+ db=db,
271
+ dry_run=dry_run,
272
+ )
273
+
274
+ logger.warning("END of execution of fix_db function")