nmdc-runtime 2.9.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nmdc-runtime might be problematic. Click here for more details.

Files changed (98) hide show
  1. nmdc_runtime/api/__init__.py +0 -0
  2. nmdc_runtime/api/analytics.py +70 -0
  3. nmdc_runtime/api/boot/__init__.py +0 -0
  4. nmdc_runtime/api/boot/capabilities.py +9 -0
  5. nmdc_runtime/api/boot/object_types.py +126 -0
  6. nmdc_runtime/api/boot/triggers.py +84 -0
  7. nmdc_runtime/api/boot/workflows.py +116 -0
  8. nmdc_runtime/api/core/__init__.py +0 -0
  9. nmdc_runtime/api/core/auth.py +208 -0
  10. nmdc_runtime/api/core/idgen.py +170 -0
  11. nmdc_runtime/api/core/metadata.py +788 -0
  12. nmdc_runtime/api/core/util.py +109 -0
  13. nmdc_runtime/api/db/__init__.py +0 -0
  14. nmdc_runtime/api/db/mongo.py +447 -0
  15. nmdc_runtime/api/db/s3.py +37 -0
  16. nmdc_runtime/api/endpoints/__init__.py +0 -0
  17. nmdc_runtime/api/endpoints/capabilities.py +25 -0
  18. nmdc_runtime/api/endpoints/find.py +794 -0
  19. nmdc_runtime/api/endpoints/ids.py +192 -0
  20. nmdc_runtime/api/endpoints/jobs.py +143 -0
  21. nmdc_runtime/api/endpoints/lib/__init__.py +0 -0
  22. nmdc_runtime/api/endpoints/lib/helpers.py +274 -0
  23. nmdc_runtime/api/endpoints/lib/path_segments.py +165 -0
  24. nmdc_runtime/api/endpoints/metadata.py +260 -0
  25. nmdc_runtime/api/endpoints/nmdcschema.py +581 -0
  26. nmdc_runtime/api/endpoints/object_types.py +38 -0
  27. nmdc_runtime/api/endpoints/objects.py +277 -0
  28. nmdc_runtime/api/endpoints/operations.py +105 -0
  29. nmdc_runtime/api/endpoints/queries.py +679 -0
  30. nmdc_runtime/api/endpoints/runs.py +98 -0
  31. nmdc_runtime/api/endpoints/search.py +38 -0
  32. nmdc_runtime/api/endpoints/sites.py +229 -0
  33. nmdc_runtime/api/endpoints/triggers.py +25 -0
  34. nmdc_runtime/api/endpoints/users.py +214 -0
  35. nmdc_runtime/api/endpoints/util.py +774 -0
  36. nmdc_runtime/api/endpoints/workflows.py +353 -0
  37. nmdc_runtime/api/main.py +401 -0
  38. nmdc_runtime/api/middleware.py +43 -0
  39. nmdc_runtime/api/models/__init__.py +0 -0
  40. nmdc_runtime/api/models/capability.py +14 -0
  41. nmdc_runtime/api/models/id.py +92 -0
  42. nmdc_runtime/api/models/job.py +37 -0
  43. nmdc_runtime/api/models/lib/__init__.py +0 -0
  44. nmdc_runtime/api/models/lib/helpers.py +78 -0
  45. nmdc_runtime/api/models/metadata.py +11 -0
  46. nmdc_runtime/api/models/minter.py +0 -0
  47. nmdc_runtime/api/models/nmdc_schema.py +146 -0
  48. nmdc_runtime/api/models/object.py +180 -0
  49. nmdc_runtime/api/models/object_type.py +20 -0
  50. nmdc_runtime/api/models/operation.py +66 -0
  51. nmdc_runtime/api/models/query.py +246 -0
  52. nmdc_runtime/api/models/query_continuation.py +111 -0
  53. nmdc_runtime/api/models/run.py +161 -0
  54. nmdc_runtime/api/models/site.py +87 -0
  55. nmdc_runtime/api/models/trigger.py +13 -0
  56. nmdc_runtime/api/models/user.py +140 -0
  57. nmdc_runtime/api/models/util.py +253 -0
  58. nmdc_runtime/api/models/workflow.py +15 -0
  59. nmdc_runtime/api/openapi.py +242 -0
  60. nmdc_runtime/config.py +7 -8
  61. nmdc_runtime/core/db/Database.py +1 -3
  62. nmdc_runtime/infrastructure/database/models/user.py +0 -9
  63. nmdc_runtime/lib/extract_nmdc_data.py +0 -8
  64. nmdc_runtime/lib/nmdc_dataframes.py +3 -7
  65. nmdc_runtime/lib/nmdc_etl_class.py +1 -7
  66. nmdc_runtime/minter/adapters/repository.py +1 -2
  67. nmdc_runtime/minter/config.py +2 -0
  68. nmdc_runtime/minter/domain/model.py +35 -1
  69. nmdc_runtime/minter/entrypoints/fastapi_app.py +1 -1
  70. nmdc_runtime/mongo_util.py +1 -2
  71. nmdc_runtime/site/backup/nmdcdb_mongodump.py +1 -1
  72. nmdc_runtime/site/backup/nmdcdb_mongoexport.py +1 -3
  73. nmdc_runtime/site/export/ncbi_xml.py +1 -2
  74. nmdc_runtime/site/export/ncbi_xml_utils.py +1 -1
  75. nmdc_runtime/site/graphs.py +1 -22
  76. nmdc_runtime/site/ops.py +60 -152
  77. nmdc_runtime/site/repository.py +0 -112
  78. nmdc_runtime/site/translation/gold_translator.py +4 -12
  79. nmdc_runtime/site/translation/neon_benthic_translator.py +0 -1
  80. nmdc_runtime/site/translation/neon_soil_translator.py +4 -5
  81. nmdc_runtime/site/translation/neon_surface_water_translator.py +0 -2
  82. nmdc_runtime/site/translation/submission_portal_translator.py +2 -54
  83. nmdc_runtime/site/translation/translator.py +63 -1
  84. nmdc_runtime/site/util.py +8 -3
  85. nmdc_runtime/site/validation/util.py +10 -5
  86. nmdc_runtime/util.py +3 -47
  87. {nmdc_runtime-2.9.0.dist-info → nmdc_runtime-2.10.0.dist-info}/METADATA +57 -6
  88. nmdc_runtime-2.10.0.dist-info/RECORD +138 -0
  89. nmdc_runtime/site/translation/emsl.py +0 -43
  90. nmdc_runtime/site/translation/gold.py +0 -53
  91. nmdc_runtime/site/translation/jgi.py +0 -32
  92. nmdc_runtime/site/translation/util.py +0 -132
  93. nmdc_runtime/site/validation/jgi.py +0 -43
  94. nmdc_runtime-2.9.0.dist-info/RECORD +0 -84
  95. {nmdc_runtime-2.9.0.dist-info → nmdc_runtime-2.10.0.dist-info}/WHEEL +0 -0
  96. {nmdc_runtime-2.9.0.dist-info → nmdc_runtime-2.10.0.dist-info}/entry_points.txt +0 -0
  97. {nmdc_runtime-2.9.0.dist-info → nmdc_runtime-2.10.0.dist-info}/licenses/LICENSE +0 -0
  98. {nmdc_runtime-2.9.0.dist-info → nmdc_runtime-2.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ import inspect
2
+ from enum import Enum
3
+ from typing import List, Any, Dict, Optional
4
+
5
+ from pydantic import BaseModel, Field, create_model
6
+ from refscan.lib.helpers import get_collection_names_from_schema
7
+
8
+ from nmdc_runtime.util import nmdc_schema_view
9
+
10
+
11
+ class FileTypeEnum(str, Enum):
12
+ ft_icr_ms_analysis_results = "FT ICR-MS Analysis Results"
13
+ gc_ms_metabolomics_results = "GC-MS Metabolomics Results"
14
+ metaproteomics_workflow_statistics = "Metaproteomics Workflow Statistics"
15
+ protein_report = "Protein Report"
16
+ peptide_report = "Peptide Report"
17
+ unfiltered_metaproteomics_results = "Unfiltered Metaproteomics Results"
18
+ read_count_and_rpkm = "Read Count and RPKM"
19
+ qc_non_rrna_r2 = "QC non-rRNA R2"
20
+ qc_non_rrna_r1 = "QC non-rRNA R1"
21
+ metagenome_bins = "Metagenome Bins"
22
+ checkm_statistics = "CheckM Statistics"
23
+ gottcha2_krona_plot = "GOTTCHA2 Krona Plot"
24
+ kraken2_krona_plot = "Kraken2 Krona Plot"
25
+ centrifuge_krona_plot = "Centrifuge Krona Plot"
26
+ kraken2_classification_report = "Kraken2 Classification Report"
27
+ kraken2_taxonomic_classification = "Kraken2 Taxonomic Classification"
28
+ centrifuge_classification_report = "Centrifuge Classification Report"
29
+ centrifuge_taxonomic_classification = "Centrifuge Taxonomic Classification"
30
+ structural_annotation_gff = "Structural Annotation GFF"
31
+ functional_annotation_gff = "Functional Annotation GFF"
32
+ annotation_amino_acid_fasta = "Annotation Amino Acid FASTA"
33
+ annotation_enzyme_commission = "Annotation Enzyme Commission"
34
+ annotation_kegg_orthology = "Annotation KEGG Orthology"
35
+ assembly_coverage_bam = "Assembly Coverage BAM"
36
+ assembly_agp = "Assembly AGP"
37
+ assembly_scaffolds = "Assembly Scaffolds"
38
+ assembly_contigs = "Assembly Contigs"
39
+ assembly_coverage_stats = "Assembly Coverage Stats"
40
+ filtered_sequencing_reads = "Filtered Sequencing Reads"
41
+ qc_statistics = "QC Statistics"
42
+ tigrfam_annotation_gff = "TIGRFam Annotation GFF"
43
+ clusters_of_orthologous_groups__cog__annotation_gff = (
44
+ "Clusters of Orthologous Groups (COG) Annotation GFF"
45
+ )
46
+ cath_funfams__functional_families__annotation_gff = (
47
+ "CATH FunFams (Functional Families) Annotation GFF"
48
+ )
49
+ superfam_annotation_gff = "SUPERFam Annotation GFF"
50
+ smart_annotation_gff = "SMART Annotation GFF"
51
+ pfam_annotation_gff = "Pfam Annotation GFF"
52
+ direct_infusion_ft_icr_ms_raw_data = "Direct Infusion FT ICR-MS Raw Data"
53
+
54
+
55
+ class DataObject(BaseModel):
56
+ id: str = Field(None)
57
+ name: str = Field(None, description="A human readable label for an entity")
58
+ description: str = Field(
59
+ None, description="a human-readable description of a thing"
60
+ )
61
+ alternative_identifiers: List[str] = Field(
62
+ None, description="A list of alternative identifiers for the entity."
63
+ )
64
+ compression_type: str = Field(
65
+ None, description="If provided, specifies the compression type"
66
+ )
67
+ data_object_type: FileTypeEnum = Field(None)
68
+ file_size_bytes: int = Field(None, description="Size of the file in bytes")
69
+ md5_checksum: str = Field(None, description="MD5 checksum of file (pre-compressed)")
70
+ type: str = Field(
71
+ "nmdc:DataObject",
72
+ description="An optional string that specifies the type object. This is used to allow for searches for different kinds of objects.",
73
+ )
74
+ url: str = Field(None)
75
+ was_generated_by: str = Field(None)
76
+
77
+
78
+ list_request_ops_per_field_type = {
79
+ (object,): ["$eq", "$neq", "$in", "$nin"],
80
+ (str,): ["$regex"],
81
+ (int, float): ["$gt", "$gte", "$lt", "$lte"],
82
+ }
83
+
84
+ list_request_ops_with_many_args = {"$in", "$nin"}
85
+
86
+
87
+ def create_list_request_model_for(cls):
88
+ field_filters = []
89
+ sig = inspect.signature(cls)
90
+ for p in sig.parameters.values():
91
+ field_name, field_type = p.name, p.annotation
92
+ if hasattr(field_type, "__origin__"): # a GenericAlias object, e.g. List[str].
93
+ field_type = field_type.__args__[0]
94
+ field_default = (
95
+ None if p.default in (inspect.Parameter.empty, []) else p.default
96
+ )
97
+ field_ops = []
98
+ for types_ok, type_ops in list_request_ops_per_field_type.items():
99
+ if field_type in types_ok or isinstance(field_type, types_ok):
100
+ field_ops.extend(type_ops)
101
+ for op in field_ops:
102
+ field_filters.append(
103
+ {
104
+ "name": field_name,
105
+ "arg_type": field_type,
106
+ "default": field_default,
107
+ "op": op,
108
+ }
109
+ )
110
+ create_model_kwargs = {"max_page_size": (int, 20), "page_token": (str, None)}
111
+ for ff in field_filters:
112
+ model_field_name = f'{ff["name"]}_{ff["op"][1:]}'
113
+ model_field_type = ff["arg_type"]
114
+ create_model_kwargs[model_field_name] = (model_field_type, ff["default"])
115
+ return create_model(f"{cls.__name__}ListRequest", **create_model_kwargs)
116
+
117
+
118
+ def list_request_filter_to_mongo_filter(req: dict):
119
+ filter_ = {}
120
+ for k, v in req.items():
121
+ if not v:
122
+ continue
123
+ field, op = k.rsplit("_", maxsplit=1)
124
+ op = f"${op}"
125
+ if field not in filter_:
126
+ filter_[field] = {}
127
+ if op not in filter_[field]:
128
+ if op in list_request_ops_with_many_args:
129
+ filter_[field][op] = v.split(",")
130
+ else:
131
+ filter_[field][op] = v
132
+ return filter_
133
+
134
+
135
+ DataObjectListRequest = create_list_request_model_for(DataObject)
136
+
137
+ SimplifiedDocument = Dict[str, Any]
138
+
139
+ schema_view = nmdc_schema_view()
140
+ SimplifiedNMDCDatabase = create_model(
141
+ "NMDCDatabase",
142
+ **{
143
+ coll_name: Optional[list[SimplifiedDocument]]
144
+ for coll_name in get_collection_names_from_schema(schema_view)
145
+ },
146
+ )
@@ -0,0 +1,180 @@
1
+ import datetime
2
+ import hashlib
3
+ import http
4
+ from enum import Enum
5
+ from typing import Optional, List, Dict
6
+
7
+ from pydantic import (
8
+ field_validator,
9
+ model_validator,
10
+ Field,
11
+ StringConstraints,
12
+ BaseModel,
13
+ AnyUrl,
14
+ HttpUrl,
15
+ field_serializer,
16
+ )
17
+ from typing_extensions import Annotated
18
+
19
+
20
+ class AccessMethodType(str, Enum):
21
+ s3 = "s3"
22
+ gs = "gs"
23
+ ftp = "ftp"
24
+ gsiftp = "gsiftp"
25
+ globus = "globus"
26
+ htsget = "htsget"
27
+ https = "https"
28
+ file = "file"
29
+
30
+
31
+ class AccessURL(BaseModel):
32
+ headers: Optional[Dict[str, str]] = None
33
+ url: AnyUrl
34
+
35
+ @field_serializer("url")
36
+ def serialize_url(self, url: AnyUrl, _info):
37
+ return str(url)
38
+
39
+
40
+ class AccessMethod(BaseModel):
41
+ access_id: Optional[Annotated[str, StringConstraints(min_length=1)]] = None
42
+ access_url: Optional[AccessURL] = None
43
+ region: Optional[str] = None
44
+ type: AccessMethodType = AccessMethodType.https
45
+
46
+ @model_validator(mode="before")
47
+ def at_least_one_of_access_id_and_url(cls, values):
48
+ access_id, access_url = values.get("access_id"), values.get("access_url")
49
+ if access_id is None and access_url is None:
50
+ raise ValueError(
51
+ "At least one of access_url and access_id must be provided."
52
+ )
53
+ return values
54
+
55
+
56
+ ChecksumType = Annotated[
57
+ str,
58
+ StringConstraints(
59
+ pattern=rf"(?P<checksumtype>({'|'.join(sorted(hashlib.algorithms_guaranteed))}))"
60
+ ),
61
+ ]
62
+
63
+
64
+ class Checksum(BaseModel):
65
+ checksum: Annotated[str, StringConstraints(min_length=1)]
66
+ type: ChecksumType
67
+
68
+
69
+ DrsId = Annotated[str, StringConstraints(pattern=r"^[A-Za-z0-9._~\-]+$")]
70
+ PortableFilename = Annotated[str, StringConstraints(pattern=r"^[A-Za-z0-9._\-]+$")]
71
+
72
+
73
+ class ContentsObject(BaseModel):
74
+ contents: Optional[List["ContentsObject"]] = None
75
+ drs_uri: Optional[List[AnyUrl]] = None
76
+ id: Optional[DrsId] = None
77
+ name: PortableFilename
78
+
79
+ @model_validator(mode="before")
80
+ def no_contents_means_single_blob(cls, values):
81
+ contents, id_ = values.get("contents"), values.get("id")
82
+ if contents is None and id_ is None:
83
+ raise ValueError("no contents means no further nesting, so id required")
84
+ return values
85
+
86
+ @field_serializer("drs_uri")
87
+ def serialize_url(self, drs_uri: Optional[List[AnyUrl]], _info):
88
+ if drs_uri is not None and len(drs_uri) > 0:
89
+ return [str(u) for u in drs_uri]
90
+ return drs_uri
91
+
92
+
93
+ # Note: Between Pydantic v1 and v2, the `update_forward_refs` method was renamed to `model_rebuild`.
94
+ # Reference: https://docs.pydantic.dev/2.11/migration/#changes-to-pydanticbasemodel
95
+ ContentsObject.model_rebuild()
96
+
97
+ Mimetype = Annotated[str, StringConstraints(pattern=r"^\w+/[-+.\w]+$")]
98
+ SizeInBytes = Annotated[int, Field(strict=True, ge=0)]
99
+
100
+
101
+ class Error(BaseModel):
102
+ msg: Optional[str] = None
103
+ status_code: http.HTTPStatus
104
+
105
+
106
+ class DrsObjectBase(BaseModel):
107
+ aliases: Optional[List[str]] = None
108
+ description: Optional[str] = None
109
+ mime_type: Optional[Mimetype] = None
110
+ name: Optional[PortableFilename] = None
111
+
112
+
113
+ class DrsObjectIn(DrsObjectBase):
114
+ access_methods: Optional[List[AccessMethod]] = None
115
+ checksums: List[Checksum]
116
+ contents: Optional[List[ContentsObject]] = None
117
+ created_time: datetime.datetime
118
+ size: SizeInBytes
119
+ updated_time: Optional[datetime.datetime] = None
120
+ version: Optional[str] = None
121
+
122
+ @model_validator(mode="before")
123
+ def no_contents_means_single_blob(cls, values):
124
+ contents, access_methods = values.get("contents"), values.get("access_methods")
125
+ if contents is None and access_methods is None:
126
+ raise ValueError(
127
+ "no contents means single blob, which requires access_methods"
128
+ )
129
+ return values
130
+
131
+ @field_validator("checksums")
132
+ @classmethod
133
+ def at_least_one_checksum(cls, v):
134
+ if not len(v) >= 1:
135
+ raise ValueError("At least one checksum requried")
136
+ return v
137
+
138
+
139
+ class DrsObject(DrsObjectIn):
140
+ id: DrsId
141
+ self_uri: AnyUrl
142
+
143
+ @field_serializer("self_uri")
144
+ def serialize_url(self, self_uri: AnyUrl, _info):
145
+ return str(self_uri)
146
+
147
+
148
+ Seconds = Annotated[int, Field(strict=True, gt=0)]
149
+
150
+
151
+ class ObjectPresignedUrl(BaseModel):
152
+ url: HttpUrl
153
+ expires_in: Seconds = 300
154
+
155
+ @field_serializer("url")
156
+ def serialize_url(self, url: HttpUrl, _info):
157
+ return str(url)
158
+
159
+
160
+ class DrsObjectOutBase(DrsObjectBase):
161
+ checksums: List[Checksum]
162
+ created_time: datetime.datetime
163
+ id: DrsId
164
+ self_uri: AnyUrl
165
+ size: SizeInBytes
166
+ updated_time: Optional[datetime.datetime] = None
167
+ version: Optional[str] = None
168
+
169
+ @field_serializer("self_uri")
170
+ def serialize_url(self, self_uri: AnyUrl, _info):
171
+ return str(self_uri)
172
+
173
+
174
+ class DrsObjectBlobOut(DrsObjectOutBase):
175
+ access_methods: List[AccessMethod]
176
+
177
+
178
+ class DrsObjectBundleOut(DrsObjectOutBase):
179
+ access_methods: Optional[List[AccessMethod]] = None
180
+ contents: List[ContentsObject]
@@ -0,0 +1,20 @@
1
+ import datetime
2
+ from typing import Optional, List
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from nmdc_runtime.api.models.object import DrsObject
7
+
8
+
9
+ class ObjectTypeBase(BaseModel):
10
+ name: Optional[str] = None
11
+ description: Optional[str] = None
12
+
13
+
14
+ class ObjectType(ObjectTypeBase):
15
+ id: str
16
+ created_at: datetime.datetime
17
+
18
+
19
+ class DrsObjectWithTypes(DrsObject):
20
+ types: Optional[List[str]] = None
@@ -0,0 +1,66 @@
1
+ import datetime
2
+ from typing import Generic, TypeVar, Optional, List, Any, Union
3
+
4
+ from pydantic import StringConstraints, BaseModel, HttpUrl, field_serializer
5
+
6
+ from nmdc_runtime.api.models.util import ResultT
7
+ from typing_extensions import Annotated
8
+
9
+ MetadataT = TypeVar("MetadataT")
10
+
11
+
12
+ PythonImportPath = Annotated[str, StringConstraints(pattern=r"^[A-Za-z0-9_.]+$")]
13
+
14
+
15
+ class OperationError(BaseModel):
16
+ code: str
17
+ message: str
18
+ details: Any = None
19
+
20
+
21
+ class Operation(BaseModel, Generic[ResultT, MetadataT]):
22
+ id: str
23
+ done: bool = False
24
+ expire_time: datetime.datetime
25
+ result: Optional[Union[ResultT, OperationError]] = None
26
+ metadata: Optional[MetadataT] = None
27
+
28
+
29
+ class UpdateOperationRequest(BaseModel, Generic[ResultT, MetadataT]):
30
+ done: bool = False
31
+ result: Optional[Union[ResultT, OperationError]] = None
32
+ metadata: Optional[MetadataT] = {}
33
+
34
+
35
+ class ListOperationsResponse(BaseModel, Generic[ResultT, MetadataT]):
36
+ resources: List[Operation[ResultT, MetadataT]]
37
+ next_page_token: Optional[str] = None
38
+
39
+
40
+ class Result(BaseModel):
41
+ model: Optional[PythonImportPath] = None
42
+
43
+
44
+ class EmptyResult(Result):
45
+ pass
46
+
47
+
48
+ class Metadata(BaseModel):
49
+ # XXX alternative: set model field using __class__ on __init__()?
50
+ model: Optional[PythonImportPath] = None
51
+ cancelled: Optional[bool] = None
52
+
53
+
54
+ class PausedOrNot(Metadata):
55
+ paused: bool
56
+
57
+
58
+ class ObjectPutMetadata(Metadata):
59
+ object_id: str
60
+ site_id: str
61
+ url: HttpUrl
62
+ expires_in_seconds: int
63
+
64
+ @field_serializer("url")
65
+ def serialize_url(self, url: HttpUrl, _info):
66
+ return str(url)
@@ -0,0 +1,246 @@
1
+ import json
2
+ import logging
3
+ from typing import Optional, Any, Dict, List, Union, TypedDict
4
+
5
+ import bson
6
+ import bson.json_util
7
+ from pydantic import (
8
+ model_validator,
9
+ Field,
10
+ BaseModel,
11
+ PositiveInt,
12
+ NonNegativeInt,
13
+ field_validator,
14
+ WrapSerializer,
15
+ )
16
+ from toolz import assoc, assoc_in
17
+ from typing_extensions import Annotated
18
+
19
+ from nmdc_runtime.api.core.util import pick
20
+
21
+
22
+ def bson_to_json(doc: Any, handler) -> dict:
23
+ """Ensure a dict with e.g. mongo ObjectIds will serialize as JSON."""
24
+ return json.loads(bson.json_util.dumps(doc))
25
+
26
+
27
+ Document = Annotated[Dict[str, Any], WrapSerializer(bson_to_json)]
28
+
29
+ OneOrZero = Annotated[int, Field(ge=0, le=1)]
30
+ One = Annotated[int, Field(ge=1, le=1)]
31
+ MinusOne = Annotated[int, Field(ge=-1, le=-1)]
32
+ OneOrMinusOne = Union[One, MinusOne]
33
+
34
+
35
+ class CommandBase(BaseModel):
36
+ comment: Optional[Any] = None
37
+
38
+
39
+ class CollStatsCommand(CommandBase):
40
+ collStats: str
41
+ scale: Optional[int] = 1
42
+
43
+
44
+ class CountCommand(CommandBase):
45
+ count: str
46
+ query: Optional[Document] = None
47
+
48
+
49
+ class FindCommand(CommandBase):
50
+ find: str
51
+ filter: Optional[Document] = None
52
+ projection: Optional[Dict[str, OneOrZero]] = None
53
+ allowPartialResults: Optional[bool] = True
54
+ batchSize: Optional[PositiveInt] = 101
55
+ sort: Optional[Dict[str, OneOrMinusOne]] = None
56
+ limit: Optional[NonNegativeInt] = None
57
+
58
+
59
+ class AggregateCommand(CommandBase):
60
+ aggregate: str
61
+ pipeline: List[Document]
62
+ allowDiskUse: Optional[bool] = False
63
+ cursor: Optional[Document] = None
64
+
65
+ @field_validator("pipeline")
66
+ @classmethod
67
+ def disallow_invalid_pipeline_stages(
68
+ cls, pipeline: List[Document]
69
+ ) -> List[Document]:
70
+ deny_list = ["$out", "$merge"]
71
+
72
+ if any(
73
+ key in deny_list for pipeline_stage in pipeline for key in pipeline_stage
74
+ ):
75
+ raise ValueError("$Out and $merge pipeline stages are not allowed.")
76
+
77
+ return pipeline
78
+
79
+ @model_validator(mode="before")
80
+ @classmethod
81
+ def ensure_default_value_for_cursor(cls, data: Any) -> Document:
82
+ if isinstance(data, dict) and "cursor" not in data:
83
+ return assoc(data, "cursor", {"batchSize": 25})
84
+ return data
85
+
86
+
87
+ class GetMoreCommand(CommandBase):
88
+ # Note: No `collection` field. See `QueryContinuation` for inter-API-request "sessions" are modeled.
89
+ getMore: str # Note: runtime uses a `str` id, not an `int` like mongo's native session cursors.
90
+ batchSize: Optional[PositiveInt] = None
91
+
92
+
93
+ class CommandResponse(BaseModel):
94
+ ok: OneOrZero
95
+
96
+
97
+ class CollStatsCommandResponse(CommandResponse):
98
+ ns: str
99
+ size: float
100
+ count: float
101
+ avgObjSize: Optional[float] = None
102
+ storageSize: float
103
+ totalIndexSize: float
104
+ totalSize: float
105
+ scaleFactor: float
106
+
107
+
108
+ class CountCommandResponse(CommandResponse):
109
+ n: NonNegativeInt
110
+
111
+
112
+ class CommandResponseCursor(BaseModel):
113
+ # Note: No `ns` field, `id` is a `str`, and `partialResultsReturned` aliased to `queriedShardsUnavailable` to be
114
+ # less confusing to Runtime API clients. See `QueryContinuation` for inter-API-request "sessions" are modeled.
115
+ batch: List[Document]
116
+ partialResultsReturned: Optional[bool] = Field(
117
+ None, alias="queriedShardsUnavailable"
118
+ )
119
+ id: Optional[str] = None
120
+
121
+ @field_validator("id", mode="before")
122
+ @classmethod
123
+ def coerce_int_to_str(cls, value: Any) -> Any:
124
+ if isinstance(value, int):
125
+ return str(value)
126
+ else:
127
+ return value
128
+
129
+
130
+ class CursorYieldingCommandResponse(CommandResponse):
131
+ cursor: CommandResponseCursor
132
+
133
+ @classmethod
134
+ def slimmed(cls, cmd_response) -> Optional["CursorYieldingCommandResponse"]:
135
+ """Create a new response object that retains only the `_id` for each cursor batch document."""
136
+ dump: dict = cmd_response.model_dump(exclude_unset=True)
137
+
138
+ # If any dictionary in this batch lacks an `_id` key, log a warning and return `None`.`
139
+ id_list = [pick(["_id"], batch_doc) for batch_doc in dump["cursor"]["batch"]]
140
+ if any("_id" not in doc for doc in id_list):
141
+ logging.warning("Some documents in the batch lack an `_id` field.")
142
+ return None
143
+
144
+ dump = assoc_in(
145
+ dump,
146
+ ["cursor", "batch"],
147
+ id_list,
148
+ )
149
+ return cls(**dump)
150
+
151
+
152
+ class DeleteStatement(BaseModel):
153
+ q: Document
154
+ # `limit` is required: https://www.mongodb.com/docs/manual/reference/command/delete/#std-label-deletes-array-limit
155
+ limit: OneOrZero
156
+ hint: Optional[Dict[str, OneOrMinusOne]] = None
157
+
158
+
159
+ class DeleteCommand(CommandBase):
160
+ delete: str
161
+ deletes: List[DeleteStatement]
162
+
163
+
164
+ class DeleteCommandResponse(CommandResponse):
165
+ ok: OneOrZero
166
+ n: NonNegativeInt
167
+ writeErrors: Optional[List[Document]] = None
168
+
169
+
170
+ # Custom types for the `delete_specs` derived from `DeleteStatement`s.
171
+ DeleteSpec = TypedDict("DeleteSpec", {"filter": Document, "limit": OneOrZero})
172
+ DeleteSpecs = List[DeleteSpec]
173
+
174
+
175
+ # If `multi==True` all documents that meet the query criteria will be updated.
176
+ # Else only a single document that meets the query criteria will be updated.
177
+ class UpdateStatement(BaseModel):
178
+ q: Document
179
+ u: Document
180
+ upsert: bool = False
181
+ multi: bool = False
182
+ hint: Optional[Dict[str, OneOrMinusOne]] = None
183
+
184
+
185
+ # Custom types for the `update_specs` derived from `UpdateStatement`s.
186
+ UpdateSpec = TypedDict("UpdateSpec", {"filter": Document, "limit": OneOrZero})
187
+ UpdateSpecs = List[UpdateSpec]
188
+
189
+
190
+ class UpdateCommand(CommandBase):
191
+ update: str
192
+ updates: List[UpdateStatement]
193
+
194
+
195
+ class DocumentUpserted(BaseModel):
196
+ index: NonNegativeInt
197
+ _id: bson.ObjectId
198
+
199
+
200
+ class UpdateCommandResponse(CommandResponse):
201
+ ok: OneOrZero
202
+ n: NonNegativeInt
203
+ nModified: NonNegativeInt
204
+ upserted: Optional[List[DocumentUpserted]] = None
205
+ writeErrors: Optional[List[Document]] = None
206
+
207
+
208
+ QueryCmd = Union[FindCommand, AggregateCommand]
209
+
210
+ CursorYieldingCommand = Union[
211
+ QueryCmd,
212
+ GetMoreCommand,
213
+ ]
214
+
215
+
216
+ Cmd = Union[
217
+ CursorYieldingCommand,
218
+ CollStatsCommand,
219
+ CountCommand,
220
+ DeleteCommand,
221
+ UpdateCommand,
222
+ ]
223
+
224
+ CommandResponseOptions = Union[
225
+ CursorYieldingCommandResponse,
226
+ CollStatsCommandResponse,
227
+ CountCommandResponse,
228
+ DeleteCommandResponse,
229
+ UpdateCommandResponse,
230
+ ]
231
+
232
+
233
+ def command_response_for(type_):
234
+ r"""
235
+ TODO: Add a docstring and type hints to this function.
236
+ """
237
+ if issubclass(type_, CursorYieldingCommand):
238
+ return CursorYieldingCommandResponse
239
+
240
+ d = {
241
+ CollStatsCommand: CollStatsCommandResponse,
242
+ CountCommand: CountCommandResponse,
243
+ DeleteCommand: DeleteCommandResponse,
244
+ UpdateCommand: UpdateCommandResponse,
245
+ }
246
+ return d.get(type_)