wellapi 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wellapi/__init__.py +5 -0
- wellapi/__main__.py +3 -0
- wellapi/applications.py +389 -0
- wellapi/awsmodel.py +17 -0
- wellapi/build/__init__.py +0 -0
- wellapi/build/cdk.py +141 -0
- wellapi/build/packager.py +82 -0
- wellapi/build/sam_openapi.py +10 -0
- wellapi/cli/__init__.py +0 -0
- wellapi/cli/main.py +67 -0
- wellapi/convertors.py +89 -0
- wellapi/datastructures.py +383 -0
- wellapi/dependencies/__init__.py +0 -0
- wellapi/dependencies/models.py +138 -0
- wellapi/dependencies/utils.py +923 -0
- wellapi/exceptions.py +53 -0
- wellapi/local/__init__.py +0 -0
- wellapi/local/reloader.py +94 -0
- wellapi/local/router.py +116 -0
- wellapi/local/server.py +154 -0
- wellapi/middleware/__init__.py +0 -0
- wellapi/middleware/base.py +18 -0
- wellapi/middleware/error.py +239 -0
- wellapi/middleware/exceptions.py +74 -0
- wellapi/middleware/main.py +26 -0
- wellapi/models.py +150 -0
- wellapi/openapi/__init__.py +0 -0
- wellapi/openapi/docs.py +344 -0
- wellapi/openapi/models.py +404 -0
- wellapi/openapi/utils.py +535 -0
- wellapi/params.py +481 -0
- wellapi/routing.py +248 -0
- wellapi/security.py +82 -0
- wellapi/utils.py +37 -0
- wellapi-0.2.1.dist-info/METADATA +32 -0
- wellapi-0.2.1.dist-info/RECORD +38 -0
- wellapi-0.2.1.dist-info/WHEEL +4 -0
- wellapi-0.2.1.dist-info/entry_points.txt +2 -0
wellapi/openapi/utils.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
import http.client
|
|
2
|
+
import inspect
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any, Literal, cast
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from pydantic._internal._utils import lenient_issubclass
|
|
9
|
+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
|
10
|
+
|
|
11
|
+
from wellapi.applications import Lambda
|
|
12
|
+
from wellapi.datastructures import DefaultPlaceholder
|
|
13
|
+
from wellapi.dependencies.models import Dependant, ModelField
|
|
14
|
+
from wellapi.dependencies.utils import (
|
|
15
|
+
_get_flat_fields_from_params,
|
|
16
|
+
get_flat_dependant,
|
|
17
|
+
get_flat_params,
|
|
18
|
+
)
|
|
19
|
+
from wellapi.models import ResponseAPIGateway
|
|
20
|
+
from wellapi.openapi.models import (
|
|
21
|
+
METHODS_WITH_BODY,
|
|
22
|
+
REF_PREFIX,
|
|
23
|
+
REF_TEMPLATE,
|
|
24
|
+
OpenAPI,
|
|
25
|
+
ParameterInType,
|
|
26
|
+
RequestValidators,
|
|
27
|
+
)
|
|
28
|
+
from wellapi.params import Body, ParamTypes
|
|
29
|
+
from wellapi.routing import is_body_allowed_for_status_code
|
|
30
|
+
|
|
31
|
+
validation_error_definition = {
|
|
32
|
+
"title": "ValidationError",
|
|
33
|
+
"type": "object",
|
|
34
|
+
"properties": {
|
|
35
|
+
"loc": {
|
|
36
|
+
"title": "Location",
|
|
37
|
+
"type": "array",
|
|
38
|
+
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
|
39
|
+
},
|
|
40
|
+
"msg": {"title": "Message", "type": "string"},
|
|
41
|
+
"type": {"title": "Error Type", "type": "string"},
|
|
42
|
+
},
|
|
43
|
+
"required": ["loc", "msg", "type"],
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
validation_error_response_definition = {
|
|
47
|
+
"title": "HTTPValidationError",
|
|
48
|
+
"type": "object",
|
|
49
|
+
"properties": {
|
|
50
|
+
"detail": {
|
|
51
|
+
"title": "Detail",
|
|
52
|
+
"type": "array",
|
|
53
|
+
"items": {"$ref": REF_PREFIX + "ValidationError"},
|
|
54
|
+
}
|
|
55
|
+
},
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
status_code_ranges: dict[str, str] = {
|
|
59
|
+
"1XX": "Information",
|
|
60
|
+
"2XX": "Success",
|
|
61
|
+
"3XX": "Redirection",
|
|
62
|
+
"4XX": "Client Error",
|
|
63
|
+
"5XX": "Server Error",
|
|
64
|
+
"DEFAULT": "Default Response",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
request_validators = {
|
|
68
|
+
RequestValidators.basic: {
|
|
69
|
+
"validateRequestBody": True,
|
|
70
|
+
"validateRequestParameters": True,
|
|
71
|
+
},
|
|
72
|
+
RequestValidators.paramsOnly: {
|
|
73
|
+
"validateRequestBody": False,
|
|
74
|
+
"validateRequestParameters": True,
|
|
75
|
+
},
|
|
76
|
+
RequestValidators.bodyOnly: {
|
|
77
|
+
"validateRequestBody": True,
|
|
78
|
+
"validateRequestParameters": False,
|
|
79
|
+
},
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_openapi(
|
|
84
|
+
*,
|
|
85
|
+
title: str,
|
|
86
|
+
version: str,
|
|
87
|
+
openapi_version: str = "3.0.1",
|
|
88
|
+
description: str | None = None,
|
|
89
|
+
lambdas: Sequence[Lambda],
|
|
90
|
+
tags: list[dict[str, Any]] | None = None,
|
|
91
|
+
servers: list[dict[str, str | Any]] | None = None,
|
|
92
|
+
terms_of_service: str | None = None,
|
|
93
|
+
contact: dict[str, str | Any] | None = None,
|
|
94
|
+
license_info: dict[str, str | Any] | None = None,
|
|
95
|
+
separate_input_output_schemas: bool = True,
|
|
96
|
+
) -> dict[str, Any]:
|
|
97
|
+
info: dict[str, Any] = {"title": title, "version": version}
|
|
98
|
+
if description:
|
|
99
|
+
info["description"] = description
|
|
100
|
+
if terms_of_service:
|
|
101
|
+
info["termsOfService"] = terms_of_service
|
|
102
|
+
if contact:
|
|
103
|
+
info["contact"] = contact
|
|
104
|
+
if license_info:
|
|
105
|
+
info["license"] = license_info
|
|
106
|
+
output: dict[str, Any] = {
|
|
107
|
+
"openapi": openapi_version,
|
|
108
|
+
"info": info,
|
|
109
|
+
"x-amazon-apigateway-request-validators": request_validators,
|
|
110
|
+
}
|
|
111
|
+
if servers:
|
|
112
|
+
output["servers"] = servers
|
|
113
|
+
components: dict[str, dict[str, Any]] = {}
|
|
114
|
+
paths: dict[str, dict[str, Any]] = {}
|
|
115
|
+
webhook_paths: dict[str, dict[str, Any]] = {}
|
|
116
|
+
operation_ids: set[str] = set()
|
|
117
|
+
all_fields = get_fields_from_routes(list(lambdas or []))
|
|
118
|
+
field_mapping, definitions = get_definitions(
|
|
119
|
+
fields=all_fields,
|
|
120
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
121
|
+
)
|
|
122
|
+
for lambda_ in lambdas or []:
|
|
123
|
+
if lambda_.type_ != "endpoint":
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
result = get_openapi_path(
|
|
127
|
+
route=lambda_,
|
|
128
|
+
operation_ids=operation_ids,
|
|
129
|
+
field_mapping=field_mapping,
|
|
130
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
131
|
+
)
|
|
132
|
+
if result:
|
|
133
|
+
path, security_schemes, path_definitions = result
|
|
134
|
+
if path:
|
|
135
|
+
paths.setdefault(lambda_.path_format, {}).update(path)
|
|
136
|
+
if security_schemes:
|
|
137
|
+
components.setdefault("securitySchemes", {}).update(security_schemes)
|
|
138
|
+
if path_definitions:
|
|
139
|
+
definitions.update(path_definitions)
|
|
140
|
+
|
|
141
|
+
if definitions:
|
|
142
|
+
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
|
|
143
|
+
if components:
|
|
144
|
+
output["components"] = components
|
|
145
|
+
output["paths"] = paths
|
|
146
|
+
if webhook_paths:
|
|
147
|
+
output["webhooks"] = webhook_paths
|
|
148
|
+
if tags:
|
|
149
|
+
output["tags"] = tags
|
|
150
|
+
return OpenAPI(**output).model_dump(by_alias=True, exclude_none=True, mode="json")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_fields_from_routes(
|
|
154
|
+
routes: Sequence[Lambda],
|
|
155
|
+
) -> list[ModelField]:
|
|
156
|
+
body_fields_from_routes: list[ModelField] = []
|
|
157
|
+
responses_from_routes: list[ModelField] = []
|
|
158
|
+
request_fields_from_routes: list[ModelField] = []
|
|
159
|
+
callback_flat_models: list[ModelField] = []
|
|
160
|
+
for route in routes:
|
|
161
|
+
if route.type_ != "endpoint":
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
if route.body_field:
|
|
165
|
+
assert isinstance(route.body_field, ModelField), (
|
|
166
|
+
"A request body must be a Pydantic Field"
|
|
167
|
+
)
|
|
168
|
+
body_fields_from_routes.append(route.body_field)
|
|
169
|
+
if route.response_field:
|
|
170
|
+
responses_from_routes.append(route.response_field)
|
|
171
|
+
if route.response_fields:
|
|
172
|
+
responses_from_routes.extend(route.response_fields.values())
|
|
173
|
+
|
|
174
|
+
params = get_flat_params(route.dependant)
|
|
175
|
+
request_fields_from_routes.extend(params)
|
|
176
|
+
|
|
177
|
+
flat_models = callback_flat_models + list(
|
|
178
|
+
body_fields_from_routes + responses_from_routes + request_fields_from_routes
|
|
179
|
+
)
|
|
180
|
+
return flat_models
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def get_definitions(
|
|
184
|
+
*,
|
|
185
|
+
fields: list[ModelField],
|
|
186
|
+
separate_input_output_schemas: bool = True,
|
|
187
|
+
) -> tuple[
|
|
188
|
+
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
|
|
189
|
+
dict[str, dict[str, Any]],
|
|
190
|
+
]:
|
|
191
|
+
override_mode: Literal["validation"] | None = (
|
|
192
|
+
None if separate_input_output_schemas else "validation"
|
|
193
|
+
)
|
|
194
|
+
inputs = [
|
|
195
|
+
(field, override_mode or field.mode, field._type_adapter.core_schema)
|
|
196
|
+
for field in fields
|
|
197
|
+
]
|
|
198
|
+
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
|
199
|
+
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
|
|
200
|
+
return field_mapping, definitions
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def get_openapi_path(
|
|
204
|
+
*,
|
|
205
|
+
route: Lambda,
|
|
206
|
+
operation_ids: set[str],
|
|
207
|
+
field_mapping: dict[
|
|
208
|
+
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
209
|
+
],
|
|
210
|
+
separate_input_output_schemas: bool = True,
|
|
211
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
|
212
|
+
path = {}
|
|
213
|
+
security_schemes: dict[str, Any] = {}
|
|
214
|
+
definitions: dict[str, Any] = {}
|
|
215
|
+
if isinstance(route.response_class, DefaultPlaceholder):
|
|
216
|
+
current_response_class: type[ResponseAPIGateway] = route.response_class.value
|
|
217
|
+
else:
|
|
218
|
+
current_response_class = route.response_class
|
|
219
|
+
assert current_response_class, "A response class is needed to generate OpenAPI"
|
|
220
|
+
route_response_media_type = "application/json"
|
|
221
|
+
operation = get_openapi_operation_metadata(route=route, operation_ids=operation_ids)
|
|
222
|
+
parameters: list[dict[str, Any]] = []
|
|
223
|
+
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
|
224
|
+
security_definitions, operation_security = get_openapi_security_definitions(
|
|
225
|
+
flat_dependant=flat_dependant
|
|
226
|
+
)
|
|
227
|
+
if operation_security:
|
|
228
|
+
operation.setdefault("security", []).extend(operation_security)
|
|
229
|
+
if security_definitions:
|
|
230
|
+
security_schemes.update(security_definitions)
|
|
231
|
+
operation_parameters = _get_openapi_operation_parameters(
|
|
232
|
+
dependant=route.dependant,
|
|
233
|
+
field_mapping=field_mapping,
|
|
234
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
235
|
+
)
|
|
236
|
+
parameters.extend(operation_parameters)
|
|
237
|
+
if parameters:
|
|
238
|
+
all_parameters = {(param["in"], param["name"]): param for param in parameters}
|
|
239
|
+
required_parameters = {
|
|
240
|
+
(param["in"], param["name"]): param
|
|
241
|
+
for param in parameters
|
|
242
|
+
if param.get("required")
|
|
243
|
+
}
|
|
244
|
+
# Make sure required definitions of the same parameter take precedence
|
|
245
|
+
# over non-required definitions
|
|
246
|
+
all_parameters.update(required_parameters)
|
|
247
|
+
operation["parameters"] = list(all_parameters.values())
|
|
248
|
+
if route.method in METHODS_WITH_BODY:
|
|
249
|
+
request_body_oai = get_openapi_operation_request_body(
|
|
250
|
+
body_field=route.body_field,
|
|
251
|
+
field_mapping=field_mapping,
|
|
252
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
253
|
+
)
|
|
254
|
+
if request_body_oai:
|
|
255
|
+
operation["requestBody"] = request_body_oai
|
|
256
|
+
|
|
257
|
+
request_validator: RequestValidators | None = None
|
|
258
|
+
if "parameters" in operation and "requestBody" in operation:
|
|
259
|
+
request_validator = RequestValidators.basic
|
|
260
|
+
elif "parameters" in operation:
|
|
261
|
+
request_validator = RequestValidators.paramsOnly
|
|
262
|
+
elif "requestBody" in operation:
|
|
263
|
+
request_validator = RequestValidators.bodyOnly
|
|
264
|
+
operation["x-amazon-apigateway-request-validator"] = request_validator
|
|
265
|
+
|
|
266
|
+
if route.status_code is not None:
|
|
267
|
+
status_code = str(route.status_code)
|
|
268
|
+
else:
|
|
269
|
+
# It would probably make more sense for all response classes to have an
|
|
270
|
+
# explicit default status_code, and to extract it from them, instead of
|
|
271
|
+
# doing this inspection tricks, that would probably be in the future
|
|
272
|
+
# TODO: probably make status_code a default class attribute for all
|
|
273
|
+
# responses in Starlette
|
|
274
|
+
response_signature = inspect.signature(current_response_class.__init__)
|
|
275
|
+
status_code_param = response_signature.parameters.get("status_code")
|
|
276
|
+
if status_code_param is not None:
|
|
277
|
+
if isinstance(status_code_param.default, int):
|
|
278
|
+
status_code = str(status_code_param.default)
|
|
279
|
+
operation.setdefault("responses", {}).setdefault(status_code, {})["description"] = (
|
|
280
|
+
route.response_description
|
|
281
|
+
)
|
|
282
|
+
if route_response_media_type and is_body_allowed_for_status_code(route.status_code):
|
|
283
|
+
if route.response_field:
|
|
284
|
+
response_schema = get_schema_from_model_field(
|
|
285
|
+
field=route.response_field,
|
|
286
|
+
field_mapping=field_mapping,
|
|
287
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
response_schema = {}
|
|
291
|
+
operation.setdefault("responses", {}).setdefault(status_code, {}).setdefault(
|
|
292
|
+
"content", {}
|
|
293
|
+
).setdefault(route_response_media_type, {})["schema"] = response_schema
|
|
294
|
+
if route.responses:
|
|
295
|
+
operation_responses = operation.setdefault("responses", {})
|
|
296
|
+
for (
|
|
297
|
+
additional_status_code,
|
|
298
|
+
additional_response,
|
|
299
|
+
) in route.responses.items():
|
|
300
|
+
process_response = additional_response.copy()
|
|
301
|
+
process_response.pop("model", None)
|
|
302
|
+
status_code_key = str(additional_status_code).upper()
|
|
303
|
+
if status_code_key == "DEFAULT":
|
|
304
|
+
status_code_key = "default"
|
|
305
|
+
openapi_response = operation_responses.setdefault(status_code_key, {})
|
|
306
|
+
assert isinstance(process_response, dict), (
|
|
307
|
+
"An additional response must be a dict"
|
|
308
|
+
)
|
|
309
|
+
field = route.response_fields.get(additional_status_code)
|
|
310
|
+
additional_field_schema: dict[str, Any] | None = None
|
|
311
|
+
if field:
|
|
312
|
+
additional_field_schema = get_schema_from_model_field(
|
|
313
|
+
field=field,
|
|
314
|
+
field_mapping=field_mapping,
|
|
315
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
316
|
+
)
|
|
317
|
+
media_type = route_response_media_type
|
|
318
|
+
additional_schema = (
|
|
319
|
+
process_response.setdefault("content", {})
|
|
320
|
+
.setdefault(media_type, {})
|
|
321
|
+
.setdefault("schema", {})
|
|
322
|
+
)
|
|
323
|
+
deep_dict_update(additional_schema, additional_field_schema)
|
|
324
|
+
status_text: str | None = status_code_ranges.get(
|
|
325
|
+
str(additional_status_code).upper()
|
|
326
|
+
) or http.client.responses.get(int(additional_status_code))
|
|
327
|
+
description = (
|
|
328
|
+
process_response.get("description")
|
|
329
|
+
or openapi_response.get("description")
|
|
330
|
+
or status_text
|
|
331
|
+
or "Additional Response"
|
|
332
|
+
)
|
|
333
|
+
deep_dict_update(openapi_response, process_response)
|
|
334
|
+
openapi_response["description"] = description
|
|
335
|
+
http422 = "422"
|
|
336
|
+
all_route_params = get_flat_params(route.dependant)
|
|
337
|
+
if (all_route_params or route.body_field) and not any(
|
|
338
|
+
status in operation["responses"] for status in [http422, "4XX", "default"]
|
|
339
|
+
):
|
|
340
|
+
operation["responses"][http422] = {
|
|
341
|
+
"description": "Validation Error",
|
|
342
|
+
"content": {
|
|
343
|
+
"application/json": {
|
|
344
|
+
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
|
|
345
|
+
}
|
|
346
|
+
},
|
|
347
|
+
}
|
|
348
|
+
if "ValidationError" not in definitions:
|
|
349
|
+
definitions.update(
|
|
350
|
+
{
|
|
351
|
+
"ValidationError": validation_error_definition,
|
|
352
|
+
"HTTPValidationError": validation_error_response_definition,
|
|
353
|
+
}
|
|
354
|
+
)
|
|
355
|
+
path[route.method.lower()] = operation
|
|
356
|
+
return path, security_schemes, definitions
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def get_openapi_operation_metadata(
|
|
360
|
+
*, route: Lambda, operation_ids: set[str]
|
|
361
|
+
) -> dict[str, Any]:
|
|
362
|
+
operation: dict[str, Any] = {}
|
|
363
|
+
if route.tags:
|
|
364
|
+
operation["tags"] = route.tags
|
|
365
|
+
operation["summary"] = route.summary or route.name.replace("_", " ").title()
|
|
366
|
+
if route.description:
|
|
367
|
+
operation["description"] = route.description
|
|
368
|
+
operation_id = route.operation_id or route.unique_id
|
|
369
|
+
if operation_id in operation_ids:
|
|
370
|
+
message = (
|
|
371
|
+
f"Duplicate Operation ID {operation_id} for function "
|
|
372
|
+
+ f"{route.endpoint.__name__}"
|
|
373
|
+
)
|
|
374
|
+
file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
|
|
375
|
+
if file_name:
|
|
376
|
+
message += f" at {file_name}"
|
|
377
|
+
warnings.warn(message, stacklevel=1)
|
|
378
|
+
operation_ids.add(operation_id)
|
|
379
|
+
operation["operationId"] = operation_id
|
|
380
|
+
if route.deprecated:
|
|
381
|
+
operation["deprecated"] = route.deprecated
|
|
382
|
+
operation["x-amazon-apigateway-integration"] = {
|
|
383
|
+
"uri": {
|
|
384
|
+
"Fn::Sub": f"arn:aws:apigateway:${{AWS::Region}}:lambda:path/2015-03-31/functions/${{{route.arn}Function.Arn}}/invocations"
|
|
385
|
+
},
|
|
386
|
+
"httpMethod": "POST",
|
|
387
|
+
"type": "aws_proxy",
|
|
388
|
+
}
|
|
389
|
+
return operation
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def get_openapi_security_definitions(
|
|
393
|
+
flat_dependant: Dependant,
|
|
394
|
+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
|
395
|
+
security_definitions = {}
|
|
396
|
+
operation_security = []
|
|
397
|
+
for security_requirement in flat_dependant.security_requirements:
|
|
398
|
+
security_definition = security_requirement.security_scheme.model.model_dump(
|
|
399
|
+
by_alias=True,
|
|
400
|
+
exclude_none=True,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
security_name = security_requirement.security_scheme.scheme_name
|
|
404
|
+
security_definitions[security_name] = security_definition
|
|
405
|
+
operation_security.append({security_name: security_requirement.scopes})
|
|
406
|
+
return security_definitions, operation_security
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def get_schema_from_model_field(
|
|
410
|
+
*,
|
|
411
|
+
field: ModelField,
|
|
412
|
+
field_mapping: dict[
|
|
413
|
+
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
414
|
+
],
|
|
415
|
+
separate_input_output_schemas: bool = True,
|
|
416
|
+
) -> dict[str, Any]:
|
|
417
|
+
override_mode: Literal["validation"] | None = (
|
|
418
|
+
None if separate_input_output_schemas else "validation"
|
|
419
|
+
)
|
|
420
|
+
# This expects that GenerateJsonSchema was already used to generate the definitions
|
|
421
|
+
json_schema = field_mapping[(field, override_mode or field.mode)]
|
|
422
|
+
|
|
423
|
+
return json_schema
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _get_openapi_operation_parameters(
|
|
427
|
+
*,
|
|
428
|
+
dependant: Dependant,
|
|
429
|
+
field_mapping: dict[
|
|
430
|
+
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
431
|
+
],
|
|
432
|
+
separate_input_output_schemas: bool = True,
|
|
433
|
+
) -> list[dict[str, Any]]:
|
|
434
|
+
parameters = []
|
|
435
|
+
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
|
436
|
+
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
|
437
|
+
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
|
438
|
+
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
|
439
|
+
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
|
440
|
+
parameter_groups = [
|
|
441
|
+
(ParamTypes.path, path_params),
|
|
442
|
+
(ParamTypes.query, query_params),
|
|
443
|
+
(ParamTypes.header, header_params),
|
|
444
|
+
(ParamTypes.cookie, cookie_params),
|
|
445
|
+
]
|
|
446
|
+
default_convert_underscores = True
|
|
447
|
+
if len(flat_dependant.header_params) == 1:
|
|
448
|
+
first_field = flat_dependant.header_params[0]
|
|
449
|
+
if lenient_issubclass(first_field.type_, BaseModel):
|
|
450
|
+
default_convert_underscores = getattr(
|
|
451
|
+
first_field.field_info, "convert_underscores", True
|
|
452
|
+
)
|
|
453
|
+
for param_type, param_group in parameter_groups:
|
|
454
|
+
for param in param_group:
|
|
455
|
+
field_info = param.field_info
|
|
456
|
+
# field_info = cast(Param, field_info)
|
|
457
|
+
if not getattr(field_info, "include_in_schema", True):
|
|
458
|
+
continue
|
|
459
|
+
param_schema = get_schema_from_model_field(
|
|
460
|
+
field=param,
|
|
461
|
+
field_mapping=field_mapping,
|
|
462
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
463
|
+
)
|
|
464
|
+
name = param.alias
|
|
465
|
+
convert_underscores = getattr(
|
|
466
|
+
param.field_info,
|
|
467
|
+
"convert_underscores",
|
|
468
|
+
default_convert_underscores,
|
|
469
|
+
)
|
|
470
|
+
if (
|
|
471
|
+
param_type == ParamTypes.header
|
|
472
|
+
and param.alias == param.name
|
|
473
|
+
and convert_underscores
|
|
474
|
+
):
|
|
475
|
+
name = param.name.replace("_", "-")
|
|
476
|
+
|
|
477
|
+
parameter = {
|
|
478
|
+
"name": name,
|
|
479
|
+
"in": ParameterInType(param_type.value),
|
|
480
|
+
"required": param.required,
|
|
481
|
+
"schema": param_schema,
|
|
482
|
+
}
|
|
483
|
+
if field_info.description:
|
|
484
|
+
parameter["description"] = field_info.description
|
|
485
|
+
|
|
486
|
+
if getattr(field_info, "deprecated", None):
|
|
487
|
+
parameter["deprecated"] = True
|
|
488
|
+
parameters.append(parameter)
|
|
489
|
+
return parameters
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def get_openapi_operation_request_body(
|
|
493
|
+
*,
|
|
494
|
+
body_field: ModelField | None,
|
|
495
|
+
field_mapping: dict[
|
|
496
|
+
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
497
|
+
],
|
|
498
|
+
separate_input_output_schemas: bool = True,
|
|
499
|
+
) -> dict[str, Any] | None:
|
|
500
|
+
if not body_field:
|
|
501
|
+
return None
|
|
502
|
+
assert isinstance(body_field, ModelField)
|
|
503
|
+
body_schema = get_schema_from_model_field(
|
|
504
|
+
field=body_field,
|
|
505
|
+
field_mapping=field_mapping,
|
|
506
|
+
separate_input_output_schemas=separate_input_output_schemas,
|
|
507
|
+
)
|
|
508
|
+
field_info = cast(Body, body_field.field_info)
|
|
509
|
+
request_media_type = field_info.media_type
|
|
510
|
+
required = body_field.required
|
|
511
|
+
request_body_oai: dict[str, Any] = {}
|
|
512
|
+
if required:
|
|
513
|
+
request_body_oai["required"] = required
|
|
514
|
+
request_media_content: dict[str, Any] = {"schema": body_schema}
|
|
515
|
+
|
|
516
|
+
request_body_oai["content"] = {request_media_type: request_media_content}
|
|
517
|
+
return request_body_oai
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def deep_dict_update(main_dict: dict[Any, Any], update_dict: dict[Any, Any]) -> None:
|
|
521
|
+
for key, value in update_dict.items():
|
|
522
|
+
if (
|
|
523
|
+
key in main_dict
|
|
524
|
+
and isinstance(main_dict[key], dict)
|
|
525
|
+
and isinstance(value, dict)
|
|
526
|
+
):
|
|
527
|
+
deep_dict_update(main_dict[key], value)
|
|
528
|
+
elif (
|
|
529
|
+
key in main_dict
|
|
530
|
+
and isinstance(main_dict[key], list)
|
|
531
|
+
and isinstance(update_dict[key], list)
|
|
532
|
+
):
|
|
533
|
+
main_dict[key] = main_dict[key] + update_dict[key]
|
|
534
|
+
else:
|
|
535
|
+
main_dict[key] = value
|