aind-data-transfer-service 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aind-data-transfer-service might be problematic. Click here for more details.
- aind_data_transfer_service/__init__.py +9 -0
- aind_data_transfer_service/configs/__init__.py +1 -0
- aind_data_transfer_service/configs/csv_handler.py +59 -0
- aind_data_transfer_service/configs/job_configs.py +545 -0
- aind_data_transfer_service/configs/job_upload_template.py +153 -0
- aind_data_transfer_service/hpc/__init__.py +1 -0
- aind_data_transfer_service/hpc/client.py +151 -0
- aind_data_transfer_service/hpc/models.py +492 -0
- aind_data_transfer_service/log_handler.py +58 -0
- aind_data_transfer_service/models/__init__.py +1 -0
- aind_data_transfer_service/models/core.py +300 -0
- aind_data_transfer_service/models/internal.py +277 -0
- aind_data_transfer_service/server.py +1125 -0
- aind_data_transfer_service/templates/index.html +245 -0
- aind_data_transfer_service/templates/job_params.html +194 -0
- aind_data_transfer_service/templates/job_status.html +323 -0
- aind_data_transfer_service/templates/job_tasks_table.html +146 -0
- aind_data_transfer_service/templates/task_logs.html +31 -0
- aind_data_transfer_service-1.12.0.dist-info/METADATA +49 -0
- aind_data_transfer_service-1.12.0.dist-info/RECORD +23 -0
- aind_data_transfer_service-1.12.0.dist-info/WHEEL +5 -0
- aind_data_transfer_service-1.12.0.dist-info/licenses/LICENSE +21 -0
- aind_data_transfer_service-1.12.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Core models for using V2 of aind-data-transfer-service"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
|
8
|
+
|
|
9
|
+
from aind_data_schema_models.data_name_patterns import build_data_name
|
|
10
|
+
from aind_data_schema_models.modalities import Modality
|
|
11
|
+
from aind_data_schema_models.platforms import Platform
|
|
12
|
+
from pydantic import (
|
|
13
|
+
BaseModel,
|
|
14
|
+
ConfigDict,
|
|
15
|
+
EmailStr,
|
|
16
|
+
Field,
|
|
17
|
+
ValidationInfo,
|
|
18
|
+
computed_field,
|
|
19
|
+
field_validator,
|
|
20
|
+
model_validator,
|
|
21
|
+
)
|
|
22
|
+
from pydantic_settings import BaseSettings
|
|
23
|
+
|
|
24
|
+
_validation_context: ContextVar[Union[Dict[str, Any], None]] = ContextVar(
|
|
25
|
+
"_validation_context", default=None
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@contextmanager
|
|
30
|
+
def validation_context(context: Union[Dict[str, Any], None]) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Following guide in:
|
|
33
|
+
https://docs.pydantic.dev/latest/concepts/validators/#validation-context
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
context : Union[Dict[str, Any], None]
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
None
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
token = _validation_context.set(context)
|
|
44
|
+
try:
|
|
45
|
+
yield
|
|
46
|
+
finally:
|
|
47
|
+
_validation_context.reset(token)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Task(BaseModel):
|
|
51
|
+
"""Configuration for a task run during a data transfer upload job."""
|
|
52
|
+
|
|
53
|
+
skip_task: bool = Field(
|
|
54
|
+
default=False,
|
|
55
|
+
description=(
|
|
56
|
+
"Skip running this task. If true, only task_id and skip_step are "
|
|
57
|
+
"required."
|
|
58
|
+
),
|
|
59
|
+
title="Skip Step",
|
|
60
|
+
)
|
|
61
|
+
image: Optional[str] = Field(
|
|
62
|
+
default=None, description="Name of docker image to run", title="Image"
|
|
63
|
+
)
|
|
64
|
+
image_version: Optional[str] = Field(
|
|
65
|
+
default=None,
|
|
66
|
+
description="Version of docker image to run",
|
|
67
|
+
title="Image Version",
|
|
68
|
+
)
|
|
69
|
+
image_resources: Optional[Dict[str, Any]] = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description="Slurm environment. Must be json serializable.",
|
|
72
|
+
title="Image Resources",
|
|
73
|
+
)
|
|
74
|
+
job_settings: Optional[Dict[str, Any]] = Field(
|
|
75
|
+
default=None,
|
|
76
|
+
description="Settings for the job.",
|
|
77
|
+
title="Job Settings",
|
|
78
|
+
)
|
|
79
|
+
command_script: Optional[str] = Field(
|
|
80
|
+
default=None,
|
|
81
|
+
description=(
|
|
82
|
+
"""
|
|
83
|
+
Command script to run. A few strings may be replaced:
|
|
84
|
+
%JOB_SETTINGS: This will be replaced with json.dumps(job_settings)
|
|
85
|
+
%OUTPUT_LOCATION: Output location such as a local directory
|
|
86
|
+
%S3_LOCATION: Location of S3 where to upload data to
|
|
87
|
+
%INPUT_SOURCE: If a job requires a dynamic input source,
|
|
88
|
+
then this may be replaced.
|
|
89
|
+
%IMAGE: The containerized image.
|
|
90
|
+
%IMAGE_VERSION: The image version.
|
|
91
|
+
%ENV_FILE: An environment file location, such as aws configs.
|
|
92
|
+
"""
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@field_validator(
|
|
97
|
+
"image_resources",
|
|
98
|
+
"job_settings",
|
|
99
|
+
mode="after",
|
|
100
|
+
)
|
|
101
|
+
def validate_json_serializable(
|
|
102
|
+
cls, v: Optional[Dict[str, Any]], info: ValidationInfo
|
|
103
|
+
) -> Optional[Dict[str, Any]]:
|
|
104
|
+
"""Validate that fields are json serializable."""
|
|
105
|
+
if v is not None:
|
|
106
|
+
try:
|
|
107
|
+
json.dumps(v)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"{info.field_name} must be json serializable! If "
|
|
111
|
+
f"converting from a Pydantic model, please use "
|
|
112
|
+
f'model.model_dump(mode="json"). {e}'
|
|
113
|
+
)
|
|
114
|
+
return v
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class UploadJobConfigsV2(BaseSettings):
|
|
118
|
+
"""Configuration for a data transfer upload job"""
|
|
119
|
+
|
|
120
|
+
# noinspection PyMissingConstructor
|
|
121
|
+
def __init__(self, /, **data: Any) -> None:
|
|
122
|
+
"""Add context manager to init for validating fields."""
|
|
123
|
+
self.__pydantic_validator__.validate_python(
|
|
124
|
+
data,
|
|
125
|
+
self_instance=self,
|
|
126
|
+
context=_validation_context.get(),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
model_config = ConfigDict(use_enum_values=True, extra="ignore")
|
|
130
|
+
|
|
131
|
+
job_type: str = Field(
|
|
132
|
+
default="default",
|
|
133
|
+
description=(
|
|
134
|
+
"Job type for the upload job. Tasks will be run based on the "
|
|
135
|
+
"job_type unless otherwise specified in task_overrides."
|
|
136
|
+
),
|
|
137
|
+
title="Job Type",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
user_email: Optional[EmailStr] = Field(
|
|
141
|
+
default=None,
|
|
142
|
+
description=(
|
|
143
|
+
"Optional email address to receive job status notifications"
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
email_notification_types: Optional[
|
|
147
|
+
Set[Literal["begin", "end", "fail", "retry", "all"]]
|
|
148
|
+
] = Field(
|
|
149
|
+
default=None,
|
|
150
|
+
description=(
|
|
151
|
+
"Types of job statuses to receive email notifications about"
|
|
152
|
+
),
|
|
153
|
+
)
|
|
154
|
+
s3_bucket: Literal["private", "open", "default"] = Field(
|
|
155
|
+
default="default",
|
|
156
|
+
description=(
|
|
157
|
+
"Bucket where data will be uploaded. If not provided, will upload "
|
|
158
|
+
"to default bucket."
|
|
159
|
+
),
|
|
160
|
+
title="S3 Bucket",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
project_name: str = Field(
|
|
164
|
+
..., description="Name of project", title="Project Name"
|
|
165
|
+
)
|
|
166
|
+
platform: Platform.ONE_OF = Field(
|
|
167
|
+
..., description="Platform", title="Platform"
|
|
168
|
+
)
|
|
169
|
+
modalities: List[Modality.ONE_OF] = Field(
|
|
170
|
+
...,
|
|
171
|
+
description="Data collection modalities",
|
|
172
|
+
title="Modalities",
|
|
173
|
+
min_length=1,
|
|
174
|
+
)
|
|
175
|
+
subject_id: str = Field(..., description="Subject ID", title="Subject ID")
|
|
176
|
+
acq_datetime: datetime = Field(
|
|
177
|
+
...,
|
|
178
|
+
description="Datetime data was acquired",
|
|
179
|
+
title="Acquisition Datetime",
|
|
180
|
+
)
|
|
181
|
+
tasks: Dict[str, Union[Task, Dict[str, Task]]] = Field(
|
|
182
|
+
...,
|
|
183
|
+
description=(
|
|
184
|
+
"Dictionary of tasks to run with custom settings. The key must be "
|
|
185
|
+
"the task_id and the value must be the task or list of tasks."
|
|
186
|
+
),
|
|
187
|
+
title="Tasks",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
@computed_field
|
|
191
|
+
def s3_prefix(self) -> str:
|
|
192
|
+
"""Construct s3_prefix from configs."""
|
|
193
|
+
return build_data_name(
|
|
194
|
+
label=f"{self.platform.abbreviation}_{self.subject_id}",
|
|
195
|
+
creation_datetime=self.acq_datetime,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
@field_validator("job_type", "project_name", mode="before")
|
|
199
|
+
def validate_with_context(cls, v: str, info: ValidationInfo) -> str:
|
|
200
|
+
"""
|
|
201
|
+
Validate certain fields. If a list of accepted values is provided in a
|
|
202
|
+
context manager, then it will validate against the list. Otherwise, it
|
|
203
|
+
won't raise any validation error.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
v : str
|
|
208
|
+
Value input into the field.
|
|
209
|
+
info : ValidationInfo
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
str
|
|
214
|
+
|
|
215
|
+
"""
|
|
216
|
+
valid_list = (info.context or dict()).get(f"{info.field_name}s")
|
|
217
|
+
if valid_list is not None and v not in valid_list:
|
|
218
|
+
raise ValueError(f"{v} must be one of {valid_list}")
|
|
219
|
+
else:
|
|
220
|
+
return v
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class SubmitJobRequestV2(BaseSettings):
|
|
224
|
+
"""Main request that will be sent to the backend. Bundles jobs into a list
|
|
225
|
+
and allows a user to add an email address to receive notifications."""
|
|
226
|
+
|
|
227
|
+
# noinspection PyMissingConstructor
|
|
228
|
+
def __init__(self, /, **data: Any) -> None:
|
|
229
|
+
"""Add context manager to init for validating upload_jobs."""
|
|
230
|
+
self.__pydantic_validator__.validate_python(
|
|
231
|
+
data,
|
|
232
|
+
self_instance=self,
|
|
233
|
+
context=_validation_context.get(),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
model_config = ConfigDict(use_enum_values=True, extra="ignore")
|
|
237
|
+
|
|
238
|
+
dag_id: Literal["transform_and_upload_v2"] = "transform_and_upload_v2"
|
|
239
|
+
user_email: Optional[EmailStr] = Field(
|
|
240
|
+
default=None,
|
|
241
|
+
description=(
|
|
242
|
+
"Optional email address to receive job status notifications"
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
email_notification_types: Set[
|
|
246
|
+
Literal["begin", "end", "fail", "retry", "all"]
|
|
247
|
+
] = Field(
|
|
248
|
+
default={"fail"},
|
|
249
|
+
description=(
|
|
250
|
+
"Types of job statuses to receive email notifications about"
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
upload_jobs: List[UploadJobConfigsV2] = Field(
|
|
254
|
+
...,
|
|
255
|
+
description="List of upload jobs to process. Max of 50 at a time.",
|
|
256
|
+
min_length=1,
|
|
257
|
+
max_length=50,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
@model_validator(mode="after")
|
|
261
|
+
def propagate_email_settings(self):
|
|
262
|
+
"""Propagate email settings from global to individual jobs"""
|
|
263
|
+
global_email_user = self.user_email
|
|
264
|
+
global_email_notification_types = self.email_notification_types
|
|
265
|
+
for upload_job in self.upload_jobs:
|
|
266
|
+
if global_email_user is not None and upload_job.user_email is None:
|
|
267
|
+
upload_job.user_email = global_email_user
|
|
268
|
+
if upload_job.email_notification_types is None:
|
|
269
|
+
upload_job.email_notification_types = (
|
|
270
|
+
global_email_notification_types
|
|
271
|
+
)
|
|
272
|
+
return self
|
|
273
|
+
|
|
274
|
+
@model_validator(mode="after")
|
|
275
|
+
def check_duplicate_upload_jobs(self, info: ValidationInfo):
|
|
276
|
+
"""Validate that there are no duplicate upload jobs. If a list of
|
|
277
|
+
current jobs is provided in a context manager, jobs are also checked
|
|
278
|
+
against the list."""
|
|
279
|
+
jobs_map = dict()
|
|
280
|
+
# check jobs with the same s3_prefix
|
|
281
|
+
for job in self.upload_jobs:
|
|
282
|
+
prefix = job.s3_prefix
|
|
283
|
+
job_json = json.dumps(
|
|
284
|
+
job.model_dump(mode="json", exclude_none=True), sort_keys=True
|
|
285
|
+
)
|
|
286
|
+
jobs_map.setdefault(prefix, set())
|
|
287
|
+
if job_json in jobs_map[prefix]:
|
|
288
|
+
raise ValueError(f"Duplicate jobs found for {prefix}")
|
|
289
|
+
jobs_map[prefix].add(job_json)
|
|
290
|
+
# check against any jobs in the context
|
|
291
|
+
current_jobs = (info.context or dict()).get("current_jobs", list())
|
|
292
|
+
for job in current_jobs:
|
|
293
|
+
prefix = job.get("s3_prefix")
|
|
294
|
+
if (
|
|
295
|
+
prefix is not None
|
|
296
|
+
and prefix in jobs_map
|
|
297
|
+
and json.dumps(job, sort_keys=True) in jobs_map[prefix]
|
|
298
|
+
):
|
|
299
|
+
raise ValueError(f"Job is already running/queued for {prefix}")
|
|
300
|
+
return self
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""Module for internal data models used in application"""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import os
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
from typing import List, Optional, Union
|
|
7
|
+
|
|
8
|
+
from mypy_boto3_ssm.type_defs import ParameterMetadataTypeDef
|
|
9
|
+
from pydantic import AwareDatetime, BaseModel, Field, field_validator
|
|
10
|
+
from starlette.datastructures import QueryParams
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AirflowDagRun(BaseModel):
|
|
14
|
+
"""Data model for dag_run entry when requesting info from airflow"""
|
|
15
|
+
|
|
16
|
+
conf: Optional[dict]
|
|
17
|
+
dag_id: Optional[str]
|
|
18
|
+
dag_run_id: Optional[str]
|
|
19
|
+
data_interval_end: Optional[AwareDatetime]
|
|
20
|
+
data_interval_start: Optional[AwareDatetime]
|
|
21
|
+
end_date: Optional[AwareDatetime]
|
|
22
|
+
execution_date: Optional[AwareDatetime]
|
|
23
|
+
external_trigger: Optional[bool]
|
|
24
|
+
last_scheduling_decision: Optional[AwareDatetime]
|
|
25
|
+
logical_date: Optional[AwareDatetime]
|
|
26
|
+
note: Optional[str]
|
|
27
|
+
run_type: Optional[str]
|
|
28
|
+
start_date: Optional[AwareDatetime]
|
|
29
|
+
state: Optional[str]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AirflowDagRunsResponse(BaseModel):
|
|
33
|
+
"""Data model for response when requesting info from dag_runs endpoint"""
|
|
34
|
+
|
|
35
|
+
dag_runs: List[AirflowDagRun]
|
|
36
|
+
total_entries: int
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AirflowDagRunsRequestParameters(BaseModel):
|
|
40
|
+
"""Model for parameters when requesting info from dag_runs endpoint"""
|
|
41
|
+
|
|
42
|
+
dag_ids: list[str] = ["transform_and_upload", "transform_and_upload_v2"]
|
|
43
|
+
page_limit: int = 100
|
|
44
|
+
page_offset: int = 0
|
|
45
|
+
states: Optional[list[str]] = []
|
|
46
|
+
execution_date_gte: Optional[str] = (
|
|
47
|
+
datetime.now(timezone.utc) - timedelta(weeks=2)
|
|
48
|
+
).isoformat()
|
|
49
|
+
execution_date_lte: Optional[str] = None
|
|
50
|
+
order_by: str = "-execution_date"
|
|
51
|
+
|
|
52
|
+
@field_validator("execution_date_gte", mode="after")
|
|
53
|
+
def validate_min_execution_date(cls, execution_date_gte: str):
|
|
54
|
+
"""Validate the earliest submit date filter is within 2 weeks"""
|
|
55
|
+
min_execution_date = datetime.now(timezone.utc) - timedelta(weeks=2)
|
|
56
|
+
# datetime.fromisoformat does not support Z in python < 3.11
|
|
57
|
+
date_to_check = execution_date_gte.replace("Z", "+00:00")
|
|
58
|
+
if datetime.fromisoformat(date_to_check) < min_execution_date:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"execution_date_gte must be within the last 2 weeks"
|
|
61
|
+
)
|
|
62
|
+
return execution_date_gte
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_query_params(cls, query_params: QueryParams):
|
|
66
|
+
"""Maps the query parameters to the model"""
|
|
67
|
+
params = dict(query_params)
|
|
68
|
+
if "states" in params:
|
|
69
|
+
params["states"] = ast.literal_eval(params["states"])
|
|
70
|
+
return cls.model_validate(params)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AirflowTaskInstancesRequestParameters(BaseModel):
|
|
74
|
+
"""Model for parameters when requesting info from task_instances
|
|
75
|
+
endpoint"""
|
|
76
|
+
|
|
77
|
+
dag_id: str = Field(..., min_length=1)
|
|
78
|
+
dag_run_id: str = Field(..., min_length=1)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_query_params(cls, query_params: QueryParams):
|
|
82
|
+
"""Maps the query parameters to the model"""
|
|
83
|
+
params = dict(query_params)
|
|
84
|
+
return cls.model_validate(params)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AirflowTaskInstance(BaseModel):
|
|
88
|
+
"""Data model for task_instance entry when requesting info from airflow"""
|
|
89
|
+
|
|
90
|
+
dag_id: Optional[str]
|
|
91
|
+
dag_run_id: Optional[str]
|
|
92
|
+
duration: Optional[Union[int, float]]
|
|
93
|
+
end_date: Optional[AwareDatetime]
|
|
94
|
+
execution_date: Optional[AwareDatetime]
|
|
95
|
+
executor_config: Optional[str]
|
|
96
|
+
hostname: Optional[str]
|
|
97
|
+
map_index: Optional[int]
|
|
98
|
+
max_tries: Optional[int]
|
|
99
|
+
note: Optional[str]
|
|
100
|
+
operator: Optional[str]
|
|
101
|
+
pid: Optional[int]
|
|
102
|
+
pool: Optional[str]
|
|
103
|
+
pool_slots: Optional[int]
|
|
104
|
+
priority_weight: Optional[int]
|
|
105
|
+
queue: Optional[str]
|
|
106
|
+
queued_when: Optional[AwareDatetime]
|
|
107
|
+
rendered_fields: Optional[dict]
|
|
108
|
+
sla_miss: Optional[dict]
|
|
109
|
+
start_date: Optional[AwareDatetime]
|
|
110
|
+
state: Optional[str]
|
|
111
|
+
task_id: Optional[str]
|
|
112
|
+
trigger: Optional[dict]
|
|
113
|
+
triggerer_job: Optional[dict]
|
|
114
|
+
try_number: Optional[int]
|
|
115
|
+
unixname: Optional[str]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class AirflowTaskInstancesResponse(BaseModel):
|
|
119
|
+
"""Data model for response when requesting info from task_instances
|
|
120
|
+
endpoint"""
|
|
121
|
+
|
|
122
|
+
task_instances: List[AirflowTaskInstance]
|
|
123
|
+
total_entries: int
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class AirflowTaskInstanceLogsRequestParameters(BaseModel):
|
|
127
|
+
"""Model for parameters when requesting info from task_instance_logs
|
|
128
|
+
endpoint"""
|
|
129
|
+
|
|
130
|
+
# excluded fields are used to build the url
|
|
131
|
+
dag_id: str = Field(..., min_length=1, exclude=True)
|
|
132
|
+
dag_run_id: str = Field(..., min_length=1, exclude=True)
|
|
133
|
+
task_id: str = Field(..., min_length=1, exclude=True)
|
|
134
|
+
try_number: int = Field(..., ge=0, exclude=True)
|
|
135
|
+
map_index: int = Field(..., ge=-1)
|
|
136
|
+
full_content: bool = True
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def from_query_params(cls, query_params: QueryParams):
|
|
140
|
+
"""Maps the query parameters to the model"""
|
|
141
|
+
params = dict(query_params)
|
|
142
|
+
return cls.model_validate(params)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class JobStatus(BaseModel):
|
|
146
|
+
"""Model for what we want to render to the user."""
|
|
147
|
+
|
|
148
|
+
dag_id: Optional[str] = Field(None)
|
|
149
|
+
end_time: Optional[datetime] = Field(None)
|
|
150
|
+
job_id: Optional[str] = Field(None)
|
|
151
|
+
job_state: Optional[str] = Field(None)
|
|
152
|
+
name: Optional[str] = Field(None)
|
|
153
|
+
job_type: Optional[str] = Field(None)
|
|
154
|
+
comment: Optional[str] = Field(None)
|
|
155
|
+
start_time: Optional[datetime] = Field(None)
|
|
156
|
+
submit_time: Optional[datetime] = Field(None)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def from_airflow_dag_run(cls, airflow_dag_run: AirflowDagRun):
|
|
160
|
+
"""Maps the fields from the HpcJobStatusResponse to this model"""
|
|
161
|
+
name = airflow_dag_run.conf.get("s3_prefix", "")
|
|
162
|
+
job_type = airflow_dag_run.conf.get("job_type", "")
|
|
163
|
+
# v1 job_type is in CO configs
|
|
164
|
+
if job_type == "":
|
|
165
|
+
job_type = airflow_dag_run.conf.get("codeocean_configs", {}).get(
|
|
166
|
+
"job_type", ""
|
|
167
|
+
)
|
|
168
|
+
return cls(
|
|
169
|
+
dag_id=airflow_dag_run.dag_id,
|
|
170
|
+
end_time=airflow_dag_run.end_date,
|
|
171
|
+
job_id=airflow_dag_run.dag_run_id,
|
|
172
|
+
job_state=airflow_dag_run.state,
|
|
173
|
+
name=name,
|
|
174
|
+
job_type=job_type,
|
|
175
|
+
comment=airflow_dag_run.note,
|
|
176
|
+
start_time=airflow_dag_run.start_date,
|
|
177
|
+
submit_time=airflow_dag_run.execution_date,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def jinja_dict(self):
|
|
182
|
+
"""Map model to a dictionary that jinja can render"""
|
|
183
|
+
return self.model_dump(exclude_none=True)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class JobTasks(BaseModel):
|
|
187
|
+
"""Model for what is rendered to the user for each task."""
|
|
188
|
+
|
|
189
|
+
dag_id: Optional[str] = Field(None)
|
|
190
|
+
job_id: Optional[str] = Field(None)
|
|
191
|
+
task_id: Optional[str] = Field(None)
|
|
192
|
+
try_number: Optional[int] = Field(None)
|
|
193
|
+
task_state: Optional[str] = Field(None)
|
|
194
|
+
priority_weight: Optional[int] = Field(None)
|
|
195
|
+
map_index: Optional[int] = Field(None)
|
|
196
|
+
submit_time: Optional[datetime] = Field(None)
|
|
197
|
+
start_time: Optional[datetime] = Field(None)
|
|
198
|
+
end_time: Optional[datetime] = Field(None)
|
|
199
|
+
duration: Optional[Union[int, float]] = Field(None)
|
|
200
|
+
comment: Optional[str] = Field(None)
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def from_airflow_task_instance(
|
|
204
|
+
cls, airflow_task_instance: AirflowTaskInstance
|
|
205
|
+
):
|
|
206
|
+
"""Maps the fields from the HpcJobStatusResponse to this model"""
|
|
207
|
+
return cls(
|
|
208
|
+
dag_id=airflow_task_instance.dag_id,
|
|
209
|
+
job_id=airflow_task_instance.dag_run_id,
|
|
210
|
+
task_id=airflow_task_instance.task_id,
|
|
211
|
+
try_number=airflow_task_instance.try_number,
|
|
212
|
+
task_state=airflow_task_instance.state,
|
|
213
|
+
priority_weight=airflow_task_instance.priority_weight,
|
|
214
|
+
map_index=airflow_task_instance.map_index,
|
|
215
|
+
submit_time=airflow_task_instance.execution_date,
|
|
216
|
+
start_time=airflow_task_instance.start_date,
|
|
217
|
+
end_time=airflow_task_instance.end_date,
|
|
218
|
+
duration=airflow_task_instance.duration,
|
|
219
|
+
comment=airflow_task_instance.note,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class JobParamInfo(BaseModel):
|
|
224
|
+
"""Model for job parameter info from AWS Parameter Store"""
|
|
225
|
+
|
|
226
|
+
name: Optional[str]
|
|
227
|
+
last_modified: Optional[datetime]
|
|
228
|
+
job_type: str
|
|
229
|
+
task_id: str
|
|
230
|
+
modality: Optional[str]
|
|
231
|
+
|
|
232
|
+
@classmethod
|
|
233
|
+
def from_aws_describe_parameter(
|
|
234
|
+
cls,
|
|
235
|
+
parameter: ParameterMetadataTypeDef,
|
|
236
|
+
job_type: str,
|
|
237
|
+
task_id: str,
|
|
238
|
+
modality: Optional[str],
|
|
239
|
+
):
|
|
240
|
+
"""Map the parameter to the model"""
|
|
241
|
+
return cls(
|
|
242
|
+
name=parameter.get("Name"),
|
|
243
|
+
last_modified=parameter.get("LastModifiedDate"),
|
|
244
|
+
job_type=job_type,
|
|
245
|
+
task_id=task_id,
|
|
246
|
+
modality=modality,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def get_parameter_prefix(version: Optional[str] = None) -> str:
|
|
251
|
+
"""Get the prefix for job_type parameters"""
|
|
252
|
+
prefix = os.getenv("AIND_AIRFLOW_PARAM_PREFIX")
|
|
253
|
+
if version is None:
|
|
254
|
+
return prefix
|
|
255
|
+
return f"{prefix}/{version}"
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def get_parameter_regex(version: Optional[str] = None) -> str:
|
|
259
|
+
"""Create the regex pattern to match the parameter name"""
|
|
260
|
+
prefix = os.getenv("AIND_AIRFLOW_PARAM_PREFIX")
|
|
261
|
+
regex = (
|
|
262
|
+
"(?P<job_type>[^/]+)/tasks/(?P<task_id>[^/]+)"
|
|
263
|
+
"(?:/(?P<modality>[^/]+))?"
|
|
264
|
+
)
|
|
265
|
+
if version is None:
|
|
266
|
+
return f"{prefix}/{regex}"
|
|
267
|
+
return f"{prefix}/{version}/{regex}"
|
|
268
|
+
|
|
269
|
+
@staticmethod
|
|
270
|
+
def get_parameter_name(
|
|
271
|
+
job_type: str, task_id: str, version: Optional[str] = None
|
|
272
|
+
) -> str:
|
|
273
|
+
"""Create the parameter name from job_type and task_id"""
|
|
274
|
+
prefix = os.getenv("AIND_AIRFLOW_PARAM_PREFIX")
|
|
275
|
+
if version is None:
|
|
276
|
+
return f"{prefix}/{job_type}/tasks/{task_id}"
|
|
277
|
+
return f"{prefix}/{version}/{job_type}/tasks/{task_id}"
|