aind-data-transfer-service 1.17.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aind-data-transfer-service might be problematic. Click here for more details.

Files changed (23) hide show
  1. aind_data_transfer_service/__init__.py +2 -1
  2. aind_data_transfer_service/configs/csv_handler.py +10 -5
  3. aind_data_transfer_service/configs/job_upload_template.py +2 -1
  4. aind_data_transfer_service/configs/platforms_v1.py +177 -0
  5. aind_data_transfer_service/log_handler.py +3 -3
  6. aind_data_transfer_service/models/core.py +25 -4
  7. aind_data_transfer_service/server.py +225 -487
  8. {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/METADATA +4 -6
  9. aind_data_transfer_service-1.18.0.dist-info/RECORD +15 -0
  10. aind_data_transfer_service/configs/job_configs.py +0 -545
  11. aind_data_transfer_service/hpc/__init__.py +0 -1
  12. aind_data_transfer_service/hpc/client.py +0 -151
  13. aind_data_transfer_service/hpc/models.py +0 -492
  14. aind_data_transfer_service/templates/admin.html +0 -45
  15. aind_data_transfer_service/templates/index.html +0 -258
  16. aind_data_transfer_service/templates/job_params.html +0 -405
  17. aind_data_transfer_service/templates/job_status.html +0 -324
  18. aind_data_transfer_service/templates/job_tasks_table.html +0 -146
  19. aind_data_transfer_service/templates/task_logs.html +0 -31
  20. aind_data_transfer_service-1.17.0.dist-info/RECORD +0 -24
  21. {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/WHEEL +0 -0
  22. {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/licenses/LICENSE +0 -0
  23. {aind_data_transfer_service-1.17.0.dist-info → aind_data_transfer_service-1.18.0.dist-info}/top_level.txt +0 -0
@@ -5,54 +5,35 @@ import io
5
5
  import json
6
6
  import os
7
7
  import re
8
- from asyncio import gather, sleep
9
- from pathlib import PurePosixPath
8
+ from asyncio import gather
10
9
  from typing import Any, List, Optional, Union
11
10
 
12
11
  import boto3
13
- import requests
14
- from aind_data_transfer_models import (
15
- __version__ as aind_data_transfer_models_version,
16
- )
17
- from aind_data_transfer_models.core import SubmitJobRequest, validation_context
18
12
  from authlib.integrations.starlette_client import OAuth
19
13
  from botocore.exceptions import ClientError
20
14
  from fastapi import Request
21
- from fastapi.responses import JSONResponse, StreamingResponse
15
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
22
16
  from fastapi.templating import Jinja2Templates
23
- from httpx import AsyncClient
17
+ from httpx import AsyncClient, RequestError
24
18
  from openpyxl import load_workbook
25
- from pydantic import SecretStr, ValidationError
19
+ from pydantic import ValidationError
26
20
  from starlette.applications import Starlette
27
21
  from starlette.config import Config
28
22
  from starlette.middleware.sessions import SessionMiddleware
29
23
  from starlette.responses import RedirectResponse
30
24
  from starlette.routing import Route
31
25
 
32
- from aind_data_transfer_service import (
33
- OPEN_DATA_BUCKET_NAME,
34
- )
35
26
  from aind_data_transfer_service import (
36
27
  __version__ as aind_data_transfer_service_version,
37
28
  )
38
29
  from aind_data_transfer_service.configs.csv_handler import map_csv_row_to_job
39
- from aind_data_transfer_service.configs.job_configs import (
40
- BasicUploadJobConfigs as LegacyBasicUploadJobConfigs,
41
- )
42
- from aind_data_transfer_service.configs.job_configs import (
43
- HpcJobConfigs,
44
- )
45
30
  from aind_data_transfer_service.configs.job_upload_template import (
46
31
  JobUploadTemplate,
47
32
  )
48
- from aind_data_transfer_service.hpc.client import HpcClient, HpcClientConfigs
49
- from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
50
33
  from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
51
34
  from aind_data_transfer_service.models.core import (
52
35
  SubmitJobRequestV2,
53
- )
54
- from aind_data_transfer_service.models.core import (
55
- validation_context as validation_context_v2,
36
+ validation_context,
56
37
  )
57
38
  from aind_data_transfer_service.models.internal import (
58
39
  AirflowDagRunsRequestParameters,
@@ -91,17 +72,184 @@ templates = Jinja2Templates(directory=template_directory)
91
72
  # LOKI_URI
92
73
  # ENV_NAME
93
74
  # LOG_LEVEL
75
+ # AIND_DATA_TRANSFER_SERVICE_V1_URL
94
76
 
95
77
  logger = get_logger(log_configs=LoggingConfigs())
96
78
  project_names_url = os.getenv("AIND_METADATA_SERVICE_PROJECT_NAMES_URL")
79
+ aind_dts_v1_url = os.getenv(
80
+ "AIND_DATA_TRANSFER_SERVICE_V1_URL",
81
+ "http://aind-data-transfer-service-v1:5000",
82
+ )
83
+
84
+
85
+ async def proxy(
86
+ request: Request,
87
+ path: str,
88
+ async_client: AsyncClient,
89
+ ) -> Response:
90
+ """
91
+ Proxy request to v1 aind-metadata-service-server
92
+ Parameters
93
+ ----------
94
+ request : Request
95
+ path : str
96
+ async_client : AsyncClient
97
+
98
+ Returns
99
+ -------
100
+ Response
101
+
102
+ """
103
+
104
+ # Prepare headers to forward (excluding hop-by-hop headers)
105
+ headers = {
106
+ key: value
107
+ for key, value in request.headers.items()
108
+ if key.lower()
109
+ not in [
110
+ "host",
111
+ "connection",
112
+ "keep-alive",
113
+ "proxy-authenticate",
114
+ "proxy-authorization",
115
+ "te",
116
+ "trailers",
117
+ "transfer-encoding",
118
+ "upgrade",
119
+ ]
120
+ }
121
+
122
+ try:
123
+ body = await request.body()
124
+ backend_response = await async_client.request(
125
+ method=request.method,
126
+ url=path,
127
+ headers=headers,
128
+ content=body,
129
+ timeout=120, # Adjust timeout as needed
130
+ )
131
+ # Create a FastAPI Response from the backend's response
132
+ response_headers = {
133
+ key: value
134
+ for key, value in backend_response.headers.items()
135
+ if key.lower() not in ["content-encoding", "content-length"]
136
+ }
137
+ return Response(
138
+ content=backend_response.content,
139
+ status_code=backend_response.status_code,
140
+ headers=response_headers,
141
+ media_type=backend_response.headers.get("content-type"),
142
+ )
143
+ except RequestError as exc:
144
+ return Response(f"Proxy request failed: {exc}", status_code=500)
145
+
146
+
147
+ async def submit_basic_jobs(
148
+ request: Request,
149
+ ):
150
+ """submit_basic_jobs_legacy"""
151
+ async with AsyncClient(base_url=aind_dts_v1_url) as session:
152
+ return await proxy(request, "/api/submit_basic_jobs", session)
153
+
154
+
155
+ async def submit_jobs(
156
+ request: Request,
157
+ ):
158
+ """submit_basic_jobs_legacy"""
159
+ async with AsyncClient(base_url=aind_dts_v1_url) as session:
160
+ return await proxy(request, "/api/v1/submit_jobs", session)
161
+
162
+
163
+ async def submit_hpc_jobs(
164
+ request: Request,
165
+ ):
166
+ """submit_hpc_jobs_legacy"""
167
+ async with AsyncClient(base_url=aind_dts_v1_url) as session:
168
+ return await proxy(request, "/api/submit_hpc_jobs", session)
169
+
170
+
171
+ async def validate_json(request: Request):
172
+ """validate_json_legacy"""
173
+ async with AsyncClient(base_url=aind_dts_v1_url) as session:
174
+ return await proxy(request, "/api/v1/validate_json", session)
175
+
176
+
177
+ async def validate_csv(request: Request):
178
+ """Validate a csv or xlsx file. Return parsed contents as json."""
179
+ logger.info("Received request to validate csv")
180
+ async with request.form() as form:
181
+ basic_jobs = []
182
+ errors = []
183
+ if not form["file"].filename.endswith((".csv", ".xlsx")):
184
+ errors.append("Invalid input file type")
185
+ else:
186
+ content = await form["file"].read()
187
+ if form["file"].filename.endswith(".csv"):
188
+ # A few csv files created from excel have extra unicode
189
+ # byte chars. Adding "utf-8-sig" should remove them.
190
+ data = content.decode("utf-8-sig")
191
+ else:
192
+ xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
193
+ xlsx_sheet = xlsx_book.active
194
+ csv_io = io.StringIO()
195
+ csv_writer = csv.writer(csv_io)
196
+ for r in xlsx_sheet.iter_rows(values_only=True):
197
+ if any(r):
198
+ csv_writer.writerow(r)
199
+ xlsx_book.close()
200
+ data = csv_io.getvalue()
201
+ csv_reader = csv.DictReader(io.StringIO(data))
202
+ params = AirflowDagRunsRequestParameters(
203
+ dag_ids=["transform_and_upload_v2", "run_list_of_jobs"],
204
+ states=["running", "queued"],
205
+ )
206
+ _, current_jobs = await get_airflow_jobs(
207
+ params=params, get_confs=True
208
+ )
209
+ context = {
210
+ "job_types": get_job_types("v2"),
211
+ "project_names": await get_project_names(),
212
+ "current_jobs": current_jobs,
213
+ }
214
+ for row in csv_reader:
215
+ if not any(row.values()):
216
+ continue
217
+ try:
218
+ with validation_context(context):
219
+ job = map_csv_row_to_job(row=row)
220
+ # Construct hpc job setting most of the vars from the env
221
+ basic_jobs.append(
222
+ json.loads(
223
+ job.model_dump_json(
224
+ round_trip=True,
225
+ exclude_none=True,
226
+ warnings=False,
227
+ )
228
+ )
229
+ )
230
+ except ValidationError as e:
231
+ errors.append(e.json())
232
+ except Exception as e:
233
+ errors.append(f"{str(e.args)}")
234
+ message = "There were errors" if len(errors) > 0 else "Valid Data"
235
+ status_code = 406 if len(errors) > 0 else 200
236
+ content = {
237
+ "message": message,
238
+ "data": {"jobs": basic_jobs, "errors": errors},
239
+ }
240
+ return JSONResponse(
241
+ content=content,
242
+ status_code=status_code,
243
+ )
97
244
 
98
245
 
99
- def get_project_names() -> List[str]:
246
+ async def get_project_names() -> List[str]:
100
247
  """Get a list of project_names"""
101
248
  # TODO: Cache response for 5 minutes
102
- response = requests.get(project_names_url)
103
- response.raise_for_status()
104
- project_names = response.json()["data"]
249
+ async with AsyncClient() as async_client:
250
+ response = await async_client.get(project_names_url)
251
+ response.raise_for_status()
252
+ project_names = response.json()["data"]
105
253
  return project_names
106
254
 
107
255
 
@@ -247,121 +395,6 @@ async def get_airflow_jobs(
247
395
  return (total_entries, jobs_list)
248
396
 
249
397
 
250
- async def validate_csv(request: Request):
251
- """Validate a csv or xlsx file. Return parsed contents as json."""
252
- logger.info("Received request to validate csv")
253
- async with request.form() as form:
254
- basic_jobs = []
255
- errors = []
256
- if not form["file"].filename.endswith((".csv", ".xlsx")):
257
- errors.append("Invalid input file type")
258
- else:
259
- content = await form["file"].read()
260
- if form["file"].filename.endswith(".csv"):
261
- # A few csv files created from excel have extra unicode
262
- # byte chars. Adding "utf-8-sig" should remove them.
263
- data = content.decode("utf-8-sig")
264
- else:
265
- xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
266
- xlsx_sheet = xlsx_book.active
267
- csv_io = io.StringIO()
268
- csv_writer = csv.writer(csv_io)
269
- for r in xlsx_sheet.iter_rows(values_only=True):
270
- if any(r):
271
- csv_writer.writerow(r)
272
- xlsx_book.close()
273
- data = csv_io.getvalue()
274
- csv_reader = csv.DictReader(io.StringIO(data))
275
- params = AirflowDagRunsRequestParameters(
276
- dag_ids=["transform_and_upload_v2", "run_list_of_jobs"],
277
- states=["running", "queued"],
278
- )
279
- _, current_jobs = await get_airflow_jobs(
280
- params=params, get_confs=True
281
- )
282
- context = {
283
- "job_types": get_job_types("v2"),
284
- "project_names": get_project_names(),
285
- "current_jobs": current_jobs,
286
- }
287
- for row in csv_reader:
288
- if not any(row.values()):
289
- continue
290
- try:
291
- with validation_context_v2(context):
292
- job = map_csv_row_to_job(row=row)
293
- # Construct hpc job setting most of the vars from the env
294
- basic_jobs.append(
295
- json.loads(
296
- job.model_dump_json(
297
- round_trip=True,
298
- exclude_none=True,
299
- warnings=False,
300
- )
301
- )
302
- )
303
- except ValidationError as e:
304
- errors.append(e.json())
305
- except Exception as e:
306
- errors.append(f"{str(e.args)}")
307
- message = "There were errors" if len(errors) > 0 else "Valid Data"
308
- status_code = 406 if len(errors) > 0 else 200
309
- content = {
310
- "message": message,
311
- "data": {"jobs": basic_jobs, "errors": errors},
312
- }
313
- return JSONResponse(
314
- content=content,
315
- status_code=status_code,
316
- )
317
-
318
-
319
- # TODO: Deprecate this endpoint
320
- async def validate_csv_legacy(request: Request):
321
- """Validate a csv or xlsx file. Return parsed contents as json."""
322
- async with request.form() as form:
323
- basic_jobs = []
324
- errors = []
325
- if not form["file"].filename.endswith((".csv", ".xlsx")):
326
- errors.append("Invalid input file type")
327
- else:
328
- content = await form["file"].read()
329
- if form["file"].filename.endswith(".csv"):
330
- # A few csv files created from excel have extra unicode
331
- # byte chars. Adding "utf-8-sig" should remove them.
332
- data = content.decode("utf-8-sig")
333
- else:
334
- xlsx_book = load_workbook(io.BytesIO(content), read_only=True)
335
- xlsx_sheet = xlsx_book.active
336
- csv_io = io.StringIO()
337
- csv_writer = csv.writer(csv_io)
338
- for r in xlsx_sheet.iter_rows(values_only=True):
339
- if any(r):
340
- csv_writer.writerow(r)
341
- xlsx_book.close()
342
- data = csv_io.getvalue()
343
- csv_reader = csv.DictReader(io.StringIO(data))
344
- for row in csv_reader:
345
- if not any(row.values()):
346
- continue
347
- try:
348
- job = LegacyBasicUploadJobConfigs.from_csv_row(row=row)
349
- # Construct hpc job setting most of the vars from the env
350
- basic_jobs.append(job.model_dump_json())
351
- except Exception as e:
352
- errors.append(f"{e.__class__.__name__}{e.args}")
353
- message = "There were errors" if len(errors) > 0 else "Valid Data"
354
- status_code = 406 if len(errors) > 0 else 200
355
- content = {
356
- "message": message,
357
- "data": {"jobs": basic_jobs, "errors": errors},
358
- }
359
- return JSONResponse(
360
- content=content,
361
- status_code=status_code,
362
- )
363
-
364
-
365
398
  async def validate_json_v2(request: Request):
366
399
  """Validate raw json against data transfer models. Returns validated
367
400
  json or errors if request is invalid."""
@@ -375,10 +408,10 @@ async def validate_json_v2(request: Request):
375
408
  _, current_jobs = await get_airflow_jobs(params=params, get_confs=True)
376
409
  context = {
377
410
  "job_types": get_job_types("v2"),
378
- "project_names": get_project_names(),
411
+ "project_names": await get_project_names(),
379
412
  "current_jobs": current_jobs,
380
413
  }
381
- with validation_context_v2(context):
414
+ with validation_context(context):
382
415
  validated_model = SubmitJobRequestV2.model_validate_json(
383
416
  json.dumps(content)
384
417
  )
@@ -425,60 +458,6 @@ async def validate_json_v2(request: Request):
425
458
  )
426
459
 
427
460
 
428
- async def validate_json(request: Request):
429
- """Validate raw json against aind-data-transfer-models. Returns validated
430
- json or errors if request is invalid."""
431
- logger.info("Received request to validate json")
432
- content = await request.json()
433
- try:
434
- project_names = get_project_names()
435
- with validation_context({"project_names": project_names}):
436
- validated_model = SubmitJobRequest.model_validate_json(
437
- json.dumps(content)
438
- )
439
- validated_content = json.loads(
440
- validated_model.model_dump_json(warnings=False, exclude_none=True)
441
- )
442
- logger.info("Valid model detected")
443
- return JSONResponse(
444
- status_code=200,
445
- content={
446
- "message": "Valid model",
447
- "data": {
448
- "version": aind_data_transfer_models_version,
449
- "model_json": content,
450
- "validated_model_json": validated_content,
451
- },
452
- },
453
- )
454
- except ValidationError as e:
455
- logger.warning(f"There were validation errors processing {content}")
456
- return JSONResponse(
457
- status_code=406,
458
- content={
459
- "message": "There were validation errors",
460
- "data": {
461
- "version": aind_data_transfer_models_version,
462
- "model_json": content,
463
- "errors": e.json(),
464
- },
465
- },
466
- )
467
- except Exception as e:
468
- logger.exception("Internal Server Error.")
469
- return JSONResponse(
470
- status_code=500,
471
- content={
472
- "message": "There was an internal server error",
473
- "data": {
474
- "version": aind_data_transfer_models_version,
475
- "model_json": content,
476
- "errors": str(e.args),
477
- },
478
- },
479
- )
480
-
481
-
482
461
  async def submit_jobs_v2(request: Request):
483
462
  """Post SubmitJobRequestV2 raw json to hpc server to process."""
484
463
  logger.info("Received request to submit jobs v2")
@@ -491,15 +470,14 @@ async def submit_jobs_v2(request: Request):
491
470
  _, current_jobs = await get_airflow_jobs(params=params, get_confs=True)
492
471
  context = {
493
472
  "job_types": get_job_types("v2"),
494
- "project_names": get_project_names(),
473
+ "project_names": await get_project_names(),
495
474
  "current_jobs": current_jobs,
496
475
  }
497
- with validation_context_v2(context):
476
+ with validation_context(context):
498
477
  model = SubmitJobRequestV2.model_validate_json(json.dumps(content))
499
478
  full_content = json.loads(
500
479
  model.model_dump_json(warnings=False, exclude_none=True)
501
480
  )
502
- # TODO: Replace with httpx async client
503
481
  logger.info(
504
482
  f"Valid request detected. Sending list of jobs. "
505
483
  f"dag_id: {model.dag_id}"
@@ -511,80 +489,25 @@ async def submit_jobs_v2(request: Request):
511
489
  f"{job_index} of {total_jobs}."
512
490
  )
513
491
 
514
- response = requests.post(
515
- url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
492
+ async with AsyncClient(
516
493
  auth=(
517
494
  os.getenv("AIND_AIRFLOW_SERVICE_USER"),
518
495
  os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
519
- ),
520
- json={"conf": full_content},
521
- )
522
- return JSONResponse(
523
- status_code=response.status_code,
524
- content={
525
- "message": "Submitted request to airflow",
526
- "data": {"responses": [response.json()], "errors": []},
527
- },
528
- )
529
- except ValidationError as e:
530
- logger.warning(f"There were validation errors processing {content}")
531
- return JSONResponse(
532
- status_code=406,
533
- content={
534
- "message": "There were validation errors",
535
- "data": {"responses": [], "errors": e.json()},
536
- },
537
- )
538
- except Exception as e:
539
- logger.exception("Internal Server Error.")
540
- return JSONResponse(
541
- status_code=500,
542
- content={
543
- "message": "There was an internal server error",
544
- "data": {"responses": [], "errors": str(e.args)},
545
- },
546
- )
547
-
548
-
549
- async def submit_jobs(request: Request):
550
- """Post BasicJobConfigs raw json to hpc server to process."""
551
- logger.info("Received request to submit jobs")
552
- content = await request.json()
553
- try:
554
- project_names = get_project_names()
555
- with validation_context({"project_names": project_names}):
556
- model = SubmitJobRequest.model_validate_json(json.dumps(content))
557
- full_content = json.loads(
558
- model.model_dump_json(warnings=False, exclude_none=True)
559
- )
560
- # TODO: Replace with httpx async client
561
- logger.info(
562
- f"Valid request detected. Sending list of jobs. "
563
- f"Job Type: {model.job_type}"
564
- )
565
- total_jobs = len(model.upload_jobs)
566
- for job_index, job in enumerate(model.upload_jobs, 1):
567
- logger.info(
568
- f"{job.s3_prefix} sending to airflow. "
569
- f"{job_index} of {total_jobs}."
570
496
  )
571
-
572
- response = requests.post(
573
- url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
574
- auth=(
575
- os.getenv("AIND_AIRFLOW_SERVICE_USER"),
576
- os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
577
- ),
578
- json={"conf": full_content},
579
- )
497
+ ) as async_client:
498
+ response = await async_client.post(
499
+ url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
500
+ json={"conf": full_content},
501
+ )
502
+ status_code = response.status_code
503
+ response_json = response.json()
580
504
  return JSONResponse(
581
- status_code=response.status_code,
505
+ status_code=status_code,
582
506
  content={
583
507
  "message": "Submitted request to airflow",
584
- "data": {"responses": [response.json()], "errors": []},
508
+ "data": {"responses": [response_json], "errors": []},
585
509
  },
586
510
  )
587
-
588
511
  except ValidationError as e:
589
512
  logger.warning(f"There were validation errors processing {content}")
590
513
  return JSONResponse(
@@ -605,196 +528,6 @@ async def submit_jobs(request: Request):
605
528
  )
606
529
 
607
530
 
608
- # TODO: Deprecate this endpoint
609
- async def submit_basic_jobs(request: Request):
610
- """Post BasicJobConfigs raw json to hpc server to process."""
611
- content = await request.json()
612
- hpc_client_conf = HpcClientConfigs()
613
- hpc_client = HpcClient(configs=hpc_client_conf)
614
- basic_jobs = content["jobs"]
615
- hpc_jobs = []
616
- parsing_errors = []
617
- for job in basic_jobs:
618
- try:
619
- basic_upload_job = LegacyBasicUploadJobConfigs.model_validate_json(
620
- job
621
- )
622
- # Add aws_param_store_name and temp_dir
623
- basic_upload_job.aws_param_store_name = os.getenv(
624
- "HPC_AWS_PARAM_STORE_NAME"
625
- )
626
- basic_upload_job.temp_directory = os.getenv(
627
- "HPC_STAGING_DIRECTORY"
628
- )
629
- hpc_job = HpcJobConfigs(basic_upload_job_configs=basic_upload_job)
630
- hpc_jobs.append(hpc_job)
631
- except Exception as e:
632
- parsing_errors.append(
633
- f"Error parsing {job}: {e.__class__.__name__}"
634
- )
635
- if parsing_errors:
636
- status_code = 406
637
- message = "There were errors parsing the basic job configs"
638
- content = {
639
- "message": message,
640
- "data": {"responses": [], "errors": parsing_errors},
641
- }
642
- else:
643
- responses = []
644
- hpc_errors = []
645
- for hpc_job in hpc_jobs:
646
- try:
647
- job_def = hpc_job.job_definition
648
- response = hpc_client.submit_job(job_def)
649
- response_json = response.json()
650
- responses.append(response_json)
651
- # Add pause to stagger job requests to the hpc
652
- await sleep(0.2)
653
- except Exception as e:
654
- logger.error(f"{e.__class__.__name__}{e.args}")
655
- hpc_errors.append(
656
- f"Error processing "
657
- f"{hpc_job.basic_upload_job_configs.s3_prefix}"
658
- )
659
- message = (
660
- "There were errors submitting jobs to the hpc."
661
- if len(hpc_errors) > 0
662
- else "Submitted Jobs."
663
- )
664
- status_code = 500 if len(hpc_errors) > 0 else 200
665
- content = {
666
- "message": message,
667
- "data": {"responses": responses, "errors": hpc_errors},
668
- }
669
- return JSONResponse(
670
- content=content,
671
- status_code=status_code,
672
- )
673
-
674
-
675
- # TODO: Deprecate this endpoint
676
- async def submit_hpc_jobs(request: Request): # noqa: C901
677
- """Post HpcJobSubmitSettings to hpc server to process."""
678
-
679
- content = await request.json()
680
- # content should have
681
- # {
682
- # "jobs": [{"hpc_settings": str, upload_job_settings: str, script: str}]
683
- # }
684
- hpc_client_conf = HpcClientConfigs()
685
- hpc_client = HpcClient(configs=hpc_client_conf)
686
- job_configs = content["jobs"]
687
- hpc_jobs = []
688
- parsing_errors = []
689
- for job in job_configs:
690
- try:
691
- base_script = job.get("script")
692
- # If script is empty, assume that the job type is a basic job
693
- basic_job_name = None
694
- if base_script is None or base_script == "":
695
- base_script = HpcJobSubmitSettings.script_command_str(
696
- sif_loc_str=os.getenv("HPC_SIF_LOCATION")
697
- )
698
- basic_job_name = (
699
- LegacyBasicUploadJobConfigs.model_validate_json(
700
- job["upload_job_settings"]
701
- ).s3_prefix
702
- )
703
- upload_job_configs = json.loads(job["upload_job_settings"])
704
- # This will set the bucket to the private data one
705
- if upload_job_configs.get("s3_bucket") is not None:
706
- upload_job_configs = json.loads(
707
- LegacyBasicUploadJobConfigs.model_validate(
708
- upload_job_configs
709
- ).model_dump_json()
710
- )
711
- # The aws creds to use are different for aind-open-data and
712
- # everything else
713
- if upload_job_configs.get("s3_bucket") == OPEN_DATA_BUCKET_NAME:
714
- aws_secret_access_key = SecretStr(
715
- os.getenv("OPEN_DATA_AWS_SECRET_ACCESS_KEY")
716
- )
717
- aws_access_key_id = os.getenv("OPEN_DATA_AWS_ACCESS_KEY_ID")
718
- else:
719
- aws_secret_access_key = SecretStr(
720
- os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
721
- )
722
- aws_access_key_id = os.getenv("HPC_AWS_ACCESS_KEY_ID")
723
- hpc_settings = json.loads(job["hpc_settings"])
724
- if basic_job_name is not None:
725
- hpc_settings["name"] = basic_job_name
726
- hpc_job = HpcJobSubmitSettings.from_upload_job_configs(
727
- logging_directory=PurePosixPath(
728
- os.getenv("HPC_LOGGING_DIRECTORY")
729
- ),
730
- aws_secret_access_key=aws_secret_access_key,
731
- aws_access_key_id=aws_access_key_id,
732
- aws_default_region=os.getenv("HPC_AWS_DEFAULT_REGION"),
733
- aws_session_token=(
734
- (
735
- None
736
- if os.getenv("HPC_AWS_SESSION_TOKEN") is None
737
- else SecretStr(os.getenv("HPC_AWS_SESSION_TOKEN"))
738
- )
739
- ),
740
- **hpc_settings,
741
- )
742
- if not upload_job_configs:
743
- script = base_script
744
- else:
745
- script = hpc_job.attach_configs_to_script(
746
- script=base_script,
747
- base_configs=upload_job_configs,
748
- upload_configs_aws_param_store_name=os.getenv(
749
- "HPC_AWS_PARAM_STORE_NAME"
750
- ),
751
- staging_directory=os.getenv("HPC_STAGING_DIRECTORY"),
752
- )
753
- hpc_jobs.append((hpc_job, script))
754
- except Exception as e:
755
- parsing_errors.append(
756
- f"Error parsing {job['upload_job_settings']}: {repr(e)}"
757
- )
758
- if parsing_errors:
759
- status_code = 406
760
- message = "There were errors parsing the job configs"
761
- content = {
762
- "message": message,
763
- "data": {"responses": [], "errors": parsing_errors},
764
- }
765
- else:
766
- responses = []
767
- hpc_errors = []
768
- for hpc_job in hpc_jobs:
769
- hpc_job_def = hpc_job[0]
770
- try:
771
- script = hpc_job[1]
772
- response = hpc_client.submit_hpc_job(
773
- job=hpc_job_def, script=script
774
- )
775
- response_json = response.json()
776
- responses.append(response_json)
777
- # Add pause to stagger job requests to the hpc
778
- await sleep(0.2)
779
- except Exception as e:
780
- logger.error(repr(e))
781
- hpc_errors.append(f"Error processing " f"{hpc_job_def.name}")
782
- message = (
783
- "There were errors submitting jobs to the hpc."
784
- if len(hpc_errors) > 0
785
- else "Submitted Jobs."
786
- )
787
- status_code = 500 if len(hpc_errors) > 0 else 200
788
- content = {
789
- "message": message,
790
- "data": {"responses": responses, "errors": hpc_errors},
791
- }
792
- return JSONResponse(
793
- content=content,
794
- status_code=status_code,
795
- )
796
-
797
-
798
531
  async def get_job_status_list(request: Request):
799
532
  """Get status of jobs using input query params."""
800
533
 
@@ -842,20 +575,23 @@ async def get_tasks_list(request: Request):
842
575
  request.query_params
843
576
  )
844
577
  params_dict = json.loads(params.model_dump_json())
845
- response_tasks = requests.get(
846
- url=(
847
- f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}/"
848
- "taskInstances"
849
- ),
578
+ async with AsyncClient(
850
579
  auth=(
851
580
  os.getenv("AIND_AIRFLOW_SERVICE_USER"),
852
581
  os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
853
- ),
854
- )
855
- status_code = response_tasks.status_code
856
- if response_tasks.status_code == 200:
582
+ )
583
+ ) as async_client:
584
+ response_tasks = await async_client.get(
585
+ url=(
586
+ f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}/"
587
+ "taskInstances"
588
+ ),
589
+ )
590
+ status_code = response_tasks.status_code
591
+ response_json = response_tasks.json()
592
+ if status_code == 200:
857
593
  task_instances = AirflowTaskInstancesResponse.model_validate_json(
858
- json.dumps(response_tasks.json())
594
+ json.dumps(response_json)
859
595
  )
860
596
  job_tasks_list = sorted(
861
597
  [
@@ -876,7 +612,7 @@ async def get_tasks_list(request: Request):
876
612
  message = "Error retrieving job tasks list from airflow"
877
613
  data = {
878
614
  "params": params_dict,
879
- "errors": [response_tasks.json()],
615
+ "errors": [response_json],
880
616
  }
881
617
  except ValidationError as e:
882
618
  logger.warning(f"There was a validation error process task_list: {e}")
@@ -906,27 +642,29 @@ async def get_task_logs(request: Request):
906
642
  )
907
643
  params_dict = json.loads(params.model_dump_json())
908
644
  params_full = dict(params)
909
- response_logs = requests.get(
910
- url=(
911
- f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}"
912
- f"/taskInstances/{params.task_id}/logs/{params.try_number}"
913
- ),
645
+ async with AsyncClient(
914
646
  auth=(
915
647
  os.getenv("AIND_AIRFLOW_SERVICE_USER"),
916
648
  os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
917
- ),
918
- params=params_dict,
919
- )
920
- status_code = response_logs.status_code
921
- if response_logs.status_code == 200:
922
- message = "Retrieved task logs from airflow"
923
- data = {"params": params_full, "logs": response_logs.text}
924
- else:
925
- message = "Error retrieving task logs from airflow"
926
- data = {
927
- "params": params_full,
928
- "errors": [response_logs.json()],
929
- }
649
+ )
650
+ ) as async_client:
651
+ response_logs = await async_client.get(
652
+ url=(
653
+ f"{url}/{params.dag_id}/dagRuns/{params.dag_run_id}"
654
+ f"/taskInstances/{params.task_id}/logs/{params.try_number}"
655
+ ),
656
+ params=params_dict,
657
+ )
658
+ status_code = response_logs.status_code
659
+ if status_code == 200:
660
+ message = "Retrieved task logs from airflow"
661
+ data = {"params": params_full, "logs": response_logs.text}
662
+ else:
663
+ message = "Error retrieving task logs from airflow"
664
+ data = {
665
+ "params": params_full,
666
+ "errors": [response_logs.json()],
667
+ }
930
668
  except ValidationError as e:
931
669
  logger.warning(f"Error validating request parameters: {e}")
932
670
  status_code = 406
@@ -1266,7 +1004,7 @@ async def auth(request: Request):
1266
1004
 
1267
1005
  routes = [
1268
1006
  Route("/", endpoint=index, methods=["GET", "POST"]),
1269
- Route("/api/validate_csv", endpoint=validate_csv_legacy, methods=["POST"]),
1007
+ Route("/api/validate_csv", endpoint=validate_csv, methods=["POST"]),
1270
1008
  Route(
1271
1009
  "/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
1272
1010
  ),