strawberry-graphql 0.239.1__py3-none-any.whl → 0.240.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strawberry/channels/testing.py +2 -2
- strawberry/dataloader.py +2 -5
- strawberry/extensions/base_extension.py +10 -3
- strawberry/extensions/runner.py +8 -37
- strawberry/http/__init__.py +2 -1
- strawberry/http/async_base_view.py +17 -1
- strawberry/schema/base.py +3 -3
- strawberry/schema/execute.py +123 -111
- strawberry/schema/schema.py +117 -53
- strawberry/schema/subscribe.py +154 -0
- strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +50 -74
- strawberry/subscriptions/protocols/graphql_transport_ws/types.py +10 -2
- strawberry/subscriptions/protocols/graphql_ws/handlers.py +45 -52
- strawberry/subscriptions/protocols/graphql_ws/types.py +1 -0
- strawberry/types/execution.py +15 -0
- strawberry/types/union.py +1 -4
- strawberry/utils/__init__.py +5 -0
- strawberry/utils/typing.py +2 -2
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/METADATA +1 -1
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/RECORD +23 -22
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/LICENSE +0 -0
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/WHEEL +0 -0
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/entry_points.txt +0 -0
strawberry/channels/testing.py
CHANGED
@@ -144,9 +144,9 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
|
|
144
144
|
message_type = response["type"]
|
145
145
|
if message_type == NextMessage.type:
|
146
146
|
payload = NextMessage(**response).payload
|
147
|
-
ret = ExecutionResult(payload
|
147
|
+
ret = ExecutionResult(payload.get("data"), None)
|
148
148
|
if "errors" in payload:
|
149
|
-
ret.errors = self.process_errors(payload
|
149
|
+
ret.errors = self.process_errors(payload.get("errors") or [])
|
150
150
|
ret.extensions = payload.get("extensions", None)
|
151
151
|
yield ret
|
152
152
|
elif message_type == ErrorMessage.type:
|
strawberry/dataloader.py
CHANGED
@@ -208,14 +208,11 @@ class DataLoader(Generic[K, T]):
|
|
208
208
|
|
209
209
|
|
210
210
|
def should_create_new_batch(loader: DataLoader, batch: Batch) -> bool:
|
211
|
-
|
211
|
+
return bool(
|
212
212
|
batch.dispatched
|
213
213
|
or loader.max_batch_size
|
214
214
|
and len(batch) >= loader.max_batch_size
|
215
|
-
)
|
216
|
-
return True
|
217
|
-
|
218
|
-
return False
|
215
|
+
)
|
219
216
|
|
220
217
|
|
221
218
|
def get_current_batch(loader: DataLoader) -> Batch:
|
@@ -21,9 +21,11 @@ class LifecycleStep(Enum):
|
|
21
21
|
class SchemaExtension:
|
22
22
|
execution_context: ExecutionContext
|
23
23
|
|
24
|
-
|
25
|
-
|
26
|
-
|
24
|
+
# to support extensions that still use the old signature
|
25
|
+
# we have an optional argument here for ease of initialization.
|
26
|
+
def __init__(
|
27
|
+
self, *, execution_context: ExecutionContext | None = None
|
28
|
+
) -> None: ...
|
27
29
|
def on_operation( # type: ignore
|
28
30
|
self,
|
29
31
|
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
|
@@ -61,6 +63,11 @@ class SchemaExtension:
|
|
61
63
|
def get_results(self) -> AwaitableOrValue[Dict[str, Any]]:
|
62
64
|
return {}
|
63
65
|
|
66
|
+
@classmethod
|
67
|
+
def _implements_resolve(cls) -> bool:
|
68
|
+
"""Whether the extension implements the resolve method."""
|
69
|
+
return cls.resolve is not SchemaExtension.resolve
|
70
|
+
|
64
71
|
|
65
72
|
Hook = Callable[[SchemaExtension], AsyncIteratorOrIterator[None]]
|
66
73
|
|
strawberry/extensions/runner.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import inspect
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
5
|
-
|
6
|
-
from graphql import MiddlewareManager
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
7
5
|
|
8
6
|
from strawberry.extensions.context import (
|
9
7
|
ExecutingContextManager,
|
@@ -13,11 +11,11 @@ from strawberry.extensions.context import (
|
|
13
11
|
)
|
14
12
|
from strawberry.utils.await_maybe import await_maybe
|
15
13
|
|
16
|
-
from . import SchemaExtension
|
17
|
-
|
18
14
|
if TYPE_CHECKING:
|
19
15
|
from strawberry.types import ExecutionContext
|
20
16
|
|
17
|
+
from . import SchemaExtension
|
18
|
+
|
21
19
|
|
22
20
|
class SchemaExtensionsRunner:
|
23
21
|
extensions: List[SchemaExtension]
|
@@ -25,27 +23,10 @@ class SchemaExtensionsRunner:
|
|
25
23
|
def __init__(
|
26
24
|
self,
|
27
25
|
execution_context: ExecutionContext,
|
28
|
-
extensions: Optional[
|
29
|
-
List[Union[Type[SchemaExtension], SchemaExtension]]
|
30
|
-
] = None,
|
26
|
+
extensions: Optional[List[SchemaExtension]] = None,
|
31
27
|
) -> None:
|
32
28
|
self.execution_context = execution_context
|
33
|
-
|
34
|
-
if not extensions:
|
35
|
-
extensions = []
|
36
|
-
|
37
|
-
init_extensions: List[SchemaExtension] = []
|
38
|
-
|
39
|
-
for extension in extensions:
|
40
|
-
# If the extension has already been instantiated then set the
|
41
|
-
# `execution_context` attribute
|
42
|
-
if isinstance(extension, SchemaExtension):
|
43
|
-
extension.execution_context = execution_context
|
44
|
-
init_extensions.append(extension)
|
45
|
-
else:
|
46
|
-
init_extensions.append(extension(execution_context=execution_context))
|
47
|
-
|
48
|
-
self.extensions = init_extensions
|
29
|
+
self.extensions = extensions or []
|
49
30
|
|
50
31
|
def operation(self) -> OperationContextManager:
|
51
32
|
return OperationContextManager(self.extensions)
|
@@ -61,29 +42,19 @@ class SchemaExtensionsRunner:
|
|
61
42
|
|
62
43
|
def get_extensions_results_sync(self) -> Dict[str, Any]:
|
63
44
|
data: Dict[str, Any] = {}
|
64
|
-
|
65
45
|
for extension in self.extensions:
|
66
46
|
if inspect.iscoroutinefunction(extension.get_results):
|
67
47
|
msg = "Cannot use async extension hook during sync execution"
|
68
48
|
raise RuntimeError(msg)
|
69
|
-
|
70
49
|
data.update(extension.get_results()) # type: ignore
|
71
50
|
|
72
51
|
return data
|
73
52
|
|
74
|
-
async def get_extensions_results(self) -> Dict[str, Any]:
|
53
|
+
async def get_extensions_results(self, ctx: ExecutionContext) -> Dict[str, Any]:
|
75
54
|
data: Dict[str, Any] = {}
|
76
55
|
|
77
56
|
for extension in self.extensions:
|
78
|
-
|
79
|
-
data.update(results)
|
57
|
+
data.update(await await_maybe(extension.get_results()))
|
80
58
|
|
59
|
+
data.update(ctx.extensions_results)
|
81
60
|
return data
|
82
|
-
|
83
|
-
def as_middleware_manager(self, *additional_middlewares: Any) -> MiddlewareManager:
|
84
|
-
middlewares = tuple(self.extensions) + additional_middlewares
|
85
|
-
|
86
|
-
return MiddlewareManager(*middlewares)
|
87
|
-
|
88
|
-
|
89
|
-
__all__ = ["SchemaExtensionsRunner"]
|
strawberry/http/__init__.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import json
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
6
|
-
from typing_extensions import TypedDict
|
6
|
+
from typing_extensions import Literal, TypedDict
|
7
7
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
from strawberry.types import ExecutionResult
|
@@ -33,6 +33,7 @@ class GraphQLRequestData:
|
|
33
33
|
query: Optional[str]
|
34
34
|
variables: Optional[Dict[str, Any]]
|
35
35
|
operation_name: Optional[str]
|
36
|
+
protocol: Literal["http", "multipart-subscription"] = "http"
|
36
37
|
|
37
38
|
|
38
39
|
def parse_query_params(params: Dict[str, str]) -> Dict[str, Any]:
|
@@ -14,6 +14,7 @@ from typing import (
|
|
14
14
|
Tuple,
|
15
15
|
Union,
|
16
16
|
)
|
17
|
+
from typing_extensions import Literal
|
17
18
|
|
18
19
|
from graphql import GraphQLError
|
19
20
|
|
@@ -121,6 +122,15 @@ class AsyncBaseHTTPView(
|
|
121
122
|
|
122
123
|
assert self.schema
|
123
124
|
|
125
|
+
if request_data.protocol == "multipart-subscription":
|
126
|
+
return await self.schema.subscribe(
|
127
|
+
request_data.query, # type: ignore
|
128
|
+
variable_values=request_data.variables,
|
129
|
+
context_value=context,
|
130
|
+
root_value=root_value,
|
131
|
+
operation_name=request_data.operation_name,
|
132
|
+
)
|
133
|
+
|
124
134
|
return await self.schema.execute(
|
125
135
|
request_data.query,
|
126
136
|
root_value=root_value,
|
@@ -233,7 +243,7 @@ class AsyncBaseHTTPView(
|
|
233
243
|
self, stream: Callable[[], AsyncGenerator[str, None]]
|
234
244
|
) -> Callable[[], AsyncGenerator[str, None]]:
|
235
245
|
"""Adds a heartbeat to the stream, to prevent the connection from closing when there are no messages being sent."""
|
236
|
-
queue
|
246
|
+
queue: asyncio.Queue[Tuple[bool, Any]] = asyncio.Queue(1)
|
237
247
|
|
238
248
|
cancelling = False
|
239
249
|
|
@@ -312,14 +322,19 @@ class AsyncBaseHTTPView(
|
|
312
322
|
) -> GraphQLRequestData:
|
313
323
|
content_type, params = parse_content_type(request.content_type or "")
|
314
324
|
|
325
|
+
protocol: Literal["http", "multipart-subscription"] = "http"
|
326
|
+
|
315
327
|
if request.method == "GET":
|
316
328
|
data = self.parse_query_params(request.query_params)
|
329
|
+
if self._is_multipart_subscriptions(content_type, params):
|
330
|
+
protocol = "multipart-subscription"
|
317
331
|
elif "application/json" in content_type:
|
318
332
|
data = self.parse_json(await request.get_body())
|
319
333
|
elif content_type == "multipart/form-data":
|
320
334
|
data = await self.parse_multipart(request)
|
321
335
|
elif self._is_multipart_subscriptions(content_type, params):
|
322
336
|
data = await self.parse_multipart_subscriptions(request)
|
337
|
+
protocol = "multipart-subscription"
|
323
338
|
else:
|
324
339
|
raise HTTPException(400, "Unsupported content type")
|
325
340
|
|
@@ -327,6 +342,7 @@ class AsyncBaseHTTPView(
|
|
327
342
|
query=data.get("query"),
|
328
343
|
variables=data.get("variables"),
|
329
344
|
operation_name=data.get("operationName"),
|
345
|
+
protocol=protocol,
|
330
346
|
)
|
331
347
|
|
332
348
|
async def process_result(
|
strawberry/schema/base.py
CHANGED
@@ -15,7 +15,6 @@ if TYPE_CHECKING:
|
|
15
15
|
from strawberry.types import (
|
16
16
|
ExecutionContext,
|
17
17
|
ExecutionResult,
|
18
|
-
SubscriptionExecutionResult,
|
19
18
|
)
|
20
19
|
from strawberry.types.base import StrawberryObjectDefinition
|
21
20
|
from strawberry.types.enum import EnumDefinition
|
@@ -24,6 +23,7 @@ if TYPE_CHECKING:
|
|
24
23
|
from strawberry.types.union import StrawberryUnion
|
25
24
|
|
26
25
|
from .config import StrawberryConfig
|
26
|
+
from .subscribe import SubscriptionResult
|
27
27
|
|
28
28
|
|
29
29
|
class BaseSchema(Protocol):
|
@@ -43,7 +43,7 @@ class BaseSchema(Protocol):
|
|
43
43
|
root_value: Optional[Any] = None,
|
44
44
|
operation_name: Optional[str] = None,
|
45
45
|
allowed_operation_types: Optional[Iterable[OperationType]] = None,
|
46
|
-
) ->
|
46
|
+
) -> ExecutionResult:
|
47
47
|
raise NotImplementedError
|
48
48
|
|
49
49
|
@abstractmethod
|
@@ -66,7 +66,7 @@ class BaseSchema(Protocol):
|
|
66
66
|
context_value: Optional[Any] = None,
|
67
67
|
root_value: Optional[Any] = None,
|
68
68
|
operation_name: Optional[str] = None,
|
69
|
-
) ->
|
69
|
+
) -> SubscriptionResult:
|
70
70
|
raise NotImplementedError
|
71
71
|
|
72
72
|
@abstractmethod
|
strawberry/schema/execute.py
CHANGED
@@ -4,40 +4,43 @@ from asyncio import ensure_future
|
|
4
4
|
from inspect import isawaitable
|
5
5
|
from typing import (
|
6
6
|
TYPE_CHECKING,
|
7
|
+
Awaitable,
|
7
8
|
Callable,
|
8
9
|
Iterable,
|
9
10
|
List,
|
10
11
|
Optional,
|
11
|
-
Sequence,
|
12
12
|
Tuple,
|
13
13
|
Type,
|
14
14
|
TypedDict,
|
15
15
|
Union,
|
16
|
+
cast,
|
16
17
|
)
|
17
18
|
|
18
|
-
from graphql import
|
19
|
+
from graphql import ExecutionResult as GraphQLExecutionResult
|
20
|
+
from graphql import GraphQLError, parse
|
19
21
|
from graphql import execute as original_execute
|
20
22
|
from graphql.validation import validate
|
21
23
|
|
22
24
|
from strawberry.exceptions import MissingQueryError
|
23
|
-
from strawberry.extensions.runner import SchemaExtensionsRunner
|
24
25
|
from strawberry.schema.validation_rules.one_of import OneOfInputValidationRule
|
25
26
|
from strawberry.types import ExecutionResult
|
26
|
-
from strawberry.types.
|
27
|
+
from strawberry.types.execution import PreExecutionError
|
28
|
+
from strawberry.utils.await_maybe import await_maybe
|
27
29
|
|
28
30
|
from .exceptions import InvalidOperationTypeError
|
29
31
|
|
30
32
|
if TYPE_CHECKING:
|
31
|
-
from typing_extensions import NotRequired, Unpack
|
33
|
+
from typing_extensions import NotRequired, TypeAlias, Unpack
|
32
34
|
|
33
35
|
from graphql import ExecutionContext as GraphQLExecutionContext
|
34
36
|
from graphql import GraphQLSchema
|
37
|
+
from graphql.execution.middleware import MiddlewareManager
|
35
38
|
from graphql.language import DocumentNode
|
36
39
|
from graphql.validation import ASTValidationRule
|
37
40
|
|
38
|
-
from strawberry.extensions import
|
41
|
+
from strawberry.extensions.runner import SchemaExtensionsRunner
|
39
42
|
from strawberry.types import ExecutionContext
|
40
|
-
from strawberry.types.
|
43
|
+
from strawberry.types.graphql import OperationType
|
41
44
|
|
42
45
|
|
43
46
|
# duplicated because of https://github.com/mkdocstrings/griffe-typingdoc/issues/7
|
@@ -45,6 +48,11 @@ class ParseOptions(TypedDict):
|
|
45
48
|
max_tokens: NotRequired[int]
|
46
49
|
|
47
50
|
|
51
|
+
ProcessErrors: TypeAlias = (
|
52
|
+
"Callable[[List[GraphQLError], Optional[ExecutionContext]], None]"
|
53
|
+
)
|
54
|
+
|
55
|
+
|
48
56
|
def parse_document(query: str, **kwargs: Unpack[ParseOptions]) -> DocumentNode:
|
49
57
|
return parse(query, **kwargs)
|
50
58
|
|
@@ -77,112 +85,120 @@ def _run_validation(execution_context: ExecutionContext) -> None:
|
|
77
85
|
)
|
78
86
|
|
79
87
|
|
88
|
+
async def _parse_and_validate_async(
|
89
|
+
context: ExecutionContext, extensions_runner: SchemaExtensionsRunner
|
90
|
+
) -> Optional[PreExecutionError]:
|
91
|
+
if not context.query:
|
92
|
+
raise MissingQueryError()
|
93
|
+
|
94
|
+
async with extensions_runner.parsing():
|
95
|
+
try:
|
96
|
+
if not context.graphql_document:
|
97
|
+
context.graphql_document = parse_document(context.query)
|
98
|
+
|
99
|
+
except GraphQLError as error:
|
100
|
+
context.errors = [error]
|
101
|
+
return PreExecutionError(data=None, errors=[error])
|
102
|
+
|
103
|
+
except Exception as error:
|
104
|
+
error = GraphQLError(str(error), original_error=error)
|
105
|
+
context.errors = [error]
|
106
|
+
return PreExecutionError(data=None, errors=[error])
|
107
|
+
|
108
|
+
if context.operation_type not in context.allowed_operations:
|
109
|
+
raise InvalidOperationTypeError(context.operation_type)
|
110
|
+
|
111
|
+
async with extensions_runner.validation():
|
112
|
+
_run_validation(context)
|
113
|
+
if context.errors:
|
114
|
+
return PreExecutionError(
|
115
|
+
data=None,
|
116
|
+
errors=context.errors,
|
117
|
+
)
|
118
|
+
|
119
|
+
return None
|
120
|
+
|
121
|
+
|
122
|
+
async def _handle_execution_result(
|
123
|
+
context: ExecutionContext,
|
124
|
+
result: Union[GraphQLExecutionResult, ExecutionResult],
|
125
|
+
extensions_runner: SchemaExtensionsRunner,
|
126
|
+
process_errors: ProcessErrors,
|
127
|
+
) -> ExecutionResult:
|
128
|
+
# Set errors on the context so that it's easier
|
129
|
+
# to access in extensions
|
130
|
+
if result.errors:
|
131
|
+
context.errors = result.errors
|
132
|
+
|
133
|
+
# Run the `Schema.process_errors` function here before
|
134
|
+
# extensions have a chance to modify them (see the MaskErrors
|
135
|
+
# extension). That way we can log the original errors but
|
136
|
+
# only return a sanitised version to the client.
|
137
|
+
process_errors(result.errors, context)
|
138
|
+
if isinstance(result, GraphQLExecutionResult):
|
139
|
+
result = ExecutionResult(data=result.data, errors=result.errors)
|
140
|
+
result.extensions = await extensions_runner.get_extensions_results(context)
|
141
|
+
context.result = result # type: ignore # mypy failed to deduce correct type.
|
142
|
+
return result
|
143
|
+
|
144
|
+
|
145
|
+
def _coerce_error(error: Union[GraphQLError, Exception]) -> GraphQLError:
|
146
|
+
if isinstance(error, GraphQLError):
|
147
|
+
return error
|
148
|
+
return GraphQLError(str(error), original_error=error)
|
149
|
+
|
150
|
+
|
80
151
|
async def execute(
|
81
152
|
schema: GraphQLSchema,
|
82
|
-
*,
|
83
|
-
allowed_operation_types: Iterable[OperationType],
|
84
|
-
extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]],
|
85
153
|
execution_context: ExecutionContext,
|
154
|
+
extensions_runner: SchemaExtensionsRunner,
|
155
|
+
process_errors: ProcessErrors,
|
156
|
+
middleware_manager: MiddlewareManager,
|
86
157
|
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
|
87
|
-
|
88
|
-
) -> Union[ExecutionResult, SubscriptionExecutionResult]:
|
89
|
-
extensions_runner = SchemaExtensionsRunner(
|
90
|
-
execution_context=execution_context,
|
91
|
-
extensions=list(extensions),
|
92
|
-
)
|
93
|
-
|
158
|
+
) -> ExecutionResult | PreExecutionError:
|
94
159
|
try:
|
95
160
|
async with extensions_runner.operation():
|
96
161
|
# Note: In graphql-core the schema would be validated here but in
|
97
162
|
# Strawberry we are validating it at initialisation time instead
|
98
|
-
if not execution_context.query:
|
99
|
-
raise MissingQueryError()
|
100
163
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
except GraphQLError as exc:
|
109
|
-
execution_context.errors = [exc]
|
110
|
-
process_errors([exc], execution_context)
|
111
|
-
return ExecutionResult(
|
112
|
-
data=None,
|
113
|
-
errors=[exc],
|
114
|
-
extensions=await extensions_runner.get_extensions_results(),
|
115
|
-
)
|
116
|
-
|
117
|
-
if execution_context.operation_type not in allowed_operation_types:
|
118
|
-
raise InvalidOperationTypeError(execution_context.operation_type)
|
119
|
-
|
120
|
-
async with extensions_runner.validation():
|
121
|
-
_run_validation(execution_context)
|
122
|
-
if execution_context.errors:
|
123
|
-
process_errors(execution_context.errors, execution_context)
|
124
|
-
return ExecutionResult(data=None, errors=execution_context.errors)
|
164
|
+
if errors := await _parse_and_validate_async(
|
165
|
+
execution_context, extensions_runner
|
166
|
+
):
|
167
|
+
return await _handle_execution_result(
|
168
|
+
execution_context, errors, extensions_runner, process_errors
|
169
|
+
)
|
125
170
|
|
171
|
+
assert execution_context.graphql_document
|
126
172
|
async with extensions_runner.executing():
|
127
173
|
if not execution_context.result:
|
128
|
-
|
129
|
-
|
130
|
-
# TODO: make our own wrapper?
|
131
|
-
return await subscribe( # type: ignore
|
132
|
-
schema,
|
133
|
-
execution_context.graphql_document,
|
134
|
-
root_value=execution_context.root_value,
|
135
|
-
context_value=execution_context.context,
|
136
|
-
variable_values=execution_context.variables,
|
137
|
-
operation_name=execution_context.operation_name,
|
138
|
-
)
|
139
|
-
else:
|
140
|
-
result = original_execute(
|
174
|
+
res = await await_maybe(
|
175
|
+
original_execute(
|
141
176
|
schema,
|
142
177
|
execution_context.graphql_document,
|
143
178
|
root_value=execution_context.root_value,
|
144
|
-
middleware=
|
179
|
+
middleware=middleware_manager,
|
145
180
|
variable_values=execution_context.variables,
|
146
181
|
operation_name=execution_context.operation_name,
|
147
182
|
context_value=execution_context.context,
|
148
183
|
execution_context_class=execution_context_class,
|
149
184
|
)
|
185
|
+
)
|
150
186
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
execution_context.result = result
|
155
|
-
# Also set errors on the execution_context so that it's easier
|
156
|
-
# to access in extensions
|
157
|
-
if result.errors:
|
158
|
-
execution_context.errors = result.errors
|
159
|
-
|
160
|
-
# Run the `Schema.process_errors` function here before
|
161
|
-
# extensions have a chance to modify them (see the MaskErrors
|
162
|
-
# extension). That way we can log the original errors but
|
163
|
-
# only return a sanitised version to the client.
|
164
|
-
process_errors(result.errors, execution_context)
|
165
|
-
|
187
|
+
else:
|
188
|
+
res = execution_context.result
|
166
189
|
except (MissingQueryError, InvalidOperationTypeError) as e:
|
167
190
|
raise e
|
168
191
|
except Exception as exc:
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
execution_context.errors = [error]
|
175
|
-
process_errors([error], execution_context)
|
176
|
-
return ExecutionResult(
|
177
|
-
data=None,
|
178
|
-
errors=[error],
|
179
|
-
extensions=await extensions_runner.get_extensions_results(),
|
192
|
+
return await _handle_execution_result(
|
193
|
+
execution_context,
|
194
|
+
PreExecutionError(data=None, errors=[_coerce_error(exc)]),
|
195
|
+
extensions_runner,
|
196
|
+
process_errors,
|
180
197
|
)
|
181
198
|
|
182
|
-
return
|
183
|
-
|
184
|
-
|
185
|
-
extensions=await extensions_runner.get_extensions_results(),
|
199
|
+
# return results after all the operation completed.
|
200
|
+
return await _handle_execution_result(
|
201
|
+
execution_context, res, extensions_runner, process_errors
|
186
202
|
)
|
187
203
|
|
188
204
|
|
@@ -190,16 +206,12 @@ def execute_sync(
|
|
190
206
|
schema: GraphQLSchema,
|
191
207
|
*,
|
192
208
|
allowed_operation_types: Iterable[OperationType],
|
193
|
-
|
209
|
+
extensions_runner: SchemaExtensionsRunner,
|
194
210
|
execution_context: ExecutionContext,
|
195
211
|
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
|
196
|
-
process_errors:
|
212
|
+
process_errors: ProcessErrors,
|
213
|
+
middleware_manager: MiddlewareManager,
|
197
214
|
) -> ExecutionResult:
|
198
|
-
extensions_runner = SchemaExtensionsRunner(
|
199
|
-
execution_context=execution_context,
|
200
|
-
extensions=list(extensions),
|
201
|
-
)
|
202
|
-
|
203
215
|
try:
|
204
216
|
with extensions_runner.operation():
|
205
217
|
# Note: In graphql-core the schema would be validated here but in
|
@@ -214,12 +226,12 @@ def execute_sync(
|
|
214
226
|
execution_context.query, **execution_context.parse_options
|
215
227
|
)
|
216
228
|
|
217
|
-
except GraphQLError as
|
218
|
-
execution_context.errors = [
|
219
|
-
process_errors([
|
229
|
+
except GraphQLError as error:
|
230
|
+
execution_context.errors = [error]
|
231
|
+
process_errors([error], execution_context)
|
220
232
|
return ExecutionResult(
|
221
233
|
data=None,
|
222
|
-
errors=[
|
234
|
+
errors=[error],
|
223
235
|
extensions=extensions_runner.get_extensions_results_sync(),
|
224
236
|
)
|
225
237
|
|
@@ -230,7 +242,11 @@ def execute_sync(
|
|
230
242
|
_run_validation(execution_context)
|
231
243
|
if execution_context.errors:
|
232
244
|
process_errors(execution_context.errors, execution_context)
|
233
|
-
return ExecutionResult(
|
245
|
+
return ExecutionResult(
|
246
|
+
data=None,
|
247
|
+
errors=execution_context.errors,
|
248
|
+
extensions=extensions_runner.get_extensions_results_sync(),
|
249
|
+
)
|
234
250
|
|
235
251
|
with extensions_runner.executing():
|
236
252
|
if not execution_context.result:
|
@@ -238,7 +254,7 @@ def execute_sync(
|
|
238
254
|
schema,
|
239
255
|
execution_context.graphql_document,
|
240
256
|
root_value=execution_context.root_value,
|
241
|
-
middleware=
|
257
|
+
middleware=middleware_manager,
|
242
258
|
variable_values=execution_context.variables,
|
243
259
|
operation_name=execution_context.operation_name,
|
244
260
|
context_value=execution_context.context,
|
@@ -246,13 +262,15 @@ def execute_sync(
|
|
246
262
|
)
|
247
263
|
|
248
264
|
if isawaitable(result):
|
265
|
+
result = cast(Awaitable[GraphQLExecutionResult], result) # type: ignore[redundant-cast]
|
249
266
|
ensure_future(result).cancel()
|
250
267
|
raise RuntimeError(
|
251
268
|
"GraphQL execution failed to complete synchronously."
|
252
269
|
)
|
253
270
|
|
271
|
+
result = cast(GraphQLExecutionResult, result) # type: ignore[redundant-cast]
|
254
272
|
execution_context.result = result
|
255
|
-
# Also set errors on the
|
273
|
+
# Also set errors on the context so that it's easier
|
256
274
|
# to access in extensions
|
257
275
|
if result.errors:
|
258
276
|
execution_context.errors = result.errors
|
@@ -262,23 +280,17 @@ def execute_sync(
|
|
262
280
|
# extension). That way we can log the original errors but
|
263
281
|
# only return a sanitised version to the client.
|
264
282
|
process_errors(result.errors, execution_context)
|
265
|
-
|
266
283
|
except (MissingQueryError, InvalidOperationTypeError) as e:
|
267
284
|
raise e
|
268
285
|
except Exception as exc:
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
else GraphQLError(str(exc), original_error=exc)
|
273
|
-
)
|
274
|
-
execution_context.errors = [error]
|
275
|
-
process_errors([error], execution_context)
|
286
|
+
errors = [_coerce_error(exc)]
|
287
|
+
execution_context.errors = errors
|
288
|
+
process_errors(errors, execution_context)
|
276
289
|
return ExecutionResult(
|
277
290
|
data=None,
|
278
|
-
errors=
|
291
|
+
errors=errors,
|
279
292
|
extensions=extensions_runner.get_extensions_results_sync(),
|
280
293
|
)
|
281
|
-
|
282
294
|
return ExecutionResult(
|
283
295
|
data=execution_context.result.data,
|
284
296
|
errors=execution_context.result.errors,
|