nmdc-runtime 2.8.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nmdc-runtime might be problematic. Click here for more details.
- nmdc_runtime/api/__init__.py +0 -0
- nmdc_runtime/api/analytics.py +70 -0
- nmdc_runtime/api/boot/__init__.py +0 -0
- nmdc_runtime/api/boot/capabilities.py +9 -0
- nmdc_runtime/api/boot/object_types.py +126 -0
- nmdc_runtime/api/boot/triggers.py +84 -0
- nmdc_runtime/api/boot/workflows.py +116 -0
- nmdc_runtime/api/core/__init__.py +0 -0
- nmdc_runtime/api/core/auth.py +208 -0
- nmdc_runtime/api/core/idgen.py +170 -0
- nmdc_runtime/api/core/metadata.py +788 -0
- nmdc_runtime/api/core/util.py +109 -0
- nmdc_runtime/api/db/__init__.py +0 -0
- nmdc_runtime/api/db/mongo.py +447 -0
- nmdc_runtime/api/db/s3.py +37 -0
- nmdc_runtime/api/endpoints/__init__.py +0 -0
- nmdc_runtime/api/endpoints/capabilities.py +25 -0
- nmdc_runtime/api/endpoints/find.py +794 -0
- nmdc_runtime/api/endpoints/ids.py +192 -0
- nmdc_runtime/api/endpoints/jobs.py +143 -0
- nmdc_runtime/api/endpoints/lib/__init__.py +0 -0
- nmdc_runtime/api/endpoints/lib/helpers.py +274 -0
- nmdc_runtime/api/endpoints/lib/path_segments.py +165 -0
- nmdc_runtime/api/endpoints/metadata.py +260 -0
- nmdc_runtime/api/endpoints/nmdcschema.py +581 -0
- nmdc_runtime/api/endpoints/object_types.py +38 -0
- nmdc_runtime/api/endpoints/objects.py +277 -0
- nmdc_runtime/api/endpoints/operations.py +105 -0
- nmdc_runtime/api/endpoints/queries.py +679 -0
- nmdc_runtime/api/endpoints/runs.py +98 -0
- nmdc_runtime/api/endpoints/search.py +38 -0
- nmdc_runtime/api/endpoints/sites.py +229 -0
- nmdc_runtime/api/endpoints/triggers.py +25 -0
- nmdc_runtime/api/endpoints/users.py +214 -0
- nmdc_runtime/api/endpoints/util.py +774 -0
- nmdc_runtime/api/endpoints/workflows.py +353 -0
- nmdc_runtime/api/main.py +401 -0
- nmdc_runtime/api/middleware.py +43 -0
- nmdc_runtime/api/models/__init__.py +0 -0
- nmdc_runtime/api/models/capability.py +14 -0
- nmdc_runtime/api/models/id.py +92 -0
- nmdc_runtime/api/models/job.py +37 -0
- nmdc_runtime/api/models/lib/__init__.py +0 -0
- nmdc_runtime/api/models/lib/helpers.py +78 -0
- nmdc_runtime/api/models/metadata.py +11 -0
- nmdc_runtime/api/models/minter.py +0 -0
- nmdc_runtime/api/models/nmdc_schema.py +146 -0
- nmdc_runtime/api/models/object.py +180 -0
- nmdc_runtime/api/models/object_type.py +20 -0
- nmdc_runtime/api/models/operation.py +66 -0
- nmdc_runtime/api/models/query.py +246 -0
- nmdc_runtime/api/models/query_continuation.py +111 -0
- nmdc_runtime/api/models/run.py +161 -0
- nmdc_runtime/api/models/site.py +87 -0
- nmdc_runtime/api/models/trigger.py +13 -0
- nmdc_runtime/api/models/user.py +140 -0
- nmdc_runtime/api/models/util.py +253 -0
- nmdc_runtime/api/models/workflow.py +15 -0
- nmdc_runtime/api/openapi.py +242 -0
- nmdc_runtime/config.py +55 -4
- nmdc_runtime/core/db/Database.py +1 -3
- nmdc_runtime/infrastructure/database/models/user.py +0 -9
- nmdc_runtime/lib/extract_nmdc_data.py +0 -8
- nmdc_runtime/lib/nmdc_dataframes.py +3 -7
- nmdc_runtime/lib/nmdc_etl_class.py +1 -7
- nmdc_runtime/minter/adapters/repository.py +1 -2
- nmdc_runtime/minter/config.py +2 -0
- nmdc_runtime/minter/domain/model.py +35 -1
- nmdc_runtime/minter/entrypoints/fastapi_app.py +1 -1
- nmdc_runtime/mongo_util.py +1 -2
- nmdc_runtime/site/backup/nmdcdb_mongodump.py +1 -1
- nmdc_runtime/site/backup/nmdcdb_mongoexport.py +1 -3
- nmdc_runtime/site/export/ncbi_xml.py +1 -2
- nmdc_runtime/site/export/ncbi_xml_utils.py +1 -1
- nmdc_runtime/site/graphs.py +33 -28
- nmdc_runtime/site/ops.py +97 -237
- nmdc_runtime/site/repair/database_updater.py +8 -0
- nmdc_runtime/site/repository.py +7 -117
- nmdc_runtime/site/resources.py +4 -4
- nmdc_runtime/site/translation/gold_translator.py +22 -21
- nmdc_runtime/site/translation/neon_benthic_translator.py +0 -1
- nmdc_runtime/site/translation/neon_soil_translator.py +4 -5
- nmdc_runtime/site/translation/neon_surface_water_translator.py +0 -2
- nmdc_runtime/site/translation/submission_portal_translator.py +64 -54
- nmdc_runtime/site/translation/translator.py +63 -1
- nmdc_runtime/site/util.py +8 -3
- nmdc_runtime/site/validation/util.py +10 -5
- nmdc_runtime/util.py +9 -321
- {nmdc_runtime-2.8.0.dist-info → nmdc_runtime-2.10.0.dist-info}/METADATA +57 -6
- nmdc_runtime-2.10.0.dist-info/RECORD +138 -0
- nmdc_runtime/site/translation/emsl.py +0 -43
- nmdc_runtime/site/translation/gold.py +0 -53
- nmdc_runtime/site/translation/jgi.py +0 -32
- nmdc_runtime/site/translation/util.py +0 -132
- nmdc_runtime/site/validation/jgi.py +0 -43
- nmdc_runtime-2.8.0.dist-info/RECORD +0 -84
- {nmdc_runtime-2.8.0.dist-info → nmdc_runtime-2.10.0.dist-info}/WHEEL +0 -0
- {nmdc_runtime-2.8.0.dist-info → nmdc_runtime-2.10.0.dist-info}/entry_points.txt +0 -0
- {nmdc_runtime-2.8.0.dist-info → nmdc_runtime-2.10.0.dist-info}/licenses/LICENSE +0 -0
- {nmdc_runtime-2.8.0.dist-info → nmdc_runtime-2.10.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A *query continuation* is a means to effectively resume a query, i.e. a `find` or `aggregate` MongoDB database command.
|
|
3
|
+
|
|
4
|
+
A *query continuation* document represents a *continuation* (cf. <https://en.wikipedia.org/wiki/Continuation>) for a
|
|
5
|
+
query and uses a stored value ("cursor") for MongoDB's guaranteed unique-valued document field, `_id`,
|
|
6
|
+
such that the documents returned by the command are guaranteed to be sorted in ascending order by `_id`.
|
|
7
|
+
|
|
8
|
+
In this way, an API client may retrieve all documents defined by a `find` or `aggregate` command over multiple HTTP
|
|
9
|
+
requests. One can think of this process as akin to pagination; however, with "cursor-based" pagination, there are no
|
|
10
|
+
guarantees wrt a fixed "page size".
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import datetime
|
|
15
|
+
import logging
|
|
16
|
+
import json
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel, Field
|
|
19
|
+
from pymongo.database import Database as MongoDatabase
|
|
20
|
+
|
|
21
|
+
from nmdc_runtime.api.core.idgen import generate_one_id
|
|
22
|
+
from nmdc_runtime.api.core.util import now
|
|
23
|
+
from nmdc_runtime.api.db.mongo import get_mongo_db
|
|
24
|
+
from nmdc_runtime.api.models.query import (
|
|
25
|
+
CommandResponse,
|
|
26
|
+
QueryCmd,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
COLLECTION_NAME_FOR_QUERY_CONTINUATIONS = "_runtime.query_continuations"
|
|
30
|
+
|
|
31
|
+
_mdb: MongoDatabase = get_mongo_db()
|
|
32
|
+
_qc_collection = _mdb[COLLECTION_NAME_FOR_QUERY_CONTINUATIONS]
|
|
33
|
+
|
|
34
|
+
# Ensure one-hour TTL on `_runtime.query_continuations` documents via TTL Index.
|
|
35
|
+
# Reference: https://www.mongodb.com/docs/manual/core/index-ttl/
|
|
36
|
+
_qc_collection.create_index({"last_modified": 1}, expireAfterSeconds=3600)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def not_empty(lst: list) -> bool:
|
|
40
|
+
return len(lst) > 0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class QueryContinuation(BaseModel):
|
|
44
|
+
"""A query that has not completed, and that may be resumed, using `cursor` to modify `query_cmd`.
|
|
45
|
+
|
|
46
|
+
This model is intended to represent the state of a logical "session" to "page" through a query's results
|
|
47
|
+
over several HTTP requests, and may be discarded after fetching all "batches" of documents.
|
|
48
|
+
|
|
49
|
+
Thus, a mongo collection tracking query continuations may be reasonably given e.g. a so-called "TTL Index"
|
|
50
|
+
for the `last_modified` field, assuming that `last_modified` is updated each time `query` is updated.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
id: str = Field(..., alias="_id")
|
|
54
|
+
query_cmd: QueryCmd
|
|
55
|
+
cursor: str
|
|
56
|
+
last_modified: datetime.datetime
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class QueryContinuationError(Exception):
|
|
60
|
+
def __init__(self, detail: str):
|
|
61
|
+
self.detail = detail
|
|
62
|
+
|
|
63
|
+
def __repr__(self):
|
|
64
|
+
return f"{self.__class__.__name__}: {self.detail})"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def dump_qc(m: BaseModel):
|
|
68
|
+
return m.model_dump(by_alias=True, exclude_unset=True)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def create_qc(query_cmd: QueryCmd, cmd_response: CommandResponse) -> QueryContinuation:
|
|
72
|
+
"""Creates query continuation from command and response, and persists continuation to database."""
|
|
73
|
+
|
|
74
|
+
logging.info(f"cmd_response: {cmd_response}")
|
|
75
|
+
last_id = json.dumps(cmd_response.cursor.batch[-1]["_id"])
|
|
76
|
+
logging.info(f"Last document ID for query continuation: {last_id}")
|
|
77
|
+
cc = QueryContinuation(
|
|
78
|
+
_id=generate_one_id(_mdb, "query_continuation"),
|
|
79
|
+
query_cmd=query_cmd,
|
|
80
|
+
cursor=last_id,
|
|
81
|
+
last_modified=now(),
|
|
82
|
+
)
|
|
83
|
+
_qc_collection.insert_one(dump_qc(cc))
|
|
84
|
+
return cc
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_qc_by__id(_id: str) -> QueryContinuation | None:
|
|
88
|
+
r"""
|
|
89
|
+
Returns the `QueryContinuation` having the specified `_id` value, raising an exception
|
|
90
|
+
if the corresponding document does not exist in the database.
|
|
91
|
+
"""
|
|
92
|
+
doc = _qc_collection.find_one({"_id": _id})
|
|
93
|
+
if doc is None:
|
|
94
|
+
raise QueryContinuationError(f"cannot find cc with id {_id}")
|
|
95
|
+
return QueryContinuation(**doc)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_last_doc__id_for_qc(query_continuation: QueryContinuation) -> str:
|
|
99
|
+
"""
|
|
100
|
+
Retrieve the last document `_id` for the given `QueryContinuation`.
|
|
101
|
+
"""
|
|
102
|
+
# Assuming `query_continuation` has an attribute `cursor` that stores the last document _id
|
|
103
|
+
logging.info(f"Cursor for last doc query continuation: {query_continuation.cursor}")
|
|
104
|
+
return json.loads(query_continuation.cursor)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_initial_query_for_qc(query_continuation: QueryContinuation) -> QueryCmd:
|
|
108
|
+
"""
|
|
109
|
+
Retrieve the initial query command for the given `QueryContinuation`.
|
|
110
|
+
"""
|
|
111
|
+
return query_continuation.query_cmd
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
import os
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from dagster_graphql import DagsterGraphQLClient
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from pymongo.database import Database as MongoDatabase
|
|
9
|
+
from toolz import merge
|
|
10
|
+
|
|
11
|
+
from nmdc_runtime.api.core.idgen import generate_one_id
|
|
12
|
+
from nmdc_runtime.api.core.util import now, raise404_if_none, pick
|
|
13
|
+
from nmdc_runtime.api.models.user import User
|
|
14
|
+
|
|
15
|
+
PRODUCER_URL_BASE_DEFAULT = (
|
|
16
|
+
"https://github.com/microbiomedata/nmdc-runtime/tree/main/nmdc_runtime/"
|
|
17
|
+
)
|
|
18
|
+
SCHEMA_URL_BASE_DEFAULT = (
|
|
19
|
+
"https://github.com/microbiomedata/nmdc-runtime/tree/main/nmdc_runtime/"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
PRODUCER_URL = PRODUCER_URL_BASE_DEFAULT.replace("/main/", "/v0-0-1/") + "producer"
|
|
23
|
+
SCHEMA_URL = SCHEMA_URL_BASE_DEFAULT.replace("/main/", "/v0-0-1/") + "schema.json"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OpenLineageBase(BaseModel):
|
|
27
|
+
producer: str
|
|
28
|
+
schemaURL: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RunUserSpec(BaseModel):
|
|
32
|
+
job_id: str
|
|
33
|
+
run_config: dict = {}
|
|
34
|
+
inputs: List[str] = []
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class JobSummary(OpenLineageBase):
|
|
38
|
+
id: str
|
|
39
|
+
description: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Run(BaseModel):
|
|
43
|
+
id: str
|
|
44
|
+
facets: Optional[dict] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class RunEventType(str, Enum):
|
|
48
|
+
REQUESTED = "REQUESTED"
|
|
49
|
+
STARTED = "STARTED"
|
|
50
|
+
FAIL = "FAIL"
|
|
51
|
+
COMPLETE = "COMPLETE"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RunSummary(OpenLineageBase):
|
|
55
|
+
id: str
|
|
56
|
+
status: RunEventType
|
|
57
|
+
started_at_time: str
|
|
58
|
+
was_started_by: str
|
|
59
|
+
inputs: List[str]
|
|
60
|
+
outputs: List[str]
|
|
61
|
+
job: JobSummary
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class RunEvent(OpenLineageBase):
|
|
65
|
+
run: Run
|
|
66
|
+
job: JobSummary
|
|
67
|
+
type: RunEventType
|
|
68
|
+
time: str
|
|
69
|
+
inputs: Optional[List[str]] = []
|
|
70
|
+
outputs: Optional[List[str]] = []
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@lru_cache
|
|
74
|
+
def get_dagster_graphql_client() -> DagsterGraphQLClient:
|
|
75
|
+
hostname, port_str = os.getenv("DAGIT_HOST").split("://", 1)[-1].split(":", 1)
|
|
76
|
+
port_number = int(port_str)
|
|
77
|
+
return DagsterGraphQLClient(hostname=hostname, port_number=port_number)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _add_run_requested_event(run_spec: RunUserSpec, mdb: MongoDatabase, user: User):
|
|
81
|
+
# XXX what we consider a "job" here, is currently a "workflow" elsewhere...
|
|
82
|
+
job = raise404_if_none(mdb.workflows.find_one({"id": run_spec.job_id}))
|
|
83
|
+
run_id = generate_one_id(mdb, "runs")
|
|
84
|
+
event = RunEvent(
|
|
85
|
+
producer=user.username,
|
|
86
|
+
schemaURL=SCHEMA_URL,
|
|
87
|
+
run=Run(id=run_id, facets={"nmdcRuntime_runConfig": run_spec.run_config}),
|
|
88
|
+
job=merge(
|
|
89
|
+
pick(["id", "description"], job),
|
|
90
|
+
{"producer": PRODUCER_URL, "schemaURL": SCHEMA_URL},
|
|
91
|
+
),
|
|
92
|
+
type=RunEventType.REQUESTED,
|
|
93
|
+
time=now(as_str=True),
|
|
94
|
+
inputs=run_spec.inputs,
|
|
95
|
+
)
|
|
96
|
+
mdb.run_events.insert_one(event.model_dump())
|
|
97
|
+
return run_id
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _add_run_started_event(run_id: str, mdb: MongoDatabase):
|
|
101
|
+
requested: RunEvent = RunEvent(
|
|
102
|
+
**raise404_if_none(
|
|
103
|
+
mdb.run_events.find_one(
|
|
104
|
+
{"run.id": run_id, "type": "REQUESTED"}, sort=[("time", -1)]
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
mdb.run_events.insert_one(
|
|
109
|
+
RunEvent(
|
|
110
|
+
producer=PRODUCER_URL,
|
|
111
|
+
schemaURL=SCHEMA_URL,
|
|
112
|
+
run=requested.run,
|
|
113
|
+
job=requested.job,
|
|
114
|
+
type=RunEventType.STARTED,
|
|
115
|
+
time=now(as_str=True),
|
|
116
|
+
).model_dump()
|
|
117
|
+
)
|
|
118
|
+
return run_id
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _add_run_fail_event(run_id: str, mdb: MongoDatabase):
|
|
122
|
+
requested: RunEvent = RunEvent(
|
|
123
|
+
**raise404_if_none(
|
|
124
|
+
mdb.run_events.find_one(
|
|
125
|
+
{"run.id": run_id, "type": "REQUESTED"}, sort=[("time", -1)]
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
mdb.run_events.insert_one(
|
|
130
|
+
RunEvent(
|
|
131
|
+
producer=PRODUCER_URL,
|
|
132
|
+
schemaURL=SCHEMA_URL,
|
|
133
|
+
run=requested.run,
|
|
134
|
+
job=requested.job,
|
|
135
|
+
type=RunEventType.FAIL,
|
|
136
|
+
time=now(as_str=True),
|
|
137
|
+
).model_dump()
|
|
138
|
+
)
|
|
139
|
+
return run_id
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _add_run_complete_event(run_id: str, mdb: MongoDatabase, outputs: List[str]):
|
|
143
|
+
started: RunEvent = RunEvent(
|
|
144
|
+
**raise404_if_none(
|
|
145
|
+
mdb.run_events.find_one(
|
|
146
|
+
{"run.id": run_id, "type": "STARTED"}, sort=[("time", -1)]
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
mdb.run_events.insert_one(
|
|
151
|
+
RunEvent(
|
|
152
|
+
producer=PRODUCER_URL,
|
|
153
|
+
schemaURL=SCHEMA_URL,
|
|
154
|
+
run=started.run,
|
|
155
|
+
job=started.job,
|
|
156
|
+
type=RunEventType.COMPLETE,
|
|
157
|
+
time=now(as_str=True),
|
|
158
|
+
outputs=outputs,
|
|
159
|
+
).model_dump()
|
|
160
|
+
)
|
|
161
|
+
return run_id
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
import pymongo.database
|
|
4
|
+
from fastapi import Depends
|
|
5
|
+
from jose import JWTError, jwt
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from nmdc_runtime.api.core.auth import (
|
|
9
|
+
verify_password,
|
|
10
|
+
TokenData,
|
|
11
|
+
optional_oauth2_scheme,
|
|
12
|
+
)
|
|
13
|
+
from nmdc_runtime.api.db.mongo import get_mongo_db
|
|
14
|
+
from nmdc_runtime.api.models.user import (
|
|
15
|
+
oauth2_scheme,
|
|
16
|
+
credentials_exception,
|
|
17
|
+
SECRET_KEY,
|
|
18
|
+
ALGORITHM,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Site(BaseModel):
|
|
23
|
+
id: str
|
|
24
|
+
capability_ids: List[str] = []
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SiteClientInDB(BaseModel):
|
|
28
|
+
id: str
|
|
29
|
+
hashed_secret: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SiteInDB(Site):
|
|
33
|
+
clients: List[SiteClientInDB] = []
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_site(mdb, client_id: str) -> Optional[SiteInDB]:
|
|
37
|
+
r"""
|
|
38
|
+
Returns the site, if any, for which the specified `client_id` was generated.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
site = mdb.sites.find_one({"clients.id": client_id})
|
|
42
|
+
if site is not None:
|
|
43
|
+
return SiteInDB(**site)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def authenticate_site_client(mdb, client_id: str, client_secret: str):
|
|
47
|
+
site = get_site(mdb, client_id)
|
|
48
|
+
if not site:
|
|
49
|
+
return False
|
|
50
|
+
hashed_secret = next(
|
|
51
|
+
client.hashed_secret for client in site.clients if client.id == client_id
|
|
52
|
+
)
|
|
53
|
+
if not verify_password(client_secret, hashed_secret):
|
|
54
|
+
return False
|
|
55
|
+
return site
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def get_current_client_site(
|
|
59
|
+
token: str = Depends(oauth2_scheme),
|
|
60
|
+
mdb: pymongo.database.Database = Depends(get_mongo_db),
|
|
61
|
+
):
|
|
62
|
+
if mdb.invalidated_tokens.find_one({"_id": token}):
|
|
63
|
+
raise credentials_exception
|
|
64
|
+
try:
|
|
65
|
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
66
|
+
subject: str = payload.get("sub")
|
|
67
|
+
if subject is None:
|
|
68
|
+
raise credentials_exception
|
|
69
|
+
if not subject.startswith("client:"):
|
|
70
|
+
raise credentials_exception
|
|
71
|
+
client_id = subject.split("client:", 1)[1]
|
|
72
|
+
token_data = TokenData(subject=client_id)
|
|
73
|
+
except JWTError:
|
|
74
|
+
raise credentials_exception
|
|
75
|
+
site = get_site(mdb, client_id=token_data.subject)
|
|
76
|
+
if site is None:
|
|
77
|
+
raise credentials_exception
|
|
78
|
+
return site
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
async def maybe_get_current_client_site(
|
|
82
|
+
token: str = Depends(optional_oauth2_scheme),
|
|
83
|
+
mdb: pymongo.database.Database = Depends(get_mongo_db),
|
|
84
|
+
):
|
|
85
|
+
if token is None:
|
|
86
|
+
return None
|
|
87
|
+
return await get_current_client_site(token, mdb)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import pymongo.database
|
|
4
|
+
from fastapi import Depends, HTTPException
|
|
5
|
+
from jose import JWTError, jwt
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from nmdc_runtime.api.core.auth import (
|
|
9
|
+
verify_password,
|
|
10
|
+
SECRET_KEY,
|
|
11
|
+
ALGORITHM,
|
|
12
|
+
oauth2_scheme,
|
|
13
|
+
credentials_exception,
|
|
14
|
+
TokenData,
|
|
15
|
+
bearer_scheme,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from nmdc_runtime.api.models.site import get_site
|
|
19
|
+
|
|
20
|
+
from nmdc_runtime.api.db.mongo import get_mongo_db
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class User(BaseModel):
|
|
24
|
+
username: str
|
|
25
|
+
email: Optional[str] = None
|
|
26
|
+
full_name: Optional[str] = None
|
|
27
|
+
site_admin: Optional[List[str]] = []
|
|
28
|
+
disabled: Optional[bool] = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class UserIn(User):
|
|
32
|
+
password: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class UserInDB(User):
|
|
36
|
+
hashed_password: str
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_user(mdb, username: str) -> Optional[UserInDB]:
|
|
40
|
+
r"""
|
|
41
|
+
Returns the user having the specified username.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
user = mdb.users.find_one({"username": username})
|
|
45
|
+
if user is not None:
|
|
46
|
+
return UserInDB(**user)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def authenticate_user(mdb, username: str, password: str) -> Union[UserInDB, bool]:
|
|
50
|
+
r"""
|
|
51
|
+
Returns the user, if any, having the specified username/password combination.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
user = get_user(mdb, username)
|
|
55
|
+
if not user:
|
|
56
|
+
return False
|
|
57
|
+
if not verify_password(password, user.hashed_password):
|
|
58
|
+
return False
|
|
59
|
+
return user
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
async def get_current_user(
|
|
63
|
+
token: str = Depends(oauth2_scheme),
|
|
64
|
+
bearer_credentials: str = Depends(bearer_scheme),
|
|
65
|
+
mdb: pymongo.database.Database = Depends(get_mongo_db),
|
|
66
|
+
) -> UserInDB:
|
|
67
|
+
r"""
|
|
68
|
+
Returns a user based upon the provided token.
|
|
69
|
+
|
|
70
|
+
If the token belongs to a site client, the returned user is an ephemeral "user"
|
|
71
|
+
whose username is the site client's `client_id`.
|
|
72
|
+
|
|
73
|
+
Raises an exception if the token is invalid.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
if mdb.invalidated_tokens.find_one({"_id": token}):
|
|
77
|
+
raise credentials_exception
|
|
78
|
+
try:
|
|
79
|
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
80
|
+
subject: str = payload.get("sub")
|
|
81
|
+
if subject is None:
|
|
82
|
+
raise credentials_exception
|
|
83
|
+
if not subject.startswith("user:") and not subject.startswith("client:"):
|
|
84
|
+
raise credentials_exception
|
|
85
|
+
|
|
86
|
+
# subject is in the form "user:foo" or "client:bar"
|
|
87
|
+
username = subject.split(":", 1)[1]
|
|
88
|
+
token_data = TokenData(subject=username)
|
|
89
|
+
except (JWTError, AttributeError) as e:
|
|
90
|
+
print(f"jwt error: {e}")
|
|
91
|
+
raise credentials_exception
|
|
92
|
+
|
|
93
|
+
# Coerce a "client" into a "user"
|
|
94
|
+
# TODO: consolidate the client/user distinction.
|
|
95
|
+
if subject.startswith("user:"):
|
|
96
|
+
user = get_user(mdb, username=token_data.subject)
|
|
97
|
+
elif subject.startswith("client:"):
|
|
98
|
+
# construct a user from the client_id
|
|
99
|
+
user = get_client_user(mdb, client_id=token_data.subject)
|
|
100
|
+
else:
|
|
101
|
+
raise credentials_exception
|
|
102
|
+
if user is None:
|
|
103
|
+
raise credentials_exception
|
|
104
|
+
return user
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_client_user(mdb, client_id: str) -> UserInDB:
|
|
108
|
+
r"""
|
|
109
|
+
Returns an ephemeral "user" whose username is the specified `client_id`
|
|
110
|
+
and whose password is the hashed secret of the client; provided that the
|
|
111
|
+
specified `client_id` is associated with a site in the database.
|
|
112
|
+
|
|
113
|
+
TODO: Clarify the above summary of the function.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Get the site associated with the identified client.
|
|
117
|
+
site = get_site(mdb, client_id)
|
|
118
|
+
if site is None:
|
|
119
|
+
raise credentials_exception
|
|
120
|
+
|
|
121
|
+
# Get the client, itself, via the site.
|
|
122
|
+
client = next(client for client in site.clients if client.id == client_id)
|
|
123
|
+
if client is None:
|
|
124
|
+
raise credentials_exception
|
|
125
|
+
|
|
126
|
+
# Make an ephemeral "user" whose username matches the client's `id`.
|
|
127
|
+
user = UserInDB(username=client.id, hashed_password=client.hashed_secret)
|
|
128
|
+
return user
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
async def get_current_active_user(
|
|
132
|
+
current_user: UserInDB = Depends(get_current_user),
|
|
133
|
+
) -> UserInDB:
|
|
134
|
+
r"""
|
|
135
|
+
Returns the current user, provided their user account is not disabled.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
if current_user.disabled:
|
|
139
|
+
raise HTTPException(status_code=400, detail="Inactive user")
|
|
140
|
+
return current_user
|