aind-data-transfer-service 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aind-data-transfer-service might be problematic. Click here for more details.

@@ -0,0 +1,300 @@
1
+ """Core models for using V2 of aind-data-transfer-service"""
2
+
3
+ import json
4
+ from contextlib import contextmanager
5
+ from contextvars import ContextVar
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List, Literal, Optional, Set, Union
8
+
9
+ from aind_data_schema_models.data_name_patterns import build_data_name
10
+ from aind_data_schema_models.modalities import Modality
11
+ from aind_data_schema_models.platforms import Platform
12
+ from pydantic import (
13
+ BaseModel,
14
+ ConfigDict,
15
+ EmailStr,
16
+ Field,
17
+ ValidationInfo,
18
+ computed_field,
19
+ field_validator,
20
+ model_validator,
21
+ )
22
+ from pydantic_settings import BaseSettings
23
+
24
+ _validation_context: ContextVar[Union[Dict[str, Any], None]] = ContextVar(
25
+ "_validation_context", default=None
26
+ )
27
+
28
+
29
+ @contextmanager
30
+ def validation_context(context: Union[Dict[str, Any], None]) -> None:
31
+ """
32
+ Following guide in:
33
+ https://docs.pydantic.dev/latest/concepts/validators/#validation-context
34
+ Parameters
35
+ ----------
36
+ context : Union[Dict[str, Any], None]
37
+
38
+ Returns
39
+ -------
40
+ None
41
+
42
+ """
43
+ token = _validation_context.set(context)
44
+ try:
45
+ yield
46
+ finally:
47
+ _validation_context.reset(token)
48
+
49
+
50
+ class Task(BaseModel):
51
+ """Configuration for a task run during a data transfer upload job."""
52
+
53
+ skip_task: bool = Field(
54
+ default=False,
55
+ description=(
56
+ "Skip running this task. If true, only task_id and skip_step are "
57
+ "required."
58
+ ),
59
+ title="Skip Step",
60
+ )
61
+ image: Optional[str] = Field(
62
+ default=None, description="Name of docker image to run", title="Image"
63
+ )
64
+ image_version: Optional[str] = Field(
65
+ default=None,
66
+ description="Version of docker image to run",
67
+ title="Image Version",
68
+ )
69
+ image_resources: Optional[Dict[str, Any]] = Field(
70
+ default=None,
71
+ description="Slurm environment. Must be json serializable.",
72
+ title="Image Resources",
73
+ )
74
+ job_settings: Optional[Dict[str, Any]] = Field(
75
+ default=None,
76
+ description="Settings for the job.",
77
+ title="Job Settings",
78
+ )
79
+ command_script: Optional[str] = Field(
80
+ default=None,
81
+ description=(
82
+ """
83
+ Command script to run. A few strings may be replaced:
84
+ %JOB_SETTINGS: This will be replaced with json.dumps(job_settings)
85
+ %OUTPUT_LOCATION: Output location such as a local directory
86
+ %S3_LOCATION: Location of S3 where to upload data to
87
+ %INPUT_SOURCE: If a job requires a dynamic input source,
88
+ then this may be replaced.
89
+ %IMAGE: The containerized image.
90
+ %IMAGE_VERSION: The image version.
91
+ %ENV_FILE: An environment file location, such as aws configs.
92
+ """
93
+ ),
94
+ )
95
+
96
+ @field_validator(
97
+ "image_resources",
98
+ "job_settings",
99
+ mode="after",
100
+ )
101
+ def validate_json_serializable(
102
+ cls, v: Optional[Dict[str, Any]], info: ValidationInfo
103
+ ) -> Optional[Dict[str, Any]]:
104
+ """Validate that fields are json serializable."""
105
+ if v is not None:
106
+ try:
107
+ json.dumps(v)
108
+ except Exception as e:
109
+ raise ValueError(
110
+ f"{info.field_name} must be json serializable! If "
111
+ f"converting from a Pydantic model, please use "
112
+ f'model.model_dump(mode="json"). {e}'
113
+ )
114
+ return v
115
+
116
+
117
+ class UploadJobConfigsV2(BaseSettings):
118
+ """Configuration for a data transfer upload job"""
119
+
120
+ # noinspection PyMissingConstructor
121
+ def __init__(self, /, **data: Any) -> None:
122
+ """Add context manager to init for validating fields."""
123
+ self.__pydantic_validator__.validate_python(
124
+ data,
125
+ self_instance=self,
126
+ context=_validation_context.get(),
127
+ )
128
+
129
+ model_config = ConfigDict(use_enum_values=True, extra="ignore")
130
+
131
+ job_type: str = Field(
132
+ default="default",
133
+ description=(
134
+ "Job type for the upload job. Tasks will be run based on the "
135
+ "job_type unless otherwise specified in task_overrides."
136
+ ),
137
+ title="Job Type",
138
+ )
139
+
140
+ user_email: Optional[EmailStr] = Field(
141
+ default=None,
142
+ description=(
143
+ "Optional email address to receive job status notifications"
144
+ ),
145
+ )
146
+ email_notification_types: Optional[
147
+ Set[Literal["begin", "end", "fail", "retry", "all"]]
148
+ ] = Field(
149
+ default=None,
150
+ description=(
151
+ "Types of job statuses to receive email notifications about"
152
+ ),
153
+ )
154
+ s3_bucket: Literal["private", "open", "default"] = Field(
155
+ default="default",
156
+ description=(
157
+ "Bucket where data will be uploaded. If not provided, will upload "
158
+ "to default bucket."
159
+ ),
160
+ title="S3 Bucket",
161
+ )
162
+
163
+ project_name: str = Field(
164
+ ..., description="Name of project", title="Project Name"
165
+ )
166
+ platform: Platform.ONE_OF = Field(
167
+ ..., description="Platform", title="Platform"
168
+ )
169
+ modalities: List[Modality.ONE_OF] = Field(
170
+ ...,
171
+ description="Data collection modalities",
172
+ title="Modalities",
173
+ min_length=1,
174
+ )
175
+ subject_id: str = Field(..., description="Subject ID", title="Subject ID")
176
+ acq_datetime: datetime = Field(
177
+ ...,
178
+ description="Datetime data was acquired",
179
+ title="Acquisition Datetime",
180
+ )
181
+ tasks: Dict[str, Union[Task, Dict[str, Task]]] = Field(
182
+ ...,
183
+ description=(
184
+ "Dictionary of tasks to run with custom settings. The key must be "
185
+ "the task_id and the value must be the task or list of tasks."
186
+ ),
187
+ title="Tasks",
188
+ )
189
+
190
+ @computed_field
191
+ def s3_prefix(self) -> str:
192
+ """Construct s3_prefix from configs."""
193
+ return build_data_name(
194
+ label=f"{self.platform.abbreviation}_{self.subject_id}",
195
+ creation_datetime=self.acq_datetime,
196
+ )
197
+
198
+ @field_validator("job_type", "project_name", mode="before")
199
+ def validate_with_context(cls, v: str, info: ValidationInfo) -> str:
200
+ """
201
+ Validate certain fields. If a list of accepted values is provided in a
202
+ context manager, then it will validate against the list. Otherwise, it
203
+ won't raise any validation error.
204
+
205
+ Parameters
206
+ ----------
207
+ v : str
208
+ Value input into the field.
209
+ info : ValidationInfo
210
+
211
+ Returns
212
+ -------
213
+ str
214
+
215
+ """
216
+ valid_list = (info.context or dict()).get(f"{info.field_name}s")
217
+ if valid_list is not None and v not in valid_list:
218
+ raise ValueError(f"{v} must be one of {valid_list}")
219
+ else:
220
+ return v
221
+
222
+
223
+ class SubmitJobRequestV2(BaseSettings):
224
+ """Main request that will be sent to the backend. Bundles jobs into a list
225
+ and allows a user to add an email address to receive notifications."""
226
+
227
+ # noinspection PyMissingConstructor
228
+ def __init__(self, /, **data: Any) -> None:
229
+ """Add context manager to init for validating upload_jobs."""
230
+ self.__pydantic_validator__.validate_python(
231
+ data,
232
+ self_instance=self,
233
+ context=_validation_context.get(),
234
+ )
235
+
236
+ model_config = ConfigDict(use_enum_values=True, extra="ignore")
237
+
238
+ dag_id: Literal["transform_and_upload_v2"] = "transform_and_upload_v2"
239
+ user_email: Optional[EmailStr] = Field(
240
+ default=None,
241
+ description=(
242
+ "Optional email address to receive job status notifications"
243
+ ),
244
+ )
245
+ email_notification_types: Set[
246
+ Literal["begin", "end", "fail", "retry", "all"]
247
+ ] = Field(
248
+ default={"fail"},
249
+ description=(
250
+ "Types of job statuses to receive email notifications about"
251
+ ),
252
+ )
253
+ upload_jobs: List[UploadJobConfigsV2] = Field(
254
+ ...,
255
+ description="List of upload jobs to process. Max of 50 at a time.",
256
+ min_length=1,
257
+ max_length=50,
258
+ )
259
+
260
+ @model_validator(mode="after")
261
+ def propagate_email_settings(self):
262
+ """Propagate email settings from global to individual jobs"""
263
+ global_email_user = self.user_email
264
+ global_email_notification_types = self.email_notification_types
265
+ for upload_job in self.upload_jobs:
266
+ if global_email_user is not None and upload_job.user_email is None:
267
+ upload_job.user_email = global_email_user
268
+ if upload_job.email_notification_types is None:
269
+ upload_job.email_notification_types = (
270
+ global_email_notification_types
271
+ )
272
+ return self
273
+
274
+ @model_validator(mode="after")
275
+ def check_duplicate_upload_jobs(self, info: ValidationInfo):
276
+ """Validate that there are no duplicate upload jobs. If a list of
277
+ current jobs is provided in a context manager, jobs are also checked
278
+ against the list."""
279
+ jobs_map = dict()
280
+ # check jobs with the same s3_prefix
281
+ for job in self.upload_jobs:
282
+ prefix = job.s3_prefix
283
+ job_json = json.dumps(
284
+ job.model_dump(mode="json", exclude_none=True), sort_keys=True
285
+ )
286
+ jobs_map.setdefault(prefix, set())
287
+ if job_json in jobs_map[prefix]:
288
+ raise ValueError(f"Duplicate jobs found for {prefix}")
289
+ jobs_map[prefix].add(job_json)
290
+ # check against any jobs in the context
291
+ current_jobs = (info.context or dict()).get("current_jobs", list())
292
+ for job in current_jobs:
293
+ prefix = job.get("s3_prefix")
294
+ if (
295
+ prefix is not None
296
+ and prefix in jobs_map
297
+ and json.dumps(job, sort_keys=True) in jobs_map[prefix]
298
+ ):
299
+ raise ValueError(f"Job is already running/queued for {prefix}")
300
+ return self
@@ -0,0 +1,277 @@
1
+ """Module for internal data models used in application"""
2
+
3
+ import ast
4
+ import os
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import List, Optional, Union
7
+
8
+ from mypy_boto3_ssm.type_defs import ParameterMetadataTypeDef
9
+ from pydantic import AwareDatetime, BaseModel, Field, field_validator
10
+ from starlette.datastructures import QueryParams
11
+
12
+
13
+ class AirflowDagRun(BaseModel):
14
+ """Data model for dag_run entry when requesting info from airflow"""
15
+
16
+ conf: Optional[dict]
17
+ dag_id: Optional[str]
18
+ dag_run_id: Optional[str]
19
+ data_interval_end: Optional[AwareDatetime]
20
+ data_interval_start: Optional[AwareDatetime]
21
+ end_date: Optional[AwareDatetime]
22
+ execution_date: Optional[AwareDatetime]
23
+ external_trigger: Optional[bool]
24
+ last_scheduling_decision: Optional[AwareDatetime]
25
+ logical_date: Optional[AwareDatetime]
26
+ note: Optional[str]
27
+ run_type: Optional[str]
28
+ start_date: Optional[AwareDatetime]
29
+ state: Optional[str]
30
+
31
+
32
+ class AirflowDagRunsResponse(BaseModel):
33
+ """Data model for response when requesting info from dag_runs endpoint"""
34
+
35
+ dag_runs: List[AirflowDagRun]
36
+ total_entries: int
37
+
38
+
39
+ class AirflowDagRunsRequestParameters(BaseModel):
40
+ """Model for parameters when requesting info from dag_runs endpoint"""
41
+
42
+ dag_ids: list[str] = ["transform_and_upload", "transform_and_upload_v2"]
43
+ page_limit: int = 100
44
+ page_offset: int = 0
45
+ states: Optional[list[str]] = []
46
+ execution_date_gte: Optional[str] = (
47
+ datetime.now(timezone.utc) - timedelta(weeks=2)
48
+ ).isoformat()
49
+ execution_date_lte: Optional[str] = None
50
+ order_by: str = "-execution_date"
51
+
52
+ @field_validator("execution_date_gte", mode="after")
53
+ def validate_min_execution_date(cls, execution_date_gte: str):
54
+ """Validate the earliest submit date filter is within 2 weeks"""
55
+ min_execution_date = datetime.now(timezone.utc) - timedelta(weeks=2)
56
+ # datetime.fromisoformat does not support Z in python < 3.11
57
+ date_to_check = execution_date_gte.replace("Z", "+00:00")
58
+ if datetime.fromisoformat(date_to_check) < min_execution_date:
59
+ raise ValueError(
60
+ "execution_date_gte must be within the last 2 weeks"
61
+ )
62
+ return execution_date_gte
63
+
64
+ @classmethod
65
+ def from_query_params(cls, query_params: QueryParams):
66
+ """Maps the query parameters to the model"""
67
+ params = dict(query_params)
68
+ if "states" in params:
69
+ params["states"] = ast.literal_eval(params["states"])
70
+ return cls.model_validate(params)
71
+
72
+
73
+ class AirflowTaskInstancesRequestParameters(BaseModel):
74
+ """Model for parameters when requesting info from task_instances
75
+ endpoint"""
76
+
77
+ dag_id: str = Field(..., min_length=1)
78
+ dag_run_id: str = Field(..., min_length=1)
79
+
80
+ @classmethod
81
+ def from_query_params(cls, query_params: QueryParams):
82
+ """Maps the query parameters to the model"""
83
+ params = dict(query_params)
84
+ return cls.model_validate(params)
85
+
86
+
87
+ class AirflowTaskInstance(BaseModel):
88
+ """Data model for task_instance entry when requesting info from airflow"""
89
+
90
+ dag_id: Optional[str]
91
+ dag_run_id: Optional[str]
92
+ duration: Optional[Union[int, float]]
93
+ end_date: Optional[AwareDatetime]
94
+ execution_date: Optional[AwareDatetime]
95
+ executor_config: Optional[str]
96
+ hostname: Optional[str]
97
+ map_index: Optional[int]
98
+ max_tries: Optional[int]
99
+ note: Optional[str]
100
+ operator: Optional[str]
101
+ pid: Optional[int]
102
+ pool: Optional[str]
103
+ pool_slots: Optional[int]
104
+ priority_weight: Optional[int]
105
+ queue: Optional[str]
106
+ queued_when: Optional[AwareDatetime]
107
+ rendered_fields: Optional[dict]
108
+ sla_miss: Optional[dict]
109
+ start_date: Optional[AwareDatetime]
110
+ state: Optional[str]
111
+ task_id: Optional[str]
112
+ trigger: Optional[dict]
113
+ triggerer_job: Optional[dict]
114
+ try_number: Optional[int]
115
+ unixname: Optional[str]
116
+
117
+
118
+ class AirflowTaskInstancesResponse(BaseModel):
119
+ """Data model for response when requesting info from task_instances
120
+ endpoint"""
121
+
122
+ task_instances: List[AirflowTaskInstance]
123
+ total_entries: int
124
+
125
+
126
+ class AirflowTaskInstanceLogsRequestParameters(BaseModel):
127
+ """Model for parameters when requesting info from task_instance_logs
128
+ endpoint"""
129
+
130
+ # excluded fields are used to build the url
131
+ dag_id: str = Field(..., min_length=1, exclude=True)
132
+ dag_run_id: str = Field(..., min_length=1, exclude=True)
133
+ task_id: str = Field(..., min_length=1, exclude=True)
134
+ try_number: int = Field(..., ge=0, exclude=True)
135
+ map_index: int = Field(..., ge=-1)
136
+ full_content: bool = True
137
+
138
+ @classmethod
139
+ def from_query_params(cls, query_params: QueryParams):
140
+ """Maps the query parameters to the model"""
141
+ params = dict(query_params)
142
+ return cls.model_validate(params)
143
+
144
+
145
+ class JobStatus(BaseModel):
146
+ """Model for what we want to render to the user."""
147
+
148
+ dag_id: Optional[str] = Field(None)
149
+ end_time: Optional[datetime] = Field(None)
150
+ job_id: Optional[str] = Field(None)
151
+ job_state: Optional[str] = Field(None)
152
+ name: Optional[str] = Field(None)
153
+ job_type: Optional[str] = Field(None)
154
+ comment: Optional[str] = Field(None)
155
+ start_time: Optional[datetime] = Field(None)
156
+ submit_time: Optional[datetime] = Field(None)
157
+
158
+ @classmethod
159
+ def from_airflow_dag_run(cls, airflow_dag_run: AirflowDagRun):
160
+ """Maps the fields from the HpcJobStatusResponse to this model"""
161
+ name = airflow_dag_run.conf.get("s3_prefix", "")
162
+ job_type = airflow_dag_run.conf.get("job_type", "")
163
+ # v1 job_type is in CO configs
164
+ if job_type == "":
165
+ job_type = airflow_dag_run.conf.get("codeocean_configs", {}).get(
166
+ "job_type", ""
167
+ )
168
+ return cls(
169
+ dag_id=airflow_dag_run.dag_id,
170
+ end_time=airflow_dag_run.end_date,
171
+ job_id=airflow_dag_run.dag_run_id,
172
+ job_state=airflow_dag_run.state,
173
+ name=name,
174
+ job_type=job_type,
175
+ comment=airflow_dag_run.note,
176
+ start_time=airflow_dag_run.start_date,
177
+ submit_time=airflow_dag_run.execution_date,
178
+ )
179
+
180
+ @property
181
+ def jinja_dict(self):
182
+ """Map model to a dictionary that jinja can render"""
183
+ return self.model_dump(exclude_none=True)
184
+
185
+
186
+ class JobTasks(BaseModel):
187
+ """Model for what is rendered to the user for each task."""
188
+
189
+ dag_id: Optional[str] = Field(None)
190
+ job_id: Optional[str] = Field(None)
191
+ task_id: Optional[str] = Field(None)
192
+ try_number: Optional[int] = Field(None)
193
+ task_state: Optional[str] = Field(None)
194
+ priority_weight: Optional[int] = Field(None)
195
+ map_index: Optional[int] = Field(None)
196
+ submit_time: Optional[datetime] = Field(None)
197
+ start_time: Optional[datetime] = Field(None)
198
+ end_time: Optional[datetime] = Field(None)
199
+ duration: Optional[Union[int, float]] = Field(None)
200
+ comment: Optional[str] = Field(None)
201
+
202
+ @classmethod
203
+ def from_airflow_task_instance(
204
+ cls, airflow_task_instance: AirflowTaskInstance
205
+ ):
206
+ """Maps the fields from the HpcJobStatusResponse to this model"""
207
+ return cls(
208
+ dag_id=airflow_task_instance.dag_id,
209
+ job_id=airflow_task_instance.dag_run_id,
210
+ task_id=airflow_task_instance.task_id,
211
+ try_number=airflow_task_instance.try_number,
212
+ task_state=airflow_task_instance.state,
213
+ priority_weight=airflow_task_instance.priority_weight,
214
+ map_index=airflow_task_instance.map_index,
215
+ submit_time=airflow_task_instance.execution_date,
216
+ start_time=airflow_task_instance.start_date,
217
+ end_time=airflow_task_instance.end_date,
218
+ duration=airflow_task_instance.duration,
219
+ comment=airflow_task_instance.note,
220
+ )
221
+
222
+
223
+ class JobParamInfo(BaseModel):
224
+ """Model for job parameter info from AWS Parameter Store"""
225
+
226
+ name: Optional[str]
227
+ last_modified: Optional[datetime]
228
+ job_type: str
229
+ task_id: str
230
+ modality: Optional[str]
231
+
232
+ @classmethod
233
+ def from_aws_describe_parameter(
234
+ cls,
235
+ parameter: ParameterMetadataTypeDef,
236
+ job_type: str,
237
+ task_id: str,
238
+ modality: Optional[str],
239
+ ):
240
+ """Map the parameter to the model"""
241
+ return cls(
242
+ name=parameter.get("Name"),
243
+ last_modified=parameter.get("LastModifiedDate"),
244
+ job_type=job_type,
245
+ task_id=task_id,
246
+ modality=modality,
247
+ )
248
+
249
+ @staticmethod
250
+ def get_parameter_prefix(version: Optional[str] = None) -> str:
251
+ """Get the prefix for job_type parameters"""
252
+ prefix = os.getenv("AIND_AIRFLOW_PARAM_PREFIX")
253
+ if version is None:
254
+ return prefix
255
+ return f"{prefix}/{version}"
256
+
257
+ @staticmethod
258
+ def get_parameter_regex(version: Optional[str] = None) -> str:
259
+ """Create the regex pattern to match the parameter name"""
260
+ prefix = os.getenv("AIND_AIRFLOW_PARAM_PREFIX")
261
+ regex = (
262
+ "(?P<job_type>[^/]+)/tasks/(?P<task_id>[^/]+)"
263
+ "(?:/(?P<modality>[^/]+))?"
264
+ )
265
+ if version is None:
266
+ return f"{prefix}/{regex}"
267
+ return f"{prefix}/{version}/{regex}"
268
+
269
+ @staticmethod
270
+ def get_parameter_name(
271
+ job_type: str, task_id: str, version: Optional[str] = None
272
+ ) -> str:
273
+ """Create the parameter name from job_type and task_id"""
274
+ prefix = os.getenv("AIND_AIRFLOW_PARAM_PREFIX")
275
+ if version is None:
276
+ return f"{prefix}/{job_type}/tasks/{task_id}"
277
+ return f"{prefix}/{version}/{job_type}/tasks/{task_id}"