strawberry-graphql 0.239.1__py3-none-any.whl → 0.240.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strawberry/channels/testing.py +2 -2
- strawberry/dataloader.py +2 -5
- strawberry/extensions/base_extension.py +10 -3
- strawberry/extensions/runner.py +8 -37
- strawberry/http/__init__.py +2 -1
- strawberry/http/async_base_view.py +17 -1
- strawberry/schema/base.py +3 -3
- strawberry/schema/execute.py +123 -111
- strawberry/schema/schema.py +117 -53
- strawberry/schema/subscribe.py +154 -0
- strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +50 -74
- strawberry/subscriptions/protocols/graphql_transport_ws/types.py +10 -2
- strawberry/subscriptions/protocols/graphql_ws/handlers.py +45 -52
- strawberry/subscriptions/protocols/graphql_ws/types.py +1 -0
- strawberry/types/execution.py +15 -0
- strawberry/types/union.py +1 -4
- strawberry/utils/__init__.py +5 -0
- strawberry/utils/typing.py +2 -2
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/METADATA +1 -1
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/RECORD +23 -22
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/LICENSE +0 -0
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/WHEEL +0 -0
- {strawberry_graphql-0.239.1.dist-info → strawberry_graphql-0.240.0.dist-info}/entry_points.txt +0 -0
strawberry/schema/schema.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import warnings
|
4
|
-
from functools import lru_cache
|
4
|
+
from functools import cached_property, lru_cache
|
5
5
|
from typing import (
|
6
6
|
TYPE_CHECKING,
|
7
7
|
Any,
|
8
|
-
AsyncIterator,
|
9
8
|
Dict,
|
10
9
|
Iterable,
|
11
10
|
List,
|
@@ -22,18 +21,19 @@ from graphql import (
|
|
22
21
|
GraphQLNonNull,
|
23
22
|
GraphQLSchema,
|
24
23
|
get_introspection_query,
|
25
|
-
parse,
|
26
24
|
validate_schema,
|
27
25
|
)
|
28
|
-
from graphql.execution import
|
26
|
+
from graphql.execution.middleware import MiddlewareManager
|
29
27
|
from graphql.type.directives import specified_directives
|
30
28
|
|
31
29
|
from strawberry import relay
|
32
30
|
from strawberry.annotation import StrawberryAnnotation
|
31
|
+
from strawberry.extensions import SchemaExtension
|
33
32
|
from strawberry.extensions.directives import (
|
34
33
|
DirectivesExtension,
|
35
34
|
DirectivesExtensionSync,
|
36
35
|
)
|
36
|
+
from strawberry.extensions.runner import SchemaExtensionsRunner
|
37
37
|
from strawberry.schema.schema_converter import GraphQLCoreConverter
|
38
38
|
from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
|
39
39
|
from strawberry.types import ExecutionContext
|
@@ -41,19 +41,17 @@ from strawberry.types.base import StrawberryObjectDefinition, has_object_definit
|
|
41
41
|
from strawberry.types.graphql import OperationType
|
42
42
|
|
43
43
|
from ..printer import print_schema
|
44
|
-
from ..utils.await_maybe import await_maybe
|
45
44
|
from . import compat
|
46
45
|
from .base import BaseSchema
|
47
46
|
from .config import StrawberryConfig
|
48
47
|
from .execute import execute, execute_sync
|
48
|
+
from .subscribe import SubscriptionResult, subscribe
|
49
49
|
|
50
50
|
if TYPE_CHECKING:
|
51
51
|
from graphql import ExecutionContext as GraphQLExecutionContext
|
52
|
-
from graphql import ExecutionResult as GraphQLExecutionResult
|
53
52
|
|
54
53
|
from strawberry.directive import StrawberryDirective
|
55
|
-
from strawberry.
|
56
|
-
from strawberry.types import ExecutionResult, SubscriptionExecutionResult
|
54
|
+
from strawberry.types import ExecutionResult
|
57
55
|
from strawberry.types.base import StrawberryType
|
58
56
|
from strawberry.types.enum import EnumDefinition
|
59
57
|
from strawberry.types.field import StrawberryField
|
@@ -85,7 +83,7 @@ class Schema(BaseSchema):
|
|
85
83
|
] = None,
|
86
84
|
schema_directives: Iterable[object] = (),
|
87
85
|
) -> None:
|
88
|
-
"""Default Schema to be
|
86
|
+
"""Default Schema to be used in a Strawberry application.
|
89
87
|
|
90
88
|
A GraphQL Schema class used to define the structure and configuration
|
91
89
|
of GraphQL queries, mutations, and subscriptions.
|
@@ -125,6 +123,7 @@ class Schema(BaseSchema):
|
|
125
123
|
self.subscription = subscription
|
126
124
|
|
127
125
|
self.extensions = extensions
|
126
|
+
self._cached_middleware_manager: MiddlewareManager | None = None
|
128
127
|
self.execution_context_class = execution_context_class
|
129
128
|
self.config = config or StrawberryConfig()
|
130
129
|
|
@@ -214,15 +213,63 @@ class Schema(BaseSchema):
|
|
214
213
|
formatted_errors = "\n\n".join(f"❌ {error.message}" for error in errors)
|
215
214
|
raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}")
|
216
215
|
|
217
|
-
def get_extensions(
|
218
|
-
|
219
|
-
) -> List[Union[Type[SchemaExtension], SchemaExtension]]:
|
220
|
-
extensions = list(self.extensions)
|
221
|
-
|
216
|
+
def get_extensions(self, sync: bool = False) -> List[SchemaExtension]:
|
217
|
+
extensions = []
|
222
218
|
if self.directives:
|
223
|
-
extensions
|
219
|
+
extensions = [
|
220
|
+
*self.extensions,
|
221
|
+
DirectivesExtensionSync if sync else DirectivesExtension,
|
222
|
+
]
|
223
|
+
extensions.extend(self.extensions)
|
224
|
+
return [
|
225
|
+
ext if isinstance(ext, SchemaExtension) else ext(execution_context=None)
|
226
|
+
for ext in extensions
|
227
|
+
]
|
228
|
+
|
229
|
+
@cached_property
|
230
|
+
def _sync_extensions(self) -> List[SchemaExtension]:
|
231
|
+
return self.get_extensions(sync=True)
|
232
|
+
|
233
|
+
@cached_property
|
234
|
+
def _async_extensions(self) -> List[SchemaExtension]:
|
235
|
+
return self.get_extensions(sync=False)
|
236
|
+
|
237
|
+
def create_extensions_runner(
|
238
|
+
self, execution_context: ExecutionContext, extensions: list[SchemaExtension]
|
239
|
+
) -> SchemaExtensionsRunner:
|
240
|
+
return SchemaExtensionsRunner(
|
241
|
+
execution_context=execution_context,
|
242
|
+
extensions=extensions,
|
243
|
+
)
|
224
244
|
|
225
|
-
|
245
|
+
def _get_middleware_manager(
|
246
|
+
self, extensions: list[SchemaExtension]
|
247
|
+
) -> MiddlewareManager:
|
248
|
+
# create a middleware manager with all the extensions that implement resolve
|
249
|
+
if not self._cached_middleware_manager:
|
250
|
+
self._cached_middleware_manager = MiddlewareManager(
|
251
|
+
*(ext for ext in extensions if ext._implements_resolve())
|
252
|
+
)
|
253
|
+
return self._cached_middleware_manager
|
254
|
+
|
255
|
+
def _create_execution_context(
|
256
|
+
self,
|
257
|
+
query: Optional[str],
|
258
|
+
allowed_operation_types: Iterable[OperationType],
|
259
|
+
variable_values: Optional[Dict[str, Any]] = None,
|
260
|
+
context_value: Optional[Any] = None,
|
261
|
+
root_value: Optional[Any] = None,
|
262
|
+
operation_name: Optional[str] = None,
|
263
|
+
) -> ExecutionContext:
|
264
|
+
return ExecutionContext(
|
265
|
+
query=query,
|
266
|
+
schema=self,
|
267
|
+
allowed_operations=allowed_operation_types,
|
268
|
+
context=context_value,
|
269
|
+
root_value=root_value,
|
270
|
+
variables=variable_values,
|
271
|
+
provided_operation_name=operation_name,
|
272
|
+
)
|
226
273
|
|
227
274
|
@lru_cache
|
228
275
|
def get_type_by_name(
|
@@ -284,31 +331,33 @@ class Schema(BaseSchema):
|
|
284
331
|
root_value: Optional[Any] = None,
|
285
332
|
operation_name: Optional[str] = None,
|
286
333
|
allowed_operation_types: Optional[Iterable[OperationType]] = None,
|
287
|
-
) ->
|
334
|
+
) -> ExecutionResult:
|
288
335
|
if allowed_operation_types is None:
|
289
336
|
allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
|
290
337
|
|
291
|
-
|
292
|
-
execution_context = ExecutionContext(
|
338
|
+
execution_context = self._create_execution_context(
|
293
339
|
query=query,
|
294
|
-
|
295
|
-
|
340
|
+
allowed_operation_types=allowed_operation_types,
|
341
|
+
variable_values=variable_values,
|
342
|
+
context_value=context_value,
|
296
343
|
root_value=root_value,
|
297
|
-
|
298
|
-
provided_operation_name=operation_name,
|
344
|
+
operation_name=operation_name,
|
299
345
|
)
|
300
|
-
|
301
|
-
|
346
|
+
extensions = self.get_extensions()
|
347
|
+
# TODO (#3571): remove this when we implement execution context as parameter.
|
348
|
+
for extension in extensions:
|
349
|
+
extension.execution_context = execution_context
|
350
|
+
return await execute(
|
302
351
|
self._schema,
|
303
|
-
extensions=self.get_extensions(),
|
304
|
-
execution_context_class=self.execution_context_class,
|
305
352
|
execution_context=execution_context,
|
306
|
-
|
353
|
+
extensions_runner=self.create_extensions_runner(
|
354
|
+
execution_context, extensions
|
355
|
+
),
|
307
356
|
process_errors=self._process_errors,
|
357
|
+
middleware_manager=self._get_middleware_manager(extensions),
|
358
|
+
execution_context_class=self.execution_context_class,
|
308
359
|
)
|
309
360
|
|
310
|
-
return result
|
311
|
-
|
312
361
|
def execute_sync(
|
313
362
|
self,
|
314
363
|
query: Optional[str],
|
@@ -321,44 +370,59 @@ class Schema(BaseSchema):
|
|
321
370
|
if allowed_operation_types is None:
|
322
371
|
allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
|
323
372
|
|
324
|
-
execution_context =
|
373
|
+
execution_context = self._create_execution_context(
|
325
374
|
query=query,
|
326
|
-
|
327
|
-
|
375
|
+
allowed_operation_types=allowed_operation_types,
|
376
|
+
variable_values=variable_values,
|
377
|
+
context_value=context_value,
|
328
378
|
root_value=root_value,
|
329
|
-
|
330
|
-
provided_operation_name=operation_name,
|
379
|
+
operation_name=operation_name,
|
331
380
|
)
|
332
|
-
|
333
|
-
|
381
|
+
extensions = self._sync_extensions
|
382
|
+
# TODO (#3571): remove this when we implement execution context as parameter.
|
383
|
+
for extension in extensions:
|
384
|
+
extension.execution_context = execution_context
|
385
|
+
return execute_sync(
|
334
386
|
self._schema,
|
335
|
-
extensions=self.get_extensions(sync=True),
|
336
|
-
execution_context_class=self.execution_context_class,
|
337
387
|
execution_context=execution_context,
|
388
|
+
extensions_runner=self.create_extensions_runner(
|
389
|
+
execution_context, extensions
|
390
|
+
),
|
391
|
+
execution_context_class=self.execution_context_class,
|
338
392
|
allowed_operation_types=allowed_operation_types,
|
339
393
|
process_errors=self._process_errors,
|
394
|
+
middleware_manager=self._get_middleware_manager(extensions),
|
340
395
|
)
|
341
396
|
|
342
|
-
return result
|
343
|
-
|
344
397
|
async def subscribe(
|
345
398
|
self,
|
346
|
-
|
347
|
-
query: str,
|
399
|
+
query: Optional[str],
|
348
400
|
variable_values: Optional[Dict[str, Any]] = None,
|
349
401
|
context_value: Optional[Any] = None,
|
350
402
|
root_value: Optional[Any] = None,
|
351
403
|
operation_name: Optional[str] = None,
|
352
|
-
) ->
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
404
|
+
) -> SubscriptionResult:
|
405
|
+
execution_context = self._create_execution_context(
|
406
|
+
query=query,
|
407
|
+
allowed_operation_types=(OperationType.SUBSCRIPTION,),
|
408
|
+
variable_values=variable_values,
|
409
|
+
context_value=context_value,
|
410
|
+
root_value=root_value,
|
411
|
+
operation_name=operation_name,
|
412
|
+
)
|
413
|
+
extensions = self._async_extensions
|
414
|
+
# TODO (#3571): remove this when we implement execution context as parameter.
|
415
|
+
for extension in extensions:
|
416
|
+
extension.execution_context = execution_context
|
417
|
+
return await subscribe(
|
418
|
+
self._schema,
|
419
|
+
execution_context=execution_context,
|
420
|
+
extensions_runner=self.create_extensions_runner(
|
421
|
+
execution_context, extensions
|
422
|
+
),
|
423
|
+
process_errors=self._process_errors,
|
424
|
+
middleware_manager=self._get_middleware_manager(extensions),
|
425
|
+
execution_context_class=self.execution_context_class,
|
362
426
|
)
|
363
427
|
|
364
428
|
def _resolve_node_ids(self) -> None:
|
@@ -0,0 +1,154 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Optional, Type, Union
|
4
|
+
|
5
|
+
from graphql import (
|
6
|
+
ExecutionResult as OriginalExecutionResult,
|
7
|
+
)
|
8
|
+
from graphql.execution import ExecutionContext as GraphQLExecutionContext
|
9
|
+
from graphql.execution import subscribe as original_subscribe
|
10
|
+
|
11
|
+
from strawberry.types import ExecutionResult
|
12
|
+
from strawberry.types.execution import ExecutionContext, PreExecutionError
|
13
|
+
from strawberry.utils import IS_GQL_32
|
14
|
+
from strawberry.utils.await_maybe import await_maybe
|
15
|
+
|
16
|
+
from .execute import (
|
17
|
+
ProcessErrors,
|
18
|
+
_coerce_error,
|
19
|
+
_handle_execution_result,
|
20
|
+
_parse_and_validate_async,
|
21
|
+
)
|
22
|
+
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from typing_extensions import TypeAlias
|
25
|
+
|
26
|
+
from graphql.execution.middleware import MiddlewareManager
|
27
|
+
from graphql.type.schema import GraphQLSchema
|
28
|
+
|
29
|
+
from ..extensions.runner import SchemaExtensionsRunner
|
30
|
+
|
31
|
+
SubscriptionResult: TypeAlias = Union[
|
32
|
+
PreExecutionError, AsyncGenerator[ExecutionResult, None]
|
33
|
+
]
|
34
|
+
|
35
|
+
OriginSubscriptionResult = Union[
|
36
|
+
OriginalExecutionResult,
|
37
|
+
AsyncIterator[OriginalExecutionResult],
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
async def _subscribe(
|
42
|
+
schema: GraphQLSchema,
|
43
|
+
execution_context: ExecutionContext,
|
44
|
+
extensions_runner: SchemaExtensionsRunner,
|
45
|
+
process_errors: ProcessErrors,
|
46
|
+
middleware_manager: MiddlewareManager,
|
47
|
+
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
|
48
|
+
) -> AsyncGenerator[Union[PreExecutionError, ExecutionResult], None]:
|
49
|
+
async with extensions_runner.operation():
|
50
|
+
if initial_error := await _parse_and_validate_async(
|
51
|
+
context=execution_context,
|
52
|
+
extensions_runner=extensions_runner,
|
53
|
+
):
|
54
|
+
initial_error.extensions = await extensions_runner.get_extensions_results(
|
55
|
+
execution_context
|
56
|
+
)
|
57
|
+
yield await _handle_execution_result(
|
58
|
+
execution_context, initial_error, extensions_runner, process_errors
|
59
|
+
)
|
60
|
+
try:
|
61
|
+
async with extensions_runner.executing():
|
62
|
+
assert execution_context.graphql_document is not None
|
63
|
+
gql_33_kwargs = {
|
64
|
+
"middleware": middleware_manager,
|
65
|
+
"execution_context_class": execution_context_class,
|
66
|
+
}
|
67
|
+
try:
|
68
|
+
# Might not be awaitable for pre-execution errors.
|
69
|
+
aiter_or_result: OriginSubscriptionResult = await await_maybe(
|
70
|
+
original_subscribe(
|
71
|
+
schema,
|
72
|
+
execution_context.graphql_document,
|
73
|
+
root_value=execution_context.root_value,
|
74
|
+
variable_values=execution_context.variables,
|
75
|
+
operation_name=execution_context.operation_name,
|
76
|
+
context_value=execution_context.context,
|
77
|
+
**{} if IS_GQL_32 else gql_33_kwargs, # type: ignore[arg-type]
|
78
|
+
)
|
79
|
+
)
|
80
|
+
# graphql-core 3.2 doesn't handle some of the pre-execution errors.
|
81
|
+
# see `test_subscription_immediate_error`
|
82
|
+
except Exception as exc:
|
83
|
+
aiter_or_result = OriginalExecutionResult(
|
84
|
+
data=None, errors=[_coerce_error(exc)]
|
85
|
+
)
|
86
|
+
|
87
|
+
# Handle pre-execution errors.
|
88
|
+
if isinstance(aiter_or_result, OriginalExecutionResult):
|
89
|
+
yield await _handle_execution_result(
|
90
|
+
execution_context,
|
91
|
+
PreExecutionError(data=None, errors=aiter_or_result.errors),
|
92
|
+
extensions_runner,
|
93
|
+
process_errors,
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
try:
|
97
|
+
async for result in aiter_or_result:
|
98
|
+
yield await _handle_execution_result(
|
99
|
+
execution_context,
|
100
|
+
result,
|
101
|
+
extensions_runner,
|
102
|
+
process_errors,
|
103
|
+
)
|
104
|
+
# graphql-core doesn't handle exceptions raised while executing.
|
105
|
+
except Exception as exc:
|
106
|
+
yield await _handle_execution_result(
|
107
|
+
execution_context,
|
108
|
+
OriginalExecutionResult(data=None, errors=[_coerce_error(exc)]),
|
109
|
+
extensions_runner,
|
110
|
+
process_errors,
|
111
|
+
)
|
112
|
+
# catch exceptions raised in `on_execute` hook.
|
113
|
+
except Exception as exc:
|
114
|
+
origin_result = OriginalExecutionResult(
|
115
|
+
data=None, errors=[_coerce_error(exc)]
|
116
|
+
)
|
117
|
+
yield await _handle_execution_result(
|
118
|
+
execution_context,
|
119
|
+
origin_result,
|
120
|
+
extensions_runner,
|
121
|
+
process_errors,
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
async def subscribe(
|
126
|
+
schema: GraphQLSchema,
|
127
|
+
execution_context: ExecutionContext,
|
128
|
+
extensions_runner: SchemaExtensionsRunner,
|
129
|
+
process_errors: ProcessErrors,
|
130
|
+
middleware_manager: MiddlewareManager,
|
131
|
+
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
|
132
|
+
) -> SubscriptionResult:
|
133
|
+
asyncgen = _subscribe(
|
134
|
+
schema,
|
135
|
+
execution_context,
|
136
|
+
extensions_runner,
|
137
|
+
process_errors,
|
138
|
+
middleware_manager,
|
139
|
+
execution_context_class,
|
140
|
+
)
|
141
|
+
# GrapQL-core might return an initial error result instead of an async iterator.
|
142
|
+
# This happens when "there was an immediate error" i.e resolver is not an async iterator.
|
143
|
+
# To overcome this while maintaining the extension contexts we do this trick.
|
144
|
+
first = await asyncgen.__anext__()
|
145
|
+
if isinstance(first, PreExecutionError):
|
146
|
+
await asyncgen.aclose()
|
147
|
+
return first
|
148
|
+
|
149
|
+
async def _wrapper() -> AsyncGenerator[ExecutionResult, None]:
|
150
|
+
yield first
|
151
|
+
async for result in asyncgen:
|
152
|
+
yield result
|
153
|
+
|
154
|
+
return _wrapper()
|
@@ -7,15 +7,13 @@ from contextlib import suppress
|
|
7
7
|
from typing import (
|
8
8
|
TYPE_CHECKING,
|
9
9
|
Any,
|
10
|
-
|
11
|
-
AsyncIterator,
|
10
|
+
Awaitable,
|
12
11
|
Callable,
|
13
12
|
Dict,
|
14
13
|
List,
|
15
14
|
Optional,
|
16
15
|
)
|
17
16
|
|
18
|
-
from graphql import ExecutionResult as GraphQLExecutionResult
|
19
17
|
from graphql import GraphQLError, GraphQLSyntaxError, parse
|
20
18
|
|
21
19
|
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
|
@@ -24,11 +22,14 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
|
|
24
22
|
ConnectionInitMessage,
|
25
23
|
ErrorMessage,
|
26
24
|
NextMessage,
|
25
|
+
NextPayload,
|
27
26
|
PingMessage,
|
28
27
|
PongMessage,
|
29
28
|
SubscribeMessage,
|
30
29
|
SubscribeMessagePayload,
|
31
30
|
)
|
31
|
+
from strawberry.types import ExecutionResult
|
32
|
+
from strawberry.types.execution import PreExecutionError
|
32
33
|
from strawberry.types.graphql import OperationType
|
33
34
|
from strawberry.types.unset import UNSET
|
34
35
|
from strawberry.utils.debug import pretty_print_graphql_operation
|
@@ -38,10 +39,10 @@ if TYPE_CHECKING:
|
|
38
39
|
from datetime import timedelta
|
39
40
|
|
40
41
|
from strawberry.schema import BaseSchema
|
42
|
+
from strawberry.schema.subscribe import SubscriptionResult
|
41
43
|
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
|
42
44
|
GraphQLTransportMessage,
|
43
45
|
)
|
44
|
-
from strawberry.types import ExecutionResult
|
45
46
|
|
46
47
|
|
47
48
|
class BaseGraphQLTransportWSHandler(ABC):
|
@@ -243,10 +244,10 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
243
244
|
if isinstance(context, dict):
|
244
245
|
context["connection_params"] = self.connection_params
|
245
246
|
root_value = await self.get_root_value()
|
246
|
-
|
247
|
+
result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]
|
247
248
|
# Get an AsyncGenerator yielding the results
|
248
249
|
if operation_type == OperationType.SUBSCRIPTION:
|
249
|
-
result_source =
|
250
|
+
result_source = self.schema.subscribe(
|
250
251
|
query=message.payload.query,
|
251
252
|
variable_values=message.payload.variables,
|
252
253
|
operation_name=message.payload.operationName,
|
@@ -254,29 +255,16 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
254
255
|
root_value=root_value,
|
255
256
|
)
|
256
257
|
else:
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
operation_name=message.payload.operationName,
|
265
|
-
)
|
266
|
-
|
267
|
-
result_source = get_result_source()
|
258
|
+
result_source = self.schema.execute(
|
259
|
+
query=message.payload.query,
|
260
|
+
variable_values=message.payload.variables,
|
261
|
+
context_value=context,
|
262
|
+
root_value=root_value,
|
263
|
+
operation_name=message.payload.operationName,
|
264
|
+
)
|
268
265
|
|
269
266
|
operation = Operation(self, message.id, operation_type)
|
270
267
|
|
271
|
-
# Handle initial validation errors
|
272
|
-
if isinstance(result_source, GraphQLExecutionResult):
|
273
|
-
assert operation_type == OperationType.SUBSCRIPTION
|
274
|
-
assert result_source.errors
|
275
|
-
payload = [err.formatted for err in result_source.errors]
|
276
|
-
await self.send_message(ErrorMessage(id=message.id, payload=payload))
|
277
|
-
self.schema.process_errors(result_source.errors)
|
278
|
-
return
|
279
|
-
|
280
268
|
# Create task to handle this subscription, reserve the operation ID
|
281
269
|
operation.task = asyncio.create_task(
|
282
270
|
self.operation_task(result_source, operation)
|
@@ -284,65 +272,37 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
284
272
|
self.operations[message.id] = operation
|
285
273
|
|
286
274
|
async def operation_task(
|
287
|
-
self,
|
275
|
+
self,
|
276
|
+
result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult],
|
277
|
+
operation: Operation,
|
288
278
|
) -> None:
|
289
279
|
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
|
290
280
|
# TODO: Handle errors in this method using self.handle_task_exception()
|
291
281
|
try:
|
292
|
-
await
|
293
|
-
|
294
|
-
#
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
282
|
+
first_res_or_agen = await result_source
|
283
|
+
# that's an immediate error we should end the operation
|
284
|
+
# without a COMPLETE message
|
285
|
+
if isinstance(first_res_or_agen, PreExecutionError):
|
286
|
+
assert first_res_or_agen.errors
|
287
|
+
await operation.send_initial_errors(first_res_or_agen.errors)
|
288
|
+
# that's a mutation / query result
|
289
|
+
elif isinstance(first_res_or_agen, ExecutionResult):
|
290
|
+
await operation.send_next(first_res_or_agen)
|
291
|
+
await operation.send_message(CompleteMessage(id=operation.id))
|
292
|
+
else:
|
293
|
+
async for result in first_res_or_agen:
|
294
|
+
await operation.send_next(result)
|
295
|
+
await operation.send_message(CompleteMessage(id=operation.id))
|
296
|
+
|
297
|
+
except BaseException as e: # pragma: no cover
|
298
|
+
self.operations.pop(operation.id, None)
|
301
299
|
raise
|
302
|
-
else:
|
303
|
-
await operation.send_message(CompleteMessage(id=operation.id))
|
304
300
|
finally:
|
305
301
|
# add this task to a list to be reaped later
|
306
302
|
task = asyncio.current_task()
|
307
303
|
assert task is not None
|
308
304
|
self.completed_tasks.append(task)
|
309
305
|
|
310
|
-
async def handle_async_results(
|
311
|
-
self,
|
312
|
-
result_source: AsyncGenerator,
|
313
|
-
operation: Operation,
|
314
|
-
) -> None:
|
315
|
-
try:
|
316
|
-
async for result in result_source:
|
317
|
-
if (
|
318
|
-
result.errors
|
319
|
-
and operation.operation_type != OperationType.SUBSCRIPTION
|
320
|
-
):
|
321
|
-
error_payload = [err.formatted for err in result.errors]
|
322
|
-
error_message = ErrorMessage(id=operation.id, payload=error_payload)
|
323
|
-
await operation.send_message(error_message)
|
324
|
-
# don't need to call schema.process_errors() here because
|
325
|
-
# it was already done by schema.execute()
|
326
|
-
return
|
327
|
-
else:
|
328
|
-
next_payload = {"data": result.data}
|
329
|
-
if result.errors:
|
330
|
-
self.schema.process_errors(result.errors)
|
331
|
-
next_payload["errors"] = [
|
332
|
-
err.formatted for err in result.errors
|
333
|
-
]
|
334
|
-
next_message = NextMessage(id=operation.id, payload=next_payload)
|
335
|
-
await operation.send_message(next_message)
|
336
|
-
except Exception as error:
|
337
|
-
# GraphQLErrors are handled by graphql-core and included in the
|
338
|
-
# ExecutionResult
|
339
|
-
error = GraphQLError(str(error), original_error=error)
|
340
|
-
error_payload = [error.formatted]
|
341
|
-
error_message = ErrorMessage(id=operation.id, payload=error_payload)
|
342
|
-
await operation.send_message(error_message)
|
343
|
-
self.schema.process_errors([error])
|
344
|
-
return
|
345
|
-
|
346
306
|
def forget_id(self, id: str) -> None:
|
347
307
|
# de-register the operation id making it immediately available
|
348
308
|
# for re-use
|
@@ -401,5 +361,21 @@ class Operation:
|
|
401
361
|
self.handler.forget_id(self.id)
|
402
362
|
await self.handler.send_message(message)
|
403
363
|
|
364
|
+
async def send_initial_errors(self, errors: list[GraphQLError]) -> None:
|
365
|
+
# Initial errors see https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#error
|
366
|
+
# "This can occur before execution starts,
|
367
|
+
# usually due to validation errors, or during the execution of the request"
|
368
|
+
await self.send_message(
|
369
|
+
ErrorMessage(id=self.id, payload=[err.formatted for err in errors])
|
370
|
+
)
|
371
|
+
|
372
|
+
async def send_next(self, execution_result: ExecutionResult) -> None:
|
373
|
+
next_payload: NextPayload = {"data": execution_result.data}
|
374
|
+
if execution_result.errors:
|
375
|
+
next_payload["errors"] = [err.formatted for err in execution_result.errors]
|
376
|
+
if execution_result.extensions:
|
377
|
+
next_payload["extensions"] = execution_result.extensions
|
378
|
+
await self.send_message(NextMessage(id=self.id, payload=next_payload))
|
379
|
+
|
404
380
|
|
405
381
|
__all__ = ["BaseGraphQLTransportWSHandler", "Operation"]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import asdict, dataclass
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict
|
5
5
|
|
6
6
|
from strawberry.types.unset import UNSET
|
7
7
|
|
@@ -68,12 +68,20 @@ class SubscribeMessage(GraphQLTransportMessage):
|
|
68
68
|
type: str = "subscribe"
|
69
69
|
|
70
70
|
|
71
|
+
class NextPayload(TypedDict, total=False):
|
72
|
+
data: Any
|
73
|
+
|
74
|
+
# Optional list of formatted graphql.GraphQLError objects
|
75
|
+
errors: Optional[List[GraphQLFormattedError]]
|
76
|
+
extensions: Optional[Dict[str, Any]]
|
77
|
+
|
78
|
+
|
71
79
|
@dataclass
|
72
80
|
class NextMessage(GraphQLTransportMessage):
|
73
81
|
"""Direction: Server -> Client."""
|
74
82
|
|
75
83
|
id: str
|
76
|
-
payload:
|
84
|
+
payload: NextPayload
|
77
85
|
type: str = "next"
|
78
86
|
|
79
87
|
def as_dict(self) -> dict:
|