aind-data-transfer-service 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aind-data-transfer-service might be problematic. Click here for more details.
- aind_data_transfer_service/__init__.py +9 -0
- aind_data_transfer_service/configs/__init__.py +1 -0
- aind_data_transfer_service/configs/csv_handler.py +59 -0
- aind_data_transfer_service/configs/job_configs.py +545 -0
- aind_data_transfer_service/configs/job_upload_template.py +153 -0
- aind_data_transfer_service/hpc/__init__.py +1 -0
- aind_data_transfer_service/hpc/client.py +151 -0
- aind_data_transfer_service/hpc/models.py +492 -0
- aind_data_transfer_service/log_handler.py +58 -0
- aind_data_transfer_service/models/__init__.py +1 -0
- aind_data_transfer_service/models/core.py +300 -0
- aind_data_transfer_service/models/internal.py +277 -0
- aind_data_transfer_service/server.py +1125 -0
- aind_data_transfer_service/templates/index.html +245 -0
- aind_data_transfer_service/templates/job_params.html +194 -0
- aind_data_transfer_service/templates/job_status.html +323 -0
- aind_data_transfer_service/templates/job_tasks_table.html +146 -0
- aind_data_transfer_service/templates/task_logs.html +31 -0
- aind_data_transfer_service-1.12.0.dist-info/METADATA +49 -0
- aind_data_transfer_service-1.12.0.dist-info/RECORD +23 -0
- aind_data_transfer_service-1.12.0.dist-info/WHEEL +5 -0
- aind_data_transfer_service-1.12.0.dist-info/licenses/LICENSE +21 -0
- aind_data_transfer_service-1.12.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Init package"""
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
__version__ = "1.12.0"
|
|
5
|
+
|
|
6
|
+
# Global constants
|
|
7
|
+
OPEN_DATA_BUCKET_NAME = os.getenv("OPEN_DATA_BUCKET_NAME", "open")
|
|
8
|
+
PRIVATE_BUCKET_NAME = os.getenv("PRIVATE_BUCKET_NAME", "private")
|
|
9
|
+
SCRATCH_BUCKET_NAME = os.getenv("SCRATCH_BUCKET_NAME", "scratch")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Package to app configurations"""
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Module to handle processing legacy csv files"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from aind_data_transfer_models.core import (
|
|
6
|
+
BasicUploadJobConfigs,
|
|
7
|
+
CodeOceanPipelineMonitorConfigs,
|
|
8
|
+
ModalityConfigs,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def map_csv_row_to_job(row: dict) -> BasicUploadJobConfigs:
|
|
13
|
+
"""
|
|
14
|
+
Maps csv row into a BasicUploadJobConfigs model
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
row : dict
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
BasicUploadJobConfigs
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
modality_configs = dict()
|
|
25
|
+
basic_job_configs = dict()
|
|
26
|
+
for key, value in row.items():
|
|
27
|
+
# Strip white spaces and replace dashes with underscores
|
|
28
|
+
clean_key = str(key).strip(" ").replace("-", "_")
|
|
29
|
+
clean_val = str(value).strip(" ")
|
|
30
|
+
# Replace empty strings with None.
|
|
31
|
+
clean_val = None if clean_val is None or clean_val == "" else clean_val
|
|
32
|
+
if clean_key.startswith("modality"):
|
|
33
|
+
modality_parts = clean_key.split(".")
|
|
34
|
+
if len(modality_parts) == 1:
|
|
35
|
+
modality_key = modality_parts[0]
|
|
36
|
+
sub_key = "modality"
|
|
37
|
+
else:
|
|
38
|
+
modality_key = modality_parts[0]
|
|
39
|
+
sub_key = modality_parts[1]
|
|
40
|
+
if (
|
|
41
|
+
modality_configs.get(modality_key) is None
|
|
42
|
+
and clean_val is not None
|
|
43
|
+
):
|
|
44
|
+
modality_configs[modality_key] = {sub_key: clean_val}
|
|
45
|
+
elif clean_val is not None:
|
|
46
|
+
modality_configs[modality_key].update({sub_key: clean_val})
|
|
47
|
+
elif clean_key == "job_type":
|
|
48
|
+
if clean_val is not None:
|
|
49
|
+
codeocean_configs = json.loads(
|
|
50
|
+
CodeOceanPipelineMonitorConfigs().model_dump_json()
|
|
51
|
+
)
|
|
52
|
+
codeocean_configs["job_type"] = clean_val
|
|
53
|
+
basic_job_configs["codeocean_configs"] = codeocean_configs
|
|
54
|
+
else:
|
|
55
|
+
basic_job_configs[clean_key] = clean_val
|
|
56
|
+
modalities = []
|
|
57
|
+
for modality_value in modality_configs.values():
|
|
58
|
+
modalities.append(ModalityConfigs(**modality_value))
|
|
59
|
+
return BasicUploadJobConfigs(modalities=modalities, **basic_job_configs)
|
|
@@ -0,0 +1,545 @@
|
|
|
1
|
+
"""This module adds classes to handle resolving common endpoints used in the
|
|
2
|
+
data transfer jobs."""
|
|
3
|
+
import re
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import PurePosixPath
|
|
6
|
+
from typing import Any, ClassVar, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
from aind_data_schema.core.data_description import build_data_name
|
|
9
|
+
from aind_data_schema_models.modalities import Modality
|
|
10
|
+
from aind_data_schema_models.platforms import Platform
|
|
11
|
+
from aind_data_schema_models.process_names import ProcessName
|
|
12
|
+
from pydantic import (
|
|
13
|
+
ConfigDict,
|
|
14
|
+
Field,
|
|
15
|
+
PrivateAttr,
|
|
16
|
+
SecretStr,
|
|
17
|
+
ValidationInfo,
|
|
18
|
+
field_validator,
|
|
19
|
+
)
|
|
20
|
+
from pydantic_settings import BaseSettings
|
|
21
|
+
|
|
22
|
+
from aind_data_transfer_service import (
|
|
23
|
+
OPEN_DATA_BUCKET_NAME,
|
|
24
|
+
PRIVATE_BUCKET_NAME,
|
|
25
|
+
SCRATCH_BUCKET_NAME,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ModalityConfigs(BaseSettings):
|
|
30
|
+
"""Class to contain configs for each modality type"""
|
|
31
|
+
|
|
32
|
+
# Need some way to extract abbreviations. Maybe a public method can be
|
|
33
|
+
# added to the Modality class
|
|
34
|
+
_MODALITY_MAP: ClassVar = {
|
|
35
|
+
m().abbreviation.upper().replace("-", "_"): m().abbreviation
|
|
36
|
+
for m in Modality.ALL
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Optional number id to assign to modality config
|
|
40
|
+
_number_id: Optional[int] = PrivateAttr(default=None)
|
|
41
|
+
modality: Modality.ONE_OF = Field(
|
|
42
|
+
..., description="Data collection modality", title="Modality"
|
|
43
|
+
)
|
|
44
|
+
source: PurePosixPath = Field(
|
|
45
|
+
...,
|
|
46
|
+
description="Location of raw data to be uploaded",
|
|
47
|
+
title="Data Source",
|
|
48
|
+
)
|
|
49
|
+
compress_raw_data: Optional[bool] = Field(
|
|
50
|
+
default=None,
|
|
51
|
+
description="Run compression on data",
|
|
52
|
+
title="Compress Raw Data",
|
|
53
|
+
validate_default=True,
|
|
54
|
+
)
|
|
55
|
+
extra_configs: Optional[PurePosixPath] = Field(
|
|
56
|
+
default=None,
|
|
57
|
+
description="Location of additional configuration file",
|
|
58
|
+
title="Extra Configs",
|
|
59
|
+
)
|
|
60
|
+
skip_staging: bool = Field(
|
|
61
|
+
default=False,
|
|
62
|
+
description="Upload uncompressed directly without staging",
|
|
63
|
+
title="Skip Staging",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def number_id(self):
|
|
68
|
+
"""Retrieve an optionally assigned numerical id"""
|
|
69
|
+
return self._number_id
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def default_output_folder_name(self):
|
|
73
|
+
"""Construct the default folder name for the modality."""
|
|
74
|
+
if self._number_id is None:
|
|
75
|
+
return self.modality.abbreviation
|
|
76
|
+
else:
|
|
77
|
+
return self.modality.abbreviation + str(self._number_id)
|
|
78
|
+
|
|
79
|
+
@field_validator("modality", mode="before")
|
|
80
|
+
def parse_modality_string(
|
|
81
|
+
cls, input_modality: Union[str, dict, Modality]
|
|
82
|
+
) -> Union[dict, Modality]:
|
|
83
|
+
"""Attempts to convert strings to a Modality model. Raises an error
|
|
84
|
+
if unable to do so."""
|
|
85
|
+
if isinstance(input_modality, str):
|
|
86
|
+
modality_abbreviation = cls._MODALITY_MAP.get(
|
|
87
|
+
input_modality.upper().replace("-", "_")
|
|
88
|
+
)
|
|
89
|
+
if modality_abbreviation is None:
|
|
90
|
+
raise AttributeError(f"Unknown Modality: {input_modality}")
|
|
91
|
+
return Modality.from_abbreviation(modality_abbreviation)
|
|
92
|
+
else:
|
|
93
|
+
return input_modality
|
|
94
|
+
|
|
95
|
+
@field_validator("compress_raw_data", mode="after")
|
|
96
|
+
def get_compress_source_default(
|
|
97
|
+
cls, compress_source: Optional[bool], info: ValidationInfo
|
|
98
|
+
) -> bool:
|
|
99
|
+
"""Set compress source default to True for ecephys data."""
|
|
100
|
+
if (
|
|
101
|
+
compress_source is None
|
|
102
|
+
and info.data.get("modality") == Modality.ECEPHYS
|
|
103
|
+
):
|
|
104
|
+
return True
|
|
105
|
+
elif compress_source is not None:
|
|
106
|
+
return compress_source
|
|
107
|
+
else:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class BasicUploadJobConfigs(BaseSettings):
|
|
112
|
+
"""Configuration for the basic upload job"""
|
|
113
|
+
|
|
114
|
+
# Allow users to pass in extra fields
|
|
115
|
+
model_config = ConfigDict(
|
|
116
|
+
extra="allow",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Legacy way required users to input platform in screaming snake case
|
|
120
|
+
_PLATFORM_MAP: ClassVar = {
|
|
121
|
+
a.upper().replace("-", "_"): a
|
|
122
|
+
for a in Platform.abbreviation_map.keys()
|
|
123
|
+
}
|
|
124
|
+
_MODALITY_ENTRY_PATTERN: ClassVar = re.compile(r"^modality(\d*)$")
|
|
125
|
+
_DATETIME_PATTERN1: ClassVar = re.compile(
|
|
126
|
+
r"^\d{4}-\d{2}-\d{2}[ |T]\d{2}:\d{2}:\d{2}$"
|
|
127
|
+
)
|
|
128
|
+
_DATETIME_PATTERN2: ClassVar = re.compile(
|
|
129
|
+
r"^\d{1,2}/\d{1,2}/\d{4} \d{1,2}:\d{2}:\d{2} [APap][Mm]$"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
aws_param_store_name: Optional[str] = Field(None)
|
|
133
|
+
|
|
134
|
+
project_name: str = Field(
|
|
135
|
+
..., description="Name of project", title="Project Name"
|
|
136
|
+
)
|
|
137
|
+
process_capsule_id: Optional[str] = Field(
|
|
138
|
+
None,
|
|
139
|
+
description="Use custom codeocean capsule or pipeline id",
|
|
140
|
+
title="Process Capsule ID",
|
|
141
|
+
)
|
|
142
|
+
s3_bucket: Optional[str] = Field(
|
|
143
|
+
None,
|
|
144
|
+
description="Bucket where data will be uploaded",
|
|
145
|
+
title="S3 Bucket",
|
|
146
|
+
validate_default=True,
|
|
147
|
+
)
|
|
148
|
+
platform: Platform.ONE_OF = Field(
|
|
149
|
+
..., description="Platform", title="Platform"
|
|
150
|
+
)
|
|
151
|
+
modalities: List[ModalityConfigs] = Field(
|
|
152
|
+
...,
|
|
153
|
+
description="Data collection modalities and their directory location",
|
|
154
|
+
title="Modalities",
|
|
155
|
+
min_items=1,
|
|
156
|
+
)
|
|
157
|
+
subject_id: str = Field(..., description="Subject ID", title="Subject ID")
|
|
158
|
+
acq_datetime: datetime = Field(
|
|
159
|
+
...,
|
|
160
|
+
description="Datetime data was acquired",
|
|
161
|
+
title="Acquisition Datetime",
|
|
162
|
+
)
|
|
163
|
+
process_name: ProcessName = Field(
|
|
164
|
+
default=ProcessName.OTHER,
|
|
165
|
+
description="Type of processing performed on the raw data source.",
|
|
166
|
+
title="Process Name",
|
|
167
|
+
)
|
|
168
|
+
metadata_dir: Optional[PurePosixPath] = Field(
|
|
169
|
+
default=None,
|
|
170
|
+
description="Directory of metadata",
|
|
171
|
+
title="Metadata Directory",
|
|
172
|
+
)
|
|
173
|
+
# Deprecated. Will be removed in future versions.
|
|
174
|
+
behavior_dir: Optional[PurePosixPath] = Field(
|
|
175
|
+
default=None,
|
|
176
|
+
description=(
|
|
177
|
+
"Directory of behavior data. This field is deprecated and will be "
|
|
178
|
+
"removed in future versions. Instead, this will be included in "
|
|
179
|
+
"the modalities list."
|
|
180
|
+
),
|
|
181
|
+
title="Behavior Directory",
|
|
182
|
+
)
|
|
183
|
+
log_level: str = Field(
|
|
184
|
+
default="WARNING",
|
|
185
|
+
description="Logging level. Default is WARNING.",
|
|
186
|
+
title="Log Level",
|
|
187
|
+
)
|
|
188
|
+
metadata_dir_force: bool = Field(
|
|
189
|
+
default=False,
|
|
190
|
+
description=(
|
|
191
|
+
"Whether to override metadata from service with metadata in "
|
|
192
|
+
"optional metadata directory"
|
|
193
|
+
),
|
|
194
|
+
title="Metadata Directory Force",
|
|
195
|
+
)
|
|
196
|
+
dry_run: bool = Field(
|
|
197
|
+
default=False,
|
|
198
|
+
description="Perform a dry-run of data upload",
|
|
199
|
+
title="Dry Run",
|
|
200
|
+
)
|
|
201
|
+
force_cloud_sync: bool = Field(
|
|
202
|
+
default=False,
|
|
203
|
+
description=(
|
|
204
|
+
"Force syncing of data folder even if location exists in cloud"
|
|
205
|
+
),
|
|
206
|
+
title="Force Cloud Sync",
|
|
207
|
+
)
|
|
208
|
+
temp_directory: Optional[PurePosixPath] = Field(
|
|
209
|
+
default=None,
|
|
210
|
+
description=(
|
|
211
|
+
"As default, the file systems temporary directory will be used as "
|
|
212
|
+
"an intermediate location to store the compressed data before "
|
|
213
|
+
"being uploaded to s3"
|
|
214
|
+
),
|
|
215
|
+
title="Temp directory",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def s3_prefix(self):
|
|
220
|
+
"""Construct s3_prefix from configs."""
|
|
221
|
+
return build_data_name(
|
|
222
|
+
label=f"{self.platform.abbreviation}_{self.subject_id}",
|
|
223
|
+
creation_datetime=self.acq_datetime,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
@field_validator("s3_bucket", mode="before")
|
|
227
|
+
def map_bucket(cls, bucket: Optional[str]) -> Optional[str]:
|
|
228
|
+
"""We're adding a policy that data uploaded through the service can
|
|
229
|
+
only land in a handful of buckets. As default, things will be
|
|
230
|
+
stored in the private bucket"""
|
|
231
|
+
if bucket is not None and bucket in [
|
|
232
|
+
OPEN_DATA_BUCKET_NAME,
|
|
233
|
+
SCRATCH_BUCKET_NAME,
|
|
234
|
+
]:
|
|
235
|
+
return bucket
|
|
236
|
+
else:
|
|
237
|
+
return PRIVATE_BUCKET_NAME
|
|
238
|
+
|
|
239
|
+
@field_validator("platform", mode="before")
|
|
240
|
+
def parse_platform_string(
|
|
241
|
+
cls, input_platform: Union[str, dict, Platform]
|
|
242
|
+
) -> Union[dict, Platform]:
|
|
243
|
+
"""Attempts to convert strings to a Platform model. Raises an error
|
|
244
|
+
if unable to do so."""
|
|
245
|
+
if isinstance(input_platform, str):
|
|
246
|
+
platform_abbreviation = cls._PLATFORM_MAP.get(
|
|
247
|
+
input_platform.upper()
|
|
248
|
+
)
|
|
249
|
+
if platform_abbreviation is None:
|
|
250
|
+
raise AttributeError(f"Unknown Platform: {input_platform}")
|
|
251
|
+
else:
|
|
252
|
+
return Platform.from_abbreviation(platform_abbreviation)
|
|
253
|
+
else:
|
|
254
|
+
return input_platform
|
|
255
|
+
|
|
256
|
+
@field_validator("acq_datetime", mode="before")
|
|
257
|
+
def _parse_datetime(cls, datetime_val: Any) -> datetime:
|
|
258
|
+
"""Parses datetime string to %YYYY-%MM-%DD HH:mm:ss"""
|
|
259
|
+
is_str = isinstance(datetime_val, str)
|
|
260
|
+
if is_str and re.match(
|
|
261
|
+
BasicUploadJobConfigs._DATETIME_PATTERN1, datetime_val
|
|
262
|
+
):
|
|
263
|
+
return datetime.fromisoformat(datetime_val)
|
|
264
|
+
elif is_str and re.match(
|
|
265
|
+
BasicUploadJobConfigs._DATETIME_PATTERN2, datetime_val
|
|
266
|
+
):
|
|
267
|
+
return datetime.strptime(datetime_val, "%m/%d/%Y %I:%M:%S %p")
|
|
268
|
+
elif is_str:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"Incorrect datetime format, should be"
|
|
271
|
+
" YYYY-MM-DD HH:mm:ss or MM/DD/YYYY I:MM:SS P"
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
return datetime_val
|
|
275
|
+
|
|
276
|
+
@field_validator("modalities", mode="after")
|
|
277
|
+
def update_number_ids(
|
|
278
|
+
cls, modality_list: List[ModalityConfigs]
|
|
279
|
+
) -> List[ModalityConfigs]:
|
|
280
|
+
"""
|
|
281
|
+
Loops through the modality list and assigns a number id
|
|
282
|
+
to duplicate modalities. For example, if a user inputs
|
|
283
|
+
multiple behavior modalities, then it will upload them
|
|
284
|
+
as behavior, behavior1, behavior2, etc. folders.
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
modality_list : List[ModalityConfigs]
|
|
288
|
+
|
|
289
|
+
Returns
|
|
290
|
+
-------
|
|
291
|
+
List[ModalityConfigs]
|
|
292
|
+
Updates the _number_id field in the ModalityConfigs
|
|
293
|
+
|
|
294
|
+
"""
|
|
295
|
+
modality_counts = {}
|
|
296
|
+
updated_list = []
|
|
297
|
+
for modality in modality_list:
|
|
298
|
+
modality_abbreviation = modality.modality.abbreviation
|
|
299
|
+
if modality_counts.get(modality_abbreviation) is None:
|
|
300
|
+
modality_counts[modality_abbreviation] = 1
|
|
301
|
+
else:
|
|
302
|
+
modality_count_num = modality_counts[modality_abbreviation]
|
|
303
|
+
modality._number_id = modality_count_num
|
|
304
|
+
modality_counts[modality_abbreviation] += 1
|
|
305
|
+
updated_list.append(modality)
|
|
306
|
+
return updated_list
|
|
307
|
+
|
|
308
|
+
@staticmethod
|
|
309
|
+
def _clean_csv_entry(csv_key: str, csv_value: Optional[str]) -> Any:
|
|
310
|
+
"""Tries to set the default value for optional settings if the csv
|
|
311
|
+
entry is blank."""
|
|
312
|
+
if (
|
|
313
|
+
csv_value is None or csv_value == "" or csv_value == " "
|
|
314
|
+
) and BasicUploadJobConfigs.model_fields.get(csv_key) is not None:
|
|
315
|
+
clean_val = BasicUploadJobConfigs.model_fields[csv_key].default
|
|
316
|
+
else:
|
|
317
|
+
clean_val = csv_value.strip()
|
|
318
|
+
return clean_val
|
|
319
|
+
|
|
320
|
+
@staticmethod
|
|
321
|
+
def _map_row_and_key_to_modality_config(
|
|
322
|
+
modality_key: str,
|
|
323
|
+
cleaned_row: Dict[str, Any],
|
|
324
|
+
modality_counts: Dict[str, Optional[int]],
|
|
325
|
+
) -> Optional[ModalityConfigs]:
|
|
326
|
+
"""
|
|
327
|
+
Maps a cleaned csv row and a key for a modality to process into an
|
|
328
|
+
ModalityConfigs object.
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
modality_key : str
|
|
332
|
+
The column header like modality0, or modality1, etc.
|
|
333
|
+
cleaned_row : Dict[str, Any]
|
|
334
|
+
The csv row that's been cleaned.
|
|
335
|
+
modality_counts : Dict[str, Optional[int]]
|
|
336
|
+
If more than one type of modality is present in the csv row, then
|
|
337
|
+
they will be assigned numerical ids. This will allow multiple of the
|
|
338
|
+
same modalities to be stored under folders like ecephys0, etc.
|
|
339
|
+
|
|
340
|
+
Returns
|
|
341
|
+
-------
|
|
342
|
+
Optional[ModalityConfigs]
|
|
343
|
+
None if unable to parse csv row properly.
|
|
344
|
+
|
|
345
|
+
"""
|
|
346
|
+
modality: str = cleaned_row[modality_key]
|
|
347
|
+
source = cleaned_row.get(f"{modality_key}.source")
|
|
348
|
+
extra_configs = cleaned_row.get(f"{modality_key}.extra_configs")
|
|
349
|
+
|
|
350
|
+
if modality is None or modality.strip() == "":
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
modality_configs = ModalityConfigs(
|
|
354
|
+
modality=modality, source=source, extra_configs=extra_configs
|
|
355
|
+
)
|
|
356
|
+
num_id = modality_counts.get(modality)
|
|
357
|
+
modality_configs._number_id = num_id
|
|
358
|
+
if num_id is None:
|
|
359
|
+
modality_counts[modality] = 1
|
|
360
|
+
else:
|
|
361
|
+
modality_counts[modality] = num_id + 1
|
|
362
|
+
return modality_configs
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def _parse_modality_configs_from_row(cls, cleaned_row: dict) -> None:
|
|
366
|
+
"""
|
|
367
|
+
Parses csv row into a list of ModalityConfigs. Will then process the
|
|
368
|
+
cleaned_row dictionary by removing the old modality keys and replacing
|
|
369
|
+
them with just modalities: List[ModalityConfigs.]
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
cleaned_row : dict
|
|
373
|
+
csv row that contains keys like modality0, modality0.source,
|
|
374
|
+
modality1, modality1.source, etc.
|
|
375
|
+
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
None
|
|
379
|
+
Modifies cleaned_row dict in-place
|
|
380
|
+
|
|
381
|
+
"""
|
|
382
|
+
modalities = []
|
|
383
|
+
modality_keys = [
|
|
384
|
+
m
|
|
385
|
+
for m in cleaned_row.keys()
|
|
386
|
+
if cls._MODALITY_ENTRY_PATTERN.match(m)
|
|
387
|
+
]
|
|
388
|
+
modality_counts: Dict[str, Optional[int]] = dict()
|
|
389
|
+
# Check uniqueness of keys
|
|
390
|
+
if len(modality_keys) != len(set(modality_keys)):
|
|
391
|
+
raise KeyError(
|
|
392
|
+
f"Modality keys need to be unique in csv "
|
|
393
|
+
f"header: {modality_keys}"
|
|
394
|
+
)
|
|
395
|
+
for modality_key in modality_keys:
|
|
396
|
+
modality_configs = cls._map_row_and_key_to_modality_config(
|
|
397
|
+
modality_key=modality_key,
|
|
398
|
+
cleaned_row=cleaned_row,
|
|
399
|
+
modality_counts=modality_counts,
|
|
400
|
+
)
|
|
401
|
+
if modality_configs is not None:
|
|
402
|
+
modalities.append(modality_configs)
|
|
403
|
+
|
|
404
|
+
# Del old modality keys and replace them with list of modality_configs
|
|
405
|
+
for row_key in [
|
|
406
|
+
m for m in cleaned_row.keys() if m.startswith("modality")
|
|
407
|
+
]:
|
|
408
|
+
del cleaned_row[row_key]
|
|
409
|
+
cleaned_row["modalities"] = modalities
|
|
410
|
+
|
|
411
|
+
@classmethod
|
|
412
|
+
def from_csv_row(
|
|
413
|
+
cls,
|
|
414
|
+
row: dict,
|
|
415
|
+
aws_param_store_name: Optional[str] = None,
|
|
416
|
+
temp_directory: Optional[str] = None,
|
|
417
|
+
):
|
|
418
|
+
"""
|
|
419
|
+
Creates a job config object from a csv row.
|
|
420
|
+
"""
|
|
421
|
+
cleaned_row = {
|
|
422
|
+
k.strip().replace("-", "_"): cls._clean_csv_entry(
|
|
423
|
+
k.strip().replace("-", "_"), v
|
|
424
|
+
)
|
|
425
|
+
for k, v in row.items()
|
|
426
|
+
}
|
|
427
|
+
cls._parse_modality_configs_from_row(cleaned_row=cleaned_row)
|
|
428
|
+
return cls(
|
|
429
|
+
**cleaned_row,
|
|
430
|
+
aws_param_store_name=aws_param_store_name,
|
|
431
|
+
temp_directory=temp_directory,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
# Deprecating this class
|
|
436
|
+
class HpcJobConfigs(BaseSettings):
|
|
437
|
+
"""Class to contain settings for hpc resources"""
|
|
438
|
+
|
|
439
|
+
hpc_nodes: int = Field(default=1, description="Number of tasks")
|
|
440
|
+
hpc_time_limit: int = Field(default=360, description="Timeout in minutes")
|
|
441
|
+
hpc_node_memory: int = Field(
|
|
442
|
+
default=50, description="Memory requested in GB"
|
|
443
|
+
)
|
|
444
|
+
hpc_partition: str
|
|
445
|
+
hpc_current_working_directory: PurePosixPath
|
|
446
|
+
hpc_logging_directory: PurePosixPath
|
|
447
|
+
hpc_aws_secret_access_key: SecretStr
|
|
448
|
+
hpc_aws_access_key_id: str
|
|
449
|
+
hpc_aws_default_region: str
|
|
450
|
+
hpc_aws_session_token: Optional[str] = Field(default=None)
|
|
451
|
+
hpc_sif_location: PurePosixPath = Field(...)
|
|
452
|
+
hpc_alt_exec_command: Optional[str] = Field(
|
|
453
|
+
default=None,
|
|
454
|
+
description=(
|
|
455
|
+
"Set this value to run a different execution command then the "
|
|
456
|
+
"default one built."
|
|
457
|
+
),
|
|
458
|
+
)
|
|
459
|
+
basic_upload_job_configs: BasicUploadJobConfigs
|
|
460
|
+
|
|
461
|
+
def _json_args_str(self) -> str:
|
|
462
|
+
"""Serialize job configs to json"""
|
|
463
|
+
return self.basic_upload_job_configs.model_dump_json()
|
|
464
|
+
|
|
465
|
+
def _script_command_str(self) -> str:
|
|
466
|
+
"""This is the command that will be sent to the hpc"""
|
|
467
|
+
command_str = [
|
|
468
|
+
"#!/bin/bash",
|
|
469
|
+
"\nsingularity",
|
|
470
|
+
"exec",
|
|
471
|
+
"--cleanenv",
|
|
472
|
+
str(self.hpc_sif_location),
|
|
473
|
+
"python",
|
|
474
|
+
"-m",
|
|
475
|
+
"aind_data_transfer.jobs.basic_job",
|
|
476
|
+
"--json-args",
|
|
477
|
+
"'",
|
|
478
|
+
self._json_args_str(),
|
|
479
|
+
"'",
|
|
480
|
+
]
|
|
481
|
+
|
|
482
|
+
return " ".join(command_str)
|
|
483
|
+
|
|
484
|
+
def _job_name(self) -> str:
|
|
485
|
+
"""Construct a name for the job"""
|
|
486
|
+
return self.basic_upload_job_configs.s3_prefix
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def job_definition(self) -> dict:
|
|
490
|
+
"""
|
|
491
|
+
Convert job configs to a dictionary that can be sent to the slurm
|
|
492
|
+
cluster via the rest api.
|
|
493
|
+
Parameters
|
|
494
|
+
----------
|
|
495
|
+
|
|
496
|
+
Returns
|
|
497
|
+
-------
|
|
498
|
+
dict
|
|
499
|
+
|
|
500
|
+
"""
|
|
501
|
+
job_name = self._job_name()
|
|
502
|
+
time_limit_str = "{:02d}:{:02d}:00".format(
|
|
503
|
+
*divmod(self.hpc_time_limit, 60)
|
|
504
|
+
)
|
|
505
|
+
mem_str = f"{self.hpc_node_memory}gb"
|
|
506
|
+
environment = {
|
|
507
|
+
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
|
|
508
|
+
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
|
|
509
|
+
"SINGULARITYENV_AWS_SECRET_ACCESS_KEY": (
|
|
510
|
+
self.hpc_aws_secret_access_key.get_secret_value()
|
|
511
|
+
),
|
|
512
|
+
"SINGULARITYENV_AWS_ACCESS_KEY_ID": self.hpc_aws_access_key_id,
|
|
513
|
+
"SINGULARITYENV_AWS_DEFAULT_REGION": self.hpc_aws_default_region,
|
|
514
|
+
}
|
|
515
|
+
if self.hpc_aws_session_token is not None:
|
|
516
|
+
environment[
|
|
517
|
+
"SINGULARITYENV_AWS_SESSION_TOKEN"
|
|
518
|
+
] = self.hpc_aws_session_token
|
|
519
|
+
|
|
520
|
+
if self.hpc_alt_exec_command is not None:
|
|
521
|
+
exec_script = self.hpc_alt_exec_command
|
|
522
|
+
else:
|
|
523
|
+
exec_script = self._script_command_str()
|
|
524
|
+
|
|
525
|
+
log_std_out_path = self.hpc_logging_directory / (job_name + ".out")
|
|
526
|
+
log_std_err_path = self.hpc_logging_directory / (
|
|
527
|
+
job_name + "_error.out"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
return {
|
|
531
|
+
"job": {
|
|
532
|
+
"name": job_name,
|
|
533
|
+
"nodes": self.hpc_nodes,
|
|
534
|
+
"time_limit": time_limit_str,
|
|
535
|
+
"partition": self.hpc_partition,
|
|
536
|
+
"current_working_directory": (
|
|
537
|
+
str(self.hpc_current_working_directory)
|
|
538
|
+
),
|
|
539
|
+
"standard_output": str(log_std_out_path),
|
|
540
|
+
"standard_error": str(log_std_err_path),
|
|
541
|
+
"memory_per_node": mem_str,
|
|
542
|
+
"environment": environment,
|
|
543
|
+
},
|
|
544
|
+
"script": exec_script,
|
|
545
|
+
}
|