strawberry-graphql 0.239.1__py3-none-any.whl → 0.240.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import warnings
4
- from functools import lru_cache
4
+ from functools import cached_property, lru_cache
5
5
  from typing import (
6
6
  TYPE_CHECKING,
7
7
  Any,
8
- AsyncIterator,
9
8
  Dict,
10
9
  Iterable,
11
10
  List,
@@ -22,18 +21,19 @@ from graphql import (
22
21
  GraphQLNonNull,
23
22
  GraphQLSchema,
24
23
  get_introspection_query,
25
- parse,
26
24
  validate_schema,
27
25
  )
28
- from graphql.execution import subscribe
26
+ from graphql.execution.middleware import MiddlewareManager
29
27
  from graphql.type.directives import specified_directives
30
28
 
31
29
  from strawberry import relay
32
30
  from strawberry.annotation import StrawberryAnnotation
31
+ from strawberry.extensions import SchemaExtension
33
32
  from strawberry.extensions.directives import (
34
33
  DirectivesExtension,
35
34
  DirectivesExtensionSync,
36
35
  )
36
+ from strawberry.extensions.runner import SchemaExtensionsRunner
37
37
  from strawberry.schema.schema_converter import GraphQLCoreConverter
38
38
  from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
39
39
  from strawberry.types import ExecutionContext
@@ -41,19 +41,17 @@ from strawberry.types.base import StrawberryObjectDefinition, has_object_definit
41
41
  from strawberry.types.graphql import OperationType
42
42
 
43
43
  from ..printer import print_schema
44
- from ..utils.await_maybe import await_maybe
45
44
  from . import compat
46
45
  from .base import BaseSchema
47
46
  from .config import StrawberryConfig
48
47
  from .execute import execute, execute_sync
48
+ from .subscribe import SubscriptionResult, subscribe
49
49
 
50
50
  if TYPE_CHECKING:
51
51
  from graphql import ExecutionContext as GraphQLExecutionContext
52
- from graphql import ExecutionResult as GraphQLExecutionResult
53
52
 
54
53
  from strawberry.directive import StrawberryDirective
55
- from strawberry.extensions import SchemaExtension
56
- from strawberry.types import ExecutionResult, SubscriptionExecutionResult
54
+ from strawberry.types import ExecutionResult
57
55
  from strawberry.types.base import StrawberryType
58
56
  from strawberry.types.enum import EnumDefinition
59
57
  from strawberry.types.field import StrawberryField
@@ -85,7 +83,7 @@ class Schema(BaseSchema):
85
83
  ] = None,
86
84
  schema_directives: Iterable[object] = (),
87
85
  ) -> None:
88
- """Default Schema to be to be used in a Strawberry application.
86
+ """Default Schema to be used in a Strawberry application.
89
87
 
90
88
  A GraphQL Schema class used to define the structure and configuration
91
89
  of GraphQL queries, mutations, and subscriptions.
@@ -125,6 +123,7 @@ class Schema(BaseSchema):
125
123
  self.subscription = subscription
126
124
 
127
125
  self.extensions = extensions
126
+ self._cached_middleware_manager: MiddlewareManager | None = None
128
127
  self.execution_context_class = execution_context_class
129
128
  self.config = config or StrawberryConfig()
130
129
 
@@ -214,15 +213,63 @@ class Schema(BaseSchema):
214
213
  formatted_errors = "\n\n".join(f"❌ {error.message}" for error in errors)
215
214
  raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}")
216
215
 
217
- def get_extensions(
218
- self, sync: bool = False
219
- ) -> List[Union[Type[SchemaExtension], SchemaExtension]]:
220
- extensions = list(self.extensions)
221
-
216
+ def get_extensions(self, sync: bool = False) -> List[SchemaExtension]:
217
+ extensions = []
222
218
  if self.directives:
223
- extensions.append(DirectivesExtensionSync if sync else DirectivesExtension)
219
+ extensions = [
220
+ *self.extensions,
221
+ DirectivesExtensionSync if sync else DirectivesExtension,
222
+ ]
223
+ extensions.extend(self.extensions)
224
+ return [
225
+ ext if isinstance(ext, SchemaExtension) else ext(execution_context=None)
226
+ for ext in extensions
227
+ ]
228
+
229
+ @cached_property
230
+ def _sync_extensions(self) -> List[SchemaExtension]:
231
+ return self.get_extensions(sync=True)
232
+
233
+ @cached_property
234
+ def _async_extensions(self) -> List[SchemaExtension]:
235
+ return self.get_extensions(sync=False)
236
+
237
+ def create_extensions_runner(
238
+ self, execution_context: ExecutionContext, extensions: list[SchemaExtension]
239
+ ) -> SchemaExtensionsRunner:
240
+ return SchemaExtensionsRunner(
241
+ execution_context=execution_context,
242
+ extensions=extensions,
243
+ )
224
244
 
225
- return extensions
245
+ def _get_middleware_manager(
246
+ self, extensions: list[SchemaExtension]
247
+ ) -> MiddlewareManager:
248
+ # create a middleware manager with all the extensions that implement resolve
249
+ if not self._cached_middleware_manager:
250
+ self._cached_middleware_manager = MiddlewareManager(
251
+ *(ext for ext in extensions if ext._implements_resolve())
252
+ )
253
+ return self._cached_middleware_manager
254
+
255
+ def _create_execution_context(
256
+ self,
257
+ query: Optional[str],
258
+ allowed_operation_types: Iterable[OperationType],
259
+ variable_values: Optional[Dict[str, Any]] = None,
260
+ context_value: Optional[Any] = None,
261
+ root_value: Optional[Any] = None,
262
+ operation_name: Optional[str] = None,
263
+ ) -> ExecutionContext:
264
+ return ExecutionContext(
265
+ query=query,
266
+ schema=self,
267
+ allowed_operations=allowed_operation_types,
268
+ context=context_value,
269
+ root_value=root_value,
270
+ variables=variable_values,
271
+ provided_operation_name=operation_name,
272
+ )
226
273
 
227
274
  @lru_cache
228
275
  def get_type_by_name(
@@ -284,31 +331,33 @@ class Schema(BaseSchema):
284
331
  root_value: Optional[Any] = None,
285
332
  operation_name: Optional[str] = None,
286
333
  allowed_operation_types: Optional[Iterable[OperationType]] = None,
287
- ) -> Union[ExecutionResult, SubscriptionExecutionResult]:
334
+ ) -> ExecutionResult:
288
335
  if allowed_operation_types is None:
289
336
  allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
290
337
 
291
- # Create execution context
292
- execution_context = ExecutionContext(
338
+ execution_context = self._create_execution_context(
293
339
  query=query,
294
- schema=self,
295
- context=context_value,
340
+ allowed_operation_types=allowed_operation_types,
341
+ variable_values=variable_values,
342
+ context_value=context_value,
296
343
  root_value=root_value,
297
- variables=variable_values,
298
- provided_operation_name=operation_name,
344
+ operation_name=operation_name,
299
345
  )
300
-
301
- result = await execute(
346
+ extensions = self.get_extensions()
347
+ # TODO (#3571): remove this when we implement execution context as parameter.
348
+ for extension in extensions:
349
+ extension.execution_context = execution_context
350
+ return await execute(
302
351
  self._schema,
303
- extensions=self.get_extensions(),
304
- execution_context_class=self.execution_context_class,
305
352
  execution_context=execution_context,
306
- allowed_operation_types=allowed_operation_types,
353
+ extensions_runner=self.create_extensions_runner(
354
+ execution_context, extensions
355
+ ),
307
356
  process_errors=self._process_errors,
357
+ middleware_manager=self._get_middleware_manager(extensions),
358
+ execution_context_class=self.execution_context_class,
308
359
  )
309
360
 
310
- return result
311
-
312
361
  def execute_sync(
313
362
  self,
314
363
  query: Optional[str],
@@ -321,44 +370,59 @@ class Schema(BaseSchema):
321
370
  if allowed_operation_types is None:
322
371
  allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
323
372
 
324
- execution_context = ExecutionContext(
373
+ execution_context = self._create_execution_context(
325
374
  query=query,
326
- schema=self,
327
- context=context_value,
375
+ allowed_operation_types=allowed_operation_types,
376
+ variable_values=variable_values,
377
+ context_value=context_value,
328
378
  root_value=root_value,
329
- variables=variable_values,
330
- provided_operation_name=operation_name,
379
+ operation_name=operation_name,
331
380
  )
332
-
333
- result = execute_sync(
381
+ extensions = self._sync_extensions
382
+ # TODO (#3571): remove this when we implement execution context as parameter.
383
+ for extension in extensions:
384
+ extension.execution_context = execution_context
385
+ return execute_sync(
334
386
  self._schema,
335
- extensions=self.get_extensions(sync=True),
336
- execution_context_class=self.execution_context_class,
337
387
  execution_context=execution_context,
388
+ extensions_runner=self.create_extensions_runner(
389
+ execution_context, extensions
390
+ ),
391
+ execution_context_class=self.execution_context_class,
338
392
  allowed_operation_types=allowed_operation_types,
339
393
  process_errors=self._process_errors,
394
+ middleware_manager=self._get_middleware_manager(extensions),
340
395
  )
341
396
 
342
- return result
343
-
344
397
  async def subscribe(
345
398
  self,
346
- # TODO: make this optional when we support extensions
347
- query: str,
399
+ query: Optional[str],
348
400
  variable_values: Optional[Dict[str, Any]] = None,
349
401
  context_value: Optional[Any] = None,
350
402
  root_value: Optional[Any] = None,
351
403
  operation_name: Optional[str] = None,
352
- ) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]:
353
- return await await_maybe(
354
- subscribe(
355
- self._schema,
356
- parse(query),
357
- root_value=root_value,
358
- context_value=context_value,
359
- variable_values=variable_values,
360
- operation_name=operation_name,
361
- )
404
+ ) -> SubscriptionResult:
405
+ execution_context = self._create_execution_context(
406
+ query=query,
407
+ allowed_operation_types=(OperationType.SUBSCRIPTION,),
408
+ variable_values=variable_values,
409
+ context_value=context_value,
410
+ root_value=root_value,
411
+ operation_name=operation_name,
412
+ )
413
+ extensions = self._async_extensions
414
+ # TODO (#3571): remove this when we implement execution context as parameter.
415
+ for extension in extensions:
416
+ extension.execution_context = execution_context
417
+ return await subscribe(
418
+ self._schema,
419
+ execution_context=execution_context,
420
+ extensions_runner=self.create_extensions_runner(
421
+ execution_context, extensions
422
+ ),
423
+ process_errors=self._process_errors,
424
+ middleware_manager=self._get_middleware_manager(extensions),
425
+ execution_context_class=self.execution_context_class,
362
426
  )
363
427
 
364
428
  def _resolve_node_ids(self) -> None:
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Optional, Type, Union
4
+
5
+ from graphql import (
6
+ ExecutionResult as OriginalExecutionResult,
7
+ )
8
+ from graphql.execution import ExecutionContext as GraphQLExecutionContext
9
+ from graphql.execution import subscribe as original_subscribe
10
+
11
+ from strawberry.types import ExecutionResult
12
+ from strawberry.types.execution import ExecutionContext, PreExecutionError
13
+ from strawberry.utils import IS_GQL_32
14
+ from strawberry.utils.await_maybe import await_maybe
15
+
16
+ from .execute import (
17
+ ProcessErrors,
18
+ _coerce_error,
19
+ _handle_execution_result,
20
+ _parse_and_validate_async,
21
+ )
22
+
23
+ if TYPE_CHECKING:
24
+ from typing_extensions import TypeAlias
25
+
26
+ from graphql.execution.middleware import MiddlewareManager
27
+ from graphql.type.schema import GraphQLSchema
28
+
29
+ from ..extensions.runner import SchemaExtensionsRunner
30
+
31
+ SubscriptionResult: TypeAlias = Union[
32
+ PreExecutionError, AsyncGenerator[ExecutionResult, None]
33
+ ]
34
+
35
+ OriginSubscriptionResult = Union[
36
+ OriginalExecutionResult,
37
+ AsyncIterator[OriginalExecutionResult],
38
+ ]
39
+
40
+
41
+ async def _subscribe(
42
+ schema: GraphQLSchema,
43
+ execution_context: ExecutionContext,
44
+ extensions_runner: SchemaExtensionsRunner,
45
+ process_errors: ProcessErrors,
46
+ middleware_manager: MiddlewareManager,
47
+ execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
48
+ ) -> AsyncGenerator[Union[PreExecutionError, ExecutionResult], None]:
49
+ async with extensions_runner.operation():
50
+ if initial_error := await _parse_and_validate_async(
51
+ context=execution_context,
52
+ extensions_runner=extensions_runner,
53
+ ):
54
+ initial_error.extensions = await extensions_runner.get_extensions_results(
55
+ execution_context
56
+ )
57
+ yield await _handle_execution_result(
58
+ execution_context, initial_error, extensions_runner, process_errors
59
+ )
60
+ try:
61
+ async with extensions_runner.executing():
62
+ assert execution_context.graphql_document is not None
63
+ gql_33_kwargs = {
64
+ "middleware": middleware_manager,
65
+ "execution_context_class": execution_context_class,
66
+ }
67
+ try:
68
+ # Might not be awaitable for pre-execution errors.
69
+ aiter_or_result: OriginSubscriptionResult = await await_maybe(
70
+ original_subscribe(
71
+ schema,
72
+ execution_context.graphql_document,
73
+ root_value=execution_context.root_value,
74
+ variable_values=execution_context.variables,
75
+ operation_name=execution_context.operation_name,
76
+ context_value=execution_context.context,
77
+ **{} if IS_GQL_32 else gql_33_kwargs, # type: ignore[arg-type]
78
+ )
79
+ )
80
+ # graphql-core 3.2 doesn't handle some of the pre-execution errors.
81
+ # see `test_subscription_immediate_error`
82
+ except Exception as exc:
83
+ aiter_or_result = OriginalExecutionResult(
84
+ data=None, errors=[_coerce_error(exc)]
85
+ )
86
+
87
+ # Handle pre-execution errors.
88
+ if isinstance(aiter_or_result, OriginalExecutionResult):
89
+ yield await _handle_execution_result(
90
+ execution_context,
91
+ PreExecutionError(data=None, errors=aiter_or_result.errors),
92
+ extensions_runner,
93
+ process_errors,
94
+ )
95
+ else:
96
+ try:
97
+ async for result in aiter_or_result:
98
+ yield await _handle_execution_result(
99
+ execution_context,
100
+ result,
101
+ extensions_runner,
102
+ process_errors,
103
+ )
104
+ # graphql-core doesn't handle exceptions raised while executing.
105
+ except Exception as exc:
106
+ yield await _handle_execution_result(
107
+ execution_context,
108
+ OriginalExecutionResult(data=None, errors=[_coerce_error(exc)]),
109
+ extensions_runner,
110
+ process_errors,
111
+ )
112
+ # catch exceptions raised in `on_execute` hook.
113
+ except Exception as exc:
114
+ origin_result = OriginalExecutionResult(
115
+ data=None, errors=[_coerce_error(exc)]
116
+ )
117
+ yield await _handle_execution_result(
118
+ execution_context,
119
+ origin_result,
120
+ extensions_runner,
121
+ process_errors,
122
+ )
123
+
124
+
125
+ async def subscribe(
126
+ schema: GraphQLSchema,
127
+ execution_context: ExecutionContext,
128
+ extensions_runner: SchemaExtensionsRunner,
129
+ process_errors: ProcessErrors,
130
+ middleware_manager: MiddlewareManager,
131
+ execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
132
+ ) -> SubscriptionResult:
133
+ asyncgen = _subscribe(
134
+ schema,
135
+ execution_context,
136
+ extensions_runner,
137
+ process_errors,
138
+ middleware_manager,
139
+ execution_context_class,
140
+ )
141
+ # GrapQL-core might return an initial error result instead of an async iterator.
142
+ # This happens when "there was an immediate error" i.e resolver is not an async iterator.
143
+ # To overcome this while maintaining the extension contexts we do this trick.
144
+ first = await asyncgen.__anext__()
145
+ if isinstance(first, PreExecutionError):
146
+ await asyncgen.aclose()
147
+ return first
148
+
149
+ async def _wrapper() -> AsyncGenerator[ExecutionResult, None]:
150
+ yield first
151
+ async for result in asyncgen:
152
+ yield result
153
+
154
+ return _wrapper()
@@ -7,15 +7,13 @@ from contextlib import suppress
7
7
  from typing import (
8
8
  TYPE_CHECKING,
9
9
  Any,
10
- AsyncGenerator,
11
- AsyncIterator,
10
+ Awaitable,
12
11
  Callable,
13
12
  Dict,
14
13
  List,
15
14
  Optional,
16
15
  )
17
16
 
18
- from graphql import ExecutionResult as GraphQLExecutionResult
19
17
  from graphql import GraphQLError, GraphQLSyntaxError, parse
20
18
 
21
19
  from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
@@ -24,11 +22,14 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
24
22
  ConnectionInitMessage,
25
23
  ErrorMessage,
26
24
  NextMessage,
25
+ NextPayload,
27
26
  PingMessage,
28
27
  PongMessage,
29
28
  SubscribeMessage,
30
29
  SubscribeMessagePayload,
31
30
  )
31
+ from strawberry.types import ExecutionResult
32
+ from strawberry.types.execution import PreExecutionError
32
33
  from strawberry.types.graphql import OperationType
33
34
  from strawberry.types.unset import UNSET
34
35
  from strawberry.utils.debug import pretty_print_graphql_operation
@@ -38,10 +39,10 @@ if TYPE_CHECKING:
38
39
  from datetime import timedelta
39
40
 
40
41
  from strawberry.schema import BaseSchema
42
+ from strawberry.schema.subscribe import SubscriptionResult
41
43
  from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
42
44
  GraphQLTransportMessage,
43
45
  )
44
- from strawberry.types import ExecutionResult
45
46
 
46
47
 
47
48
  class BaseGraphQLTransportWSHandler(ABC):
@@ -243,10 +244,10 @@ class BaseGraphQLTransportWSHandler(ABC):
243
244
  if isinstance(context, dict):
244
245
  context["connection_params"] = self.connection_params
245
246
  root_value = await self.get_root_value()
246
-
247
+ result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]
247
248
  # Get an AsyncGenerator yielding the results
248
249
  if operation_type == OperationType.SUBSCRIPTION:
249
- result_source = await self.schema.subscribe(
250
+ result_source = self.schema.subscribe(
250
251
  query=message.payload.query,
251
252
  variable_values=message.payload.variables,
252
253
  operation_name=message.payload.operationName,
@@ -254,29 +255,16 @@ class BaseGraphQLTransportWSHandler(ABC):
254
255
  root_value=root_value,
255
256
  )
256
257
  else:
257
- # create AsyncGenerator returning a single result
258
- async def get_result_source() -> AsyncIterator[ExecutionResult]:
259
- yield await self.schema.execute( # type: ignore
260
- query=message.payload.query,
261
- variable_values=message.payload.variables,
262
- context_value=context,
263
- root_value=root_value,
264
- operation_name=message.payload.operationName,
265
- )
266
-
267
- result_source = get_result_source()
258
+ result_source = self.schema.execute(
259
+ query=message.payload.query,
260
+ variable_values=message.payload.variables,
261
+ context_value=context,
262
+ root_value=root_value,
263
+ operation_name=message.payload.operationName,
264
+ )
268
265
 
269
266
  operation = Operation(self, message.id, operation_type)
270
267
 
271
- # Handle initial validation errors
272
- if isinstance(result_source, GraphQLExecutionResult):
273
- assert operation_type == OperationType.SUBSCRIPTION
274
- assert result_source.errors
275
- payload = [err.formatted for err in result_source.errors]
276
- await self.send_message(ErrorMessage(id=message.id, payload=payload))
277
- self.schema.process_errors(result_source.errors)
278
- return
279
-
280
268
  # Create task to handle this subscription, reserve the operation ID
281
269
  operation.task = asyncio.create_task(
282
270
  self.operation_task(result_source, operation)
@@ -284,65 +272,37 @@ class BaseGraphQLTransportWSHandler(ABC):
284
272
  self.operations[message.id] = operation
285
273
 
286
274
  async def operation_task(
287
- self, result_source: AsyncGenerator, operation: Operation
275
+ self,
276
+ result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult],
277
+ operation: Operation,
288
278
  ) -> None:
289
279
  """The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
290
280
  # TODO: Handle errors in this method using self.handle_task_exception()
291
281
  try:
292
- await self.handle_async_results(result_source, operation)
293
- except BaseException: # pragma: no cover
294
- # cleanup in case of something really unexpected
295
- # wait for generator to be closed to ensure that any existing
296
- # 'finally' statement is called
297
- with suppress(RuntimeError):
298
- await result_source.aclose()
299
- if operation.id in self.operations:
300
- del self.operations[operation.id]
282
+ first_res_or_agen = await result_source
283
+ # that's an immediate error we should end the operation
284
+ # without a COMPLETE message
285
+ if isinstance(first_res_or_agen, PreExecutionError):
286
+ assert first_res_or_agen.errors
287
+ await operation.send_initial_errors(first_res_or_agen.errors)
288
+ # that's a mutation / query result
289
+ elif isinstance(first_res_or_agen, ExecutionResult):
290
+ await operation.send_next(first_res_or_agen)
291
+ await operation.send_message(CompleteMessage(id=operation.id))
292
+ else:
293
+ async for result in first_res_or_agen:
294
+ await operation.send_next(result)
295
+ await operation.send_message(CompleteMessage(id=operation.id))
296
+
297
+ except BaseException as e: # pragma: no cover
298
+ self.operations.pop(operation.id, None)
301
299
  raise
302
- else:
303
- await operation.send_message(CompleteMessage(id=operation.id))
304
300
  finally:
305
301
  # add this task to a list to be reaped later
306
302
  task = asyncio.current_task()
307
303
  assert task is not None
308
304
  self.completed_tasks.append(task)
309
305
 
310
- async def handle_async_results(
311
- self,
312
- result_source: AsyncGenerator,
313
- operation: Operation,
314
- ) -> None:
315
- try:
316
- async for result in result_source:
317
- if (
318
- result.errors
319
- and operation.operation_type != OperationType.SUBSCRIPTION
320
- ):
321
- error_payload = [err.formatted for err in result.errors]
322
- error_message = ErrorMessage(id=operation.id, payload=error_payload)
323
- await operation.send_message(error_message)
324
- # don't need to call schema.process_errors() here because
325
- # it was already done by schema.execute()
326
- return
327
- else:
328
- next_payload = {"data": result.data}
329
- if result.errors:
330
- self.schema.process_errors(result.errors)
331
- next_payload["errors"] = [
332
- err.formatted for err in result.errors
333
- ]
334
- next_message = NextMessage(id=operation.id, payload=next_payload)
335
- await operation.send_message(next_message)
336
- except Exception as error:
337
- # GraphQLErrors are handled by graphql-core and included in the
338
- # ExecutionResult
339
- error = GraphQLError(str(error), original_error=error)
340
- error_payload = [error.formatted]
341
- error_message = ErrorMessage(id=operation.id, payload=error_payload)
342
- await operation.send_message(error_message)
343
- self.schema.process_errors([error])
344
- return
345
-
346
306
  def forget_id(self, id: str) -> None:
347
307
  # de-register the operation id making it immediately available
348
308
  # for re-use
@@ -401,5 +361,21 @@ class Operation:
401
361
  self.handler.forget_id(self.id)
402
362
  await self.handler.send_message(message)
403
363
 
364
+ async def send_initial_errors(self, errors: list[GraphQLError]) -> None:
365
+ # Initial errors see https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#error
366
+ # "This can occur before execution starts,
367
+ # usually due to validation errors, or during the execution of the request"
368
+ await self.send_message(
369
+ ErrorMessage(id=self.id, payload=[err.formatted for err in errors])
370
+ )
371
+
372
+ async def send_next(self, execution_result: ExecutionResult) -> None:
373
+ next_payload: NextPayload = {"data": execution_result.data}
374
+ if execution_result.errors:
375
+ next_payload["errors"] = [err.formatted for err in execution_result.errors]
376
+ if execution_result.extensions:
377
+ next_payload["extensions"] = execution_result.extensions
378
+ await self.send_message(NextMessage(id=self.id, payload=next_payload))
379
+
404
380
 
405
381
  __all__ = ["BaseGraphQLTransportWSHandler", "Operation"]
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import asdict, dataclass
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict
5
5
 
6
6
  from strawberry.types.unset import UNSET
7
7
 
@@ -68,12 +68,20 @@ class SubscribeMessage(GraphQLTransportMessage):
68
68
  type: str = "subscribe"
69
69
 
70
70
 
71
+ class NextPayload(TypedDict, total=False):
72
+ data: Any
73
+
74
+ # Optional list of formatted graphql.GraphQLError objects
75
+ errors: Optional[List[GraphQLFormattedError]]
76
+ extensions: Optional[Dict[str, Any]]
77
+
78
+
71
79
  @dataclass
72
80
  class NextMessage(GraphQLTransportMessage):
73
81
  """Direction: Server -> Client."""
74
82
 
75
83
  id: str
76
- payload: Dict[str, Any] # TODO: shape like FormattedExecutionResult
84
+ payload: NextPayload
77
85
  type: str = "next"
78
86
 
79
87
  def as_dict(self) -> dict: