starmallow 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
starmallow/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.3.1"
1
+ __version__ = "0.3.3"
2
2
 
3
3
  from .applications import StarMallow
4
4
  from .exceptions import RequestValidationError
@@ -135,10 +135,13 @@ class StarMallow(Starlette):
135
135
  ] = {}
136
136
 
137
137
  for key, value in self.exception_handlers.items():
138
+ # Ensure we handle any middleware exceptions using the Exception handler
139
+ # But also ensure we don't fall through all middlewares if the route itself threw an Exception
140
+ # As this would result in an incredibly long stacktrace
138
141
  if key in (500, Exception):
139
142
  error_handler = value
140
- else:
141
- exception_handlers[key] = value
143
+
144
+ exception_handlers[key] = value
142
145
 
143
146
  middleware = (
144
147
  [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
starmallow/endpoint.py CHANGED
@@ -45,6 +45,8 @@ from starmallow.utils import (
45
45
  get_args,
46
46
  get_model_field,
47
47
  get_path_param_names,
48
+ get_typed_return_annotation,
49
+ get_typed_signature,
48
50
  is_marshmallow_dataclass,
49
51
  is_marshmallow_field,
50
52
  is_marshmallow_schema,
@@ -255,7 +257,7 @@ class EndpointMixin:
255
257
  try:
256
258
  return get_model_field(model, **kwargs)
257
259
  except Exception as e:
258
- raise Exception(f'Unknown model type for parameter {parameter_name}, model is {model}') from e
260
+ raise Exception(f'Unknown model type for parameter {parameter_name}, model is {model}, {type(model)}') from e
259
261
 
260
262
  def get_resolved_param(self, resolved_param: ResolvedParam, annotation: Any, path: str) -> ResolvedParam:
261
263
  # Supports `field = ResolvedParam(resolver_callable)
@@ -274,7 +276,7 @@ class EndpointMixin:
274
276
  ) -> Dict[ParamType, Dict[str, Param]]:
275
277
  path_param_names = get_path_param_names(path)
276
278
  params = {param_type: {} for param_type in ParamType}
277
- for name, parameter in inspect.signature(func).parameters.items():
279
+ for name, parameter in get_typed_signature(func).parameters.items():
278
280
  default_value = parameter.default
279
281
 
280
282
  # The type annotation. i.e.: 'str' in these `value: str`. Or `value: [str, Query(gt=3)]`
@@ -316,15 +318,18 @@ class EndpointMixin:
316
318
  or isinstance(starmallow_param, NoParam)
317
319
  ):
318
320
  continue
321
+ elif isinstance(starmallow_param, Security):
322
+ security_param: Security = self.get_resolved_param(starmallow_param, type_annotation, path=path)
323
+ params[ParamType.security][name] = security_param
324
+ continue
319
325
  elif isinstance(starmallow_param, ResolvedParam):
320
326
  resolved_param: ResolvedParam = self.get_resolved_param(starmallow_param, type_annotation, path=path)
321
- params[ParamType.resolved][name] = resolved_param
322
327
 
323
- if isinstance(starmallow_param, Security):
324
- params[ParamType.security][name] = resolved_param
325
- # Allow `ResolvedParam(HTTPBearer())`
326
- elif isinstance(resolved_param.resolver, SecurityBaseResolver):
328
+ # Allow `ResolvedParam(HTTPBearer())` - treat as securty param
329
+ if isinstance(resolved_param.resolver, SecurityBaseResolver):
327
330
  params[ParamType.security][name] = resolved_param
331
+ else:
332
+ params[ParamType.resolved][name] = resolved_param
328
333
 
329
334
  continue
330
335
  elif lenient_issubclass(
@@ -394,7 +399,7 @@ class EndpointMixin:
394
399
  ) -> EndpointModel:
395
400
  params = self._get_params(endpoint, path)
396
401
 
397
- response_model = create_response_model(response_model or inspect.signature(endpoint).return_annotation)
402
+ response_model = create_response_model(response_model or get_typed_return_annotation(endpoint))
398
403
 
399
404
  return EndpointModel(
400
405
  path=path,
starmallow/params.py CHANGED
@@ -215,10 +215,12 @@ class NoParam:
215
215
 
216
216
 
217
217
  class ResolvedParam:
218
- def __init__(self, resolver: Callable[[Any], Any] = None):
218
+ def __init__(self, resolver: Callable[[Any], Any] = None, use_cache: bool = True):
219
219
  self.resolver = resolver
220
220
  # Set when we resolve the routes in the EnpointMixin
221
221
  self.resolver_params: Dict[ParamType, Dict[str, Param]] = {}
222
+ self.use_cache = use_cache
223
+ self.cache_key = (self.resolver, None)
222
224
 
223
225
 
224
226
  class Security(ResolvedParam):
@@ -227,9 +229,12 @@ class Security(ResolvedParam):
227
229
  self,
228
230
  resolver: SecurityBaseResolver = None,
229
231
  scopes: Optional[Sequence[str]] = None,
232
+ use_cache: bool = True,
230
233
  ):
231
234
  # Not calling super so that the resolver typehinting actually works in VSCode
232
235
  self.resolver = resolver
233
236
  # Set when we resolve the routes in the EnpointMixin
234
237
  self.resolver_params: Dict[ParamType, Dict[str, Param]] = {}
235
238
  self.scopes = scopes or []
239
+ self.use_cache = use_cache
240
+ self.cache_key = (self.resolver, tuple(sorted(set(self.scopes or []))))
@@ -0,0 +1,292 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ from contextlib import AsyncExitStack
5
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
6
+
7
+ import marshmallow as ma
8
+ import marshmallow.fields as mf
9
+ from marshmallow.error_store import ErrorStore
10
+ from marshmallow.utils import missing as missing_
11
+ from starlette.background import BackgroundTasks
12
+ from starlette.datastructures import FormData, Headers, QueryParams
13
+ from starlette.exceptions import HTTPException
14
+ from starlette.requests import HTTPConnection, Request
15
+ from starlette.responses import Response
16
+ from starlette.websockets import WebSocket
17
+
18
+ from starmallow.params import Param, ParamType, ResolvedParam
19
+ from starmallow.utils import (
20
+ is_async_gen_callable,
21
+ is_gen_callable,
22
+ lenient_issubclass,
23
+ solve_generator,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ async def get_body(
30
+ request: Request,
31
+ form_params: Dict[str, Param],
32
+ body_params: Dict[str, Param],
33
+ ) -> Union[FormData, bytes, Dict[str, Any]]:
34
+ is_body_form = bool(form_params)
35
+ should_process_body = is_body_form or body_params
36
+ try:
37
+ body: Any = None
38
+ if should_process_body:
39
+ if is_body_form:
40
+ body = await request.form()
41
+ stack = request.scope.get("starmallow_astack")
42
+ assert isinstance(stack, AsyncExitStack)
43
+ stack.push_async_callback(body.close)
44
+ else:
45
+ body_bytes = await request.body()
46
+ if body_bytes:
47
+ json_body: Any = missing_
48
+ content_type_value: str = request.headers.get("content-type")
49
+ if not content_type_value:
50
+ json_body = await request.json()
51
+ else:
52
+ main_type, sub_type = content_type_value.split('/')
53
+ if main_type == "application":
54
+ if sub_type == "json" or sub_type.endswith("+json"):
55
+ json_body = await request.json()
56
+ if json_body != missing_:
57
+ body = json_body
58
+ else:
59
+ body = body_bytes
60
+
61
+ return body
62
+ except Exception as e:
63
+ raise HTTPException(
64
+ status_code=400, detail="There was an error parsing the body",
65
+ ) from e
66
+
67
+
68
+ def request_params_to_args(
69
+ received_params: Union[Mapping[str, Any], QueryParams, Headers],
70
+ endpoint_params: Dict[str, Param],
71
+ ignore_namespace: bool = True,
72
+ ) -> Tuple[Dict[str, Any], ErrorStore]:
73
+ values = {}
74
+ error_store = ErrorStore()
75
+ for field_name, param in endpoint_params.items():
76
+ if isinstance(param.model, mf.Field):
77
+ try:
78
+ # Load model from specific param
79
+ values[field_name] = param.model.deserialize(
80
+ received_params.get(field_name, ma.missing),
81
+ field_name,
82
+ received_params,
83
+ )
84
+ except ma.ValidationError as error:
85
+ error_store.store_error(error.messages, field_name)
86
+ elif isinstance(param.model, ma.Schema):
87
+ try:
88
+ if ignore_namespace:
89
+ # Load model from entire params
90
+ values[field_name] = param.model.load(received_params, unknown=ma.EXCLUDE)
91
+ else:
92
+ values[field_name] = param.model.load(
93
+ received_params.get(field_name, ma.missing),
94
+ unknown=ma.EXCLUDE,
95
+ )
96
+ except ma.ValidationError as error:
97
+ error_store.store_error(error.messages)
98
+ else:
99
+ raise Exception(f'Invalid model type {type(param.model)}, expected marshmallow Schema or Field')
100
+
101
+ return values, error_store
102
+
103
+
104
+ async def resolve_basic_args(
105
+ request: Request | WebSocket,
106
+ response: Response,
107
+ background_tasks: BackgroundTasks,
108
+ params: Dict[ParamType, Dict[str, Param]],
109
+ ):
110
+ path_values, path_errors = request_params_to_args(
111
+ request.path_params,
112
+ params.get(ParamType.path),
113
+ )
114
+ query_values, query_errors = request_params_to_args(
115
+ request.query_params,
116
+ params.get(ParamType.query),
117
+ )
118
+ header_values, header_errors = request_params_to_args(
119
+ request.headers,
120
+ params.get(ParamType.header),
121
+ )
122
+ cookie_values, cookie_errors = request_params_to_args(
123
+ request.cookies,
124
+ params.get(ParamType.cookie),
125
+ )
126
+
127
+ form_params = params.get(ParamType.form)
128
+ body_params = params.get(ParamType.body)
129
+ body = await get_body(request, form_params, body_params)
130
+ form_values, form_errors = {}, None
131
+ json_values, json_errors = {}, None
132
+ if form_params:
133
+ form_values, form_errors = request_params_to_args(
134
+ body if body is not None and isinstance(body, FormData) else {},
135
+ form_params,
136
+ # If there is only one parameter defined, then don't namespace by the parameter name
137
+ # Otherwise we honor the namespace: https://fastapi.tiangolo.com/tutorial/body-multiple-params/
138
+ ignore_namespace=len(form_params) == 1,
139
+ )
140
+ if body_params:
141
+ json_values, json_errors = request_params_to_args(
142
+ body if body is not None and isinstance(body, Mapping) else {},
143
+ body_params,
144
+ # If there is only one parameter defined, then don't namespace by the parameter name
145
+ # Otherwise we honor the namespace: https://fastapi.tiangolo.com/tutorial/body-multiple-params/
146
+ ignore_namespace=len(body_params) == 1,
147
+ )
148
+
149
+ values = {
150
+ **path_values,
151
+ **query_values,
152
+ **header_values,
153
+ **cookie_values,
154
+ **form_values,
155
+ **json_values,
156
+ }
157
+ errors = {}
158
+ if path_errors.errors:
159
+ errors['path'] = path_errors.errors
160
+ if query_errors.errors:
161
+ errors['query'] = query_errors.errors
162
+ if header_errors.errors:
163
+ errors['header'] = header_errors.errors
164
+ if cookie_errors.errors:
165
+ errors['cookie'] = cookie_errors.errors
166
+ if form_errors and form_errors.errors:
167
+ errors['form'] = form_errors.errors
168
+ if json_errors and json_errors.errors:
169
+ errors['json'] = json_errors.errors
170
+
171
+ # Handle non-field params
172
+ for param_name, param_type in params.get(ParamType.noparam).items():
173
+ if lenient_issubclass(param_type, (HTTPConnection, Request, WebSocket)):
174
+ values[param_name] = request
175
+ elif lenient_issubclass(param_type, Response):
176
+ values[param_name] = response
177
+ elif lenient_issubclass(param_type, BackgroundTasks):
178
+ values[param_name] = background_tasks
179
+
180
+ return values, errors
181
+
182
+
183
+ async def call_resolver(
184
+ request: Request | WebSocket,
185
+ param_name: str,
186
+ resolved_param: ResolvedParam,
187
+ resolver_kwargs: Dict[str, Any],
188
+ ):
189
+ # Resolver can be a class with __call__ function
190
+ resolver = resolved_param.resolver
191
+ if not inspect.isfunction(resolver) and callable(resolver):
192
+ resolver = resolver.__call__
193
+ elif not inspect.isfunction(resolver):
194
+ raise TypeError(f'{param_name} = {resolved_param} resolver is not a function or callable')
195
+
196
+ if is_gen_callable(resolver) or is_async_gen_callable(resolver):
197
+ stack = request.scope.get("starmallow_astack")
198
+ assert isinstance(stack, AsyncExitStack)
199
+ return await solve_generator(
200
+ call=resolver, stack=stack, gen_kwargs=resolver_kwargs,
201
+ )
202
+ elif asyncio.iscoroutinefunction(resolver):
203
+ return await resolver(**resolver_kwargs)
204
+ else:
205
+ return resolver(**resolver_kwargs)
206
+
207
+
208
+ async def resolve_subparams(
209
+ request: Request | WebSocket,
210
+ response: Response,
211
+ background_tasks: BackgroundTasks,
212
+ params: Dict[str, ResolvedParam],
213
+ dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]],
214
+ ) -> Dict[str, Any]:
215
+ values = {}
216
+ for param_name, resolved_param in params.items():
217
+ if resolved_param.use_cache and resolved_param.cache_key in dependency_cache:
218
+ values[param_name] = dependency_cache[resolved_param.cache_key]
219
+ continue
220
+
221
+ resolver_kwargs, resolver_errors = await resolve_params(
222
+ request=request,
223
+ background_tasks=background_tasks,
224
+ response=response,
225
+ params=resolved_param.resolver_params,
226
+ dependency_cache=dependency_cache,
227
+ )
228
+
229
+ # Exit early since other resolvers may rely on this one, which could raise argument exceptions
230
+ if resolver_errors:
231
+ return None, resolver_errors
232
+
233
+ resolved_value = await call_resolver(request, param_name, resolved_param, resolver_kwargs)
234
+ values[param_name] = resolved_value
235
+ if resolved_param.use_cache:
236
+ dependency_cache[resolved_param.cache_key] = resolved_value
237
+
238
+ return values, {}
239
+
240
+
241
+ async def resolve_params(
242
+ request: Request | WebSocket,
243
+ params: Dict[ParamType, Dict[str, Param]],
244
+ background_tasks: Optional[BackgroundTasks] = None,
245
+ response: Optional[Response] = None,
246
+ dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
247
+ ) -> Tuple[Dict[str, Any], Dict[str, Union[Any, List, Dict]]]:
248
+ dependency_cache = dependency_cache or {}
249
+
250
+ if response is None:
251
+ response = Response()
252
+ del response.headers["content-length"]
253
+ response.status_code = None # type: ignore
254
+
255
+ if background_tasks is None:
256
+ background_tasks = BackgroundTasks()
257
+
258
+ # Process security params first so we can raise permission issues
259
+ security_values, errors = await resolve_subparams(
260
+ request,
261
+ response,
262
+ background_tasks,
263
+ params.get(ParamType.security),
264
+ dependency_cache=dependency_cache,
265
+ )
266
+ if errors:
267
+ return None, errors
268
+
269
+ arg_values, errors = await resolve_basic_args(
270
+ request,
271
+ response,
272
+ background_tasks,
273
+ params,
274
+ )
275
+ if errors:
276
+ return None, errors
277
+
278
+ resolved_values, errors = await resolve_subparams(
279
+ request,
280
+ response,
281
+ background_tasks,
282
+ params.get(ParamType.resolved),
283
+ dependency_cache=dependency_cache,
284
+ )
285
+ if errors:
286
+ return None, errors
287
+
288
+ return {
289
+ **security_values,
290
+ **arg_values,
291
+ **resolved_values,
292
+ }, {}
starmallow/routing.py CHANGED
@@ -2,20 +2,13 @@ import asyncio
2
2
  import functools
3
3
  import inspect
4
4
  import logging
5
- from contextlib import AsyncExitStack
6
5
  from enum import Enum, IntEnum
7
- from typing import Any, Callable, Coroutine, Dict, List, Mapping, Optional, Set, Tuple, Type, Union
6
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Tuple, Type, Union
8
7
 
9
8
  import marshmallow as ma
10
- import marshmallow.fields as mf
11
- from marshmallow.error_store import ErrorStore
12
- from marshmallow.utils import missing as missing_
13
9
  from starlette import routing
14
- from starlette.background import BackgroundTasks
15
10
  from starlette.concurrency import run_in_threadpool
16
- from starlette.datastructures import FormData, Headers, QueryParams
17
- from starlette.exceptions import HTTPException
18
- from starlette.requests import HTTPConnection, Request
11
+ from starlette.requests import Request
19
12
  from starlette.responses import Response
20
13
  from starlette.routing import BaseRoute, Match, compile_path, request_response
21
14
  from starlette.status import WS_1008_POLICY_VIOLATION
@@ -28,225 +21,24 @@ from starmallow.decorators import EndpointOptions
28
21
  from starmallow.endpoint import EndpointMixin, EndpointModel
29
22
  from starmallow.endpoints import APIHTTPEndpoint
30
23
  from starmallow.exceptions import RequestValidationError, WebSocketRequestValidationError
31
- from starmallow.params import Param, ParamType
24
+ from starmallow.request_resolver import resolve_params
32
25
  from starmallow.responses import JSONResponse
33
26
  from starmallow.types import DecoratedCallable
34
27
  from starmallow.utils import (
35
28
  create_response_model,
36
29
  generate_unique_id,
37
30
  get_name,
31
+ get_typed_signature,
38
32
  get_value_or_default,
39
- is_async_gen_callable,
40
33
  is_body_allowed_for_status_code,
41
- is_gen_callable,
42
34
  is_marshmallow_field,
43
35
  is_marshmallow_schema,
44
- lenient_issubclass,
45
- solve_generator,
46
36
  )
47
37
  from starmallow.websockets import APIWebSocket
48
38
 
49
39
  logger = logging.getLogger(__name__)
50
40
 
51
41
 
52
- async def get_body(
53
- request: Request,
54
- endpoint_model: "EndpointModel",
55
- ) -> Union[FormData, bytes, Dict[str, Any]]:
56
- is_body_form = bool(endpoint_model.flat_params[ParamType.form])
57
- should_process_body = is_body_form or endpoint_model.flat_params[ParamType.body]
58
- try:
59
- body: Any = None
60
- if should_process_body:
61
- if is_body_form:
62
- body = await request.form()
63
- stack = request.scope.get("starmallow_astack")
64
- assert isinstance(stack, AsyncExitStack)
65
- stack.push_async_callback(body.close)
66
- else:
67
- body_bytes = await request.body()
68
- if body_bytes:
69
- json_body: Any = missing_
70
- content_type_value: str = request.headers.get("content-type")
71
- if not content_type_value:
72
- json_body = await request.json()
73
- else:
74
- main_type, sub_type = content_type_value.split('/')
75
- if main_type == "application":
76
- if sub_type == "json" or sub_type.endswith("+json"):
77
- json_body = await request.json()
78
- if json_body != missing_:
79
- body = json_body
80
- else:
81
- body = body_bytes
82
-
83
- return body
84
- except Exception as e:
85
- raise HTTPException(
86
- status_code=400, detail="There was an error parsing the body",
87
- ) from e
88
-
89
-
90
- def request_params_to_args(
91
- received_params: Union[Mapping[str, Any], QueryParams, Headers],
92
- endpoint_params: Dict[str, Param],
93
- ignore_namespace: bool = True,
94
- ) -> Tuple[Dict[str, Any], ErrorStore]:
95
- values = {}
96
- error_store = ErrorStore()
97
- for field_name, param in endpoint_params.items():
98
- if isinstance(param.model, mf.Field):
99
- try:
100
- # Load model from specific param
101
- values[field_name] = param.model.deserialize(
102
- received_params.get(field_name, ma.missing),
103
- field_name,
104
- received_params,
105
- )
106
- except ma.ValidationError as error:
107
- error_store.store_error(error.messages, field_name)
108
- elif isinstance(param.model, ma.Schema):
109
- try:
110
- if ignore_namespace:
111
- # Load model from entire params
112
- values[field_name] = param.model.load(received_params, unknown=ma.EXCLUDE)
113
- else:
114
- values[field_name] = param.model.load(
115
- received_params.get(field_name, ma.missing),
116
- unknown=ma.EXCLUDE,
117
- )
118
- except ma.ValidationError as error:
119
- error_store.store_error(error.messages)
120
- else:
121
- raise Exception(f'Invalid model type {type(param.model)}, expected marshmallow Schema or Field')
122
-
123
- return values, error_store
124
-
125
-
126
- async def get_request_args(
127
- request: Request | WebSocket,
128
- endpoint_model: EndpointModel,
129
- background_tasks: Optional[BackgroundTasks] = None,
130
- response: Optional[Response] = None,
131
- ) -> Tuple[Dict[str, Any], Dict[str, Union[Any, List, Dict]]]:
132
- path_values, path_errors = request_params_to_args(
133
- request.path_params,
134
- endpoint_model.path_params,
135
- )
136
- query_values, query_errors = request_params_to_args(
137
- request.query_params,
138
- endpoint_model.query_params,
139
- )
140
- header_values, header_errors = request_params_to_args(
141
- request.headers,
142
- endpoint_model.header_params,
143
- )
144
- cookie_values, cookie_errors = request_params_to_args(
145
- request.cookies,
146
- endpoint_model.cookie_params,
147
- )
148
-
149
- body = await get_body(request, endpoint_model)
150
- form_values, form_errors = {}, None
151
- json_values, json_errors = {}, None
152
- form_params = endpoint_model.form_params
153
- if form_params:
154
- form_values, form_errors = request_params_to_args(
155
- body if body is not None and isinstance(body, FormData) else {},
156
- form_params,
157
- # If there is only one parameter defined, then don't namespace by the parameter name
158
- # Otherwise we honor the namespace: https://fastapi.tiangolo.com/tutorial/body-multiple-params/
159
- ignore_namespace=len(form_params) == 1,
160
- )
161
- body_params = endpoint_model.body_params
162
- if body_params:
163
- json_values, json_errors = request_params_to_args(
164
- body if body is not None and isinstance(body, Mapping) else {},
165
- body_params,
166
- # If there is only one parameter defined, then don't namespace by the parameter name
167
- # Otherwise we honor the namespace: https://fastapi.tiangolo.com/tutorial/body-multiple-params/
168
- ignore_namespace=len(body_params) == 1,
169
- )
170
-
171
- values = {
172
- **path_values,
173
- **query_values,
174
- **header_values,
175
- **cookie_values,
176
- **form_values,
177
- **json_values,
178
- }
179
- errors = {}
180
- if path_errors.errors:
181
- errors['path'] = path_errors.errors
182
- if query_errors.errors:
183
- errors['query'] = query_errors.errors
184
- if header_errors.errors:
185
- errors['header'] = header_errors.errors
186
- if cookie_errors.errors:
187
- errors['cookie'] = cookie_errors.errors
188
- if form_errors and form_errors.errors:
189
- errors['form'] = form_errors.errors
190
- if json_errors and json_errors.errors:
191
- errors['json'] = json_errors.errors
192
-
193
- # Exit before resolving ResolvedParams as that could cause function call exceptions
194
- if errors:
195
- return None, errors
196
-
197
- if response is None:
198
- response = Response()
199
- del response.headers["content-length"]
200
- response.status_code = None # type: ignore
201
-
202
- # Handle non-field params
203
- for param_name, param_type in endpoint_model.non_field_params.items():
204
- if lenient_issubclass(param_type, (HTTPConnection, Request, WebSocket)):
205
- values[param_name] = request
206
- elif lenient_issubclass(param_type, Response):
207
- values[param_name] = response
208
- elif lenient_issubclass(param_type, BackgroundTasks):
209
- if background_tasks is None:
210
- background_tasks = BackgroundTasks()
211
- values[param_name] = background_tasks
212
-
213
- # Handle resolved params
214
- for param_name, resolved_param in endpoint_model.resolved_params.items():
215
- # Get all known arguments for the resolver.
216
- resolver_kwargs = {}
217
- for name, parameter in inspect.signature(resolved_param.resolver).parameters.items():
218
- if lenient_issubclass(parameter.annotation, (HTTPConnection, Request, WebSocket)):
219
- resolver_kwargs[name] = request
220
- elif lenient_issubclass(parameter.annotation, Response):
221
- resolver_kwargs[name] = response
222
- elif lenient_issubclass(parameter.annotation, BackgroundTasks):
223
- if background_tasks is None:
224
- background_tasks = BackgroundTasks()
225
- resolver_kwargs[name] = background_tasks
226
- elif name in values:
227
- resolver_kwargs[name] = values[name]
228
-
229
- # Resolver can be a class with __call__ function
230
- resolver = resolved_param.resolver
231
- if not inspect.isfunction(resolver) and callable(resolver):
232
- resolver = resolver.__call__
233
- elif not inspect.isfunction(resolver):
234
- raise TypeError(f'{param_name} = {resolved_param} resolver is not a function or callable')
235
-
236
- if is_gen_callable(resolver) or is_async_gen_callable(resolver):
237
- stack = request.scope.get("starmallow_astack")
238
- assert isinstance(stack, AsyncExitStack)
239
- values[param_name] = await solve_generator(
240
- call=resolver, stack=stack, gen_kwargs=resolver_kwargs,
241
- )
242
- elif asyncio.iscoroutinefunction(resolver):
243
- values[param_name] = await resolver(**resolver_kwargs)
244
- else:
245
- values[param_name] = resolver(**resolver_kwargs)
246
-
247
- return values, errors
248
-
249
-
250
42
  async def run_endpoint_function(
251
43
  endpoint_model: EndpointModel,
252
44
  values: Dict[str, Any],
@@ -255,7 +47,7 @@ async def run_endpoint_function(
255
47
 
256
48
  kwargs = {
257
49
  name: values[name]
258
- for name in inspect.signature(endpoint_model.call).parameters
50
+ for name in get_typed_signature(endpoint_model.call).parameters
259
51
  if name in values
260
52
  }
261
53
 
@@ -283,7 +75,7 @@ def get_request_handler(
283
75
  assert endpoint_model.call is not None, "dependant.call must be a function"
284
76
 
285
77
  async def app(request: Request) -> Response:
286
- values, errors = await get_request_args(request, endpoint_model)
78
+ values, errors = await resolve_params(request, endpoint_model.params)
287
79
 
288
80
  if errors:
289
81
  raise RequestValidationError(errors)
@@ -318,7 +110,7 @@ def get_websocker_hander(
318
110
  assert endpoint_model.call is not None, "dependant.call must be a function"
319
111
 
320
112
  async def app(websocket: WebSocket) -> None:
321
- values, errors = await get_request_args(websocket, endpoint_model)
113
+ values, errors = await resolve_params(websocket, endpoint_model.params)
322
114
 
323
115
  if errors:
324
116
  await websocket.close(code=WS_1008_POLICY_VIOLATION)
@@ -16,7 +16,7 @@ class APIKeyIn(Enum):
16
16
  cookie = "cookie"
17
17
 
18
18
 
19
- @ma_dataclass
19
+ @ma_dataclass(frozen=True)
20
20
  class APIKeyModel(SecurityBase):
21
21
  type: ClassVar[SecurityTypes] = SecurityTypes.apiKey
22
22
  in_: APIKeyIn = required_field(data_key='in')
@@ -14,7 +14,7 @@ class SecurityTypes(Enum):
14
14
  openIdConnect = "openIdConnect"
15
15
 
16
16
 
17
- @ma_dataclass
17
+ @ma_dataclass(frozen=True)
18
18
  class SecurityBase:
19
19
  type: ClassVar[SecurityTypes]
20
20
  description: Optional[str] = None
@@ -11,13 +11,13 @@ from starmallow.security.base import SecurityBase, SecurityBaseResolver, Securit
11
11
  from starmallow.security.utils import get_authorization_scheme_param
12
12
 
13
13
 
14
- @ma_dataclass
14
+ @ma_dataclass(frozen=True)
15
15
  class HTTPBasicCredentials:
16
16
  username: str
17
17
  password: str
18
18
 
19
19
 
20
- @ma_dataclass
20
+ @ma_dataclass(frozen=True)
21
21
  class HTTPAuthorizationCredentials:
22
22
  '''
23
23
  Will hold the parsed HTTP Authorization Creditials like bearer tokens.
@@ -30,14 +30,14 @@ class HTTPAuthorizationCredentials:
30
30
  credentials: str
31
31
 
32
32
 
33
- @ma_dataclass
33
+ @ma_dataclass(frozen=True)
34
34
  class HTTPBaseModel(SecurityBase):
35
35
  type: ClassVar[SecurityTypes] = SecurityTypes.http
36
36
  description: Optional[str] = None
37
37
  scheme: str = None
38
38
 
39
39
 
40
- @ma_dataclass
40
+ @ma_dataclass(frozen=True)
41
41
  class HTTPBearerModel(HTTPBaseModel):
42
42
  scheme: str = "bearer"
43
43
  bearerFormat: Optional[str] = None
@@ -72,7 +72,6 @@ class HTTPBase(SecurityBaseResolver):
72
72
  return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
73
73
 
74
74
 
75
- @ma_dataclass
76
75
  class HTTPBasic(HTTPBase):
77
76
 
78
77
  def __init__(
@@ -125,7 +124,6 @@ class HTTPBasic(HTTPBase):
125
124
  return HTTPBasicCredentials(username=username, password=password)
126
125
 
127
126
 
128
- @ma_dataclass
129
127
  class HTTPBearer(HTTPBase):
130
128
 
131
129
  def __init__(
@@ -163,7 +161,6 @@ class HTTPBearer(HTTPBase):
163
161
  return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
164
162
 
165
163
 
166
- @ma_dataclass
167
164
  class HTTPDigest(HTTPBase):
168
165
 
169
166
  def __init__(
@@ -13,7 +13,7 @@ from starmallow.security.utils import get_authorization_scheme_param
13
13
 
14
14
 
15
15
  # region - Models
16
- @ma_dataclass
16
+ @ma_dataclass(frozen=True)
17
17
  class OAuthFlow:
18
18
  refreshUrl: Optional[str] = optional_field()
19
19
  scopes: Dict[str, str] = optional_field(default_factory=dict)
@@ -31,28 +31,28 @@ class OAuthFlow:
31
31
  }
32
32
 
33
33
 
34
- @ma_dataclass
34
+ @ma_dataclass(frozen=True)
35
35
  class OAuthFlowImplicit(OAuthFlow):
36
36
  authorizationUrl: str = required_field()
37
37
 
38
38
 
39
- @ma_dataclass
39
+ @ma_dataclass(frozen=True)
40
40
  class OAuthFlowPassword(OAuthFlow):
41
41
  tokenUrl: str = required_field()
42
42
 
43
43
 
44
- @ma_dataclass
44
+ @ma_dataclass(frozen=True)
45
45
  class OAuthFlowClientCredentials(OAuthFlow):
46
46
  tokenUrl: str = required_field()
47
47
 
48
48
 
49
- @ma_dataclass
49
+ @ma_dataclass(frozen=True)
50
50
  class OAuthFlowAuthorizationCode(OAuthFlow):
51
51
  authorizationUrl: str = required_field()
52
52
  tokenUrl: str = required_field()
53
53
 
54
54
 
55
- @ma_dataclass
55
+ @ma_dataclass(frozen=True)
56
56
  class OAuthFlowsModel:
57
57
  implicit: Optional[OAuthFlowImplicit] = optional_field()
58
58
  password: Optional[OAuthFlowPassword] = optional_field()
@@ -72,7 +72,7 @@ class OAuthFlowsModel:
72
72
  }
73
73
 
74
74
 
75
- @ma_dataclass
75
+ @ma_dataclass(frozen=True)
76
76
  class OAuth2Model(SecurityBase):
77
77
  type: SecurityTypes = SecurityTypes.oauth2
78
78
  flows: OAuthFlowsModel = required_field()
@@ -9,7 +9,7 @@ from starmallow.dataclasses import required_field
9
9
  from starmallow.security.base import SecurityBase, SecurityBaseResolver, SecurityTypes
10
10
 
11
11
 
12
- @ma_dataclass
12
+ @ma_dataclass(frozen=True)
13
13
  class OpenIdConnectModel(SecurityBase):
14
14
  type: SecurityTypes = SecurityTypes.openIdConnect
15
15
  openIdConnectUrl: str = required_field()
starmallow/utils.py CHANGED
@@ -9,11 +9,13 @@ from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
9
9
  from dataclasses import is_dataclass
10
10
  from decimal import Decimal
11
11
  from enum import Enum
12
+ from types import NoneType
12
13
  from typing import (
13
14
  TYPE_CHECKING,
14
15
  Any,
15
16
  Callable,
16
17
  Dict,
18
+ ForwardRef,
17
19
  FrozenSet,
18
20
  List,
19
21
  Mapping,
@@ -22,6 +24,7 @@ from typing import (
22
24
  Tuple,
23
25
  Type,
24
26
  Union,
27
+ _eval_type,
25
28
  _GenericAlias,
26
29
  get_args,
27
30
  get_origin,
@@ -303,7 +306,7 @@ def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) ->
303
306
 
304
307
 
305
308
  def create_response_model(type_: Type[Any]) -> ma.Schema | mf.Field | None:
306
- if type_ in [inspect._empty, None] or issubclass(type_, Response):
309
+ if type_ in [inspect._empty, None] or (inspect.isclass(type_) and issubclass(type_, Response)):
307
310
  return None
308
311
 
309
312
  field = get_model_field(type_)
@@ -337,3 +340,55 @@ def get_name(endpoint: Callable) -> str:
337
340
  if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
338
341
  return endpoint.__qualname__
339
342
  return endpoint.__class__.__name__
343
+
344
+
345
+ # Functions that help resolve forward references like
346
+ # def foo(a: 'str') -> 'UnresolvedClass': pass
347
+ # inspect.signature returns the literal string instead of the actual type.
348
+ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
349
+ signature = inspect.signature(call)
350
+ globalns = getattr(call, "__globals__", {})
351
+ typed_params = [
352
+ inspect.Parameter(
353
+ name=param.name,
354
+ kind=param.kind,
355
+ default=param.default,
356
+ annotation=get_typed_annotation(param.annotation, globalns),
357
+ )
358
+ for param in signature.parameters.values()
359
+ ]
360
+ typed_signature = inspect.Signature(typed_params)
361
+ return typed_signature
362
+
363
+
364
+ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
365
+ if isinstance(annotation, str):
366
+ annotation = ForwardRef(annotation)
367
+ annotation = evaluate_forwardref(annotation, globalns, globalns)
368
+ return annotation
369
+
370
+
371
+ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
372
+ signature = inspect.signature(call)
373
+ annotation = signature.return_annotation
374
+
375
+ if annotation is inspect.Signature.empty:
376
+ return None
377
+
378
+ globalns = getattr(call, "__globals__", {})
379
+ return get_typed_annotation(annotation, globalns)
380
+
381
+
382
+
383
+ def evaluate_forwardref(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any:
384
+ """Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
385
+ if value is None:
386
+ value = NoneType
387
+ elif isinstance(value, str):
388
+ value = ForwardRef(value, is_argument=False, is_class=True)
389
+
390
+ try:
391
+ return _eval_type(value, globalns, localns) # type: ignore
392
+ except NameError:
393
+ # the point of this function is to be tolerant to this case
394
+ return value
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: starmallow
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: StarMallow framework
5
5
  Project-URL: Homepage, https://github.com/mvanderlee/starmallow
6
6
  Author-email: Michiel Vanderlee <jmt.vanderlee@gmail.com>
@@ -1,5 +1,5 @@
1
- starmallow/__init__.py,sha256=slSraQMA-1h7BqB09favkpD1LL7yREP4RLaDIVzs6QI,322
2
- starmallow/applications.py,sha256=aoaHg_MgmRqRdaM8MFzeqsHqXuOQDWk6G70UQbdm_Fk,29469
1
+ starmallow/__init__.py,sha256=-ln4p19cwSrpztpgmO3aEauXyq2lSCuXvZH4ece_AlA,322
2
+ starmallow/applications.py,sha256=oZxxLof82QdfK44-q6wUHq9z8_sRGSEVzSQkNvfGKIg,29708
3
3
  starmallow/concurrency.py,sha256=MVRjo4Vqss_yqhaoeVt3xb7rLaSuAq_q9uYgTwbsojE,1375
4
4
  starmallow/constants.py,sha256=u0h8cJKhJY0oIZqzr7wpEZG2bPLrw5FroMnn3d8KBNQ,129
5
5
  starmallow/dataclasses.py,sha256=ap9DInvQjH2AyI4MAAnbDEuNnbPb94PigaNmEb7AQU8,2658
@@ -7,18 +7,19 @@ starmallow/datastructures.py,sha256=iH_KJuJ6kBCWEsnHFLdA3iyb6ZxhfdMHYrJlhiEZtDU,
7
7
  starmallow/decorators.py,sha256=SBrzmKxzF2q7hNCW_V7j0UV461QERSh9OTtTdTFi6Kg,3597
8
8
  starmallow/delimited_field.py,sha256=gonWgYg6G5xH2yXAyfDgkePmQ8dUaRSp2hdJ3mCfOBw,3466
9
9
  starmallow/docs.py,sha256=eA39LunVMEoPU5ge4qxm2eiJbrFTUSUu5EhG1L_LKxk,6268
10
- starmallow/endpoint.py,sha256=5oHoWu2dSG_nnctRM60ZPkB1R9Mh4cJv5yxRbYEZHec,15593
10
+ starmallow/endpoint.py,sha256=T4VH7CSM-M157ijHUVihoru8BKmAzDBMRMBpeww5JR8,15841
11
11
  starmallow/endpoints.py,sha256=UrwVZCxbmWI20iNtJ0oXxo4d3-y12TjsOGs_jnStTiU,939
12
12
  starmallow/exception_handlers.py,sha256=gr2qLYWEtsIEH28n7OreEiiLVz6Y7b6osRyS9esJbBk,891
13
13
  starmallow/exceptions.py,sha256=vabtPJkTmtCdC8_2OPBE8Osz0v0KxaSOX6IWf1jgNkc,872
14
14
  starmallow/fields.py,sha256=arrTabCYoJFZcoY69EZTBH3YUg7CUSr3-zYLiAjYLTM,1238
15
- starmallow/params.py,sha256=XRWIFLm2H5jQUIo4gm5Oi4xVqGNosaQSSi7QYqjJyxQ,7000
15
+ starmallow/params.py,sha256=bb5hHps7zabhYKpQjqqPqPSxtG4gvFCZtL9CptgH4sY,7253
16
+ starmallow/request_resolver.py,sha256=6NikzrNpQE82j-iasTipeBFblIx-WfX3l6EvdIXMFkY,10450
16
17
  starmallow/responses.py,sha256=k2pf_m21ykf_FECdODUz400pMucMJJf_Zm8TXFujvaU,2012
17
- starmallow/routing.py,sha256=bq7RThRbpPzMUKgkBhqUimeKx5KAXZt-5ydYmBMLYgg,47312
18
+ starmallow/routing.py,sha256=QPBl4X8M2bjlj5EGt21KLyYuroLRg60ZvNdIm74IbQQ,38743
18
19
  starmallow/schema_generator.py,sha256=BKtXVQoNFWoAIEtiRNylWls_7nyFIshy3_myooogjoI,17806
19
20
  starmallow/serializers.py,sha256=rBEKMNgONgz_bai12uDvAEMCI_aEFGsqMSeIoWtlrOI,12514
20
21
  starmallow/types.py,sha256=8GXWjvzXQhF5NMHf14fbid6uErxVd1Xk_w2I4FoUgZ4,717
21
- starmallow/utils.py,sha256=MS44NCYDpKA3JRCvJ7lRhrBK57wT5T8QlylZxlcZLEU,9484
22
+ starmallow/utils.py,sha256=AkVesTbYnshaXZqdgpc09TpZF07LIUeCOwOlvCaQrAI,11412
22
23
  starmallow/websockets.py,sha256=yIz3LzTBMNclpEoG7oTMbQwxbcdKNU6M8XcqZMyBTuA,2223
23
24
  starmallow/ext/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
25
  starmallow/ext/marshmallow/__init__.py,sha256=33jENGdfPq4-CDG0LOmN3KOGW1pXTy7a2oMwy4hrYzM,208
@@ -26,13 +27,13 @@ starmallow/ext/marshmallow/openapi.py,sha256=5aGvbwLGVucsVhXExpYeyt8n5dQTzazrf-n
26
27
  starmallow/middleware/__init__.py,sha256=vtNm85Z9pUPjJd-9giJGg3YL1wO7Jm5ooXBm31pDOK8,53
27
28
  starmallow/middleware/asyncexitstack.py,sha256=0GPhQSxqSVmAiVIqBIN5slueWYZ8bwh9f2bBPy7AbP0,1191
28
29
  starmallow/security/__init__.py,sha256=1rQFBIGnEbE51XDZSSi9NgPjXLScFq3RoLu4vk0KVYw,191
29
- starmallow/security/api_key.py,sha256=v2a3FHv1c--F2guiJ3wxKQi5k0nIcl40d4tqMPFyb44,3131
30
- starmallow/security/base.py,sha256=6ybCCf22t8GNR4RZXIzOfFEGws28S-KVqri-gHHXVCU,1131
31
- starmallow/security/http.py,sha256=rMwBYQQRil5iVjM87b0gsCENSFQXiqsdAfy0g6Qmvt8,6597
32
- starmallow/security/oauth2.py,sha256=PWdrgqUeijxzRAQilXMXRb9DnA-U2-xMQ5LKL4S66t8,9914
33
- starmallow/security/open_id_connect_url.py,sha256=ykokB7mJYu4pFsHW4Ro1y71h-5H11mt90jyv64EIQBM,1386
30
+ starmallow/security/api_key.py,sha256=E326Sxb_qhWbfN70vHuq4KEJcToW1Fxw0qGL0pHmQjc,3144
31
+ starmallow/security/base.py,sha256=_7PR7tepr0CHJxg6uTc_cBAeY90jBS5gu8z5598yEM0,1144
32
+ starmallow/security/http.py,sha256=cpGjM1kFDq3i_bOY96kMkf4cspBUxFkkET9lTK3NA-0,6607
33
+ starmallow/security/oauth2.py,sha256=1nv1580PY4cwgu5gzpQCf2MfMNv2Cfv05753AUHPOhQ,10005
34
+ starmallow/security/open_id_connect_url.py,sha256=IPsL2YzWc2mPwJbrUn6oFRTi7uRAG6mR62CGwmzBs1k,1399
34
35
  starmallow/security/utils.py,sha256=bd8T0YM7UQD5ATKucr1bNtAvz_Y3__dVNAv5UebiPvc,293
35
- starmallow-0.3.1.dist-info/METADATA,sha256=6RbpVd1OozE5ggQjC9wDR5rk6_d_VLnFBj4tpTju0bk,5651
36
- starmallow-0.3.1.dist-info/WHEEL,sha256=y1bSCq4r5i4nMmpXeUJMqs3ipKvkZObrIXSvJHm1qCI,87
37
- starmallow-0.3.1.dist-info/licenses/LICENSE.md,sha256=QelyGgOzch8CXzy6HrYwHh7nmj0rlWkDA0YzmZ3CPaY,1084
38
- starmallow-0.3.1.dist-info/RECORD,,
36
+ starmallow-0.3.3.dist-info/METADATA,sha256=Xtq-b_MKBANM7Frnbvx1u3CN9P_b4Po2STs_3Jjfitk,5651
37
+ starmallow-0.3.3.dist-info/WHEEL,sha256=y1bSCq4r5i4nMmpXeUJMqs3ipKvkZObrIXSvJHm1qCI,87
38
+ starmallow-0.3.3.dist-info/licenses/LICENSE.md,sha256=QelyGgOzch8CXzy6HrYwHh7nmj0rlWkDA0YzmZ3CPaY,1084
39
+ starmallow-0.3.3.dist-info/RECORD,,