wellapi 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,535 @@
1
+ import http.client
2
+ import inspect
3
+ import warnings
4
+ from collections.abc import Sequence
5
+ from typing import Any, Literal, cast
6
+
7
+ from pydantic import BaseModel
8
+ from pydantic._internal._utils import lenient_issubclass
9
+ from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
10
+
11
+ from wellapi.applications import Lambda
12
+ from wellapi.datastructures import DefaultPlaceholder
13
+ from wellapi.dependencies.models import Dependant, ModelField
14
+ from wellapi.dependencies.utils import (
15
+ _get_flat_fields_from_params,
16
+ get_flat_dependant,
17
+ get_flat_params,
18
+ )
19
+ from wellapi.models import ResponseAPIGateway
20
+ from wellapi.openapi.models import (
21
+ METHODS_WITH_BODY,
22
+ REF_PREFIX,
23
+ REF_TEMPLATE,
24
+ OpenAPI,
25
+ ParameterInType,
26
+ RequestValidators,
27
+ )
28
+ from wellapi.params import Body, ParamTypes
29
+ from wellapi.routing import is_body_allowed_for_status_code
30
+
31
+ validation_error_definition = {
32
+ "title": "ValidationError",
33
+ "type": "object",
34
+ "properties": {
35
+ "loc": {
36
+ "title": "Location",
37
+ "type": "array",
38
+ "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
39
+ },
40
+ "msg": {"title": "Message", "type": "string"},
41
+ "type": {"title": "Error Type", "type": "string"},
42
+ },
43
+ "required": ["loc", "msg", "type"],
44
+ }
45
+
46
+ validation_error_response_definition = {
47
+ "title": "HTTPValidationError",
48
+ "type": "object",
49
+ "properties": {
50
+ "detail": {
51
+ "title": "Detail",
52
+ "type": "array",
53
+ "items": {"$ref": REF_PREFIX + "ValidationError"},
54
+ }
55
+ },
56
+ }
57
+
58
+ status_code_ranges: dict[str, str] = {
59
+ "1XX": "Information",
60
+ "2XX": "Success",
61
+ "3XX": "Redirection",
62
+ "4XX": "Client Error",
63
+ "5XX": "Server Error",
64
+ "DEFAULT": "Default Response",
65
+ }
66
+
67
+ request_validators = {
68
+ RequestValidators.basic: {
69
+ "validateRequestBody": True,
70
+ "validateRequestParameters": True,
71
+ },
72
+ RequestValidators.paramsOnly: {
73
+ "validateRequestBody": False,
74
+ "validateRequestParameters": True,
75
+ },
76
+ RequestValidators.bodyOnly: {
77
+ "validateRequestBody": True,
78
+ "validateRequestParameters": False,
79
+ },
80
+ }
81
+
82
+
83
+ def get_openapi(
84
+ *,
85
+ title: str,
86
+ version: str,
87
+ openapi_version: str = "3.0.1",
88
+ description: str | None = None,
89
+ lambdas: Sequence[Lambda],
90
+ tags: list[dict[str, Any]] | None = None,
91
+ servers: list[dict[str, str | Any]] | None = None,
92
+ terms_of_service: str | None = None,
93
+ contact: dict[str, str | Any] | None = None,
94
+ license_info: dict[str, str | Any] | None = None,
95
+ separate_input_output_schemas: bool = True,
96
+ ) -> dict[str, Any]:
97
+ info: dict[str, Any] = {"title": title, "version": version}
98
+ if description:
99
+ info["description"] = description
100
+ if terms_of_service:
101
+ info["termsOfService"] = terms_of_service
102
+ if contact:
103
+ info["contact"] = contact
104
+ if license_info:
105
+ info["license"] = license_info
106
+ output: dict[str, Any] = {
107
+ "openapi": openapi_version,
108
+ "info": info,
109
+ "x-amazon-apigateway-request-validators": request_validators,
110
+ }
111
+ if servers:
112
+ output["servers"] = servers
113
+ components: dict[str, dict[str, Any]] = {}
114
+ paths: dict[str, dict[str, Any]] = {}
115
+ webhook_paths: dict[str, dict[str, Any]] = {}
116
+ operation_ids: set[str] = set()
117
+ all_fields = get_fields_from_routes(list(lambdas or []))
118
+ field_mapping, definitions = get_definitions(
119
+ fields=all_fields,
120
+ separate_input_output_schemas=separate_input_output_schemas,
121
+ )
122
+ for lambda_ in lambdas or []:
123
+ if lambda_.type_ != "endpoint":
124
+ continue
125
+
126
+ result = get_openapi_path(
127
+ route=lambda_,
128
+ operation_ids=operation_ids,
129
+ field_mapping=field_mapping,
130
+ separate_input_output_schemas=separate_input_output_schemas,
131
+ )
132
+ if result:
133
+ path, security_schemes, path_definitions = result
134
+ if path:
135
+ paths.setdefault(lambda_.path_format, {}).update(path)
136
+ if security_schemes:
137
+ components.setdefault("securitySchemes", {}).update(security_schemes)
138
+ if path_definitions:
139
+ definitions.update(path_definitions)
140
+
141
+ if definitions:
142
+ components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
143
+ if components:
144
+ output["components"] = components
145
+ output["paths"] = paths
146
+ if webhook_paths:
147
+ output["webhooks"] = webhook_paths
148
+ if tags:
149
+ output["tags"] = tags
150
+ return OpenAPI(**output).model_dump(by_alias=True, exclude_none=True, mode="json")
151
+
152
+
153
+ def get_fields_from_routes(
154
+ routes: Sequence[Lambda],
155
+ ) -> list[ModelField]:
156
+ body_fields_from_routes: list[ModelField] = []
157
+ responses_from_routes: list[ModelField] = []
158
+ request_fields_from_routes: list[ModelField] = []
159
+ callback_flat_models: list[ModelField] = []
160
+ for route in routes:
161
+ if route.type_ != "endpoint":
162
+ continue
163
+
164
+ if route.body_field:
165
+ assert isinstance(route.body_field, ModelField), (
166
+ "A request body must be a Pydantic Field"
167
+ )
168
+ body_fields_from_routes.append(route.body_field)
169
+ if route.response_field:
170
+ responses_from_routes.append(route.response_field)
171
+ if route.response_fields:
172
+ responses_from_routes.extend(route.response_fields.values())
173
+
174
+ params = get_flat_params(route.dependant)
175
+ request_fields_from_routes.extend(params)
176
+
177
+ flat_models = callback_flat_models + list(
178
+ body_fields_from_routes + responses_from_routes + request_fields_from_routes
179
+ )
180
+ return flat_models
181
+
182
+
183
+ def get_definitions(
184
+ *,
185
+ fields: list[ModelField],
186
+ separate_input_output_schemas: bool = True,
187
+ ) -> tuple[
188
+ dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
189
+ dict[str, dict[str, Any]],
190
+ ]:
191
+ override_mode: Literal["validation"] | None = (
192
+ None if separate_input_output_schemas else "validation"
193
+ )
194
+ inputs = [
195
+ (field, override_mode or field.mode, field._type_adapter.core_schema)
196
+ for field in fields
197
+ ]
198
+ schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
199
+ field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
200
+ return field_mapping, definitions
201
+
202
+
203
+ def get_openapi_path(
204
+ *,
205
+ route: Lambda,
206
+ operation_ids: set[str],
207
+ field_mapping: dict[
208
+ tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
209
+ ],
210
+ separate_input_output_schemas: bool = True,
211
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
212
+ path = {}
213
+ security_schemes: dict[str, Any] = {}
214
+ definitions: dict[str, Any] = {}
215
+ if isinstance(route.response_class, DefaultPlaceholder):
216
+ current_response_class: type[ResponseAPIGateway] = route.response_class.value
217
+ else:
218
+ current_response_class = route.response_class
219
+ assert current_response_class, "A response class is needed to generate OpenAPI"
220
+ route_response_media_type = "application/json"
221
+ operation = get_openapi_operation_metadata(route=route, operation_ids=operation_ids)
222
+ parameters: list[dict[str, Any]] = []
223
+ flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
224
+ security_definitions, operation_security = get_openapi_security_definitions(
225
+ flat_dependant=flat_dependant
226
+ )
227
+ if operation_security:
228
+ operation.setdefault("security", []).extend(operation_security)
229
+ if security_definitions:
230
+ security_schemes.update(security_definitions)
231
+ operation_parameters = _get_openapi_operation_parameters(
232
+ dependant=route.dependant,
233
+ field_mapping=field_mapping,
234
+ separate_input_output_schemas=separate_input_output_schemas,
235
+ )
236
+ parameters.extend(operation_parameters)
237
+ if parameters:
238
+ all_parameters = {(param["in"], param["name"]): param for param in parameters}
239
+ required_parameters = {
240
+ (param["in"], param["name"]): param
241
+ for param in parameters
242
+ if param.get("required")
243
+ }
244
+ # Make sure required definitions of the same parameter take precedence
245
+ # over non-required definitions
246
+ all_parameters.update(required_parameters)
247
+ operation["parameters"] = list(all_parameters.values())
248
+ if route.method in METHODS_WITH_BODY:
249
+ request_body_oai = get_openapi_operation_request_body(
250
+ body_field=route.body_field,
251
+ field_mapping=field_mapping,
252
+ separate_input_output_schemas=separate_input_output_schemas,
253
+ )
254
+ if request_body_oai:
255
+ operation["requestBody"] = request_body_oai
256
+
257
+ request_validator: RequestValidators | None = None
258
+ if "parameters" in operation and "requestBody" in operation:
259
+ request_validator = RequestValidators.basic
260
+ elif "parameters" in operation:
261
+ request_validator = RequestValidators.paramsOnly
262
+ elif "requestBody" in operation:
263
+ request_validator = RequestValidators.bodyOnly
264
+ operation["x-amazon-apigateway-request-validator"] = request_validator
265
+
266
+ if route.status_code is not None:
267
+ status_code = str(route.status_code)
268
+ else:
269
+ # It would probably make more sense for all response classes to have an
270
+ # explicit default status_code, and to extract it from them, instead of
271
+ # doing this inspection tricks, that would probably be in the future
272
+ # TODO: probably make status_code a default class attribute for all
273
+ # responses in Starlette
274
+ response_signature = inspect.signature(current_response_class.__init__)
275
+ status_code_param = response_signature.parameters.get("status_code")
276
+ if status_code_param is not None:
277
+ if isinstance(status_code_param.default, int):
278
+ status_code = str(status_code_param.default)
279
+ operation.setdefault("responses", {}).setdefault(status_code, {})["description"] = (
280
+ route.response_description
281
+ )
282
+ if route_response_media_type and is_body_allowed_for_status_code(route.status_code):
283
+ if route.response_field:
284
+ response_schema = get_schema_from_model_field(
285
+ field=route.response_field,
286
+ field_mapping=field_mapping,
287
+ separate_input_output_schemas=separate_input_output_schemas,
288
+ )
289
+ else:
290
+ response_schema = {}
291
+ operation.setdefault("responses", {}).setdefault(status_code, {}).setdefault(
292
+ "content", {}
293
+ ).setdefault(route_response_media_type, {})["schema"] = response_schema
294
+ if route.responses:
295
+ operation_responses = operation.setdefault("responses", {})
296
+ for (
297
+ additional_status_code,
298
+ additional_response,
299
+ ) in route.responses.items():
300
+ process_response = additional_response.copy()
301
+ process_response.pop("model", None)
302
+ status_code_key = str(additional_status_code).upper()
303
+ if status_code_key == "DEFAULT":
304
+ status_code_key = "default"
305
+ openapi_response = operation_responses.setdefault(status_code_key, {})
306
+ assert isinstance(process_response, dict), (
307
+ "An additional response must be a dict"
308
+ )
309
+ field = route.response_fields.get(additional_status_code)
310
+ additional_field_schema: dict[str, Any] | None = None
311
+ if field:
312
+ additional_field_schema = get_schema_from_model_field(
313
+ field=field,
314
+ field_mapping=field_mapping,
315
+ separate_input_output_schemas=separate_input_output_schemas,
316
+ )
317
+ media_type = route_response_media_type
318
+ additional_schema = (
319
+ process_response.setdefault("content", {})
320
+ .setdefault(media_type, {})
321
+ .setdefault("schema", {})
322
+ )
323
+ deep_dict_update(additional_schema, additional_field_schema)
324
+ status_text: str | None = status_code_ranges.get(
325
+ str(additional_status_code).upper()
326
+ ) or http.client.responses.get(int(additional_status_code))
327
+ description = (
328
+ process_response.get("description")
329
+ or openapi_response.get("description")
330
+ or status_text
331
+ or "Additional Response"
332
+ )
333
+ deep_dict_update(openapi_response, process_response)
334
+ openapi_response["description"] = description
335
+ http422 = "422"
336
+ all_route_params = get_flat_params(route.dependant)
337
+ if (all_route_params or route.body_field) and not any(
338
+ status in operation["responses"] for status in [http422, "4XX", "default"]
339
+ ):
340
+ operation["responses"][http422] = {
341
+ "description": "Validation Error",
342
+ "content": {
343
+ "application/json": {
344
+ "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
345
+ }
346
+ },
347
+ }
348
+ if "ValidationError" not in definitions:
349
+ definitions.update(
350
+ {
351
+ "ValidationError": validation_error_definition,
352
+ "HTTPValidationError": validation_error_response_definition,
353
+ }
354
+ )
355
+ path[route.method.lower()] = operation
356
+ return path, security_schemes, definitions
357
+
358
+
359
+ def get_openapi_operation_metadata(
360
+ *, route: Lambda, operation_ids: set[str]
361
+ ) -> dict[str, Any]:
362
+ operation: dict[str, Any] = {}
363
+ if route.tags:
364
+ operation["tags"] = route.tags
365
+ operation["summary"] = route.summary or route.name.replace("_", " ").title()
366
+ if route.description:
367
+ operation["description"] = route.description
368
+ operation_id = route.operation_id or route.unique_id
369
+ if operation_id in operation_ids:
370
+ message = (
371
+ f"Duplicate Operation ID {operation_id} for function "
372
+ + f"{route.endpoint.__name__}"
373
+ )
374
+ file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
375
+ if file_name:
376
+ message += f" at {file_name}"
377
+ warnings.warn(message, stacklevel=1)
378
+ operation_ids.add(operation_id)
379
+ operation["operationId"] = operation_id
380
+ if route.deprecated:
381
+ operation["deprecated"] = route.deprecated
382
+ operation["x-amazon-apigateway-integration"] = {
383
+ "uri": {
384
+ "Fn::Sub": f"arn:aws:apigateway:${{AWS::Region}}:lambda:path/2015-03-31/functions/${{{route.arn}Function.Arn}}/invocations"
385
+ },
386
+ "httpMethod": "POST",
387
+ "type": "aws_proxy",
388
+ }
389
+ return operation
390
+
391
+
392
+ def get_openapi_security_definitions(
393
+ flat_dependant: Dependant,
394
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
395
+ security_definitions = {}
396
+ operation_security = []
397
+ for security_requirement in flat_dependant.security_requirements:
398
+ security_definition = security_requirement.security_scheme.model.model_dump(
399
+ by_alias=True,
400
+ exclude_none=True,
401
+ )
402
+
403
+ security_name = security_requirement.security_scheme.scheme_name
404
+ security_definitions[security_name] = security_definition
405
+ operation_security.append({security_name: security_requirement.scopes})
406
+ return security_definitions, operation_security
407
+
408
+
409
+ def get_schema_from_model_field(
410
+ *,
411
+ field: ModelField,
412
+ field_mapping: dict[
413
+ tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
414
+ ],
415
+ separate_input_output_schemas: bool = True,
416
+ ) -> dict[str, Any]:
417
+ override_mode: Literal["validation"] | None = (
418
+ None if separate_input_output_schemas else "validation"
419
+ )
420
+ # This expects that GenerateJsonSchema was already used to generate the definitions
421
+ json_schema = field_mapping[(field, override_mode or field.mode)]
422
+
423
+ return json_schema
424
+
425
+
426
+ def _get_openapi_operation_parameters(
427
+ *,
428
+ dependant: Dependant,
429
+ field_mapping: dict[
430
+ tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
431
+ ],
432
+ separate_input_output_schemas: bool = True,
433
+ ) -> list[dict[str, Any]]:
434
+ parameters = []
435
+ flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
436
+ path_params = _get_flat_fields_from_params(flat_dependant.path_params)
437
+ query_params = _get_flat_fields_from_params(flat_dependant.query_params)
438
+ header_params = _get_flat_fields_from_params(flat_dependant.header_params)
439
+ cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
440
+ parameter_groups = [
441
+ (ParamTypes.path, path_params),
442
+ (ParamTypes.query, query_params),
443
+ (ParamTypes.header, header_params),
444
+ (ParamTypes.cookie, cookie_params),
445
+ ]
446
+ default_convert_underscores = True
447
+ if len(flat_dependant.header_params) == 1:
448
+ first_field = flat_dependant.header_params[0]
449
+ if lenient_issubclass(first_field.type_, BaseModel):
450
+ default_convert_underscores = getattr(
451
+ first_field.field_info, "convert_underscores", True
452
+ )
453
+ for param_type, param_group in parameter_groups:
454
+ for param in param_group:
455
+ field_info = param.field_info
456
+ # field_info = cast(Param, field_info)
457
+ if not getattr(field_info, "include_in_schema", True):
458
+ continue
459
+ param_schema = get_schema_from_model_field(
460
+ field=param,
461
+ field_mapping=field_mapping,
462
+ separate_input_output_schemas=separate_input_output_schemas,
463
+ )
464
+ name = param.alias
465
+ convert_underscores = getattr(
466
+ param.field_info,
467
+ "convert_underscores",
468
+ default_convert_underscores,
469
+ )
470
+ if (
471
+ param_type == ParamTypes.header
472
+ and param.alias == param.name
473
+ and convert_underscores
474
+ ):
475
+ name = param.name.replace("_", "-")
476
+
477
+ parameter = {
478
+ "name": name,
479
+ "in": ParameterInType(param_type.value),
480
+ "required": param.required,
481
+ "schema": param_schema,
482
+ }
483
+ if field_info.description:
484
+ parameter["description"] = field_info.description
485
+
486
+ if getattr(field_info, "deprecated", None):
487
+ parameter["deprecated"] = True
488
+ parameters.append(parameter)
489
+ return parameters
490
+
491
+
492
+ def get_openapi_operation_request_body(
493
+ *,
494
+ body_field: ModelField | None,
495
+ field_mapping: dict[
496
+ tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
497
+ ],
498
+ separate_input_output_schemas: bool = True,
499
+ ) -> dict[str, Any] | None:
500
+ if not body_field:
501
+ return None
502
+ assert isinstance(body_field, ModelField)
503
+ body_schema = get_schema_from_model_field(
504
+ field=body_field,
505
+ field_mapping=field_mapping,
506
+ separate_input_output_schemas=separate_input_output_schemas,
507
+ )
508
+ field_info = cast(Body, body_field.field_info)
509
+ request_media_type = field_info.media_type
510
+ required = body_field.required
511
+ request_body_oai: dict[str, Any] = {}
512
+ if required:
513
+ request_body_oai["required"] = required
514
+ request_media_content: dict[str, Any] = {"schema": body_schema}
515
+
516
+ request_body_oai["content"] = {request_media_type: request_media_content}
517
+ return request_body_oai
518
+
519
+
520
+ def deep_dict_update(main_dict: dict[Any, Any], update_dict: dict[Any, Any]) -> None:
521
+ for key, value in update_dict.items():
522
+ if (
523
+ key in main_dict
524
+ and isinstance(main_dict[key], dict)
525
+ and isinstance(value, dict)
526
+ ):
527
+ deep_dict_update(main_dict[key], value)
528
+ elif (
529
+ key in main_dict
530
+ and isinstance(main_dict[key], list)
531
+ and isinstance(update_dict[key], list)
532
+ ):
533
+ main_dict[key] = main_dict[key] + update_dict[key]
534
+ else:
535
+ main_dict[key] = value