aind-data-transfer-service 1.17.2__py3-none-any.whl → 1.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aind-data-transfer-service might be problematic. Click here for more details.
- aind_data_transfer_service/__init__.py +2 -1
- aind_data_transfer_service/configs/csv_handler.py +10 -5
- aind_data_transfer_service/configs/job_upload_template.py +2 -1
- aind_data_transfer_service/configs/platforms_v1.py +177 -0
- aind_data_transfer_service/log_handler.py +3 -3
- aind_data_transfer_service/models/core.py +25 -4
- aind_data_transfer_service/server.py +174 -447
- {aind_data_transfer_service-1.17.2.dist-info → aind_data_transfer_service-1.18.1.dist-info}/METADATA +4 -6
- aind_data_transfer_service-1.18.1.dist-info/RECORD +15 -0
- aind_data_transfer_service/configs/job_configs.py +0 -545
- aind_data_transfer_service/hpc/__init__.py +0 -1
- aind_data_transfer_service/hpc/client.py +0 -154
- aind_data_transfer_service/hpc/models.py +0 -492
- aind_data_transfer_service-1.17.2.dist-info/RECORD +0 -18
- {aind_data_transfer_service-1.17.2.dist-info → aind_data_transfer_service-1.18.1.dist-info}/WHEEL +0 -0
- {aind_data_transfer_service-1.17.2.dist-info → aind_data_transfer_service-1.18.1.dist-info}/licenses/LICENSE +0 -0
- {aind_data_transfer_service-1.17.2.dist-info → aind_data_transfer_service-1.18.1.dist-info}/top_level.txt +0 -0
|
@@ -5,53 +5,35 @@ import io
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
7
|
import re
|
|
8
|
-
from asyncio import gather
|
|
9
|
-
from pathlib import PurePosixPath
|
|
8
|
+
from asyncio import gather
|
|
10
9
|
from typing import Any, List, Optional, Union
|
|
11
10
|
|
|
12
11
|
import boto3
|
|
13
|
-
from aind_data_transfer_models import (
|
|
14
|
-
__version__ as aind_data_transfer_models_version,
|
|
15
|
-
)
|
|
16
|
-
from aind_data_transfer_models.core import SubmitJobRequest, validation_context
|
|
17
12
|
from authlib.integrations.starlette_client import OAuth
|
|
18
13
|
from botocore.exceptions import ClientError
|
|
19
14
|
from fastapi import Request
|
|
20
|
-
from fastapi.responses import JSONResponse, StreamingResponse
|
|
15
|
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
21
16
|
from fastapi.templating import Jinja2Templates
|
|
22
|
-
from httpx import AsyncClient
|
|
17
|
+
from httpx import AsyncClient, RequestError
|
|
23
18
|
from openpyxl import load_workbook
|
|
24
|
-
from pydantic import
|
|
19
|
+
from pydantic import ValidationError
|
|
25
20
|
from starlette.applications import Starlette
|
|
26
21
|
from starlette.config import Config
|
|
27
22
|
from starlette.middleware.sessions import SessionMiddleware
|
|
28
23
|
from starlette.responses import RedirectResponse
|
|
29
24
|
from starlette.routing import Route
|
|
30
25
|
|
|
31
|
-
from aind_data_transfer_service import (
|
|
32
|
-
OPEN_DATA_BUCKET_NAME,
|
|
33
|
-
)
|
|
34
26
|
from aind_data_transfer_service import (
|
|
35
27
|
__version__ as aind_data_transfer_service_version,
|
|
36
28
|
)
|
|
37
29
|
from aind_data_transfer_service.configs.csv_handler import map_csv_row_to_job
|
|
38
|
-
from aind_data_transfer_service.configs.job_configs import (
|
|
39
|
-
BasicUploadJobConfigs as LegacyBasicUploadJobConfigs,
|
|
40
|
-
)
|
|
41
|
-
from aind_data_transfer_service.configs.job_configs import (
|
|
42
|
-
HpcJobConfigs,
|
|
43
|
-
)
|
|
44
30
|
from aind_data_transfer_service.configs.job_upload_template import (
|
|
45
31
|
JobUploadTemplate,
|
|
46
32
|
)
|
|
47
|
-
from aind_data_transfer_service.hpc.client import HpcClient, HpcClientConfigs
|
|
48
|
-
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
|
|
49
33
|
from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
|
|
50
34
|
from aind_data_transfer_service.models.core import (
|
|
51
35
|
SubmitJobRequestV2,
|
|
52
|
-
|
|
53
|
-
from aind_data_transfer_service.models.core import (
|
|
54
|
-
validation_context as validation_context_v2,
|
|
36
|
+
validation_context,
|
|
55
37
|
)
|
|
56
38
|
from aind_data_transfer_service.models.internal import (
|
|
57
39
|
AirflowDagRunsRequestParameters,
|
|
@@ -90,9 +72,175 @@ templates = Jinja2Templates(directory=template_directory)
|
|
|
90
72
|
# LOKI_URI
|
|
91
73
|
# ENV_NAME
|
|
92
74
|
# LOG_LEVEL
|
|
75
|
+
# AIND_DATA_TRANSFER_SERVICE_V1_URL
|
|
93
76
|
|
|
94
77
|
logger = get_logger(log_configs=LoggingConfigs())
|
|
95
78
|
project_names_url = os.getenv("AIND_METADATA_SERVICE_PROJECT_NAMES_URL")
|
|
79
|
+
aind_dts_v1_url = os.getenv(
|
|
80
|
+
"AIND_DATA_TRANSFER_SERVICE_V1_URL",
|
|
81
|
+
"http://aind-data-transfer-service-v1:5000",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def proxy(
|
|
86
|
+
request: Request,
|
|
87
|
+
path: str,
|
|
88
|
+
async_client: AsyncClient,
|
|
89
|
+
) -> Response:
|
|
90
|
+
"""
|
|
91
|
+
Proxy request to v1 aind-metadata-service-server
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
request : Request
|
|
95
|
+
path : str
|
|
96
|
+
async_client : AsyncClient
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
Response
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# Prepare headers to forward (excluding hop-by-hop headers)
|
|
105
|
+
headers = {
|
|
106
|
+
key: value
|
|
107
|
+
for key, value in request.headers.items()
|
|
108
|
+
if key.lower()
|
|
109
|
+
not in [
|
|
110
|
+
"host",
|
|
111
|
+
"connection",
|
|
112
|
+
"keep-alive",
|
|
113
|
+
"proxy-authenticate",
|
|
114
|
+
"proxy-authorization",
|
|
115
|
+
"te",
|
|
116
|
+
"trailers",
|
|
117
|
+
"transfer-encoding",
|
|
118
|
+
"upgrade",
|
|
119
|
+
]
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
body = await request.body()
|
|
124
|
+
backend_response = await async_client.request(
|
|
125
|
+
method=request.method,
|
|
126
|
+
url=path,
|
|
127
|
+
headers=headers,
|
|
128
|
+
content=body,
|
|
129
|
+
timeout=120, # Adjust timeout as needed
|
|
130
|
+
)
|
|
131
|
+
# Create a FastAPI Response from the backend's response
|
|
132
|
+
response_headers = {
|
|
133
|
+
key: value
|
|
134
|
+
for key, value in backend_response.headers.items()
|
|
135
|
+
if key.lower() not in ["content-encoding", "content-length"]
|
|
136
|
+
}
|
|
137
|
+
return Response(
|
|
138
|
+
content=backend_response.content,
|
|
139
|
+
status_code=backend_response.status_code,
|
|
140
|
+
headers=response_headers,
|
|
141
|
+
media_type=backend_response.headers.get("content-type"),
|
|
142
|
+
)
|
|
143
|
+
except RequestError as exc:
|
|
144
|
+
return Response(f"Proxy request failed: {exc}", status_code=500)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def submit_basic_jobs(
|
|
148
|
+
request: Request,
|
|
149
|
+
):
|
|
150
|
+
"""submit_basic_jobs_legacy"""
|
|
151
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
152
|
+
return await proxy(request, "/api/submit_basic_jobs", session)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
async def submit_jobs(
|
|
156
|
+
request: Request,
|
|
157
|
+
):
|
|
158
|
+
"""submit_basic_jobs_legacy"""
|
|
159
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
160
|
+
return await proxy(request, "/api/v1/submit_jobs", session)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def submit_hpc_jobs(
|
|
164
|
+
request: Request,
|
|
165
|
+
):
|
|
166
|
+
"""submit_hpc_jobs_legacy"""
|
|
167
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
168
|
+
return await proxy(request, "/api/submit_hpc_jobs", session)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def validate_json(request: Request):
|
|
172
|
+
"""validate_json_legacy"""
|
|
173
|
+
async with AsyncClient(base_url=aind_dts_v1_url) as session:
|
|
174
|
+
return await proxy(request, "/api/v1/validate_json", session)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def validate_csv(request: Request):
|
|
178
|
+
"""Validate a csv or xlsx file. Return parsed contents as json."""
|
|
179
|
+
logger.info("Received request to validate csv")
|
|
180
|
+
async with request.form() as form:
|
|
181
|
+
basic_jobs = []
|
|
182
|
+
errors = []
|
|
183
|
+
if not form["file"].filename.endswith((".csv", ".xlsx")):
|
|
184
|
+
errors.append("Invalid input file type")
|
|
185
|
+
else:
|
|
186
|
+
content = await form["file"].read()
|
|
187
|
+
if form["file"].filename.endswith(".csv"):
|
|
188
|
+
# A few csv files created from excel have extra unicode
|
|
189
|
+
# byte chars. Adding "utf-8-sig" should remove them.
|
|
190
|
+
data = content.decode("utf-8-sig")
|
|
191
|
+
else:
|
|
192
|
+
xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
|
|
193
|
+
xlsx_sheet = xlsx_book.active
|
|
194
|
+
csv_io = io.StringIO()
|
|
195
|
+
csv_writer = csv.writer(csv_io)
|
|
196
|
+
for r in xlsx_sheet.iter_rows(values_only=True):
|
|
197
|
+
if any(r):
|
|
198
|
+
csv_writer.writerow(r)
|
|
199
|
+
xlsx_book.close()
|
|
200
|
+
data = csv_io.getvalue()
|
|
201
|
+
csv_reader = csv.DictReader(io.StringIO(data))
|
|
202
|
+
params = AirflowDagRunsRequestParameters(
|
|
203
|
+
dag_ids=["transform_and_upload_v2", "run_list_of_jobs"],
|
|
204
|
+
states=["running", "queued"],
|
|
205
|
+
)
|
|
206
|
+
_, current_jobs = await get_airflow_jobs(
|
|
207
|
+
params=params, get_confs=True
|
|
208
|
+
)
|
|
209
|
+
context = {
|
|
210
|
+
"job_types": get_job_types("v2"),
|
|
211
|
+
"project_names": await get_project_names(),
|
|
212
|
+
"current_jobs": current_jobs,
|
|
213
|
+
}
|
|
214
|
+
for row in csv_reader:
|
|
215
|
+
if not any(row.values()):
|
|
216
|
+
continue
|
|
217
|
+
try:
|
|
218
|
+
with validation_context(context):
|
|
219
|
+
job = map_csv_row_to_job(row=row)
|
|
220
|
+
# Construct hpc job setting most of the vars from the env
|
|
221
|
+
basic_jobs.append(
|
|
222
|
+
json.loads(
|
|
223
|
+
job.model_dump_json(
|
|
224
|
+
round_trip=True,
|
|
225
|
+
exclude_none=True,
|
|
226
|
+
warnings=False,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
except ValidationError as e:
|
|
231
|
+
errors.append(e.json())
|
|
232
|
+
except Exception as e:
|
|
233
|
+
errors.append(f"{str(e.args)}")
|
|
234
|
+
message = "There were errors" if len(errors) > 0 else "Valid Data"
|
|
235
|
+
status_code = 406 if len(errors) > 0 else 200
|
|
236
|
+
content = {
|
|
237
|
+
"message": message,
|
|
238
|
+
"data": {"jobs": basic_jobs, "errors": errors},
|
|
239
|
+
}
|
|
240
|
+
return JSONResponse(
|
|
241
|
+
content=content,
|
|
242
|
+
status_code=status_code,
|
|
243
|
+
)
|
|
96
244
|
|
|
97
245
|
|
|
98
246
|
async def get_project_names() -> List[str]:
|
|
@@ -247,121 +395,6 @@ async def get_airflow_jobs(
|
|
|
247
395
|
return (total_entries, jobs_list)
|
|
248
396
|
|
|
249
397
|
|
|
250
|
-
async def validate_csv(request: Request):
|
|
251
|
-
"""Validate a csv or xlsx file. Return parsed contents as json."""
|
|
252
|
-
logger.info("Received request to validate csv")
|
|
253
|
-
async with request.form() as form:
|
|
254
|
-
basic_jobs = []
|
|
255
|
-
errors = []
|
|
256
|
-
if not form["file"].filename.endswith((".csv", ".xlsx")):
|
|
257
|
-
errors.append("Invalid input file type")
|
|
258
|
-
else:
|
|
259
|
-
content = await form["file"].read()
|
|
260
|
-
if form["file"].filename.endswith(".csv"):
|
|
261
|
-
# A few csv files created from excel have extra unicode
|
|
262
|
-
# byte chars. Adding "utf-8-sig" should remove them.
|
|
263
|
-
data = content.decode("utf-8-sig")
|
|
264
|
-
else:
|
|
265
|
-
xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
|
|
266
|
-
xlsx_sheet = xlsx_book.active
|
|
267
|
-
csv_io = io.StringIO()
|
|
268
|
-
csv_writer = csv.writer(csv_io)
|
|
269
|
-
for r in xlsx_sheet.iter_rows(values_only=True):
|
|
270
|
-
if any(r):
|
|
271
|
-
csv_writer.writerow(r)
|
|
272
|
-
xlsx_book.close()
|
|
273
|
-
data = csv_io.getvalue()
|
|
274
|
-
csv_reader = csv.DictReader(io.StringIO(data))
|
|
275
|
-
params = AirflowDagRunsRequestParameters(
|
|
276
|
-
dag_ids=["transform_and_upload_v2", "run_list_of_jobs"],
|
|
277
|
-
states=["running", "queued"],
|
|
278
|
-
)
|
|
279
|
-
_, current_jobs = await get_airflow_jobs(
|
|
280
|
-
params=params, get_confs=True
|
|
281
|
-
)
|
|
282
|
-
context = {
|
|
283
|
-
"job_types": get_job_types("v2"),
|
|
284
|
-
"project_names": await get_project_names(),
|
|
285
|
-
"current_jobs": current_jobs,
|
|
286
|
-
}
|
|
287
|
-
for row in csv_reader:
|
|
288
|
-
if not any(row.values()):
|
|
289
|
-
continue
|
|
290
|
-
try:
|
|
291
|
-
with validation_context_v2(context):
|
|
292
|
-
job = map_csv_row_to_job(row=row)
|
|
293
|
-
# Construct hpc job setting most of the vars from the env
|
|
294
|
-
basic_jobs.append(
|
|
295
|
-
json.loads(
|
|
296
|
-
job.model_dump_json(
|
|
297
|
-
round_trip=True,
|
|
298
|
-
exclude_none=True,
|
|
299
|
-
warnings=False,
|
|
300
|
-
)
|
|
301
|
-
)
|
|
302
|
-
)
|
|
303
|
-
except ValidationError as e:
|
|
304
|
-
errors.append(e.json())
|
|
305
|
-
except Exception as e:
|
|
306
|
-
errors.append(f"{str(e.args)}")
|
|
307
|
-
message = "There were errors" if len(errors) > 0 else "Valid Data"
|
|
308
|
-
status_code = 406 if len(errors) > 0 else 200
|
|
309
|
-
content = {
|
|
310
|
-
"message": message,
|
|
311
|
-
"data": {"jobs": basic_jobs, "errors": errors},
|
|
312
|
-
}
|
|
313
|
-
return JSONResponse(
|
|
314
|
-
content=content,
|
|
315
|
-
status_code=status_code,
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
# TODO: Deprecate this endpoint
|
|
320
|
-
async def validate_csv_legacy(request: Request):
|
|
321
|
-
"""Validate a csv or xlsx file. Return parsed contents as json."""
|
|
322
|
-
async with request.form() as form:
|
|
323
|
-
basic_jobs = []
|
|
324
|
-
errors = []
|
|
325
|
-
if not form["file"].filename.endswith((".csv", ".xlsx")):
|
|
326
|
-
errors.append("Invalid input file type")
|
|
327
|
-
else:
|
|
328
|
-
content = await form["file"].read()
|
|
329
|
-
if form["file"].filename.endswith(".csv"):
|
|
330
|
-
# A few csv files created from excel have extra unicode
|
|
331
|
-
# byte chars. Adding "utf-8-sig" should remove them.
|
|
332
|
-
data = content.decode("utf-8-sig")
|
|
333
|
-
else:
|
|
334
|
-
xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
|
|
335
|
-
xlsx_sheet = xlsx_book.active
|
|
336
|
-
csv_io = io.StringIO()
|
|
337
|
-
csv_writer = csv.writer(csv_io)
|
|
338
|
-
for r in xlsx_sheet.iter_rows(values_only=True):
|
|
339
|
-
if any(r):
|
|
340
|
-
csv_writer.writerow(r)
|
|
341
|
-
xlsx_book.close()
|
|
342
|
-
data = csv_io.getvalue()
|
|
343
|
-
csv_reader = csv.DictReader(io.StringIO(data))
|
|
344
|
-
for row in csv_reader:
|
|
345
|
-
if not any(row.values()):
|
|
346
|
-
continue
|
|
347
|
-
try:
|
|
348
|
-
job = LegacyBasicUploadJobConfigs.from_csv_row(row=row)
|
|
349
|
-
# Construct hpc job setting most of the vars from the env
|
|
350
|
-
basic_jobs.append(job.model_dump_json())
|
|
351
|
-
except Exception as e:
|
|
352
|
-
errors.append(f"{e.__class__.__name__}{e.args}")
|
|
353
|
-
message = "There were errors" if len(errors) > 0 else "Valid Data"
|
|
354
|
-
status_code = 406 if len(errors) > 0 else 200
|
|
355
|
-
content = {
|
|
356
|
-
"message": message,
|
|
357
|
-
"data": {"jobs": basic_jobs, "errors": errors},
|
|
358
|
-
}
|
|
359
|
-
return JSONResponse(
|
|
360
|
-
content=content,
|
|
361
|
-
status_code=status_code,
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
|
|
365
398
|
async def validate_json_v2(request: Request):
|
|
366
399
|
"""Validate raw json against data transfer models. Returns validated
|
|
367
400
|
json or errors if request is invalid."""
|
|
@@ -378,7 +411,7 @@ async def validate_json_v2(request: Request):
|
|
|
378
411
|
"project_names": await get_project_names(),
|
|
379
412
|
"current_jobs": current_jobs,
|
|
380
413
|
}
|
|
381
|
-
with
|
|
414
|
+
with validation_context(context):
|
|
382
415
|
validated_model = SubmitJobRequestV2.model_validate_json(
|
|
383
416
|
json.dumps(content)
|
|
384
417
|
)
|
|
@@ -425,60 +458,6 @@ async def validate_json_v2(request: Request):
|
|
|
425
458
|
)
|
|
426
459
|
|
|
427
460
|
|
|
428
|
-
async def validate_json(request: Request):
|
|
429
|
-
"""Validate raw json against aind-data-transfer-models. Returns validated
|
|
430
|
-
json or errors if request is invalid."""
|
|
431
|
-
logger.info("Received request to validate json")
|
|
432
|
-
content = await request.json()
|
|
433
|
-
try:
|
|
434
|
-
project_names = await get_project_names()
|
|
435
|
-
with validation_context({"project_names": project_names}):
|
|
436
|
-
validated_model = SubmitJobRequest.model_validate_json(
|
|
437
|
-
json.dumps(content)
|
|
438
|
-
)
|
|
439
|
-
validated_content = json.loads(
|
|
440
|
-
validated_model.model_dump_json(warnings=False, exclude_none=True)
|
|
441
|
-
)
|
|
442
|
-
logger.info("Valid model detected")
|
|
443
|
-
return JSONResponse(
|
|
444
|
-
status_code=200,
|
|
445
|
-
content={
|
|
446
|
-
"message": "Valid model",
|
|
447
|
-
"data": {
|
|
448
|
-
"version": aind_data_transfer_models_version,
|
|
449
|
-
"model_json": content,
|
|
450
|
-
"validated_model_json": validated_content,
|
|
451
|
-
},
|
|
452
|
-
},
|
|
453
|
-
)
|
|
454
|
-
except ValidationError as e:
|
|
455
|
-
logger.warning(f"There were validation errors processing {content}")
|
|
456
|
-
return JSONResponse(
|
|
457
|
-
status_code=406,
|
|
458
|
-
content={
|
|
459
|
-
"message": "There were validation errors",
|
|
460
|
-
"data": {
|
|
461
|
-
"version": aind_data_transfer_models_version,
|
|
462
|
-
"model_json": content,
|
|
463
|
-
"errors": e.json(),
|
|
464
|
-
},
|
|
465
|
-
},
|
|
466
|
-
)
|
|
467
|
-
except Exception as e:
|
|
468
|
-
logger.exception("Internal Server Error.")
|
|
469
|
-
return JSONResponse(
|
|
470
|
-
status_code=500,
|
|
471
|
-
content={
|
|
472
|
-
"message": "There was an internal server error",
|
|
473
|
-
"data": {
|
|
474
|
-
"version": aind_data_transfer_models_version,
|
|
475
|
-
"model_json": content,
|
|
476
|
-
"errors": str(e.args),
|
|
477
|
-
},
|
|
478
|
-
},
|
|
479
|
-
)
|
|
480
|
-
|
|
481
|
-
|
|
482
461
|
async def submit_jobs_v2(request: Request):
|
|
483
462
|
"""Post SubmitJobRequestV2 raw json to hpc server to process."""
|
|
484
463
|
logger.info("Received request to submit jobs v2")
|
|
@@ -494,7 +473,7 @@ async def submit_jobs_v2(request: Request):
|
|
|
494
473
|
"project_names": await get_project_names(),
|
|
495
474
|
"current_jobs": current_jobs,
|
|
496
475
|
}
|
|
497
|
-
with
|
|
476
|
+
with validation_context(context):
|
|
498
477
|
model = SubmitJobRequestV2.model_validate_json(json.dumps(content))
|
|
499
478
|
full_content = json.loads(
|
|
500
479
|
model.model_dump_json(warnings=False, exclude_none=True)
|
|
@@ -549,258 +528,6 @@ async def submit_jobs_v2(request: Request):
|
|
|
549
528
|
)
|
|
550
529
|
|
|
551
530
|
|
|
552
|
-
async def submit_jobs(request: Request):
|
|
553
|
-
"""Post BasicJobConfigs raw json to hpc server to process."""
|
|
554
|
-
logger.info("Received request to submit jobs")
|
|
555
|
-
content = await request.json()
|
|
556
|
-
try:
|
|
557
|
-
project_names = await get_project_names()
|
|
558
|
-
with validation_context({"project_names": project_names}):
|
|
559
|
-
model = SubmitJobRequest.model_validate_json(json.dumps(content))
|
|
560
|
-
full_content = json.loads(
|
|
561
|
-
model.model_dump_json(warnings=False, exclude_none=True)
|
|
562
|
-
)
|
|
563
|
-
logger.info(
|
|
564
|
-
f"Valid request detected. Sending list of jobs. "
|
|
565
|
-
f"Job Type: {model.job_type}"
|
|
566
|
-
)
|
|
567
|
-
total_jobs = len(model.upload_jobs)
|
|
568
|
-
for job_index, job in enumerate(model.upload_jobs, 1):
|
|
569
|
-
logger.info(
|
|
570
|
-
f"{job.s3_prefix} sending to airflow. "
|
|
571
|
-
f"{job_index} of {total_jobs}."
|
|
572
|
-
)
|
|
573
|
-
|
|
574
|
-
async with AsyncClient(
|
|
575
|
-
auth=(
|
|
576
|
-
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
|
|
577
|
-
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
|
|
578
|
-
)
|
|
579
|
-
) as async_client:
|
|
580
|
-
response = await async_client.post(
|
|
581
|
-
url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
|
|
582
|
-
json={"conf": full_content},
|
|
583
|
-
)
|
|
584
|
-
status_code = response.status_code
|
|
585
|
-
response_json = response.json()
|
|
586
|
-
return JSONResponse(
|
|
587
|
-
status_code=status_code,
|
|
588
|
-
content={
|
|
589
|
-
"message": "Submitted request to airflow",
|
|
590
|
-
"data": {"responses": [response_json], "errors": []},
|
|
591
|
-
},
|
|
592
|
-
)
|
|
593
|
-
|
|
594
|
-
except ValidationError as e:
|
|
595
|
-
logger.warning(f"There were validation errors processing {content}")
|
|
596
|
-
return JSONResponse(
|
|
597
|
-
status_code=406,
|
|
598
|
-
content={
|
|
599
|
-
"message": "There were validation errors",
|
|
600
|
-
"data": {"responses": [], "errors": e.json()},
|
|
601
|
-
},
|
|
602
|
-
)
|
|
603
|
-
except Exception as e:
|
|
604
|
-
logger.exception("Internal Server Error.")
|
|
605
|
-
return JSONResponse(
|
|
606
|
-
status_code=500,
|
|
607
|
-
content={
|
|
608
|
-
"message": "There was an internal server error",
|
|
609
|
-
"data": {"responses": [], "errors": str(e.args)},
|
|
610
|
-
},
|
|
611
|
-
)
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
# TODO: Deprecate this endpoint
|
|
615
|
-
async def submit_basic_jobs(request: Request):
|
|
616
|
-
"""Post BasicJobConfigs raw json to hpc server to process."""
|
|
617
|
-
content = await request.json()
|
|
618
|
-
hpc_client_conf = HpcClientConfigs()
|
|
619
|
-
hpc_client = HpcClient(configs=hpc_client_conf)
|
|
620
|
-
basic_jobs = content["jobs"]
|
|
621
|
-
hpc_jobs = []
|
|
622
|
-
parsing_errors = []
|
|
623
|
-
for job in basic_jobs:
|
|
624
|
-
try:
|
|
625
|
-
basic_upload_job = LegacyBasicUploadJobConfigs.model_validate_json(
|
|
626
|
-
job
|
|
627
|
-
)
|
|
628
|
-
# Add aws_param_store_name and temp_dir
|
|
629
|
-
basic_upload_job.aws_param_store_name = os.getenv(
|
|
630
|
-
"HPC_AWS_PARAM_STORE_NAME"
|
|
631
|
-
)
|
|
632
|
-
basic_upload_job.temp_directory = os.getenv(
|
|
633
|
-
"HPC_STAGING_DIRECTORY"
|
|
634
|
-
)
|
|
635
|
-
hpc_job = HpcJobConfigs(basic_upload_job_configs=basic_upload_job)
|
|
636
|
-
hpc_jobs.append(hpc_job)
|
|
637
|
-
except Exception as e:
|
|
638
|
-
parsing_errors.append(
|
|
639
|
-
f"Error parsing {job}: {e.__class__.__name__}"
|
|
640
|
-
)
|
|
641
|
-
if parsing_errors:
|
|
642
|
-
status_code = 406
|
|
643
|
-
message = "There were errors parsing the basic job configs"
|
|
644
|
-
content = {
|
|
645
|
-
"message": message,
|
|
646
|
-
"data": {"responses": [], "errors": parsing_errors},
|
|
647
|
-
}
|
|
648
|
-
else:
|
|
649
|
-
responses = []
|
|
650
|
-
hpc_errors = []
|
|
651
|
-
for hpc_job in hpc_jobs:
|
|
652
|
-
try:
|
|
653
|
-
job_def = hpc_job.job_definition
|
|
654
|
-
response = await hpc_client.submit_job(job_def)
|
|
655
|
-
response_json = response.json()
|
|
656
|
-
responses.append(response_json)
|
|
657
|
-
# Add pause to stagger job requests to the hpc
|
|
658
|
-
await sleep(0.2)
|
|
659
|
-
except Exception as e:
|
|
660
|
-
logger.error(f"{e.__class__.__name__}{e.args}")
|
|
661
|
-
hpc_errors.append(
|
|
662
|
-
f"Error processing "
|
|
663
|
-
f"{hpc_job.basic_upload_job_configs.s3_prefix}"
|
|
664
|
-
)
|
|
665
|
-
message = (
|
|
666
|
-
"There were errors submitting jobs to the hpc."
|
|
667
|
-
if len(hpc_errors) > 0
|
|
668
|
-
else "Submitted Jobs."
|
|
669
|
-
)
|
|
670
|
-
status_code = 500 if len(hpc_errors) > 0 else 200
|
|
671
|
-
content = {
|
|
672
|
-
"message": message,
|
|
673
|
-
"data": {"responses": responses, "errors": hpc_errors},
|
|
674
|
-
}
|
|
675
|
-
return JSONResponse(
|
|
676
|
-
content=content,
|
|
677
|
-
status_code=status_code,
|
|
678
|
-
)
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
# TODO: Deprecate this endpoint
|
|
682
|
-
async def submit_hpc_jobs(request: Request): # noqa: C901
|
|
683
|
-
"""Post HpcJobSubmitSettings to hpc server to process."""
|
|
684
|
-
|
|
685
|
-
content = await request.json()
|
|
686
|
-
# content should have
|
|
687
|
-
# {
|
|
688
|
-
# "jobs": [{"hpc_settings": str, upload_job_settings: str, script: str}]
|
|
689
|
-
# }
|
|
690
|
-
hpc_client_conf = HpcClientConfigs()
|
|
691
|
-
hpc_client = HpcClient(configs=hpc_client_conf)
|
|
692
|
-
job_configs = content["jobs"]
|
|
693
|
-
hpc_jobs = []
|
|
694
|
-
parsing_errors = []
|
|
695
|
-
for job in job_configs:
|
|
696
|
-
try:
|
|
697
|
-
base_script = job.get("script")
|
|
698
|
-
# If script is empty, assume that the job type is a basic job
|
|
699
|
-
basic_job_name = None
|
|
700
|
-
if base_script is None or base_script == "":
|
|
701
|
-
base_script = HpcJobSubmitSettings.script_command_str(
|
|
702
|
-
sif_loc_str=os.getenv("HPC_SIF_LOCATION")
|
|
703
|
-
)
|
|
704
|
-
basic_job_name = (
|
|
705
|
-
LegacyBasicUploadJobConfigs.model_validate_json(
|
|
706
|
-
job["upload_job_settings"]
|
|
707
|
-
).s3_prefix
|
|
708
|
-
)
|
|
709
|
-
upload_job_configs = json.loads(job["upload_job_settings"])
|
|
710
|
-
# This will set the bucket to the private data one
|
|
711
|
-
if upload_job_configs.get("s3_bucket") is not None:
|
|
712
|
-
upload_job_configs = json.loads(
|
|
713
|
-
LegacyBasicUploadJobConfigs.model_validate(
|
|
714
|
-
upload_job_configs
|
|
715
|
-
).model_dump_json()
|
|
716
|
-
)
|
|
717
|
-
# The aws creds to use are different for aind-open-data and
|
|
718
|
-
# everything else
|
|
719
|
-
if upload_job_configs.get("s3_bucket") == OPEN_DATA_BUCKET_NAME:
|
|
720
|
-
aws_secret_access_key = SecretStr(
|
|
721
|
-
os.getenv("OPEN_DATA_AWS_SECRET_ACCESS_KEY")
|
|
722
|
-
)
|
|
723
|
-
aws_access_key_id = os.getenv("OPEN_DATA_AWS_ACCESS_KEY_ID")
|
|
724
|
-
else:
|
|
725
|
-
aws_secret_access_key = SecretStr(
|
|
726
|
-
os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
|
|
727
|
-
)
|
|
728
|
-
aws_access_key_id = os.getenv("HPC_AWS_ACCESS_KEY_ID")
|
|
729
|
-
hpc_settings = json.loads(job["hpc_settings"])
|
|
730
|
-
if basic_job_name is not None:
|
|
731
|
-
hpc_settings["name"] = basic_job_name
|
|
732
|
-
hpc_job = HpcJobSubmitSettings.from_upload_job_configs(
|
|
733
|
-
logging_directory=PurePosixPath(
|
|
734
|
-
os.getenv("HPC_LOGGING_DIRECTORY")
|
|
735
|
-
),
|
|
736
|
-
aws_secret_access_key=aws_secret_access_key,
|
|
737
|
-
aws_access_key_id=aws_access_key_id,
|
|
738
|
-
aws_default_region=os.getenv("HPC_AWS_DEFAULT_REGION"),
|
|
739
|
-
aws_session_token=(
|
|
740
|
-
(
|
|
741
|
-
None
|
|
742
|
-
if os.getenv("HPC_AWS_SESSION_TOKEN") is None
|
|
743
|
-
else SecretStr(os.getenv("HPC_AWS_SESSION_TOKEN"))
|
|
744
|
-
)
|
|
745
|
-
),
|
|
746
|
-
**hpc_settings,
|
|
747
|
-
)
|
|
748
|
-
if not upload_job_configs:
|
|
749
|
-
script = base_script
|
|
750
|
-
else:
|
|
751
|
-
script = hpc_job.attach_configs_to_script(
|
|
752
|
-
script=base_script,
|
|
753
|
-
base_configs=upload_job_configs,
|
|
754
|
-
upload_configs_aws_param_store_name=os.getenv(
|
|
755
|
-
"HPC_AWS_PARAM_STORE_NAME"
|
|
756
|
-
),
|
|
757
|
-
staging_directory=os.getenv("HPC_STAGING_DIRECTORY"),
|
|
758
|
-
)
|
|
759
|
-
hpc_jobs.append((hpc_job, script))
|
|
760
|
-
except Exception as e:
|
|
761
|
-
parsing_errors.append(
|
|
762
|
-
f"Error parsing {job['upload_job_settings']}: {repr(e)}"
|
|
763
|
-
)
|
|
764
|
-
if parsing_errors:
|
|
765
|
-
status_code = 406
|
|
766
|
-
message = "There were errors parsing the job configs"
|
|
767
|
-
content = {
|
|
768
|
-
"message": message,
|
|
769
|
-
"data": {"responses": [], "errors": parsing_errors},
|
|
770
|
-
}
|
|
771
|
-
else:
|
|
772
|
-
responses = []
|
|
773
|
-
hpc_errors = []
|
|
774
|
-
for hpc_job in hpc_jobs:
|
|
775
|
-
hpc_job_def = hpc_job[0]
|
|
776
|
-
try:
|
|
777
|
-
script = hpc_job[1]
|
|
778
|
-
response = await hpc_client.submit_hpc_job(
|
|
779
|
-
job=hpc_job_def, script=script
|
|
780
|
-
)
|
|
781
|
-
response_json = response.json()
|
|
782
|
-
responses.append(response_json)
|
|
783
|
-
# Add pause to stagger job requests to the hpc
|
|
784
|
-
await sleep(0.2)
|
|
785
|
-
except Exception as e:
|
|
786
|
-
logger.error(repr(e))
|
|
787
|
-
hpc_errors.append(f"Error processing " f"{hpc_job_def.name}")
|
|
788
|
-
message = (
|
|
789
|
-
"There were errors submitting jobs to the hpc."
|
|
790
|
-
if len(hpc_errors) > 0
|
|
791
|
-
else "Submitted Jobs."
|
|
792
|
-
)
|
|
793
|
-
status_code = 500 if len(hpc_errors) > 0 else 200
|
|
794
|
-
content = {
|
|
795
|
-
"message": message,
|
|
796
|
-
"data": {"responses": responses, "errors": hpc_errors},
|
|
797
|
-
}
|
|
798
|
-
return JSONResponse(
|
|
799
|
-
content=content,
|
|
800
|
-
status_code=status_code,
|
|
801
|
-
)
|
|
802
|
-
|
|
803
|
-
|
|
804
531
|
async def get_job_status_list(request: Request):
|
|
805
532
|
"""Get status of jobs using input query params."""
|
|
806
533
|
|
|
@@ -1277,7 +1004,7 @@ async def auth(request: Request):
|
|
|
1277
1004
|
|
|
1278
1005
|
routes = [
|
|
1279
1006
|
Route("/", endpoint=index, methods=["GET", "POST"]),
|
|
1280
|
-
Route("/api/validate_csv", endpoint=
|
|
1007
|
+
Route("/api/validate_csv", endpoint=validate_csv, methods=["POST"]),
|
|
1281
1008
|
Route(
|
|
1282
1009
|
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
|
|
1283
1010
|
),
|