strawberry-graphql 0.190.0.dev1687447182__py3-none-any.whl → 0.192.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. strawberry/annotation.py +24 -3
  2. strawberry/arguments.py +6 -2
  3. strawberry/channels/testing.py +22 -13
  4. strawberry/cli/__init__.py +4 -4
  5. strawberry/cli/commands/upgrade/__init__.py +75 -0
  6. strawberry/cli/commands/upgrade/_fake_progress.py +21 -0
  7. strawberry/cli/commands/upgrade/_run_codemod.py +74 -0
  8. strawberry/codemods/__init__.py +0 -0
  9. strawberry/codemods/annotated_unions.py +185 -0
  10. strawberry/exceptions/invalid_union_type.py +23 -3
  11. strawberry/exceptions/utils/source_finder.py +147 -11
  12. strawberry/extensions/field_extension.py +2 -5
  13. strawberry/fastapi/router.py +5 -4
  14. strawberry/federation/union.py +4 -5
  15. strawberry/field.py +116 -75
  16. strawberry/http/__init__.py +1 -3
  17. strawberry/permission.py +3 -166
  18. strawberry/relay/fields.py +2 -0
  19. strawberry/relay/types.py +14 -4
  20. strawberry/schema/schema.py +1 -1
  21. strawberry/schema/schema_converter.py +106 -38
  22. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +20 -8
  23. strawberry/subscriptions/protocols/graphql_transport_ws/types.py +1 -1
  24. strawberry/subscriptions/protocols/graphql_ws/handlers.py +4 -7
  25. strawberry/type.py +2 -2
  26. strawberry/types/type_resolver.py +7 -29
  27. strawberry/types/types.py +6 -0
  28. strawberry/union.py +46 -17
  29. strawberry/utils/typing.py +21 -0
  30. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/METADATA +1 -1
  31. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/RECORD +34 -30
  32. strawberry/exceptions/permission_fail_silently_requires_optional.py +0 -52
  33. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/LICENSE +0 -0
  34. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/WHEEL +0 -0
  35. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/entry_points.txt +0 -0
strawberry/permission.py CHANGED
@@ -1,184 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- import abc
4
- from inspect import iscoroutinefunction
5
- from typing import (
6
- TYPE_CHECKING,
7
- Any,
8
- Awaitable,
9
- Dict,
10
- List,
11
- Optional,
12
- Type,
13
- Union,
14
- )
15
-
16
- from strawberry.exceptions import StrawberryGraphQLError
17
- from strawberry.exceptions.permission_fail_silently_requires_optional import (
18
- PermissionFailSilentlyRequiresOptionalError,
19
- )
20
- from strawberry.extensions import FieldExtension
21
- from strawberry.schema_directive import Location, StrawberrySchemaDirective
22
- from strawberry.type import StrawberryList, StrawberryOptional
23
- from strawberry.utils.await_maybe import await_maybe
24
- from strawberry.utils.cached_property import cached_property
3
+ from typing import TYPE_CHECKING, Any, Awaitable, Optional, Union
25
4
 
26
5
  if TYPE_CHECKING:
27
- from graphql import GraphQLError, GraphQLErrorExtensions
28
-
29
- from strawberry.extensions.field_extension import (
30
- AsyncExtensionResolver,
31
- SyncExtensionResolver,
32
- )
33
- from strawberry.field import StrawberryField
34
- from strawberry.types import Info
6
+ from strawberry.types.info import Info
35
7
 
36
8
 
37
- class BasePermission(abc.ABC):
9
+ class BasePermission:
38
10
  """
39
11
  Base class for creating permissions
40
12
  """
41
13
 
42
14
  message: Optional[str] = None
43
15
 
44
- error_extensions: Optional[GraphQLErrorExtensions] = None
45
-
46
- error_class: Type[GraphQLError] = StrawberryGraphQLError
47
-
48
- _schema_directive: Optional[object] = None
49
-
50
- @abc.abstractmethod
51
16
  def has_permission(
52
17
  self, source: Any, info: Info, **kwargs: Any
53
18
  ) -> Union[bool, Awaitable[bool]]:
54
19
  raise NotImplementedError(
55
20
  "Permission classes should override has_permission method"
56
21
  )
57
-
58
- def handle_no_permission(self) -> None:
59
- """
60
- Default error raising for permissions.
61
- This can be overridden to customize the behavior.
62
- """
63
-
64
- # Instantiate error class
65
- error = self.error_class(self.message or "")
66
-
67
- if self.error_extensions:
68
- # Add our extensions to the error
69
- if not error.extensions:
70
- error.extensions = dict()
71
- error.extensions.update(self.error_extensions)
72
-
73
- raise error
74
-
75
- @property
76
- def schema_directive(self) -> object:
77
- if not self._schema_directive:
78
-
79
- class AutoDirective:
80
- __strawberry_directive__ = StrawberrySchemaDirective(
81
- self.__class__.__name__,
82
- self.__class__.__name__,
83
- [Location.FIELD_DEFINITION],
84
- [],
85
- )
86
-
87
- self._schema_directive = AutoDirective()
88
-
89
- return self._schema_directive
90
-
91
-
92
- class PermissionExtension(FieldExtension):
93
- """
94
- Handles permissions for a field
95
- Instantiate this as a field extension with all of the permissions you want to apply
96
-
97
- fail_silently: bool = False will return None or [] if the permission fails
98
- instead of raising an exception. This is only valid for optional or list fields.
99
-
100
- NOTE:
101
- Currently, this is automatically added to the field, when using
102
- field.permission_classes
103
- This is deprecated behavior, please manually add the extension to field.extensions
104
- """
105
-
106
- def __init__(
107
- self,
108
- permissions: List[BasePermission],
109
- use_directives: bool = True,
110
- fail_silently: bool = False,
111
- ):
112
- self.permissions = permissions
113
- self.fail_silently = fail_silently
114
- self.return_empty_list = False
115
- self.use_directives = use_directives
116
-
117
- def apply(self, field: StrawberryField) -> None:
118
- """
119
- Applies all of the permission directives to the schema
120
- and sets up silent permissions
121
- """
122
- if self.use_directives:
123
- field.directives.extend(
124
- p.schema_directive for p in self.permissions if p.schema_directive
125
- )
126
- # We can only fail silently if the field is optional or a list
127
- if self.fail_silently:
128
- if isinstance(field.type, StrawberryOptional):
129
- if isinstance(field.type.of_type, StrawberryList):
130
- self.return_empty_list = True
131
- elif isinstance(field.type, StrawberryList):
132
- self.return_empty_list = True
133
- else:
134
- errror = PermissionFailSilentlyRequiresOptionalError(field)
135
- a = errror.exception_source
136
- raise errror
137
-
138
- def _handle_no_permission(self, permission: BasePermission) -> Any:
139
- if self.fail_silently:
140
- return [] if self.return_empty_list else None
141
- return permission.handle_no_permission()
142
-
143
- def resolve(
144
- self,
145
- next_: SyncExtensionResolver,
146
- source: Any,
147
- info: Info,
148
- **kwargs: Dict[str, Any],
149
- ) -> Any:
150
- """
151
- Checks if the permission should be accepted and
152
- raises an exception if not
153
- """
154
- for permission in self.permissions:
155
- if not permission.has_permission(source, info, **kwargs):
156
- return self._handle_no_permission(permission)
157
- return next_(source, info, **kwargs)
158
-
159
- async def resolve_async(
160
- self,
161
- next_: AsyncExtensionResolver,
162
- source: Any,
163
- info: Info,
164
- **kwargs: Dict[str, Any],
165
- ) -> Any:
166
- for permission in self.permissions:
167
- has_permission = await await_maybe(
168
- permission.has_permission(source, info, **kwargs)
169
- )
170
-
171
- if not has_permission:
172
- return self._handle_no_permission(permission)
173
- return await next_(source, info, **kwargs)
174
-
175
- @cached_property
176
- def supports_sync(self) -> bool:
177
- """The Permission extension always supports async checking using await_maybe,
178
- but only supports sync checking if there are no async permissions"""
179
- async_permissions = [
180
- True
181
- for permission in self.permissions
182
- if iscoroutinefunction(permission.has_permission)
183
- ]
184
- return len(async_permissions) == 0
@@ -235,6 +235,8 @@ class ConnectionExtension(FieldExtension):
235
235
  # for subscription support, but we can't use it here. Maybe we can refactor
236
236
  # this in the future.
237
237
  resolver_type = field.base_resolver.signature.return_annotation
238
+ if isinstance(resolver_type, str):
239
+ resolver_type = ForwardRef(resolver_type)
238
240
  if isinstance(resolver_type, ForwardRef):
239
241
  resolver_type = eval_type(
240
242
  resolver_type,
strawberry/relay/types.py CHANGED
@@ -11,6 +11,7 @@ from typing import (
11
11
  AsyncIterator,
12
12
  Awaitable,
13
13
  ClassVar,
14
+ ForwardRef,
14
15
  Generic,
15
16
  Iterable,
16
17
  Iterator,
@@ -38,16 +39,16 @@ from strawberry.object_type import interface, type
38
39
  from strawberry.private import StrawberryPrivate
39
40
  from strawberry.relay.exceptions import NodeIDAnnotationError
40
41
  from strawberry.type import StrawberryContainer, get_object_definition
42
+ from strawberry.types.info import Info # noqa: TCH001
41
43
  from strawberry.types.types import StrawberryObjectDefinition
42
44
  from strawberry.utils.aio import aenumerate, aislice, resolve_awaitable
43
45
  from strawberry.utils.inspect import in_async_context
44
- from strawberry.utils.typing import eval_type
46
+ from strawberry.utils.typing import eval_type, is_classvar
45
47
 
46
48
  from .utils import from_base64, to_base64
47
49
 
48
50
  if TYPE_CHECKING:
49
51
  from strawberry.scalars import ID
50
- from strawberry.types.info import Info
51
52
  from strawberry.utils.await_maybe import AwaitableOrValue
52
53
 
53
54
  _T = TypeVar("_T")
@@ -407,7 +408,16 @@ class Node:
407
408
  base_namespace = sys.modules[base.__module__].__dict__
408
409
 
409
410
  for attr_name, attr in getattr(base, "__annotations__", {}).items():
410
- evaled = eval_type(attr, globalns=base_namespace)
411
+ # Some ClassVar might raise TypeError when being resolved
412
+ # on some python versions. This is fine to skip since
413
+ # we are not interested in ClassVars here
414
+ if is_classvar(base, attr):
415
+ continue
416
+
417
+ evaled = eval_type(
418
+ ForwardRef(attr) if isinstance(attr, str) else attr,
419
+ globalns=base_namespace,
420
+ )
411
421
 
412
422
  if get_origin(evaled) is Annotated and any(
413
423
  isinstance(a, NodeIDPrivate) for a in get_args(evaled)
@@ -845,7 +855,7 @@ class ListConnection(Connection[NodeType]):
845
855
  field_def = type_def.get_field("edges")
846
856
  assert field_def
847
857
 
848
- field = field_def.type
858
+ field = field_def.resolve_type(type_definition=type_def)
849
859
  while isinstance(field, StrawberryContainer):
850
860
  field = field.of_type
851
861
 
@@ -23,7 +23,7 @@ from graphql import (
23
23
  parse,
24
24
  validate_schema,
25
25
  )
26
- from graphql.subscription import subscribe
26
+ from graphql.execution import subscribe
27
27
  from graphql.type.directives import specified_directives
28
28
 
29
29
  from strawberry import relay
@@ -8,6 +8,7 @@ from typing import (
8
8
  Any,
9
9
  Callable,
10
10
  Dict,
11
+ Generic,
11
12
  List,
12
13
  Optional,
13
14
  Tuple,
@@ -16,6 +17,7 @@ from typing import (
16
17
  Union,
17
18
  cast,
18
19
  )
20
+ from typing_extensions import Protocol
19
21
 
20
22
  from graphql import (
21
23
  GraphQLArgument,
@@ -53,6 +55,8 @@ from strawberry.schema.types.scalar import _make_scalar_type
53
55
  from strawberry.type import (
54
56
  StrawberryList,
55
57
  StrawberryOptional,
58
+ StrawberryType,
59
+ WithStrawberryObjectDefinition,
56
60
  has_object_definition,
57
61
  )
58
62
  from strawberry.types.info import Info
@@ -83,7 +87,56 @@ if TYPE_CHECKING:
83
87
  from strawberry.field import StrawberryField
84
88
  from strawberry.schema.config import StrawberryConfig
85
89
  from strawberry.schema_directive import StrawberrySchemaDirective
86
- from strawberry.type import StrawberryType
90
+
91
+
92
+ FieldType = TypeVar(
93
+ "FieldType", bound=Union[GraphQLField, GraphQLInputField], covariant=True
94
+ )
95
+
96
+
97
+ class FieldConverterProtocol(Generic[FieldType], Protocol):
98
+ def __call__( # pragma: no cover
99
+ self,
100
+ field: StrawberryField,
101
+ *,
102
+ override_type: Optional[
103
+ Union[StrawberryType, Type[WithStrawberryObjectDefinition]]
104
+ ] = None,
105
+ ) -> FieldType:
106
+ ...
107
+
108
+
109
+ def _get_thunk_mapping(
110
+ type_definition: StrawberryObjectDefinition,
111
+ name_converter: Callable[[StrawberryField], str],
112
+ field_converter: FieldConverterProtocol[FieldType],
113
+ ) -> Dict[str, FieldType]:
114
+ """Create a GraphQL core `ThunkMapping` mapping of field names to field types.
115
+
116
+ This method filters out remaining `strawberry.Private` annotated fields that
117
+ could not be filtered during the initialization of a `TypeDefinition` due to
118
+ postponed type-hint evaluation (PEP-563). Performing this filtering now (at
119
+ schema conversion time) ensures that all types to be included in the schema
120
+ should have already been resolved.
121
+
122
+ Raises:
123
+ TypeError: If the type of a field in ``fields`` is `UNRESOLVED`
124
+ """
125
+ thunk_mapping: Dict[str, FieldType] = {}
126
+
127
+ for field in type_definition.fields:
128
+ field_type = field.resolve_type(type_definition=type_definition)
129
+
130
+ if field_type is UNRESOLVED:
131
+ raise UnresolvedFieldTypeError(type_definition, field)
132
+
133
+ if not is_private(field_type):
134
+ thunk_mapping[name_converter(field)] = field_converter(
135
+ field,
136
+ override_type=field_type,
137
+ )
138
+
139
+ return thunk_mapping
87
140
 
88
141
 
89
142
  # graphql-core expects a resolver for an Enum type to return
@@ -248,11 +301,20 @@ class GraphQLCoreConverter:
248
301
  },
249
302
  )
250
303
 
251
- def from_field(self, field: StrawberryField) -> GraphQLField:
304
+ def from_field(
305
+ self,
306
+ field: StrawberryField,
307
+ *,
308
+ override_type: Optional[
309
+ Union[StrawberryType, Type[WithStrawberryObjectDefinition]]
310
+ ] = None,
311
+ ) -> GraphQLField:
252
312
  # self.from_resolver needs to be called before accessing field.type because
253
313
  # in there a field extension might want to change the type during its apply
254
314
  resolver = self.from_resolver(field)
255
- field_type = cast("GraphQLOutputType", self.from_maybe_optional(field.type))
315
+ field_type = cast(
316
+ "GraphQLOutputType", self.from_maybe_optional(override_type or field.type)
317
+ )
256
318
  subscribe = None
257
319
 
258
320
  if field.is_subscription:
@@ -276,8 +338,17 @@ class GraphQLCoreConverter:
276
338
  },
277
339
  )
278
340
 
279
- def from_input_field(self, field: StrawberryField) -> GraphQLInputField:
280
- field_type = cast("GraphQLInputType", self.from_maybe_optional(field.type))
341
+ def from_input_field(
342
+ self,
343
+ field: StrawberryField,
344
+ *,
345
+ override_type: Optional[
346
+ Union[StrawberryType, Type[WithStrawberryObjectDefinition]]
347
+ ] = None,
348
+ ) -> GraphQLInputField:
349
+ field_type = cast(
350
+ "GraphQLInputType", self.from_maybe_optional(override_type or field.type)
351
+ )
281
352
  default_value: object
282
353
 
283
354
  if field.default_value is UNSET or field.default_value is dataclasses.MISSING:
@@ -295,40 +366,10 @@ class GraphQLCoreConverter:
295
366
  },
296
367
  )
297
368
 
298
- FieldType = TypeVar("FieldType", GraphQLField, GraphQLInputField)
299
-
300
- @staticmethod
301
- def _get_thunk_mapping(
302
- type_definition: StrawberryObjectDefinition,
303
- name_converter: Callable[[StrawberryField], str],
304
- field_converter: Callable[[StrawberryField], FieldType],
305
- ) -> Dict[str, FieldType]:
306
- """Create a GraphQL core `ThunkMapping` mapping of field names to field types.
307
-
308
- This method filters out remaining `strawberry.Private` annotated fields that
309
- could not be filtered during the initialization of a `TypeDefinition` due to
310
- postponed type-hint evaluation (PEP-563). Performing this filtering now (at
311
- schema conversion time) ensures that all types to be included in the schema
312
- should have already been resolved.
313
-
314
- Raises:
315
- TypeError: If the type of a field in ``fields`` is `UNRESOLVED`
316
- """
317
- thunk_mapping = {}
318
-
319
- for field in type_definition.fields:
320
- if field.type is UNRESOLVED:
321
- raise UnresolvedFieldTypeError(type_definition, field)
322
-
323
- if not is_private(field.type):
324
- thunk_mapping[name_converter(field)] = field_converter(field)
325
-
326
- return thunk_mapping
327
-
328
369
  def get_graphql_fields(
329
370
  self, type_definition: StrawberryObjectDefinition
330
371
  ) -> Dict[str, GraphQLField]:
331
- return self._get_thunk_mapping(
372
+ return _get_thunk_mapping(
332
373
  type_definition=type_definition,
333
374
  name_converter=self.config.name_converter.from_field,
334
375
  field_converter=self.from_field,
@@ -337,7 +378,7 @@ class GraphQLCoreConverter:
337
378
  def get_graphql_input_fields(
338
379
  self, type_definition: StrawberryObjectDefinition
339
380
  ) -> Dict[str, GraphQLInputField]:
340
- return self._get_thunk_mapping(
381
+ return _get_thunk_mapping(
341
382
  type_definition=type_definition,
342
383
  name_converter=self.config.name_converter.from_field,
343
384
  field_converter=self.from_input_field,
@@ -522,6 +563,31 @@ class GraphQLCoreConverter:
522
563
 
523
564
  return args, kwargs
524
565
 
566
+ def _check_permissions(source: Any, info: Info, kwargs: Any):
567
+ """
568
+ Checks if the permission should be accepted and
569
+ raises an exception if not
570
+ """
571
+ for permission_class in field.permission_classes:
572
+ permission = permission_class()
573
+
574
+ if not permission.has_permission(source, info, **kwargs):
575
+ message = getattr(permission, "message", None)
576
+ raise PermissionError(message)
577
+
578
+ async def _check_permissions_async(source: Any, info: Info, kwargs: Any):
579
+ for permission_class in field.permission_classes:
580
+ permission = permission_class()
581
+ has_permission: bool
582
+
583
+ has_permission = await await_maybe(
584
+ permission.has_permission(source, info, **kwargs)
585
+ )
586
+
587
+ if not has_permission:
588
+ message = getattr(permission, "message", None)
589
+ raise PermissionError(message)
590
+
525
591
  def _strawberry_info_from_graphql(info: GraphQLResolveInfo) -> Info:
526
592
  return Info(
527
593
  _raw_info=info,
@@ -590,6 +656,7 @@ class GraphQLCoreConverter:
590
656
 
591
657
  def _resolver(_source: Any, info: GraphQLResolveInfo, **kwargs: Any):
592
658
  strawberry_info = _strawberry_info_from_graphql(info)
659
+ _check_permissions(_source, strawberry_info, kwargs)
593
660
 
594
661
  return _get_result_with_extensions(
595
662
  _source,
@@ -601,6 +668,7 @@ class GraphQLCoreConverter:
601
668
  _source: Any, info: GraphQLResolveInfo, **kwargs: Any
602
669
  ):
603
670
  strawberry_info = _strawberry_info_from_graphql(info)
671
+ await _check_permissions_async(_source, strawberry_info, kwargs)
604
672
 
605
673
  return await await_maybe(
606
674
  _get_result_with_extensions(
@@ -711,7 +779,7 @@ class GraphQLCoreConverter:
711
779
  # TypeVars, Annotations, LazyTypes, etc it can't perfectly detect issues at
712
780
  # that stage
713
781
  if not StrawberryUnion.is_valid_union_type(type_):
714
- raise InvalidUnionTypeError(union_name, type_)
782
+ raise InvalidUnionTypeError(union_name, type_, union_definition=union)
715
783
 
716
784
  # Don't reevaluate known types
717
785
  if union_name in self.type_map:
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
8
8
 
9
9
  from graphql import ExecutionResult as GraphQLExecutionResult
10
10
  from graphql import GraphQLError, GraphQLSyntaxError, parse
11
- from graphql.error.graphql_error import format_error as format_graphql_error
12
11
 
13
12
  from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
14
13
  CompleteMessage,
@@ -245,12 +244,12 @@ class BaseGraphQLTransportWSHandler(ABC):
245
244
 
246
245
  result_source = get_result_source()
247
246
 
248
- operation = Operation(self, message.id)
247
+ operation = Operation(self, message.id, operation_type)
249
248
 
250
249
  # Handle initial validation errors
251
250
  if isinstance(result_source, GraphQLExecutionResult):
252
251
  assert result_source.errors
253
- payload = [format_graphql_error(result_source.errors[0])]
252
+ payload = [err.formatted for err in result_source.errors]
254
253
  await self.send_message(ErrorMessage(id=message.id, payload=payload))
255
254
  self.schema.process_errors(result_source.errors)
256
255
  return
@@ -295,14 +294,21 @@ class BaseGraphQLTransportWSHandler(ABC):
295
294
  ) -> None:
296
295
  try:
297
296
  async for result in result_source:
298
- if result.errors:
299
- error_payload = [format_graphql_error(err) for err in result.errors]
297
+ if (
298
+ result.errors
299
+ and operation.operation_type != OperationType.SUBSCRIPTION
300
+ ):
301
+ error_payload = [err.formatted for err in result.errors]
300
302
  error_message = ErrorMessage(id=operation.id, payload=error_payload)
301
303
  await operation.send_message(error_message)
302
304
  self.schema.process_errors(result.errors)
303
305
  return
304
306
  else:
305
307
  next_payload = {"data": result.data}
308
+ if result.errors:
309
+ next_payload["errors"] = [
310
+ err.formatted for err in result.errors
311
+ ]
306
312
  next_message = NextMessage(id=operation.id, payload=next_payload)
307
313
  await operation.send_message(next_message)
308
314
  except asyncio.CancelledError:
@@ -312,7 +318,7 @@ class BaseGraphQLTransportWSHandler(ABC):
312
318
  # GraphQLErrors are handled by graphql-core and included in the
313
319
  # ExecutionResult
314
320
  error = GraphQLError(str(error), original_error=error)
315
- error_payload = [format_graphql_error(error)]
321
+ error_payload = [error.formatted]
316
322
  error_message = ErrorMessage(id=operation.id, payload=error_payload)
317
323
  await operation.send_message(error_message)
318
324
  self.schema.process_errors([error])
@@ -358,11 +364,17 @@ class Operation:
358
364
  Helps enforce protocol state transition.
359
365
  """
360
366
 
361
- __slots__ = ["handler", "id", "completed", "task"]
367
+ __slots__ = ["handler", "id", "operation_type", "completed", "task"]
362
368
 
363
- def __init__(self, handler: BaseGraphQLTransportWSHandler, id: str):
369
+ def __init__(
370
+ self,
371
+ handler: BaseGraphQLTransportWSHandler,
372
+ id: str,
373
+ operation_type: OperationType,
374
+ ):
364
375
  self.handler = handler
365
376
  self.id = id
377
+ self.operation_type = operation_type
366
378
  self.completed = False
367
379
  self.task: Optional[asyncio.Task] = None
368
380
 
@@ -85,7 +85,7 @@ class NextMessage(GraphQLTransportMessage):
85
85
  """
86
86
 
87
87
  id: str
88
- payload: Dict[str, Any] # TODO: shape like ExecutionResult
88
+ payload: Dict[str, Any] # TODO: shape like FormattedExecutionResult
89
89
  type: str = "next"
90
90
 
91
91
  def as_dict(self) -> dict:
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, cast
7
7
 
8
8
  from graphql import ExecutionResult as GraphQLExecutionResult
9
9
  from graphql import GraphQLError
10
- from graphql.error.graphql_error import format_error as format_graphql_error
11
10
 
12
11
  from strawberry.subscriptions.protocols.graphql_ws import (
13
12
  GQL_COMPLETE,
@@ -133,14 +132,14 @@ class BaseGraphQLWSHandler(ABC):
133
132
  root_value=root_value,
134
133
  )
135
134
  except GraphQLError as error:
136
- error_payload = format_graphql_error(error)
135
+ error_payload = error.formatted
137
136
  await self.send_message(GQL_ERROR, operation_id, error_payload)
138
137
  self.schema.process_errors([error])
139
138
  return
140
139
 
141
140
  if isinstance(result_source, GraphQLExecutionResult):
142
141
  assert result_source.errors
143
- error_payload = format_graphql_error(result_source.errors[0])
142
+ error_payload = result_source.errors[0].formatted
144
143
  await self.send_message(GQL_ERROR, operation_id, error_payload)
145
144
  self.schema.process_errors(result_source.errors)
146
145
  return
@@ -168,9 +167,7 @@ class BaseGraphQLWSHandler(ABC):
168
167
  async for result in result_source:
169
168
  payload = {"data": result.data}
170
169
  if result.errors:
171
- payload["errors"] = [
172
- format_graphql_error(err) for err in result.errors
173
- ]
170
+ payload["errors"] = [err.formatted for err in result.errors]
174
171
  await self.send_message(GQL_DATA, operation_id, payload)
175
172
  # log errors after send_message to prevent potential
176
173
  # slowdown of sending result
@@ -186,7 +183,7 @@ class BaseGraphQLWSHandler(ABC):
186
183
  await self.send_message(
187
184
  GQL_DATA,
188
185
  operation_id,
189
- {"data": None, "errors": [format_graphql_error(error)]},
186
+ {"data": None, "errors": [error.formatted]},
190
187
  )
191
188
  self.schema.process_errors([error])
192
189
 
strawberry/type.py CHANGED
@@ -13,7 +13,7 @@ from typing import (
13
13
  Union,
14
14
  overload,
15
15
  )
16
- from typing_extensions import Literal, Protocol
16
+ from typing_extensions import Literal, Protocol, Self
17
17
 
18
18
  from strawberry.utils.typing import is_concrete_generic
19
19
 
@@ -113,7 +113,7 @@ class StrawberryContainer(StrawberryType):
113
113
  type_var_map: Mapping[
114
114
  TypeVar, Union[StrawberryType, Type[WithStrawberryObjectDefinition]]
115
115
  ],
116
- ) -> StrawberryType:
116
+ ) -> Self:
117
117
  of_type_copy = self.of_type
118
118
 
119
119
  if has_object_definition(self.of_type):