aind-data-transfer-service 1.17.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aind-data-transfer-service might be problematic. Click here for more details.
- aind_data_transfer_service/__init__.py +2 -1
- aind_data_transfer_service/configs/csv_handler.py +10 -5
- aind_data_transfer_service/configs/job_upload_template.py +2 -1
- aind_data_transfer_service/configs/platforms_v1.py +177 -0
- aind_data_transfer_service/log_handler.py +3 -3
- aind_data_transfer_service/models/core.py +25 -4
- aind_data_transfer_service/server.py +225 -487
- {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/METADATA +4 -6
- aind_data_transfer_service-1.18.0.dist-info/RECORD +15 -0
- aind_data_transfer_service/configs/job_configs.py +0 -545
- aind_data_transfer_service/hpc/__init__.py +0 -1
- aind_data_transfer_service/hpc/client.py +0 -151
- aind_data_transfer_service/hpc/models.py +0 -492
- aind_data_transfer_service/templates/admin.html +0 -45
- aind_data_transfer_service/templates/index.html +0 -258
- aind_data_transfer_service/templates/job_params.html +0 -405
- aind_data_transfer_service/templates/job_status.html +0 -324
- aind_data_transfer_service/templates/job_tasks_table.html +0 -146
- aind_data_transfer_service/templates/task_logs.html +0 -31
- aind_data_transfer_service-1.17.0.dist-info/RECORD +0 -24
- {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/WHEEL +0 -0
- {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/licenses/LICENSE +0 -0
- {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -5,54 +5,35 @@ import io
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
7
|
import re
|
|
8
|
-
from asyncio import gather
|
|
9
|
-
from pathlib import PurePosixPath
|
|
8
|
+
from asyncio import gather
|
|
10
9
|
from typing import Any, List, Optional, Union
|
|
11
10
|
|
|
12
11
|
import boto3
|
|
13
|
-
import requests
|
|
14
|
-
from aind_data_transfer_models import (
|
|
15
|
-
__version__ as aind_data_transfer_models_version,
|
|
16
|
-
)
|
|
17
|
-
from aind_data_transfer_models.core import SubmitJobRequest, validation_context
|
|
18
12
|
from authlib.integrations.starlette_client import OAuth
|
|
19
13
|
from botocore.exceptions import ClientError
|
|
20
14
|
from fastapi import Request
|
|
21
|
-
from fastapi.responses import JSONResponse, StreamingResponse
|
|
15
|
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
22
16
|
from fastapi.templating import Jinja2Templates
|
|
23
|
-
from httpx import AsyncClient
|
|
17
|
+
from httpx import AsyncClient, RequestError
|
|
24
18
|
from openpyxl import load_workbook
|
|
25
|
-
from pydantic import
|
|
19
|
+
from pydantic import ValidationError
|
|
26
20
|
from starlette.applications import Starlette
|
|
27
21
|
from starlette.config import Config
|
|
28
22
|
from starlette.middleware.sessions import SessionMiddleware
|
|
29
23
|
from starlette.responses import RedirectResponse
|
|
30
24
|
from starlette.routing import Route
|
|
31
25
|
|
|
32
|
-
from aind_data_transfer_service import (
|
|
33
|
-
OPEN_DATA_BUCKET_NAME,
|
|
34
|
-
)
|
|
35
26
|
from aind_data_transfer_service import (
|
|
36
27
|
__version__ as aind_data_transfer_service_version,
|
|
37
28
|
)
|
|
38
29
|
from aind_data_transfer_service.configs.csv_handler import map_csv_row_to_job
|
|
39
|
-
from aind_data_transfer_service.configs.job_configs import (
|
|
40
|
-
BasicUploadJobConfigs as LegacyBasicUploadJobConfigs,
|
|
41
|
-
)
|
|
42
|
-
from aind_data_transfer_service.configs.job_configs import (
|
|
43
|
-
HpcJobConfigs,
|
|
44
|
-
)
|
|
45
30
|
from aind_data_transfer_service.configs.job_upload_template import (
|
|
46
31
|
JobUploadTemplate,
|
|
47
32
|
)
|
|
48
|
-
from aind_data_transfer_service.hpc.client import HpcClient, HpcClientConfigs
|
|
49
|
-
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
|
|
50
33
|
from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
|
|
51
34
|
from aind_data_transfer_service.models.core import (
|
|
52
35
|
SubmitJobRequestV2,
|
|
53
|
-
|
|
54
|
-
from aind_data_transfer_service.models.core import (
|
|
55
|
-
validation_context as validation_context_v2,
|
|
36
|
+
validation_context,
|
|
56
37
|
)
|
|
57
38
|
from aind_data_transfer_service.models.internal import (
|
|
58
39
|
AirflowDagRunsRequestParameters,
|
|
@@ -91,17 +72,184 @@ templates = Jinja2Templates(directory=template_directory)
|
|
|
91
72
|
# LOKI_URI
|
|
92
73
|
# ENV_NAME
|
|
93
74
|
# LOG_LEVEL
|
|
75
|
+
# AIND_DATA_TRANSFER_SERVICE_V1_URL
|
|
94
76
|
|
|
95
77
|
logger = get_logger(log_configs=LoggingConfigs())
|
|
96
78
|
project_names_url = os.getenv("AIND_METADATA_SERVICE_PROJECT_NAMES_URL")
|
|
79
|
+
aind_dts_v1_url = os.getenv(
|
|
80
|
+
"AIND_DATA_TRANSFER_SERVICE_V1_URL",
|
|
81
|
+
"http://aind-data-transfer-service-v1:5000",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def proxy(
|
|
86
|
+
request: Request,
|
|
87
|
+
path: str,
|
|
88
|
+
async_client: AsyncClient,
|
|
89
|
+
) -> Response:
|
|
90
|
+
"""
|
|
91
|
+
Proxy request to v1 aind-metadata-service-server
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
request : Request
|
|
95
|
+
path : str
|
|
96
|
+
async_client : AsyncClient
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
Response
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# Prepare headers to forward (excluding hop-by-hop headers)
|
|
105
|
+
headers = {
|
|
106
|
+
key: value
|
|
107
|
+
for key, value in request.headers.items()
|
|
108
|
+
if key.lower()
|
|
109
|
+
not in [
|
|
110
|
+
"host",
|
|
111
|
+
"connection",
|
|
112
|
+
"keep-alive",
|
|
113
|
+
"proxy-authenticate",
|
|
114
|
+
"proxy-authorization",
|
|
115
|
+
"te",
|
|
116
|
+
"trailers",
|
|
117
|
+
"transfer-encoding",
|
|
118
|
+
"upgrade",
|
|
119
|
+
]
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
body = await request.body()
|
|
124
|
+
backend_response = await async_client.request(
|
|
125
|
+
method=request.method,
|
|
126
|
+
url=path,
|
|
127
|
+
headers=headers,
|
|
128
|
+
content=body,
|
|
129
|
+
timeout=120, # Adjust timeout as needed
|
|
130
|
+
)
|
|
131
|
+
# Create a FastAPI Response from the backend's response
|
|
132
|
+
response_headers = {
|
|
133
|
+
key: value
|
|
134
|
+
for key, value in backend_response.headers.items()
|
|
135
|
+
if key.lower() not in ["content-encoding", "content-length"]
|
|
136
|
+
}
|
|
137
|
+
return Response(
|
|
138
|
+
content=backend_response.content,
|
|
139
|
+
status_code=backend_response.status_code,
|
|
140
|
+
headers=response_headers,
|
|
141
|
+
media_type=backend_response.headers.get("content-type"),
|
|
142
|
+
)
|
|
143
|
+
except RequestError as exc:
|
|
144
|
+
return Response(f"Proxy request failed: {exc}", status_code=500)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def submit_basic_jobs(
|
|
148
|
+
request: Request,
|
|
149
|
+
):
|
|
150
|
+
"""submit_basic_jobs_legacy"""
|
|
151
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
152
|
+
return await proxy(request, "/api/submit_basic_jobs", session)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
async def submit_jobs(
|
|
156
|
+
request: Request,
|
|
157
|
+
):
|
|
158
|
+
"""submit_basic_jobs_legacy"""
|
|
159
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
160
|
+
return await proxy(request, "/api/v1/submit_jobs", session)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def submit_hpc_jobs(
|
|
164
|
+
request: Request,
|
|
165
|
+
):
|
|
166
|
+
"""submit_hpc_jobs_legacy"""
|
|
167
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
168
|
+
return await proxy(request, "/api/submit_hpc_jobs", session)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def validate_json(request: Request):
|
|
172
|
+
"""validate_json_legacy"""
|
|
173
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
174
|
+
return await proxy(request, "/api/v1/validate_json", session)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def validate_csv(request: Request):
|
|
178
|
+
"""Validate a csv or xlsx file. Return parsed contents as json."""
|
|
179
|
+
logger.info("Received request to validate csv")
|
|
180
|
+
async with request.form() as form:
|
|
181
|
+
basic_jobs = []
|
|
182
|
+
errors = []
|
|
183
|
+
if not form["file"].filename.endswith((".csv", ".xlsx")):
|
|
184
|
+
errors.append("Invalid input file type")
|
|
185
|
+
else:
|
|
186
|
+
content = await form["file"].read()
|
|
187
|
+
if form["file"].filename.endswith(".csv"):
|
|
188
|
+
# A few csv files created from excel have extra unicode
|
|
189
|
+
# byte chars. Adding "utf-8-sig" should remove them.
|
|
190
|
+
data = content.decode("utf-8-sig")
|
|
191
|
+
else:
|
|
192
|
+
xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
|
|
193
|
+
xlsx_sheet = xlsx_book.active
|
|
194
|
+
csv_io = io.StringIO()
|
|
195
|
+
csv_writer = csv.writer(csv_io)
|
|
196
|
+
for r in xlsx_sheet.iter_rows(values_only=True):
|
|
197
|
+
if any(r):
|
|
198
|
+
csv_writer.writerow(r)
|
|
199
|
+
xlsx_book.close()
|
|
200
|
+
data = csv_io.getvalue()
|
|
201
|
+
csv_reader = csv.DictReader(io.StringIO(data))
|
|
202
|
+
params = AirflowDagRunsRequestParameters(
|
|
203
|
+
dag_ids=["transform_and_upload_v2", "run_list_of_jobs"],
|
|
204
|
+
states=["running", "queued"],
|
|
205
|
+
)
|
|
206
|
+
_, current_jobs = await get_airflow_jobs(
|
|
207
|
+
params=params, get_confs=True
|
|
208
|
+
)
|
|
209
|
+
context = {
|
|
210
|
+
"job_types": get_job_types("v2"),
|
|
211
|
+
"project_names": await get_project_names(),
|
|
212
|
+
"current_jobs": current_jobs,
|
|
213
|
+
}
|
|
214
|
+
for row in csv_reader:
|
|
215
|
+
if not any(row.values()):
|
|
216
|
+
continue
|
|
217
|
+
try:
|
|
218
|
+
with validation_context(context):
|
|
219
|
+
job = map_csv_row_to_job(row=row)
|
|
220
|
+
# Construct hpc job setting most of the vars from the env
|
|
221
|
+
basic_jobs.append(
|
|
222
|
+
json.loads(
|
|
223
|
+
job.model_dump_json(
|
|
224
|
+
round_trip=True,
|
|
225
|
+
exclude_none=True,
|
|
226
|
+
warnings=False,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
except ValidationError as e:
|
|
231
|
+
errors.append(e.json())
|
|
232
|
+
except Exception as e:
|
|
233
|
+
errors.append(f"{str(e.args)}")
|
|
234
|
+
message = "There were errors" if len(errors) > 0 else "Valid Data"
|
|
235
|
+
status_code = 406 if len(errors) > 0 else 200
|
|
236
|
+
content = {
|
|
237
|
+
"message": message,
|
|
238
|
+
"data": {"jobs": basic_jobs, "errors": errors},
|
|
239
|
+
}
|
|
240
|
+
return JSONResponse(
|
|
241
|
+
content=content,
|
|
242
|
+
status_code=status_code,
|
|
243
|
+
)
|
|
97
244
|
|
|
98
245
|
|
|
99
|
-
def get_project_names() -> List[str]:
|
|
246
|
+
async def get_project_names() -> List[str]:
|
|
100
247
|
"""Get a list of project_names"""
|
|
101
248
|
# TODO: Cache response for 5 minutes
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
249
|
+
async with AsyncClient() as async_client:
|
|
250
|
+
response = await async_client.get(project_names_url)
|
|
251
|
+
response.raise_for_status()
|
|
252
|
+
project_names = response.json()["data"]
|
|
105
253
|
return project_names
|
|
106
254
|
|
|
107
255
|
|
|
@@ -247,121 +395,6 @@ async def get_airflow_jobs(
|
|
|
247
395
|
return (total_entries, jobs_list)
|
|
248
396
|
|
|
249
397
|
|
|
250
|
-
async def validate_csv(request: Request):
|
|
251
|
-
"""Validate a csv or xlsx file. Return parsed contents as json."""
|
|
252
|
-
logger.info("Received request to validate csv")
|
|
253
|
-
async with request.form() as form:
|
|
254
|
-
basic_jobs = []
|
|
255
|
-
errors = []
|
|
256
|
-
if not form["file"].filename.endswith((".csv", ".xlsx")):
|
|
257
|
-
errors.append("Invalid input file type")
|
|
258
|
-
else:
|
|
259
|
-
content = await form["file"].read()
|
|
260
|
-
if form["file"].filename.endswith(".csv"):
|
|
261
|
-
# A few csv files created from excel have extra unicode
|
|
262
|
-
# byte chars. Adding "utf-8-sig" should remove them.
|
|
263
|
-
data = content.decode("utf-8-sig")
|
|
264
|
-
else:
|
|
265
|
-
xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
|
|
266
|
-
xlsx_sheet = xlsx_book.active
|
|
267
|
-
csv_io = io.StringIO()
|
|
268
|
-
csv_writer = csv.writer(csv_io)
|
|
269
|
-
for r in xlsx_sheet.iter_rows(values_only=True):
|
|
270
|
-
if any(r):
|
|
271
|
-
csv_writer.writerow(r)
|
|
272
|
-
xlsx_book.close()
|
|
273
|
-
data = csv_io.getvalue()
|
|
274
|
-
csv_reader = csv.DictReader(io.StringIO(data))
|
|
275
|
-
params = AirflowDagRunsRequestParameters(
|
|
276
|
-
dag_ids=["transform_and_upload_v2", "run_list_of_jobs"],
|
|
277
|
-
states=["running", "queued"],
|
|
278
|
-
)
|
|
279
|
-
_, current_jobs = await get_airflow_jobs(
|
|
280
|
-
params=params, get_confs=True
|
|
281
|
-
)
|
|
282
|
-
context = {
|
|
283
|
-
"job_types": get_job_types("v2"),
|
|
284
|
-
"project_names": get_project_names(),
|
|
285
|
-
"current_jobs": current_jobs,
|
|
286
|
-
}
|
|
287
|
-
for row in csv_reader:
|
|
288
|
-
if not any(row.values()):
|
|
289
|
-
continue
|
|
290
|
-
try:
|
|
291
|
-
with validation_context_v2(context):
|
|
292
|
-
job = map_csv_row_to_job(row=row)
|
|
293
|
-
# Construct hpc job setting most of the vars from the env
|
|
294
|
-
basic_jobs.append(
|
|
295
|
-
json.loads(
|
|
296
|
-
job.model_dump_json(
|
|
297
|
-
round_trip=True,
|
|
298
|
-
exclude_none=True,
|
|
299
|
-
warnings=False,
|
|
300
|
-
)
|
|
301
|
-
)
|
|
302
|
-
)
|
|
303
|
-
except ValidationError as e:
|
|
304
|
-
errors.append(e.json())
|
|
305
|
-
except Exception as e:
|
|
306
|
-
errors.append(f"{str(e.args)}")
|
|
307
|
-
message = "There were errors" if len(errors) > 0 else "Valid Data"
|
|
308
|
-
status_code = 406 if len(errors) > 0 else 200
|
|
309
|
-
content = {
|
|
310
|
-
"message": message,
|
|
311
|
-
"data": {"jobs": basic_jobs, "errors": errors},
|
|
312
|
-
}
|
|
313
|
-
return JSONResponse(
|
|
314
|
-
content=content,
|
|
315
|
-
status_code=status_code,
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
# TODO: Deprecate this endpoint
|
|
320
|
-
async def validate_csv_legacy(request: Request):
|
|
321
|
-
"""Validate a csv or xlsx file. Return parsed contents as json."""
|
|
322
|
-
async with request.form() as form:
|
|
323
|
-
basic_jobs = []
|
|
324
|
-
errors = []
|
|
325
|
-
if not form["file"].filename.endswith((".csv", ".xlsx")):
|
|
326
|
-
errors.append("Invalid input file type")
|
|
327
|
-
else:
|
|
328
|
-
content = await form["file"].read()
|
|
329
|
-
if form["file"].filename.endswith(".csv"):
|
|
330
|
-
# A few csv files created from excel have extra unicode
|
|
331
|
-
# byte chars. Adding "utf-8-sig" should remove them.
|
|
332
|
-
data = content.decode("utf-8-sig")
|
|
333
|
-
else:
|
|
334
|
-
xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
|
|
335
|
-
xlsx_sheet = xlsx_book.active
|
|
336
|
-
csv_io = io.StringIO()
|
|
337
|
-
csv_writer = csv.writer(csv_io)
|
|
338
|
-
for r in xlsx_sheet.iter_rows(values_only=True):
|
|
339
|
-
if any(r):
|
|
340
|
-
csv_writer.writerow(r)
|
|
341
|
-
xlsx_book.close()
|
|
342
|
-
data = csv_io.getvalue()
|
|
343
|
-
csv_reader = csv.DictReader(io.StringIO(data))
|
|
344
|
-
for row in csv_reader:
|
|
345
|
-
if not any(row.values()):
|
|
346
|
-
continue
|
|
347
|
-
try:
|
|
348
|
-
job = LegacyBasicUploadJobConfigs.from_csv_row(row=row)
|
|
349
|
-
# Construct hpc job setting most of the vars from the env
|
|
350
|
-
basic_jobs.append(job.model_dump_json())
|
|
351
|
-
except Exception as e:
|
|
352
|
-
errors.append(f"{e.__class__.__name__}{e.args}")
|
|
353
|
-
message = "There were errors" if len(errors) > 0 else "Valid Data"
|
|
354
|
-
status_code = 406 if len(errors) > 0 else 200
|
|
355
|
-
content = {
|
|
356
|
-
"message": message,
|
|
357
|
-
"data": {"jobs": basic_jobs, "errors": errors},
|
|
358
|
-
}
|
|
359
|
-
return JSONResponse(
|
|
360
|
-
content=content,
|
|
361
|
-
status_code=status_code,
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
|
|
365
398
|
async def validate_json_v2(request: Request):
|
|
366
399
|
"""Validate raw json against data transfer models. Returns validated
|
|
367
400
|
json or errors if request is invalid."""
|
|
@@ -375,10 +408,10 @@ async def validate_json_v2(request: Request):
|
|
|
375
408
|
_, current_jobs = await get_airflow_jobs(params=params, get_confs=True)
|
|
376
409
|
context = {
|
|
377
410
|
"job_types": get_job_types("v2"),
|
|
378
|
-
"project_names": get_project_names(),
|
|
411
|
+
"project_names": await get_project_names(),
|
|
379
412
|
"current_jobs": current_jobs,
|
|
380
413
|
}
|
|
381
|
-
with
|
|
414
|
+
with validation_context(context):
|
|
382
415
|
validated_model = SubmitJobRequestV2.model_validate_json(
|
|
383
416
|
json.dumps(content)
|
|
384
417
|
)
|
|
@@ -425,60 +458,6 @@ async def validate_json_v2(request: Request):
|
|
|
425
458
|
)
|
|
426
459
|
|
|
427
460
|
|
|
428
|
-
async def validate_json(request: Request):
|
|
429
|
-
"""Validate raw json against aind-data-transfer-models. Returns validated
|
|
430
|
-
json or errors if request is invalid."""
|
|
431
|
-
logger.info("Received request to validate json")
|
|
432
|
-
content = await request.json()
|
|
433
|
-
try:
|
|
434
|
-
project_names = get_project_names()
|
|
435
|
-
with validation_context({"project_names": project_names}):
|
|
436
|
-
validated_model = SubmitJobRequest.model_validate_json(
|
|
437
|
-
json.dumps(content)
|
|
438
|
-
)
|
|
439
|
-
validated_content = json.loads(
|
|
440
|
-
validated_model.model_dump_json(warnings=False, exclude_none=True)
|
|
441
|
-
)
|
|
442
|
-
logger.info("Valid model detected")
|
|
443
|
-
return JSONResponse(
|
|
444
|
-
status_code=200,
|
|
445
|
-
content={
|
|
446
|
-
"message": "Valid model",
|
|
447
|
-
"data": {
|
|
448
|
-
"version": aind_data_transfer_models_version,
|
|
449
|
-
"model_json": content,
|
|
450
|
-
"validated_model_json": validated_content,
|
|
451
|
-
},
|
|
452
|
-
},
|
|
453
|
-
)
|
|
454
|
-
except ValidationError as e:
|
|
455
|
-
logger.warning(f"There were validation errors processing {content}")
|
|
456
|
-
return JSONResponse(
|
|
457
|
-
status_code=406,
|
|
458
|
-
content={
|
|
459
|
-
"message": "There were validation errors",
|
|
460
|
-
"data": {
|
|
461
|
-
"version": aind_data_transfer_models_version,
|
|
462
|
-
"model_json": content,
|
|
463
|
-
"errors": e.json(),
|
|
464
|
-
},
|
|
465
|
-
},
|
|
466
|
-
)
|
|
467
|
-
except Exception as e:
|
|
468
|
-
logger.exception("Internal Server Error.")
|
|
469
|
-
return JSONResponse(
|
|
470
|
-
status_code=500,
|
|
471
|
-
content={
|
|
472
|
-
"message": "There was an internal server error",
|
|
473
|
-
"data": {
|
|
474
|
-
"version": aind_data_transfer_models_version,
|
|
475
|
-
"model_json": content,
|
|
476
|
-
"errors": str(e.args),
|
|
477
|
-
},
|
|
478
|
-
},
|
|
479
|
-
)
|
|
480
|
-
|
|
481
|
-
|
|
482
461
|
async def submit_jobs_v2(request: Request):
|
|
483
462
|
"""Post SubmitJobRequestV2 raw json to hpc server to process."""
|
|
484
463
|
logger.info("Received request to submit jobs v2")
|
|
@@ -491,15 +470,14 @@ async def submit_jobs_v2(request: Request):
|
|
|
491
470
|
_, current_jobs = await get_airflow_jobs(params=params, get_confs=True)
|
|
492
471
|
context = {
|
|
493
472
|
"job_types": get_job_types("v2"),
|
|
494
|
-
"project_names": get_project_names(),
|
|
473
|
+
"project_names": await get_project_names(),
|
|
495
474
|
"current_jobs": current_jobs,
|
|
496
475
|
}
|
|
497
|
-
with
|
|
476
|
+
with validation_context(context):
|
|
498
477
|
model = SubmitJobRequestV2.model_validate_json(json.dumps(content))
|
|
499
478
|
full_content = json.loads(
|
|
500
479
|
model.model_dump_json(warnings=False, exclude_none=True)
|
|
501
480
|
)
|
|
502
|
-
# TODO: Replace with httpx async client
|
|
503
481
|
logger.info(
|
|
504
482
|
f"Valid request detected. Sending list of jobs. "
|
|
505
483
|
f"dag_id: {model.dag_id}"
|
|
@@ -511,80 +489,25 @@ async def submit_jobs_v2(request: Request):
|
|
|
511
489
|
f"{job_index} of {total_jobs}."
|
|
512
490
|
)
|
|
513
491
|
|
|
514
|
-
|
|
515
|
-
url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
|
|
492
|
+
async with AsyncClient(
|
|
516
493
|
auth=(
|
|
517
494
|
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
|
|
518
495
|
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
|
|
519
|
-
),
|
|
520
|
-
json={"conf": full_content},
|
|
521
|
-
)
|
|
522
|
-
return JSONResponse(
|
|
523
|
-
status_code=response.status_code,
|
|
524
|
-
content={
|
|
525
|
-
"message": "Submitted request to airflow",
|
|
526
|
-
"data": {"responses": [response.json()], "errors": []},
|
|
527
|
-
},
|
|
528
|
-
)
|
|
529
|
-
except ValidationError as e:
|
|
530
|
-
logger.warning(f"There were validation errors processing {content}")
|
|
531
|
-
return JSONResponse(
|
|
532
|
-
status_code=406,
|
|
533
|
-
content={
|
|
534
|
-
"message": "There were validation errors",
|
|
535
|
-
"data": {"responses": [], "errors": e.json()},
|
|
536
|
-
},
|
|
537
|
-
)
|
|
538
|
-
except Exception as e:
|
|
539
|
-
logger.exception("Internal Server Error.")
|
|
540
|
-
return JSONResponse(
|
|
541
|
-
status_code=500,
|
|
542
|
-
content={
|
|
543
|
-
"message": "There was an internal server error",
|
|
544
|
-
"data": {"responses": [], "errors": str(e.args)},
|
|
545
|
-
},
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
async def submit_jobs(request: Request):
|
|
550
|
-
"""Post BasicJobConfigs raw json to hpc server to process."""
|
|
551
|
-
logger.info("Received request to submit jobs")
|
|
552
|
-
content = await request.json()
|
|
553
|
-
try:
|
|
554
|
-
project_names = get_project_names()
|
|
555
|
-
with validation_context({"project_names": project_names}):
|
|
556
|
-
model = SubmitJobRequest.model_validate_json(json.dumps(content))
|
|
557
|
-
full_content = json.loads(
|
|
558
|
-
model.model_dump_json(warnings=False, exclude_none=True)
|
|
559
|
-
)
|
|
560
|
-
# TODO: Replace with httpx async client
|
|
561
|
-
logger.info(
|
|
562
|
-
f"Valid request detected. Sending list of jobs. "
|
|
563
|
-
f"Job Type: {model.job_type}"
|
|
564
|
-
)
|
|
565
|
-
total_jobs = len(model.upload_jobs)
|
|
566
|
-
for job_index, job in enumerate(model.upload_jobs, 1):
|
|
567
|
-
logger.info(
|
|
568
|
-
f"{job.s3_prefix} sending to airflow. "
|
|
569
|
-
f"{job_index} of {total_jobs}."
|
|
570
496
|
)
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
)
|
|
578
|
-
json={"conf": full_content},
|
|
579
|
-
)
|
|
497
|
+
) as async_client:
|
|
498
|
+
response = await async_client.post(
|
|
499
|
+
url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
|
|
500
|
+
json={"conf": full_content},
|
|
501
|
+
)
|
|
502
|
+
status_code = response.status_code
|
|
503
|
+
response_json = response.json()
|
|
580
504
|
return JSONResponse(
|
|
581
|
-
status_code=
|
|
505
|
+
status_code=status_code,
|
|
582
506
|
content={
|
|
583
507
|
"message": "Submitted request to airflow",
|
|
584
|
-
"data": {"responses": [
|
|
508
|
+
"data": {"responses": [response_json], "errors": []},
|
|
585
509
|
},
|
|
586
510
|
)
|
|
587
|
-
|
|
588
511
|
except ValidationError as e:
|
|
589
512
|
logger.warning(f"There were validation errors processing {content}")
|
|
590
513
|
return JSONResponse(
|
|
@@ -605,196 +528,6 @@ async def submit_jobs(request: Request):
|
|
|
605
528
|
)
|
|
606
529
|
|
|
607
530
|
|
|
608
|
-
# TODO: Deprecate this endpoint
|
|
609
|
-
async def submit_basic_jobs(request: Request):
|
|
610
|
-
"""Post BasicJobConfigs raw json to hpc server to process."""
|
|
611
|
-
content = await request.json()
|
|
612
|
-
hpc_client_conf = HpcClientConfigs()
|
|
613
|
-
hpc_client = HpcClient(configs=hpc_client_conf)
|
|
614
|
-
basic_jobs = content["jobs"]
|
|
615
|
-
hpc_jobs = []
|
|
616
|
-
parsing_errors = []
|
|
617
|
-
for job in basic_jobs:
|
|
618
|
-
try:
|
|
619
|
-
basic_upload_job = LegacyBasicUploadJobConfigs.model_validate_json(
|
|
620
|
-
job
|
|
621
|
-
)
|
|
622
|
-
# Add aws_param_store_name and temp_dir
|
|
623
|
-
basic_upload_job.aws_param_store_name = os.getenv(
|
|
624
|
-
"HPC_AWS_PARAM_STORE_NAME"
|
|
625
|
-
)
|
|
626
|
-
basic_upload_job.temp_directory = os.getenv(
|
|
627
|
-
"HPC_STAGING_DIRECTORY"
|
|
628
|
-
)
|
|
629
|
-
hpc_job = HpcJobConfigs(basic_upload_job_configs=basic_upload_job)
|
|
630
|
-
hpc_jobs.append(hpc_job)
|
|
631
|
-
except Exception as e:
|
|
632
|
-
parsing_errors.append(
|
|
633
|
-
f"Error parsing {job}: {e.__class__.__name__}"
|
|
634
|
-
)
|
|
635
|
-
if parsing_errors:
|
|
636
|
-
status_code = 406
|
|
637
|
-
message = "There were errors parsing the basic job configs"
|
|
638
|
-
content = {
|
|
639
|
-
"message": message,
|
|
640
|
-
"data": {"responses": [], "errors": parsing_errors},
|
|
641
|
-
}
|
|
642
|
-
else:
|
|
643
|
-
responses = []
|
|
644
|
-
hpc_errors = []
|
|
645
|
-
for hpc_job in hpc_jobs:
|
|
646
|
-
try:
|
|
647
|
-
job_def = hpc_job.job_definition
|
|
648
|
-
response = hpc_client.submit_job(job_def)
|
|
649
|
-
response_json = response.json()
|
|
650
|
-
responses.append(response_json)
|
|
651
|
-
# Add pause to stagger job requests to the hpc
|
|
652
|
-
await sleep(0.2)
|
|
653
|
-
except Exception as e:
|
|
654
|
-
logger.error(f"{e.__class__.__name__}{e.args}")
|
|
655
|
-
hpc_errors.append(
|
|
656
|
-
f"Error processing "
|
|
657
|
-
f"{hpc_job.basic_upload_job_configs.s3_prefix}"
|
|
658
|
-
)
|
|
659
|
-
message = (
|
|
660
|
-
"There were errors submitting jobs to the hpc."
|
|
661
|
-
if len(hpc_errors) > 0
|
|
662
|
-
else "Submitted Jobs."
|
|
663
|
-
)
|
|
664
|
-
status_code = 500 if len(hpc_errors) > 0 else 200
|
|
665
|
-
content = {
|
|
666
|
-
"message": message,
|
|
667
|
-
"data": {"responses": responses, "errors": hpc_errors},
|
|
668
|
-
}
|
|
669
|
-
return JSONResponse(
|
|
670
|
-
content=content,
|
|
671
|
-
status_code=status_code,
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
# TODO: Deprecate this endpoint
|
|
676
|
-
async def submit_hpc_jobs(request: Request): # noqa: C901
|
|
677
|
-
"""Post HpcJobSubmitSettings to hpc server to process."""
|
|
678
|
-
|
|
679
|
-
content = await request.json()
|
|
680
|
-
# content should have
|
|
681
|
-
# {
|
|
682
|
-
# "jobs": [{"hpc_settings": str, upload_job_settings: str, script: str}]
|
|
683
|
-
# }
|
|
684
|
-
hpc_client_conf = HpcClientConfigs()
|
|
685
|
-
hpc_client = HpcClient(configs=hpc_client_conf)
|
|
686
|
-
job_configs = content["jobs"]
|
|
687
|
-
hpc_jobs = []
|
|
688
|
-
parsing_errors = []
|
|
689
|
-
for job in job_configs:
|
|
690
|
-
try:
|
|
691
|
-
base_script = job.get("script")
|
|
692
|
-
# If script is empty, assume that the job type is a basic job
|
|
693
|
-
basic_job_name = None
|
|
694
|
-
if base_script is None or base_script == "":
|
|
695
|
-
base_script = HpcJobSubmitSettings.script_command_str(
|
|
696
|
-
sif_loc_str=os.getenv("HPC_SIF_LOCATION")
|
|
697
|
-
)
|
|
698
|
-
basic_job_name = (
|
|
699
|
-
LegacyBasicUploadJobConfigs.model_validate_json(
|
|
700
|
-
job["upload_job_settings"]
|
|
701
|
-
).s3_prefix
|
|
702
|
-
)
|
|
703
|
-
upload_job_configs = json.loads(job["upload_job_settings"])
|
|
704
|
-
# This will set the bucket to the private data one
|
|
705
|
-
if upload_job_configs.get("s3_bucket") is not None:
|
|
706
|
-
upload_job_configs = json.loads(
|
|
707
|
-
LegacyBasicUploadJobConfigs.model_validate(
|
|
708
|
-
upload_job_configs
|
|
709
|
-
).model_dump_json()
|
|
710
|
-
)
|
|
711
|
-
# The aws creds to use are different for aind-open-data and
|
|
712
|
-
# everything else
|
|
713
|
-
if upload_job_configs.get("s3_bucket") == OPEN_DATA_BUCKET_NAME:
|
|
714
|
-
aws_secret_access_key = SecretStr(
|
|
715
|
-
os.getenv("OPEN_DATA_AWS_SECRET_ACCESS_KEY")
|
|
716
|
-
)
|
|
717
|
-
aws_access_key_id = os.getenv("OPEN_DATA_AWS_ACCESS_KEY_ID")
|
|
718
|
-
else:
|
|
719
|
-
aws_secret_access_key = SecretStr(
|
|
720
|
-
os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
|
|
721
|
-
)
|
|
722
|
-
aws_access_key_id = os.getenv("HPC_AWS_ACCESS_KEY_ID")
|
|
723
|
-
hpc_settings = json.loads(job["hpc_settings"])
|
|
724
|
-
if basic_job_name is not None:
|
|
725
|
-
hpc_settings["name"] = basic_job_name
|
|
726
|
-
hpc_job = HpcJobSubmitSettings.from_upload_job_configs(
|
|
727
|
-
logging_directory=PurePosixPath(
|
|
728
|
-
os.getenv("HPC_LOGGING_DIRECTORY")
|
|
729
|
-
),
|
|
730
|
-
aws_secret_access_key=aws_secret_access_key,
|
|
731
|
-
aws_access_key_id=aws_access_key_id,
|
|
732
|
-
aws_default_region=os.getenv("HPC_AWS_DEFAULT_REGION"),
|
|
733
|
-
aws_session_token=(
|
|
734
|
-
(
|
|
735
|
-
None
|
|
736
|
-
if os.getenv("HPC_AWS_SESSION_TOKEN") is None
|
|
737
|
-
else SecretStr(os.getenv("HPC_AWS_SESSION_TOKEN"))
|
|
738
|
-
)
|
|
739
|
-
),
|
|
740
|
-
**hpc_settings,
|
|
741
|
-
)
|
|
742
|
-
if not upload_job_configs:
|
|
743
|
-
script = base_script
|
|
744
|
-
else:
|
|
745
|
-
script = hpc_job.attach_configs_to_script(
|
|
746
|
-
script=base_script,
|
|
747
|
-
base_configs=upload_job_configs,
|
|
748
|
-
upload_configs_aws_param_store_name=os.getenv(
|
|
749
|
-
"HPC_AWS_PARAM_STORE_NAME"
|
|
750
|
-
),
|
|
751
|
-
staging_directory=os.getenv("HPC_STAGING_DIRECTORY"),
|
|
752
|
-
)
|
|
753
|
-
hpc_jobs.append((hpc_job, script))
|
|
754
|
-
except Exception as e:
|
|
755
|
-
parsing_errors.append(
|
|
756
|
-
f"Error parsing {job['upload_job_settings']}: {repr(e)}"
|
|
757
|
-
)
|
|
758
|
-
if parsing_errors:
|
|
759
|
-
status_code = 406
|
|
760
|
-
message = "There were errors parsing the job configs"
|
|
761
|
-
content = {
|
|
762
|
-
"message": message,
|
|
763
|
-
"data": {"responses": [], "errors": parsing_errors},
|
|
764
|
-
}
|
|
765
|
-
else:
|
|
766
|
-
responses = []
|
|
767
|
-
hpc_errors = []
|
|
768
|
-
for hpc_job in hpc_jobs:
|
|
769
|
-
hpc_job_def = hpc_job[0]
|
|
770
|
-
try:
|
|
771
|
-
script = hpc_job[1]
|
|
772
|
-
response = hpc_client.submit_hpc_job(
|
|
773
|
-
job=hpc_job_def, script=script
|
|
774
|
-
)
|
|
775
|
-
response_json = response.json()
|
|
776
|
-
responses.append(response_json)
|
|
777
|
-
# Add pause to stagger job requests to the hpc
|
|
778
|
-
await sleep(0.2)
|
|
779
|
-
except Exception as e:
|
|
780
|
-
logger.error(repr(e))
|
|
781
|
-
hpc_errors.append(f"Error processing " f"{hpc_job_def.name}")
|
|
782
|
-
message = (
|
|
783
|
-
"There were errors submitting jobs to the hpc."
|
|
784
|
-
if len(hpc_errors) > 0
|
|
785
|
-
else "Submitted Jobs."
|
|
786
|
-
)
|
|
787
|
-
status_code = 500 if len(hpc_errors) > 0 else 200
|
|
788
|
-
content = {
|
|
789
|
-
"message": message,
|
|
790
|
-
"data": {"responses": responses, "errors": hpc_errors},
|
|
791
|
-
}
|
|
792
|
-
return JSONResponse(
|
|
793
|
-
content=content,
|
|
794
|
-
status_code=status_code,
|
|
795
|
-
)
|
|
796
|
-
|
|
797
|
-
|
|
798
531
|
async def get_job_status_list(request: Request):
|
|
799
532
|
"""Get status of jobs using input query params."""
|
|
800
533
|
|
|
@@ -842,20 +575,23 @@ async def get_tasks_list(request: Request):
|
|
|
842
575
|
request.query_params
|
|
843
576
|
)
|
|
844
577
|
params_dict = json.loads(params.model_dump_json())
|
|
845
|
-
|
|
846
|
-
url=(
|
|
847
|
-
f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}/"
|
|
848
|
-
"taskInstances"
|
|
849
|
-
),
|
|
578
|
+
async with AsyncClient(
|
|
850
579
|
auth=(
|
|
851
580
|
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
|
|
852
581
|
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
|
|
853
|
-
)
|
|
854
|
-
)
|
|
855
|
-
|
|
856
|
-
|
|
582
|
+
)
|
|
583
|
+
) as async_client:
|
|
584
|
+
response_tasks = await async_client.get(
|
|
585
|
+
url=(
|
|
586
|
+
f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}/"
|
|
587
|
+
"taskInstances"
|
|
588
|
+
),
|
|
589
|
+
)
|
|
590
|
+
status_code = response_tasks.status_code
|
|
591
|
+
response_json = response_tasks.json()
|
|
592
|
+
if status_code == 200:
|
|
857
593
|
task_instances = AirflowTaskInstancesResponse.model_validate_json(
|
|
858
|
-
json.dumps(
|
|
594
|
+
json.dumps(response_json)
|
|
859
595
|
)
|
|
860
596
|
job_tasks_list = sorted(
|
|
861
597
|
[
|
|
@@ -876,7 +612,7 @@ async def get_tasks_list(request: Request):
|
|
|
876
612
|
message = "Error retrieving job tasks list from airflow"
|
|
877
613
|
data = {
|
|
878
614
|
"params": params_dict,
|
|
879
|
-
"errors": [
|
|
615
|
+
"errors": [response_json],
|
|
880
616
|
}
|
|
881
617
|
except ValidationError as e:
|
|
882
618
|
logger.warning(f"There was a validation error process task_list: {e}")
|
|
@@ -906,27 +642,29 @@ async def get_task_logs(request: Request):
|
|
|
906
642
|
)
|
|
907
643
|
params_dict = json.loads(params.model_dump_json())
|
|
908
644
|
params_full = dict(params)
|
|
909
|
-
|
|
910
|
-
url=(
|
|
911
|
-
f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}"
|
|
912
|
-
f"/taskInstances/{params.task_id}/logs/{params.try_number}"
|
|
913
|
-
),
|
|
645
|
+
async with AsyncClient(
|
|
914
646
|
auth=(
|
|
915
647
|
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
|
|
916
648
|
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
|
|
917
|
-
)
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
"
|
|
929
|
-
|
|
649
|
+
)
|
|
650
|
+
) as async_client:
|
|
651
|
+
response_logs = await async_client.get(
|
|
652
|
+
url=(
|
|
653
|
+
f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}"
|
|
654
|
+
f"/taskInstances/{params.task_id}/logs/{params.try_number}"
|
|
655
|
+
),
|
|
656
|
+
params=params_dict,
|
|
657
|
+
)
|
|
658
|
+
status_code = response_logs.status_code
|
|
659
|
+
if status_code == 200:
|
|
660
|
+
message = "Retrieved task logs from airflow"
|
|
661
|
+
data = {"params": params_full, "logs": response_logs.text}
|
|
662
|
+
else:
|
|
663
|
+
message = "Error retrieving task logs from airflow"
|
|
664
|
+
data = {
|
|
665
|
+
"params": params_full,
|
|
666
|
+
"errors": [response_logs.json()],
|
|
667
|
+
}
|
|
930
668
|
except ValidationError as e:
|
|
931
669
|
logger.warning(f"Error validating request parameters: {e}")
|
|
932
670
|
status_code = 406
|
|
@@ -1266,7 +1004,7 @@ async def auth(request: Request):
|
|
|
1266
1004
|
|
|
1267
1005
|
routes = [
|
|
1268
1006
|
Route("/", endpoint=index, methods=["GET", "POST"]),
|
|
1269
|
-
Route("/api/validate_csv", endpoint=
|
|
1007
|
+
Route("/api/validate_csv", endpoint=validate_csv, methods=["POST"]),
|
|
1270
1008
|
Route(
|
|
1271
1009
|
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
|
|
1272
1010
|
),
|